controller/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![feature(assert_matches)]
10// NOTE: Until https://github.com/PyO3/pyo3/pull/4674, `pyo3::pymethods` trigger
11// and unsafe-op-in-unsafe-fn warnings.
12#![allow(unsafe_op_in_unsafe_fn)]
13
14pub mod bootstrap;
15pub mod history;
16
17use std::collections::HashMap;
18use std::collections::HashSet;
19use std::time::Duration;
20
21use async_trait::async_trait;
22use hyperactor::Actor;
23use hyperactor::ActorId;
24use hyperactor::ActorRef;
25use hyperactor::Context;
26use hyperactor::GangId;
27use hyperactor::GangRef;
28use hyperactor::Handler;
29use hyperactor::Named;
30use hyperactor::actor::ActorHandle;
31use hyperactor::actor::ActorStatus;
32use hyperactor::cap;
33use hyperactor::channel::ChannelAddr;
34use hyperactor::clock::Clock;
35use hyperactor::data::Serialized;
36use hyperactor_mesh::comm::CommActor;
37use hyperactor_mesh::comm::CommActorMode;
38use hyperactor_mesh::comm::multicast::CastMessage;
39use hyperactor_mesh::comm::multicast::CastMessageEnvelope;
40use hyperactor_mesh::comm::multicast::DestinationPort;
41use hyperactor_mesh::comm::multicast::Uslice;
42use hyperactor_mesh::reference::ActorMeshId;
43use hyperactor_mesh::reference::ProcMeshId;
44use hyperactor_multiprocess::proc_actor::ProcActor;
45use hyperactor_multiprocess::proc_actor::spawn;
46use hyperactor_multiprocess::supervision::WorldSupervisionMessageClient;
47use hyperactor_multiprocess::supervision::WorldSupervisor;
48use hyperactor_multiprocess::system_actor::ProcLifecycleMode;
49use hyperactor_multiprocess::system_actor::SYSTEM_ACTOR_REF;
50use monarch_messages::client::ClientActor;
51use monarch_messages::client::ClientMessageClient;
52use monarch_messages::client::Exception;
53use monarch_messages::client::LogLevel;
54use monarch_messages::controller::ControllerMessage;
55use monarch_messages::controller::ControllerMessageHandler;
56use monarch_messages::controller::DeviceFailure;
57use monarch_messages::controller::Ranks;
58use monarch_messages::controller::Seq;
59use monarch_messages::controller::WorkerError;
60use monarch_messages::debugger::DebuggerAction;
61use monarch_messages::worker::Ref;
62use monarch_messages::worker::WorkerActor;
63use monarch_messages::worker::WorkerMessage;
64use ndslice::Selection;
65use ndslice::Shape;
66use ndslice::Slice;
67use ndslice::reshape::Limit;
68use ndslice::reshape::ReshapeShapeExt;
69use ndslice::selection::dsl;
70use ndslice::shape::Range;
71use serde::Deserialize;
72use serde::Serialize;
73use tokio::sync::OnceCell;
74
75const CASTING_FANOUT_SIZE: usize = 8;
76
77/// A controller for the workers that will be leveraged by the client to do the actual
78/// compute tasks. This acts a proxy managing comms with the workers and handling things like history,
79/// data dependency, worker lifecycles etc for the client abstracting it away.
80#[derive(Debug)]
81#[hyperactor::export(
82    spawn = true,
83    handlers = [
84        ControllerMessage,
85    ],
86)]
87pub(crate) struct ControllerActor {
88    client_actor_ref: OnceCell<ActorRef<ClientActor>>,
89    comm_actor_ref: ActorRef<CommActor>,
90    worker_gang_ref: GangRef<WorkerActor>,
91    history: history::History,
92    supervision_query_interval: Duration,
93    system_supervision_actor_ref: ActorRef<WorldSupervisor>,
94    worker_progress_check_interval: Duration,
95    operation_timeout: Duration,
96    operations_per_worker_progress_request: u64,
97    // The Seq and time we last sent out a WorkerMessage::RequestStatus.
98    last_controller_request_status: Option<(Seq, tokio::time::Instant)>,
99    fail_on_worker_timeout: bool,
100    world_size: usize,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize, Named)]
104pub(crate) struct ControllerParams {
105    /// The world size to track the size of all the workers.
106    pub(crate) world_size: usize,
107
108    /// Reference to the comm actor. It must be configured to target
109    /// the worker gang. The controller takes "ownership" of this actor:
110    /// it is immediately configured to target the worker gang.
111    /// This is a temporary workaround until we are fully on meshes.
112    pub(crate) comm_actor_ref: ActorRef<CommActor>,
113
114    /// Reference to the workers to send commands to.
115    pub(crate) worker_gang_ref: GangRef<WorkerActor>,
116
117    // How often to query world supervision status against system actor.
118    pub(crate) supervision_query_interval: Duration,
119
120    // How often to query for if workers are making progress.
121    pub(crate) worker_progress_check_interval: Duration,
122
123    // How long to wait for an operation to complete before considering it timed out.
124    pub(crate) operation_timeout: Duration,
125
126    // How many operations are enqueued before we request a progress update on workers.
127    pub(crate) operations_per_worker_progress_request: u64,
128
129    // If a failure should be propagated back to the client if workers are detected to be stuck.
130    pub(crate) fail_on_worker_timeout: bool,
131}
132
133#[async_trait]
134impl Actor for ControllerActor {
135    type Params = ControllerParams;
136
137    async fn new(params: ControllerParams) -> Result<Self, anyhow::Error> {
138        Ok(Self {
139            client_actor_ref: OnceCell::new(),
140            comm_actor_ref: params.comm_actor_ref,
141            worker_gang_ref: params.worker_gang_ref,
142            history: history::History::new(params.world_size),
143            supervision_query_interval: params.supervision_query_interval,
144            system_supervision_actor_ref: ActorRef::attest(SYSTEM_ACTOR_REF.actor_id().clone()),
145            worker_progress_check_interval: params.worker_progress_check_interval,
146            operation_timeout: params.operation_timeout,
147            operations_per_worker_progress_request: params.operations_per_worker_progress_request,
148            last_controller_request_status: None,
149            fail_on_worker_timeout: params.fail_on_worker_timeout,
150            world_size: params.world_size,
151        })
152    }
153
154    async fn init(&mut self, cx: &hyperactor::Instance<Self>) -> Result<(), anyhow::Error> {
155        self.comm_actor_ref.send(
156            cx,
157            CommActorMode::ImplicitWithWorldId(self.worker_gang_ref.gang_id().world_id().clone()),
158        )?;
159        Ok(())
160    }
161}
162
163impl ControllerActor {
164    /// Bootstrap the controller actor. This will create a new proc, join the system at `bootstrap_addr`
165    /// and spawn the controller actor into the proc. `labels` is an arbitrary set of name/value pairs
166    /// to be attached to the proc in system registry which can be used later to query and find the proc(s)
167    /// using system's snapshot api.
168    pub async fn bootstrap(
169        controller_id: ActorId,
170        listen_addr: ChannelAddr,
171        bootstrap_addr: ChannelAddr,
172        params: ControllerParams,
173        supervision_update_interval: Duration,
174        labels: HashMap<String, String>,
175    ) -> Result<(ActorHandle<ProcActor>, ActorRef<ControllerActor>), anyhow::Error> {
176        let bootstrap = ProcActor::bootstrap(
177            controller_id.proc_id().clone(),
178            controller_id
179                .proc_id()
180                .world_id()
181                .expect("multiprocess supports only ranked procs")
182                .clone(), // REFACTOR(marius): make world_id a parameter of ControllerActor::bootstrap
183            listen_addr,
184            bootstrap_addr.clone(),
185            supervision_update_interval,
186            labels,
187            ProcLifecycleMode::ManagedBySystem,
188        )
189        .await?;
190
191        let mut system = hyperactor_multiprocess::System::new(bootstrap_addr);
192        let client = system.attach().await?;
193
194        let controller_actor_ref = spawn::<ControllerActor>(
195            &client,
196            &bootstrap.proc_actor.bind(),
197            controller_id.clone().name(),
198            &ControllerParams {
199                comm_actor_ref: bootstrap.comm_actor.bind(),
200                ..params
201            },
202        )
203        .await?;
204
205        Ok((bootstrap.proc_actor, controller_actor_ref))
206    }
207
208    fn client(&self) -> Result<ActorRef<ClientActor>, anyhow::Error> {
209        self.client_actor_ref
210            .get()
211            .ok_or_else(|| anyhow::anyhow!("client actor ref not set"))
212            .cloned()
213    }
214
215    // Send a request_status for the seq we expect to complete by our next deadline if it is more than
216    // N ops ahead of our last request_status, or if M seconds passed where:
217    //
218    // N = self.operations_per_worker_progress_request
219    // M = self.worker_progress_check_interval
220    async fn request_status_if_needed(
221        &mut self,
222        cx: &Context<'_, Self>,
223    ) -> Result<(), anyhow::Error> {
224        if let Some((expected_seq, ..)) = self.history.deadline(
225            self.operations_per_worker_progress_request,
226            self.operation_timeout,
227            cx.clock(),
228        ) {
229            if self.last_controller_request_status.is_none_or(
230                |(last_requested_seq, last_requested_time)| {
231                    (expected_seq
232                        >= (u64::from(last_requested_seq)
233                            + self.operations_per_worker_progress_request)
234                            .into()
235                        || last_requested_time.elapsed() > self.worker_progress_check_interval)
236                        && last_requested_seq != expected_seq
237                },
238            ) {
239                // Send to all workers.
240                self.send(
241                    cx,
242                    Ranks::Slice(
243                        ndslice::Slice::new(0, vec![self.history.world_size()], vec![1]).unwrap(),
244                    ),
245                    Serialized::serialize(&WorkerMessage::RequestStatus {
246                        seq: expected_seq.clone(),
247                        controller: true,
248                    })
249                    .unwrap(),
250                )
251                .await?;
252
253                self.last_controller_request_status =
254                    Some((expected_seq.clone(), cx.clock().now()));
255            }
256        }
257
258        Ok(())
259    }
260}
261
262#[derive(Debug)]
263struct CheckWorkerProgress;
264
265#[async_trait]
266impl Handler<CheckWorkerProgress> for ControllerActor {
267    async fn handle(
268        &mut self,
269        cx: &Context<Self>,
270        _check_worker_progress: CheckWorkerProgress,
271    ) -> Result<(), anyhow::Error> {
272        let client = self.client()?;
273
274        if let Some((expected_seq, deadline, reported)) = self.history.deadline(
275            self.operations_per_worker_progress_request,
276            self.operation_timeout,
277            cx.clock(),
278        ) {
279            if !reported
280                && cx.clock().now() > deadline
281                && expected_seq >= self.history.min_incomplete_seq_reported()
282            {
283                let timed_out_ranks = self
284                    .history
285                    .first_incomplete_seqs_controller()
286                    .iter()
287                    .enumerate()
288                    .filter(|(_, seq)| seq <= &&expected_seq)
289                    .map(|(rank, _)| rank)
290                    .collect::<Vec<_>>();
291
292                let failed_rank = timed_out_ranks.first().unwrap().clone();
293
294                let timed_out_ranks_string = timed_out_ranks
295                    .into_iter()
296                    .map(|rank| rank.to_string())
297                    .collect::<Vec<_>>()
298                    .join(", ");
299
300                let message = format!(
301                    "ranks {} have operations that have not completed after {} seconds",
302                    timed_out_ranks_string,
303                    self.operation_timeout.as_secs()
304                );
305                if client
306                    .log(cx, LogLevel::Warn, message.clone())
307                    .await
308                    .is_ok()
309                {
310                    self.history.report_deadline_missed();
311                }
312
313                if self.fail_on_worker_timeout {
314                    client
315                        .result(
316                            cx,
317                            expected_seq,
318                            Some(Err(Exception::Failure(DeviceFailure {
319                                actor_id: self.worker_gang_ref.rank(failed_rank).actor_id().clone(),
320                                address: "unknown".into(),
321                                backtrace: message,
322                            }))),
323                        )
324                        .await?;
325                }
326            }
327            self.request_status_if_needed(cx).await?;
328        }
329
330        cx.self_message_with_delay(CheckWorkerProgress, self.worker_progress_check_interval)?;
331        Ok(())
332    }
333}
334
335/// Hacky translation from a sub-`Slice` to a `Selection.
336fn slice_to_selection(slice: Slice) -> Selection {
337    match (slice.sizes(), slice.strides()) {
338        // Special case exact rank `Selection`.
339        ([], []) => dsl::range(slice.offset()..=slice.offset(), dsl::true_()),
340        // Special case trivial range `Selection`.
341        ([size, rsizes @ ..], [stride, ..]) if rsizes.iter().all(|s| *s == 1) => dsl::range(
342            Range(
343                slice.offset(),
344                Some(slice.offset() + *size * *stride),
345                *stride,
346            ),
347            dsl::true_(),
348        ),
349        // Fallback to more heavy-weight translation for everything else.
350        _ => {
351            let mut selection = Selection::False;
352            let mut selected_ranks = HashSet::new();
353            for rank in slice.iter() {
354                if !selected_ranks.insert(rank) {
355                    continue;
356                }
357                selection = dsl::union(dsl::range(rank..=rank, dsl::true_()), selection);
358            }
359            selection
360        }
361    }
362}
363
364#[async_trait]
365#[hyperactor::forward(ControllerMessage)]
366impl ControllerMessageHandler for ControllerActor {
367    async fn attach(
368        &mut self,
369        cx: &Context<Self>,
370        client_actor: ActorRef<ClientActor>,
371    ) -> Result<(), anyhow::Error> {
372        tracing::debug!("attaching client actor {}", client_actor);
373        self.client_actor_ref
374            .set(client_actor)
375            .map_err(|actor_ref| anyhow::anyhow!("client actor {} already attached", actor_ref))?;
376
377        // Trigger periodical checking of supervision status and worker progress.
378        cx.self_message_with_delay(
379            ControllerMessage::CheckSupervision {},
380            self.supervision_query_interval,
381        )?;
382        cx.self_message_with_delay(CheckWorkerProgress, self.worker_progress_check_interval)?;
383        Ok(())
384    }
385
386    async fn node(
387        &mut self,
388        cx: &Context<Self>,
389        seq: Seq,
390        defs: Vec<Ref>,
391        uses: Vec<Ref>,
392    ) -> Result<(), anyhow::Error> {
393        let failures = self.history.add_invocation(seq, uses, defs);
394        let client = self.client()?;
395        for (seq, failure) in failures {
396            let _ = client.result(cx, seq, failure).await;
397        }
398        self.request_status_if_needed(cx).await?;
399
400        Ok(())
401    }
402
403    async fn drop_refs(
404        &mut self,
405        _cx: &Context<Self>,
406        refs: Vec<Ref>,
407    ) -> Result<(), anyhow::Error> {
408        self.history.delete_invocations_for_refs(refs);
409        Ok(())
410    }
411
412    async fn send(
413        &mut self,
414        cx: &Context<Self>,
415        ranks: Ranks,
416        message: Serialized,
417    ) -> Result<(), anyhow::Error> {
418        let selection = match ranks {
419            Ranks::Slice(slice) => {
420                if slice.len() == self.world_size {
421                    // All ranks are selected.
422                    Selection::True
423                } else {
424                    slice_to_selection(slice)
425                }
426            }
427            Ranks::SliceList(slices) => slices.into_iter().fold(dsl::false_(), |sel, slice| {
428                dsl::union(sel, slice_to_selection(slice))
429            }),
430        };
431
432        let slice = Slice::new(0usize, vec![self.world_size], vec![1])?;
433        // Use a made-up label to create a fake shape. This shape is used by
434        // comm actor to determine the cast rank. Cast rank is not used by
435        // DeviceMesh, but we still need a shape there to make the logic happy.
436        let made_up_shape = Shape::new(vec!["fake_in_controller".to_string()], slice.clone())?
437            .reshape(Limit::from(CASTING_FANOUT_SIZE))
438            .shape;
439
440        let message = CastMessageEnvelope::from_serialized(
441            ActorMeshId(
442                ProcMeshId(self.worker_gang_ref.gang_id().world_id().to_string()),
443                self.worker_gang_ref.gang_id().name().to_string(),
444            ),
445            cx.self_id().clone(),
446            DestinationPort::new::<WorkerActor, WorkerMessage>(
447                // This is awkward, but goes away entirely with meshes.
448                self.worker_gang_ref
449                    .gang_id()
450                    .actor_id(0)
451                    .name()
452                    .to_string(),
453            ),
454            made_up_shape,
455            message,
456        );
457
458        self.comm_actor_ref.send(
459            cx,
460            CastMessage {
461                dest: Uslice {
462                    // TODO: pass both slice and selection from client side
463                    slice,
464                    selection,
465                },
466                message,
467            },
468        )?;
469        Ok(())
470    }
471
472    async fn remote_function_failed(
473        &mut self,
474        cx: &Context<Self>,
475        seq: Seq,
476        error: WorkerError,
477    ) -> Result<(), anyhow::Error> {
478        let rank = error.worker_actor_id.rank();
479        self.history
480            .propagate_exception(seq, Exception::Error(seq, seq, error.clone()));
481        mark_worker_complete_and_propagate_exceptions(self, cx, rank, &seq).await?;
482        Ok(())
483    }
484
485    async fn status(
486        &mut self,
487        cx: &Context<Self>,
488        seq: Seq,
489        worker_actor_id: ActorId,
490        controller: bool,
491    ) -> Result<(), anyhow::Error> {
492        let rank = worker_actor_id.rank();
493
494        if controller {
495            self.history.update_deadline_tracking(rank, seq);
496        } else {
497            mark_worker_complete_and_propagate_exceptions(self, cx, rank, &seq).await?;
498        }
499        Ok(())
500    }
501
502    async fn fetch_result(
503        &mut self,
504        _cx: &Context<Self>,
505        seq: Seq,
506        result: Result<Serialized, WorkerError>,
507    ) -> Result<(), anyhow::Error> {
508        self.history.set_result(seq, result);
509        Ok(())
510    }
511
512    async fn check_supervision(&mut self, cx: &Context<Self>) -> Result<(), anyhow::Error> {
513        let gang_id: GangId = self.worker_gang_ref.clone().into();
514        let world_state = self
515            .system_supervision_actor_ref
516            .state(cx, gang_id.world_id().clone())
517            .await?;
518
519        if let Some(world_state) = world_state {
520            if !world_state.procs.is_empty() {
521                tracing::error!(
522                    "found procs with failures in world {}, state: {:?}",
523                    gang_id.world_id(),
524                    world_state
525                );
526
527                // Randomly pick a failed proc as the failed actor.
528                let (_, failed_state) = world_state.procs.iter().next().unwrap();
529                let (failed_actor, failure_reason) =
530                    failed_state.failed_actors.first().map_or_else(
531                        || {
532                            let proc_id = &failed_state.proc_id;
533                            (
534                                ActorId(proc_id.clone(), "none".into(), 0),
535                                format!(
536                                    "proc is dead due to heartbeat timeout; no backtrace is \
537                                    available; check the log of host {} running proc {} to \
538                                    figure out the root cause",
539                                    failed_state.proc_addr, proc_id
540                                ),
541                            )
542                        },
543                        |(actor, status)| {
544                            (
545                                actor.clone(),
546                                match status {
547                                    ActorStatus::Failed(msg) => msg.clone(),
548                                    _ => format!("unexpected actor status {status}"),
549                                },
550                            )
551                        },
552                    );
553
554                let exc = Exception::Failure(DeviceFailure {
555                    actor_id: failed_actor,
556                    address: failed_state.proc_addr.to_string(),
557                    backtrace: failure_reason,
558                });
559                tracing::error!("Sending failure to client: {exc:?}");
560                // Seq does not matter as the client will raise device error immediately before setting the results.
561                self.client()?
562                    .result(cx, Seq::default(), Some(Err(exc)))
563                    .await?;
564                tracing::error!("Failure successfully sent to client");
565
566                // No need to set history failures as we are directly sending back failure results.
567            }
568        }
569
570        // Schedule the next supervision check.
571        cx.self_message_with_delay(
572            ControllerMessage::CheckSupervision {},
573            self.supervision_query_interval,
574        )?;
575        Ok(())
576    }
577
578    async fn debugger_message(
579        &mut self,
580        cx: &Context<Self>,
581        debugger_actor_id: ActorId,
582        action: DebuggerAction,
583    ) -> Result<(), anyhow::Error> {
584        self.client()?
585            .debugger_message(cx, debugger_actor_id, action)
586            .await
587    }
588
589    #[cfg(test)]
590    async fn get_first_incomplete_seqs_unit_tests_only(
591        &mut self,
592        _cx: &Context<Self>,
593    ) -> Result<Vec<Seq>, anyhow::Error> {
594        Ok(self.history.first_incomplete_seqs().to_vec())
595    }
596
597    #[cfg(not(test))]
598    async fn get_first_incomplete_seqs_unit_tests_only(
599        &mut self,
600        _cx: &Context<Self>,
601    ) -> Result<Vec<Seq>, anyhow::Error> {
602        unimplemented!("get_first_incomplete_seqs_unit_tests_only is only for unit tests")
603    }
604}
605
606async fn mark_worker_complete_and_propagate_exceptions(
607    actor: &mut ControllerActor,
608    instance: &(impl cap::CanSend + cap::CanOpenPort),
609    rank: usize,
610    seq: &Seq,
611) -> Result<(), anyhow::Error> {
612    let results = actor.history.rank_completed(rank, seq.clone());
613    let client = actor.client()?;
614    // Propagate the failures to the clients.
615    for (seq, result) in results.iter() {
616        let _ = client.result(instance, seq.clone(), result.clone()).await;
617    }
618    Ok(())
619}
620
621#[cfg(test)]
622mod tests {
623    use core::panic;
624    use std::assert_matches::assert_matches;
625    use std::collections::HashMap;
626    use std::collections::HashSet;
627    use std::time::Duration;
628
629    use hyperactor::HandleClient;
630    use hyperactor::Handler;
631    use hyperactor::RefClient;
632    use hyperactor::channel;
633    use hyperactor::channel::ChannelTransport;
634    use hyperactor::clock::Clock;
635    use hyperactor::clock::RealClock;
636    use hyperactor::data::Named;
637    use hyperactor::id;
638    use hyperactor::mailbox::BoxedMailboxSender;
639    use hyperactor::mailbox::DialMailboxRouter;
640    use hyperactor::mailbox::Mailbox;
641    use hyperactor::mailbox::MailboxClient;
642    use hyperactor::mailbox::MailboxServer;
643    use hyperactor::mailbox::PortHandle;
644    use hyperactor::mailbox::PortReceiver;
645    use hyperactor::message::IndexedErasedUnbound;
646    use hyperactor::proc::Proc;
647    use hyperactor::reference::GangId;
648    use hyperactor::reference::ProcId;
649    use hyperactor::reference::WorldId;
650    use hyperactor::simnet;
651    use hyperactor_mesh::comm::CommActorParams;
652    use hyperactor_multiprocess::System;
653    use hyperactor_multiprocess::proc_actor::ProcMessage;
654    use hyperactor_multiprocess::supervision::ProcSupervisionMessage;
655    use hyperactor_multiprocess::supervision::ProcSupervisor;
656    use hyperactor_multiprocess::system_actor::SystemMessage;
657    use monarch_messages::client::ClientMessage;
658    use monarch_messages::controller::ControllerMessageClient;
659    use monarch_messages::wire_value::WireValue;
660    use monarch_messages::worker::CallFunctionParams;
661    use monarch_messages::worker::WorkerMessage;
662    use monarch_types::PyTree;
663    use torch_sys::RValue;
664
665    use super::*;
666
667    #[tokio::test]
668    async fn basic_controller() {
669        // TODO: Add a proper multiworker test
670        let proc = Proc::local();
671        let (client, client_ref, mut client_rx) = proc
672            .attach_actor::<ClientActor, ClientMessage>("client")
673            .unwrap();
674        let (worker, worker_ref, mut worker_rx) = proc
675            .attach_actor::<WorkerActor, WorkerMessage>("worker")
676            .unwrap();
677        IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(worker_ref.clone(), &worker)
678            .unwrap();
679
680        let comm_handle = proc
681            .spawn::<CommActor>("comm", CommActorParams {})
682            .await
683            .unwrap();
684
685        let controller_handle = proc
686            .spawn::<ControllerActor>(
687                "controller",
688                ControllerParams {
689                    world_size: 1,
690                    comm_actor_ref: comm_handle.bind(),
691                    worker_gang_ref: GangId(
692                        WorldId(
693                            proc.proc_id()
694                                .world_name()
695                                .expect("only ranked actors are supported in the controller tests")
696                                .to_string(),
697                        ),
698                        "worker".to_string(),
699                    )
700                    .into(),
701                    supervision_query_interval: Duration::from_secs(1),
702                    worker_progress_check_interval: Duration::from_secs(3),
703                    operation_timeout: Duration::from_secs(30),
704                    operations_per_worker_progress_request: 100,
705                    fail_on_worker_timeout: false,
706                },
707            )
708            .await
709            .unwrap();
710
711        controller_handle.attach(&client, client_ref).await.unwrap();
712
713        controller_handle
714            .node(&client, 0.into(), vec![0.into()], vec![])
715            .await
716            .unwrap();
717        controller_handle
718            .node(&client, 1.into(), vec![1.into(), 2.into()], vec![0.into()])
719            .await
720            .unwrap();
721        controller_handle
722            .node(&client, 20.into(), vec![3.into(), 4.into()], vec![])
723            .await
724            .unwrap();
725
726        ControllerMessageClient::send(
727            &controller_handle,
728            &worker,
729            Ranks::Slice(ndslice::Slice::new(0, vec![1], vec![1]).unwrap()),
730            Serialized::serialize(&WorkerMessage::CallFunction(CallFunctionParams {
731                seq: 1.into(),
732                results: vec![Some(1.into()), Some(2.into())],
733                mutates: vec![],
734                function: "os.path.split".into(),
735                args: vec![WireValue::String("/fbs/fbc/foo/bar".into())],
736                kwargs: HashMap::new(),
737                stream: 1.into(),
738                remote_process_groups: vec![],
739            }))
740            .unwrap(),
741        )
742        .await
743        .unwrap();
744
745        ControllerMessageClient::status(
746            &controller_handle,
747            &worker,
748            0.into(),
749            worker_ref.actor_id().clone(),
750            false,
751        )
752        .await
753        .unwrap();
754        let incomplete_seqs = controller_handle
755            .get_first_incomplete_seqs_unit_tests_only(&worker)
756            .await
757            .unwrap();
758        assert_eq!(incomplete_seqs[0], 0.into());
759
760        controller_handle
761            .remote_function_failed(
762                &worker,
763                1.into(),
764                WorkerError {
765                    backtrace: "some failure happened!".to_string(),
766                    worker_actor_id: worker_ref.actor_id().clone(),
767                },
768            )
769            .await
770            .unwrap();
771        ControllerMessageClient::status(
772            &controller_handle,
773            &worker,
774            2.into(),
775            worker_ref.actor_id().clone(),
776            false,
777        )
778        .await
779        .unwrap();
780
781        let incomplete_seqs = controller_handle
782            .get_first_incomplete_seqs_unit_tests_only(&worker)
783            .await
784            .unwrap();
785        assert_eq!(incomplete_seqs[0], 2.into());
786
787        controller_handle
788            .fetch_result(
789                &worker,
790                20.into(),
791                Ok(Serialized::serialize_anon(&PyTree::from(RValue::Int(42))).unwrap()),
792            )
793            .await
794            .unwrap();
795
796        // Omly a status message can trigger a fetch result to the client.
797        ControllerMessageClient::status(
798            &controller_handle,
799            &worker,
800            21.into(),
801            worker_ref.actor_id().clone(),
802            false,
803        )
804        .await
805        .unwrap();
806
807        let incomplete_seqs = controller_handle
808            .get_first_incomplete_seqs_unit_tests_only(&worker)
809            .await
810            .unwrap();
811        assert_eq!(incomplete_seqs[0], 21.into());
812
813        controller_handle.drain_and_stop().unwrap();
814        controller_handle.await;
815        let worker_messages: Vec<WorkerMessage> = worker_rx.drain();
816        assert_eq!(
817            worker_messages
818                .iter()
819                .filter(|msg| !matches!(msg, WorkerMessage::RequestStatus { .. }))
820                .count(),
821            1
822        );
823        let client_messages = client_rx.drain();
824        assert_eq!(client_messages.len(), 3);
825        let client_message = client_messages[1].clone().into_result().unwrap();
826        assert_eq!(client_message.0, 1.into());
827        assert_eq!(
828            client_message.1,
829            Some(Err(Exception::Error(
830                1.into(),
831                1.into(),
832                WorkerError {
833                    backtrace: "some failure happened!".to_string(),
834                    worker_actor_id: worker_ref.actor_id().clone(),
835                }
836            )))
837        );
838
839        let client_message = client_messages[2].clone().into_result().unwrap();
840        assert_eq!(client_message.0, 20.into());
841        assert_matches!(
842            client_message
843                .1
844                .unwrap()
845                .unwrap()
846                .deserialized::<PyTree<RValue>>()
847                .unwrap()
848                .into_leaf()
849                .unwrap(),
850            RValue::Int(42),
851        );
852    }
853
854    #[tokio::test]
855    async fn worker_timeout() {
856        tokio::time::pause();
857        let timeout_secs = 3;
858        let proc = Proc::local();
859
860        let (client, client_ref, mut client_rx) = proc
861            .attach_actor::<ClientActor, ClientMessage>("client")
862            .unwrap();
863        let (worker, worker_ref, mut worker_rx) = proc
864            .attach_actor::<WorkerActor, WorkerMessage>("worker")
865            .unwrap();
866        IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(worker_ref.clone(), &worker)
867            .unwrap();
868
869        let comm_handle = proc
870            .spawn::<CommActor>("comm", CommActorParams {})
871            .await
872            .unwrap();
873
874        let controller_handle = proc
875            .spawn::<ControllerActor>(
876                "controller",
877                ControllerParams {
878                    world_size: 1,
879                    comm_actor_ref: comm_handle.bind(),
880                    worker_gang_ref: GangId(
881                        WorldId(
882                            proc.proc_id()
883                                .world_name()
884                                .expect("only ranked actors are supported in the controller tests")
885                                .to_string(),
886                        ),
887                        "worker".to_string(),
888                    )
889                    .into(),
890                    supervision_query_interval: Duration::from_secs(100000),
891                    worker_progress_check_interval: Duration::from_secs(1),
892                    operation_timeout: Duration::from_secs(timeout_secs),
893                    operations_per_worker_progress_request: 100,
894                    fail_on_worker_timeout: false,
895                },
896            )
897            .await
898            .unwrap();
899
900        controller_handle.attach(&client, client_ref).await.unwrap();
901
902        controller_handle
903            .node(&client, 0.into(), vec![0.into()], vec![])
904            .await
905            .unwrap();
906
907        // Expect that our handler for CheckWorkerProgress will issue RequestWorkerCompletedSeq
908        match worker_rx.recv().await.unwrap().into_request_status().ok() {
909            Some((seq, controller)) if seq == 0.into() && controller => {
910                // Simulate WorkerActor::RequestWorkerCompletedSeq if joining streams takes shorter
911                // than timeout
912                for _ in 0..timeout_secs {
913                    tokio::time::advance(Duration::from_secs(1)).await;
914                }
915
916                ControllerMessageClient::status(
917                    &controller_handle,
918                    &worker,
919                    1.into(),
920                    worker_ref.actor_id().clone(),
921                    true,
922                )
923                .await
924                .unwrap();
925            }
926            _ => panic!("Expected request status message for seq 0"),
927        }
928
929        // Should have no warnings
930        let client_messages = client_rx.drain();
931        assert_eq!(client_messages.len(), 0);
932
933        controller_handle
934            .node(&client, 1.into(), vec![], vec![])
935            .await
936            .unwrap();
937
938        // Expect that our handler for CheckWorkerProgress will issue RequestWorkerCompletedSeq
939        match worker_rx.recv().await.unwrap().into_request_status().ok() {
940            Some((seq, controller)) if seq == 1.into() && controller => {
941                // Simulate WorkerActor::RequestWorkerCompletedSeq if joining streams takes longer
942                // than timeout
943                for _ in 0..timeout_secs * 2 {
944                    tokio::time::advance(Duration::from_secs(1)).await;
945                }
946
947                ControllerMessageClient::status(
948                    &controller_handle,
949                    &worker,
950                    2.into(),
951                    worker_ref.actor_id().clone(),
952                    true,
953                )
954                .await
955                .unwrap();
956            }
957            _ => panic!("Expected request status message for seq 1"),
958        }
959
960        let client_messages = client_rx.drain();
961        assert_eq!(client_messages.len(), 1);
962
963        let (level, message) = client_messages[0].clone().into_log().unwrap();
964        assert_matches!(level, LogLevel::Warn);
965        assert_eq!(
966            message,
967            "ranks 0 have operations that have not completed after 3 seconds"
968        );
969    }
970
971    #[tokio::test]
972    async fn test_failure_on_worker_timeout() {
973        tokio::time::pause();
974        let timeout_secs = 3;
975        let proc = Proc::local();
976
977        let (client, client_ref, mut client_rx) = proc
978            .attach_actor::<ClientActor, ClientMessage>("client")
979            .unwrap();
980
981        let (worker, worker_ref, mut worker_rx) = proc
982            .attach_actor::<WorkerActor, WorkerMessage>("worker")
983            .unwrap();
984        IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(worker_ref.clone(), &worker)
985            .unwrap();
986
987        let comm_handle = proc
988            .spawn::<CommActor>("comm", CommActorParams {})
989            .await
990            .unwrap();
991
992        let world_id = WorldId(
993            proc.proc_id()
994                .world_name()
995                .expect("only ranked actors are supported in the controller tests")
996                .to_string(),
997        );
998        let controller_handle = proc
999            .spawn::<ControllerActor>(
1000                "controller",
1001                ControllerParams {
1002                    world_size: 1,
1003                    comm_actor_ref: comm_handle.bind(),
1004                    worker_gang_ref: GangId(world_id, "worker".to_string()).into(),
1005                    supervision_query_interval: Duration::from_secs(100000),
1006                    worker_progress_check_interval: Duration::from_secs(1),
1007                    operation_timeout: Duration::from_secs(timeout_secs),
1008                    operations_per_worker_progress_request: 100,
1009                    fail_on_worker_timeout: true,
1010                },
1011            )
1012            .await
1013            .unwrap();
1014
1015        controller_handle.attach(&client, client_ref).await.unwrap();
1016
1017        controller_handle
1018            .node(&client, 0.into(), vec![0.into()], vec![])
1019            .await
1020            .unwrap();
1021
1022        // Expect that our handler for CheckWorkerProgress will issue RequestWorkerCompletedSeq
1023        match worker_rx.recv().await.unwrap().into_request_status().ok() {
1024            Some((seq, controller)) if seq == 0.into() && controller => {
1025                // Simulate WorkerActor::RequestWorkerCompletedSeq if joining streams takes shorter
1026                // than timeout
1027                for _ in 0..timeout_secs {
1028                    tokio::time::advance(Duration::from_secs(1)).await;
1029                }
1030
1031                ControllerMessageClient::status(
1032                    &controller_handle,
1033                    &worker,
1034                    1.into(),
1035                    worker_ref.actor_id().clone(),
1036                    true,
1037                )
1038                .await
1039                .unwrap();
1040            }
1041            _ => panic!("Expected request status message for seq 0"),
1042        }
1043
1044        // Should have no warnings
1045        let client_messages = client_rx.drain();
1046        assert_eq!(client_messages.len(), 0);
1047
1048        controller_handle
1049            .node(&client, 1.into(), vec![], vec![])
1050            .await
1051            .unwrap();
1052
1053        // Expect that our handler for CheckWorkerProgress will issue RequestWorkerCompletedSeq
1054        match worker_rx.recv().await.unwrap().into_request_status().ok() {
1055            Some((seq, controller)) if seq == 1.into() && controller => {
1056                // Simulate WorkerActor::RequestWorkerCompletedSeq if joining streams takes longer
1057                // than timeout
1058                for _ in 0..timeout_secs * 2 {
1059                    tokio::time::advance(Duration::from_secs(1)).await;
1060                }
1061
1062                ControllerMessageClient::status(
1063                    &controller_handle,
1064                    &worker,
1065                    2.into(),
1066                    worker_ref.actor_id().clone(),
1067                    true,
1068                )
1069                .await
1070                .unwrap();
1071            }
1072            _ => panic!("Expected request status message for seq 1"),
1073        }
1074
1075        let client_messages = client_rx.drain();
1076        assert_eq!(client_messages.len(), 2);
1077
1078        let (level, message) = client_messages[0].clone().into_log().unwrap();
1079        assert_matches!(level, LogLevel::Warn);
1080        assert_eq!(
1081            message,
1082            "ranks 0 have operations that have not completed after 3 seconds"
1083        );
1084
1085        let (seq, failure) = client_messages[1].clone().into_result().unwrap();
1086        assert_eq!(seq, 1.into());
1087        let DeviceFailure {
1088            backtrace,
1089            actor_id,
1090            ..
1091        } = failure
1092            .unwrap()
1093            .err()
1094            .unwrap()
1095            .as_failure()
1096            .unwrap()
1097            .clone();
1098        assert_eq!(actor_id, proc.proc_id().actor_id("worker", 0));
1099        assert!(
1100            backtrace.contains("ranks 0 have operations that have not completed after 3 seconds")
1101        );
1102    }
1103
1104    #[tokio::test]
1105    async fn failure_propagation() {
1106        // Serve a system.
1107        let server_handle = System::serve(
1108            ChannelAddr::any(ChannelTransport::Local),
1109            Duration::from_secs(10),
1110            Duration::from_secs(10),
1111        )
1112        .await
1113        .unwrap();
1114        let mut system = System::new(server_handle.local_addr().clone());
1115
1116        // Build a supervisor.
1117        let sup_mail = system.attach().await.unwrap();
1118        let (sup_tx, _sup_rx) = sup_mail.open_port::<ProcSupervisionMessage>();
1119        sup_tx.bind_to(ProcSupervisionMessage::port());
1120        let sup_ref = ActorRef::<ProcSupervisor>::attest(sup_mail.actor_id().clone());
1121
1122        // Construct a system sender.
1123        let system_sender = BoxedMailboxSender::new(MailboxClient::new(
1124            channel::dial(server_handle.local_addr().clone()).unwrap(),
1125        ));
1126
1127        // Construct a proc forwarder in terms of the system sender.
1128        let listen_addr = ChannelAddr::any(ChannelTransport::Local);
1129        let proc_forwarder =
1130            BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
1131
1132        // Bootstrap proc 'local[0]', join the system.
1133        let world_id = id!(local);
1134        let proc = Proc::new(world_id.proc_id(0), proc_forwarder.clone());
1135        let proc_actor_0 = ProcActor::bootstrap_for_proc(
1136            proc.clone(),
1137            world_id.clone(),
1138            listen_addr,
1139            server_handle.local_addr().clone(),
1140            sup_ref.clone(),
1141            Duration::from_secs(2),
1142            HashMap::new(),
1143            ProcLifecycleMode::ManagedBySystem,
1144        )
1145        .await
1146        .unwrap();
1147
1148        // Bootstrap proc 'local[1]', join the system.
1149        let proc2 = Proc::new(world_id.proc_id(1), proc_forwarder.clone());
1150        let _proc_actor_1 = ProcActor::bootstrap_for_proc(
1151            proc2.clone(),
1152            world_id.clone(),
1153            ChannelAddr::any(ChannelTransport::Local),
1154            server_handle.local_addr().clone(),
1155            sup_ref.clone(),
1156            Duration::from_secs(2),
1157            HashMap::new(),
1158            ProcLifecycleMode::ManagedBySystem,
1159        )
1160        .await
1161        .unwrap();
1162
1163        // Test
1164        let (client, client_ref, mut client_rx) = proc
1165            .attach_actor::<ClientActor, ClientMessage>("client")
1166            .unwrap();
1167        let (worker1, worker1_ref, _) = proc
1168            .attach_actor::<WorkerActor, WorkerMessage>("worker")
1169            .unwrap();
1170        IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(worker1_ref.clone(), &worker1)
1171            .unwrap();
1172        let (worker2, worker2_ref, _) = proc2
1173            .attach_actor::<WorkerActor, WorkerMessage>("worker")
1174            .unwrap();
1175        IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(worker2_ref.clone(), &worker2)
1176            .unwrap();
1177
1178        let controller_handle = proc
1179            .spawn::<ControllerActor>(
1180                "controller",
1181                ControllerParams {
1182                    world_size: 2,
1183                    comm_actor_ref: proc_actor_0.comm_actor.bind(),
1184                    worker_gang_ref: GangId(
1185                        WorldId(world_id.name().to_string()),
1186                        "worker".to_string(),
1187                    )
1188                    .into(),
1189                    supervision_query_interval: Duration::from_secs(1),
1190                    worker_progress_check_interval: Duration::from_secs(3),
1191                    operation_timeout: Duration::from_secs(30),
1192                    operations_per_worker_progress_request: 100,
1193                    fail_on_worker_timeout: false,
1194                },
1195            )
1196            .await
1197            .unwrap();
1198
1199        controller_handle.attach(&client, client_ref).await.unwrap();
1200
1201        controller_handle
1202            .node(&client, 0.into(), vec![1.into(), 2.into()], vec![])
1203            .await
1204            .unwrap();
1205        controller_handle
1206            .node(&client, 1.into(), vec![3.into()], vec![1.into()])
1207            .await
1208            .unwrap();
1209        controller_handle
1210            .node(&client, 2.into(), vec![4.into()], vec![3.into()])
1211            .await
1212            .unwrap();
1213        controller_handle
1214            .node(&client, 3.into(), vec![5.into()], vec![3.into()])
1215            .await
1216            .unwrap();
1217        controller_handle
1218            .node(&client, 4.into(), vec![6.into()], vec![3.into()])
1219            .await
1220            .unwrap();
1221        controller_handle
1222            .node(&client, 5.into(), vec![7.into()], vec![4.into()])
1223            .await
1224            .unwrap();
1225        controller_handle
1226            .node(&client, 6.into(), vec![8.into()], vec![4.into()])
1227            .await
1228            .unwrap();
1229
1230        ControllerMessageClient::status(
1231            &controller_handle,
1232            &worker1,
1233            1.into(),
1234            worker1_ref.actor_id().clone(),
1235            false,
1236        )
1237        .await
1238        .unwrap();
1239        ControllerMessageClient::status(
1240            &controller_handle,
1241            &worker2,
1242            1.into(),
1243            worker2_ref.actor_id().clone(),
1244            false,
1245        )
1246        .await
1247        .unwrap();
1248        controller_handle
1249            .remote_function_failed(
1250                &worker1,
1251                2.into(),
1252                WorkerError {
1253                    backtrace: "some failure happened!".to_string(),
1254                    worker_actor_id: worker1_ref.actor_id().clone(),
1255                },
1256            )
1257            .await
1258            .unwrap();
1259        controller_handle
1260            .remote_function_failed(
1261                &worker2,
1262                2.into(),
1263                WorkerError {
1264                    backtrace: "some failure happened!".to_string(),
1265                    worker_actor_id: worker2_ref.actor_id().clone(),
1266                },
1267            )
1268            .await
1269            .unwrap();
1270        for s in 3..=7 {
1271            ControllerMessageClient::status(
1272                &controller_handle,
1273                &worker1,
1274                s.into(),
1275                worker1_ref.actor_id().clone(),
1276                false,
1277            )
1278            .await
1279            .unwrap();
1280            ControllerMessageClient::status(
1281                &controller_handle,
1282                &worker2,
1283                s.into(),
1284                worker2_ref.actor_id().clone(),
1285                false,
1286            )
1287            .await
1288            .unwrap();
1289        }
1290
1291        controller_handle.drain_and_stop().unwrap();
1292        controller_handle.await;
1293        let mut client_messages = client_rx.drain();
1294        client_messages.sort_by_key(|msg| msg.clone().into_result().unwrap().0);
1295        assert_eq!(client_messages.len(), 7);
1296        let client_message = client_messages[2].clone().into_result().unwrap();
1297        assert_eq!(client_message.0, 2.into());
1298        assert_eq!(
1299            client_message.1,
1300            Some(Err(Exception::Error(
1301                2.into(),
1302                2.into(),
1303                WorkerError {
1304                    backtrace: "some failure happened!".to_string(),
1305                    worker_actor_id: worker1_ref.actor_id().clone(),
1306                }
1307            )))
1308        );
1309
1310        assert_eq!(
1311            client_messages
1312                .into_iter()
1313                .map(|msg| msg.into_result().unwrap().0)
1314                .collect::<HashSet<Seq>>(),
1315            HashSet::from([
1316                0.into(),
1317                3.into(),
1318                1.into(),
1319                4.into(),
1320                2.into(),
1321                5.into(),
1322                6.into()
1323            ])
1324        )
1325    }
1326
1327    #[tokio::test]
1328    async fn test_eager_failure_reporting() {
1329        // Serve a system.
1330        let server_handle = System::serve(
1331            ChannelAddr::any(ChannelTransport::Local),
1332            Duration::from_secs(10),
1333            Duration::from_secs(10),
1334        )
1335        .await
1336        .unwrap();
1337        let mut system = System::new(server_handle.local_addr().clone());
1338
1339        // Build a supervisor.
1340        let sup_mail = system.attach().await.unwrap();
1341        let (sup_tx, _sup_rx) = sup_mail.open_port::<ProcSupervisionMessage>();
1342        sup_tx.bind_to(ProcSupervisionMessage::port());
1343        let sup_ref = ActorRef::<ProcSupervisor>::attest(sup_mail.actor_id().clone());
1344
1345        // Construct a system sender.
1346        let system_sender = BoxedMailboxSender::new(MailboxClient::new(
1347            channel::dial(server_handle.local_addr().clone()).unwrap(),
1348        ));
1349
1350        // Construct a proc forwarder in terms of the system sender.
1351        let listen_addr = ChannelAddr::any(ChannelTransport::Local);
1352        let proc_forwarder =
1353            BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
1354
1355        // Bootstrap proc 'local[0]', join the system.
1356        let world_id = id!(local);
1357        let proc = Proc::new(world_id.proc_id(0), proc_forwarder.clone());
1358        let proc_actor_0 = ProcActor::bootstrap_for_proc(
1359            proc.clone(),
1360            world_id.clone(),
1361            listen_addr,
1362            server_handle.local_addr().clone(),
1363            sup_ref.clone(),
1364            Duration::from_secs(2),
1365            HashMap::new(),
1366            ProcLifecycleMode::ManagedBySystem,
1367        )
1368        .await
1369        .unwrap();
1370
1371        // Bootstrap proc 'local[1]', join the system.
1372        let proc2 = Proc::new(world_id.proc_id(1), proc_forwarder.clone());
1373        let _proc_actor_1 = ProcActor::bootstrap_for_proc(
1374            proc2.clone(),
1375            world_id.clone(),
1376            ChannelAddr::any(ChannelTransport::Local),
1377            server_handle.local_addr().clone(),
1378            sup_ref.clone(),
1379            Duration::from_secs(2),
1380            HashMap::new(),
1381            ProcLifecycleMode::ManagedBySystem,
1382        )
1383        .await
1384        .unwrap();
1385
1386        // Test
1387        let (client, client_ref, mut client_rx) = proc
1388            .attach_actor::<ClientActor, ClientMessage>("client")
1389            .unwrap();
1390        let (worker1, worker1_ref, _) = proc
1391            .attach_actor::<WorkerActor, WorkerMessage>("worker")
1392            .unwrap();
1393
1394        let controller_handle = proc
1395            .spawn::<ControllerActor>(
1396                "controller",
1397                ControllerParams {
1398                    world_size: 1,
1399                    comm_actor_ref: proc_actor_0.comm_actor.bind(),
1400                    worker_gang_ref: GangId(
1401                        WorldId(world_id.name().to_string()),
1402                        "worker".to_string(),
1403                    )
1404                    .into(),
1405                    supervision_query_interval: Duration::from_secs(1),
1406                    worker_progress_check_interval: Duration::from_secs(3),
1407                    operation_timeout: Duration::from_secs(30),
1408                    operations_per_worker_progress_request: 100,
1409                    fail_on_worker_timeout: false,
1410                },
1411            )
1412            .await
1413            .unwrap();
1414
1415        controller_handle.attach(&client, client_ref).await.unwrap();
1416
1417        controller_handle
1418            .node(&client, 0.into(), vec![1.into()], vec![])
1419            .await
1420            .unwrap();
1421
1422        controller_handle
1423            .node(&client, 1.into(), vec![2.into()], vec![1.into()])
1424            .await
1425            .unwrap();
1426
1427        controller_handle
1428            .node(&client, 2.into(), vec![3.into()], vec![2.into()])
1429            .await
1430            .unwrap();
1431
1432        controller_handle
1433            .node(&client, 3.into(), vec![], vec![3.into()])
1434            .await
1435            .unwrap();
1436
1437        controller_handle
1438            .node(&client, 4.into(), vec![], vec![])
1439            .await
1440            .unwrap();
1441
1442        controller_handle
1443            .remote_function_failed(
1444                &worker1,
1445                0.into(),
1446                WorkerError {
1447                    backtrace: "some failure happened!".to_string(),
1448                    worker_actor_id: worker1_ref.actor_id().clone(),
1449                },
1450            )
1451            .await
1452            .unwrap();
1453
1454        controller_handle
1455            .remote_function_failed(
1456                &worker1,
1457                3.into(),
1458                WorkerError {
1459                    backtrace: "some failure happened!".to_string(),
1460                    worker_actor_id: worker1_ref.actor_id().clone(),
1461                },
1462            )
1463            .await
1464            .unwrap();
1465
1466        ControllerMessageClient::status(
1467            &controller_handle,
1468            &worker1,
1469            5.into(),
1470            worker1_ref.actor_id().clone(),
1471            false,
1472        )
1473        .await
1474        .unwrap();
1475
1476        controller_handle.drain_and_stop().unwrap();
1477        controller_handle.await;
1478
1479        let client_messages = client_rx.drain();
1480        // no double reported messages
1481        assert_eq!(client_messages.len(), 5);
1482
1483        let (errors, successes) =
1484            client_messages
1485                .into_iter()
1486                .fold((0, 0), |(errors, successes), client_message| {
1487                    let (_, result) = client_message.clone().into_result().unwrap();
1488                    match result {
1489                        Some(Err(Exception::Error(_, _, _))) => (errors + 1, successes),
1490                        None => (errors, successes + 1),
1491                        _ => {
1492                            panic!("should only be exceptions or no result");
1493                        }
1494                    }
1495                });
1496
1497        // Assert that we have 4 error messages and 1 non-error message
1498        assert_eq!(errors, 4);
1499        assert_eq!(successes, 1);
1500    }
1501
1502    #[tokio::test]
1503    async fn test_bootstrap() {
1504        let server_handle = System::serve(
1505            ChannelAddr::any(ChannelTransport::Local),
1506            Duration::from_secs(10),
1507            Duration::from_secs(10),
1508        )
1509        .await
1510        .unwrap();
1511
1512        let controller_id = id!(controller[0].root);
1513        let proc_id = id!(world[0]);
1514        let (proc_handle, actor_ref) = ControllerActor::bootstrap(
1515            controller_id.clone(),
1516            ChannelAddr::any(ChannelTransport::Local),
1517            server_handle.local_addr().clone(),
1518            ControllerParams {
1519                world_size: 1,
1520                comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1521                worker_gang_ref: GangId(
1522                    WorldId(
1523                        proc_id
1524                            .world_name()
1525                            .expect("only ranked actors are supported in the controller tests")
1526                            .to_string(),
1527                    ),
1528                    "worker".to_string(),
1529                )
1530                .into(),
1531                supervision_query_interval: Duration::from_secs(1),
1532                worker_progress_check_interval: Duration::from_secs(3),
1533                operation_timeout: Duration::from_secs(30),
1534                operations_per_worker_progress_request: 100,
1535                fail_on_worker_timeout: false,
1536            },
1537            Duration::from_secs(1),
1538            HashMap::new(),
1539        )
1540        .await
1541        .unwrap();
1542        assert_eq!(*actor_ref.actor_id(), controller_id);
1543
1544        proc_handle.drain_and_stop().unwrap();
1545    }
1546
1547    async fn mock_proc_actor(
1548        idx: usize,
1549        rank: usize,
1550    ) -> (
1551        WorldId,
1552        ProcId,
1553        ChannelAddr,
1554        Mailbox,
1555        PortHandle<ProcMessage>,
1556        PortReceiver<ProcMessage>,
1557    ) {
1558        let world_id = id!(world);
1559        // Set up a local actor.
1560        let local_proc_id = world_id.proc_id(rank);
1561        let (local_proc_addr, local_proc_rx) =
1562            channel::serve(ChannelAddr::any(ChannelTransport::Local))
1563                .await
1564                .unwrap();
1565        let local_proc_mbox = Mailbox::new_detached(
1566            local_proc_id.actor_id(format!("test_dummy_proc{}", idx).to_string(), 0),
1567        );
1568        let (local_proc_message_port, local_proc_message_receiver) = local_proc_mbox.open_port();
1569        local_proc_message_port.bind();
1570
1571        let _local_proc_serve_handle = local_proc_mbox.clone().serve(local_proc_rx);
1572        (
1573            world_id,
1574            local_proc_id,
1575            local_proc_addr,
1576            local_proc_mbox,
1577            local_proc_message_port,
1578            local_proc_message_receiver,
1579        )
1580    }
1581
1582    #[tokio::test]
1583    async fn test_sim_supervision_failure() {
1584        // Start system actor.
1585        simnet::start();
1586        simnet::simnet_handle()
1587            .unwrap()
1588            .set_training_script_state(simnet::TrainingScriptState::Waiting);
1589
1590        let system_sim_addr =
1591            ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix)));
1592        // Set very long supervision_update_timeout
1593        let server_handle = System::serve(
1594            system_sim_addr.clone(),
1595            Duration::from_secs(1000),
1596            Duration::from_secs(1000),
1597        )
1598        .await
1599        .unwrap();
1600
1601        let mut system = System::new(server_handle.local_addr().clone());
1602        let client_mailbox = system.attach().await.unwrap();
1603
1604        // Bootstrap the controller
1605        let controller_id = id!(controller[0].root);
1606        let proc_id = id!(world[0]);
1607        let controller_proc_listen_addr =
1608            ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix)));
1609
1610        let (_, actor_ref) = ControllerActor::bootstrap(
1611            controller_id.clone(),
1612            controller_proc_listen_addr,
1613            system_sim_addr,
1614            ControllerParams {
1615                world_size: 1,
1616                comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1617                worker_gang_ref: GangId(
1618                    WorldId(
1619                        proc_id
1620                            .world_name()
1621                            .expect("only ranked actors are supported in the controller tests")
1622                            .to_string(),
1623                    ),
1624                    "worker".to_string(),
1625                )
1626                .into(),
1627                supervision_query_interval: Duration::from_secs(100),
1628                worker_progress_check_interval: Duration::from_secs(100),
1629                operation_timeout: Duration::from_secs(1000),
1630                operations_per_worker_progress_request: 100,
1631                fail_on_worker_timeout: false,
1632            },
1633            Duration::from_secs(100),
1634            HashMap::new(),
1635        )
1636        .await
1637        .unwrap();
1638        assert_eq!(*actor_ref.actor_id(), controller_id);
1639
1640        actor_ref
1641            .attach(
1642                &client_mailbox,
1643                ActorRef::attest(client_mailbox.actor_id().clone()),
1644            )
1645            .await
1646            .unwrap();
1647
1648        let (client_supervision_tx, mut client_supervision_rx) =
1649            client_mailbox.open_port::<ClientMessage>();
1650        client_supervision_tx.bind_to(ClientMessage::port());
1651
1652        // mock a proc actor that doesn't update supervision state
1653        let (
1654            world_id,
1655            local_proc_id,
1656            local_proc_addr,
1657            _,
1658            local_proc_message_port,
1659            mut local_proc_message_receiver,
1660        ) = mock_proc_actor(0, 1).await;
1661
1662        // Join the world.
1663        server_handle
1664            .system_actor_handle()
1665            .send(SystemMessage::Join {
1666                proc_id: local_proc_id.clone(),
1667                world_id,
1668                proc_message_port: local_proc_message_port.bind(),
1669                proc_addr: local_proc_addr,
1670                labels: HashMap::new(),
1671                lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
1672            })
1673            .unwrap();
1674
1675        assert_matches!(
1676            local_proc_message_receiver.recv().await.unwrap(),
1677            ProcMessage::Joined()
1678        );
1679
1680        // expect that supervision timeout which takes 1000 real seconds is hit super quickly
1681        // due to simulated time
1682        let result = client_supervision_rx
1683            .recv()
1684            .await
1685            .unwrap()
1686            .into_result()
1687            .unwrap();
1688        assert_eq!(result.0, Seq::default());
1689        assert!(result.1.expect("result").is_err());
1690
1691        let records = simnet::simnet_handle().unwrap().close().await.unwrap();
1692        eprintln!("{}", serde_json::to_string_pretty(&records).unwrap());
1693    }
1694    #[tokio::test]
1695    async fn test_supervision_failure() {
1696        // Start system actor.
1697        let timeout: Duration = Duration::from_secs(6);
1698        let server_handle = System::serve(
1699            ChannelAddr::any(ChannelTransport::Local),
1700            timeout.clone(),
1701            timeout.clone(),
1702        )
1703        .await
1704        .unwrap();
1705
1706        // Client actor.
1707        let mut system = System::new(server_handle.local_addr().clone());
1708        let client_mailbox = system.attach().await.unwrap();
1709        let (client_supervision_tx, mut client_supervision_rx) =
1710            client_mailbox.open_port::<ClientMessage>();
1711        client_supervision_tx.bind_to(ClientMessage::port());
1712
1713        // Bootstrap the controller
1714        let controller_id = id!(controller[0].root);
1715        let proc_id = id!(world[0]);
1716        let (_, actor_ref) = ControllerActor::bootstrap(
1717            controller_id.clone(),
1718            ChannelAddr::any(ChannelTransport::Local),
1719            server_handle.local_addr().clone(),
1720            ControllerParams {
1721                world_size: 1,
1722                comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1723                worker_gang_ref: GangId(
1724                    WorldId(
1725                        proc_id
1726                            .world_name()
1727                            .expect("only ranked actors are supported in the controller tests")
1728                            .to_string(),
1729                    ),
1730                    "worker".to_string(),
1731                )
1732                .into(),
1733                supervision_query_interval: Duration::from_secs(1),
1734                worker_progress_check_interval: Duration::from_secs(3),
1735                operation_timeout: Duration::from_secs(30),
1736                operations_per_worker_progress_request: 100,
1737                fail_on_worker_timeout: false,
1738            },
1739            Duration::from_secs(1),
1740            HashMap::new(),
1741        )
1742        .await
1743        .unwrap();
1744        assert_eq!(*actor_ref.actor_id(), controller_id);
1745
1746        actor_ref
1747            .attach(
1748                &client_mailbox,
1749                ActorRef::attest(client_mailbox.actor_id().clone()),
1750            )
1751            .await
1752            .unwrap();
1753
1754        // mock a proc actor that doesn't update supervision state
1755        let (
1756            world_id,
1757            local_proc_id,
1758            local_proc_addr,
1759            _,
1760            local_proc_message_port,
1761            mut local_proc_message_receiver,
1762        ) = mock_proc_actor(0, 1).await;
1763
1764        // Join the world.
1765        server_handle
1766            .system_actor_handle()
1767            .send(SystemMessage::Join {
1768                proc_id: local_proc_id.clone(),
1769                world_id,
1770                proc_message_port: local_proc_message_port.bind(),
1771                proc_addr: local_proc_addr,
1772                labels: HashMap::new(),
1773                lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
1774            })
1775            .unwrap();
1776
1777        assert_matches!(
1778            local_proc_message_receiver.recv().await.unwrap(),
1779            ProcMessage::Joined()
1780        );
1781
1782        // Wait a bit; supervision update should time out.
1783        RealClock.sleep(2 * timeout.clone()).await;
1784
1785        // Should've gotten the supervision message indicating supervision failure
1786        let result = client_supervision_rx
1787            .recv()
1788            .await
1789            .unwrap()
1790            .into_result()
1791            .unwrap();
1792        assert_eq!(result.0, Seq::default());
1793        assert!(result.1.expect("result").is_err());
1794    }
1795
1796    #[derive(
1797        Handler,
1798        HandleClient,
1799        RefClient,
1800        Named,
1801        Debug,
1802        Clone,
1803        Serialize,
1804        Deserialize,
1805        PartialEq
1806    )]
1807    enum PanickingMessage {
1808        Panic(String),
1809    }
1810
1811    #[derive(Debug, Default, Actor)]
1812    #[hyperactor::export(
1813        handlers = [
1814            PanickingMessage,
1815        ],
1816    )]
1817    struct PanickingActor;
1818
1819    #[async_trait]
1820    #[hyperactor::forward(PanickingMessage)]
1821    impl PanickingMessageHandler for PanickingActor {
1822        async fn panic(
1823            &mut self,
1824            _cx: &Context<Self>,
1825            err_msg: String,
1826        ) -> Result<(), anyhow::Error> {
1827            panic!("{}", err_msg);
1828        }
1829    }
1830
1831    hyperactor::remote!(PanickingActor);
1832
1833    #[tokio::test]
1834    async fn test_supervision_fault() {
1835        // Start system actor.
1836        let timeout: Duration = Duration::from_secs(6);
1837        let server_handle = System::serve(
1838            ChannelAddr::any(ChannelTransport::Local),
1839            timeout.clone(),
1840            timeout.clone(),
1841        )
1842        .await
1843        .unwrap();
1844
1845        // Client actor.
1846        let mut system = System::new(server_handle.local_addr().clone());
1847        let client_mailbox = system.attach().await.unwrap();
1848        let (client_supervision_tx, mut client_supervision_rx) =
1849            client_mailbox.open_port::<ClientMessage>();
1850        client_supervision_tx.bind_to(ClientMessage::port());
1851
1852        // Bootstrap the controller
1853        let controller_id = id!(controller[0].root);
1854        let proc_id = id!(world[0]);
1855        let (_, actor_ref) = ControllerActor::bootstrap(
1856            controller_id.clone(),
1857            ChannelAddr::any(ChannelTransport::Local),
1858            server_handle.local_addr().clone(),
1859            ControllerParams {
1860                world_size: 1,
1861                comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1862                worker_gang_ref: GangId(
1863                    WorldId(
1864                        proc_id
1865                            .world_name()
1866                            .expect("only ranked actors are supported in the controller tests")
1867                            .to_string(),
1868                    ),
1869                    "worker".to_string(),
1870                )
1871                .into(),
1872                supervision_query_interval: Duration::from_secs(1),
1873                worker_progress_check_interval: Duration::from_secs(3),
1874                operation_timeout: Duration::from_secs(30),
1875                operations_per_worker_progress_request: 100,
1876                fail_on_worker_timeout: false,
1877            },
1878            Duration::from_secs(1),
1879            HashMap::new(),
1880        )
1881        .await
1882        .unwrap();
1883        assert_eq!(*actor_ref.actor_id(), controller_id);
1884
1885        actor_ref
1886            .attach(
1887                &client_mailbox,
1888                ActorRef::attest(client_mailbox.actor_id().clone()),
1889            )
1890            .await
1891            .unwrap();
1892
1893        // bootstreap an actor that panics
1894        let world_id = id!(world);
1895        let panic_proc_id = world_id.proc_id(1);
1896        let bootstrap = ProcActor::bootstrap(
1897            panic_proc_id,
1898            world_id,
1899            ChannelAddr::any(ChannelTransport::Local),
1900            server_handle.local_addr().clone(),
1901            Duration::from_secs(3),
1902            HashMap::new(),
1903            ProcLifecycleMode::ManagedBySystem,
1904        )
1905        .await
1906        .unwrap();
1907        let actor_handle = spawn::<PanickingActor>(
1908            &client_mailbox,
1909            &bootstrap.proc_actor.bind(),
1910            "panicker",
1911            &(),
1912        )
1913        .await
1914        .unwrap();
1915
1916        actor_handle
1917            .panic(&client_mailbox, "some random failure".to_string())
1918            .await
1919            .unwrap();
1920
1921        // Get the supervision message with the panic
1922        let result = client_supervision_rx
1923            .recv()
1924            .await
1925            .unwrap()
1926            .into_result()
1927            .unwrap();
1928        assert_eq!(result.0, Seq::default());
1929        assert!(result.1.is_some() && result.1.as_ref().unwrap().is_err());
1930        let Exception::Failure(err) = result.1.unwrap().unwrap_err() else {
1931            panic!("Expected Failure exception");
1932        };
1933        assert!(err.backtrace.contains("some random failure"));
1934    }
1935}