controller/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![feature(assert_matches)]
10// NOTE: Until https://github.com/PyO3/pyo3/pull/4674, `pyo3::pymethods` trigger
11// and unsafe-op-in-unsafe-fn warnings.
12#![allow(unsafe_op_in_unsafe_fn)]
13
14pub mod bootstrap;
15pub mod history;
16
17use std::collections::HashMap;
18use std::collections::HashSet;
19use std::time::Duration;
20
21use async_trait::async_trait;
22use hyperactor::Actor;
23use hyperactor::ActorId;
24use hyperactor::ActorRef;
25use hyperactor::Context;
26use hyperactor::GangId;
27use hyperactor::GangRef;
28use hyperactor::Handler;
29use hyperactor::Named;
30use hyperactor::actor::ActorHandle;
31use hyperactor::actor::ActorStatus;
32use hyperactor::channel::ChannelAddr;
33use hyperactor::clock::Clock;
34use hyperactor::context;
35use hyperactor::data::Serialized;
36use hyperactor_mesh::comm::CommActor;
37use hyperactor_mesh::comm::CommActorMode;
38use hyperactor_mesh::comm::multicast::CastMessage;
39use hyperactor_mesh::comm::multicast::CastMessageEnvelope;
40use hyperactor_mesh::comm::multicast::DestinationPort;
41use hyperactor_mesh::comm::multicast::Uslice;
42use hyperactor_mesh::reference::ActorMeshId;
43use hyperactor_mesh::reference::ProcMeshId;
44use hyperactor_multiprocess::proc_actor::ProcActor;
45use hyperactor_multiprocess::proc_actor::spawn;
46use hyperactor_multiprocess::supervision::WorldSupervisionMessageClient;
47use hyperactor_multiprocess::supervision::WorldSupervisor;
48use hyperactor_multiprocess::system_actor::ProcLifecycleMode;
49use hyperactor_multiprocess::system_actor::SYSTEM_ACTOR_REF;
50use monarch_messages::client::ClientActor;
51use monarch_messages::client::ClientMessageClient;
52use monarch_messages::client::Exception;
53use monarch_messages::client::LogLevel;
54use monarch_messages::controller::ControllerMessage;
55use monarch_messages::controller::ControllerMessageHandler;
56use monarch_messages::controller::DeviceFailure;
57use monarch_messages::controller::Ranks;
58use monarch_messages::controller::Seq;
59use monarch_messages::controller::WorkerError;
60use monarch_messages::debugger::DebuggerAction;
61use monarch_messages::worker::Ref;
62use monarch_messages::worker::WorkerActor;
63use monarch_messages::worker::WorkerMessage;
64use ndslice::Selection;
65use ndslice::Shape;
66use ndslice::Slice;
67use ndslice::reshape::Limit;
68use ndslice::reshape::ReshapeShapeExt;
69use ndslice::selection::dsl;
70use ndslice::shape::Range;
71use serde::Deserialize;
72use serde::Serialize;
73use tokio::sync::OnceCell;
74
75const CASTING_FANOUT_SIZE: usize = 8;
76
77/// A controller for the workers that will be leveraged by the client to do the actual
78/// compute tasks. This acts a proxy managing comms with the workers and handling things like history,
79/// data dependency, worker lifecycles etc for the client abstracting it away.
80#[derive(Debug)]
81#[hyperactor::export(
82    spawn = true,
83    handlers = [
84        ControllerMessage,
85    ],
86)]
87pub(crate) struct ControllerActor {
88    client_actor_ref: OnceCell<ActorRef<ClientActor>>,
89    comm_actor_ref: ActorRef<CommActor>,
90    worker_gang_ref: GangRef<WorkerActor>,
91    history: history::History,
92    supervision_query_interval: Duration,
93    system_supervision_actor_ref: ActorRef<WorldSupervisor>,
94    worker_progress_check_interval: Duration,
95    operation_timeout: Duration,
96    operations_per_worker_progress_request: u64,
97    // The Seq and time we last sent out a WorkerMessage::RequestStatus.
98    last_controller_request_status: Option<(Seq, tokio::time::Instant)>,
99    fail_on_worker_timeout: bool,
100    world_size: usize,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize, Named)]
104pub(crate) struct ControllerParams {
105    /// The world size to track the size of all the workers.
106    pub(crate) world_size: usize,
107
108    /// Reference to the comm actor. It must be configured to target
109    /// the worker gang. The controller takes "ownership" of this actor:
110    /// it is immediately configured to target the worker gang.
111    /// This is a temporary workaround until we are fully on meshes.
112    pub(crate) comm_actor_ref: ActorRef<CommActor>,
113
114    /// Reference to the workers to send commands to.
115    pub(crate) worker_gang_ref: GangRef<WorkerActor>,
116
117    // How often to query world supervision status against system actor.
118    pub(crate) supervision_query_interval: Duration,
119
120    // How often to query for if workers are making progress.
121    pub(crate) worker_progress_check_interval: Duration,
122
123    // How long to wait for an operation to complete before considering it timed out.
124    pub(crate) operation_timeout: Duration,
125
126    // How many operations are enqueued before we request a progress update on workers.
127    pub(crate) operations_per_worker_progress_request: u64,
128
129    // If a failure should be propagated back to the client if workers are detected to be stuck.
130    pub(crate) fail_on_worker_timeout: bool,
131}
132
133#[async_trait]
134impl Actor for ControllerActor {
135    type Params = ControllerParams;
136
137    async fn new(params: ControllerParams) -> Result<Self, anyhow::Error> {
138        Ok(Self {
139            client_actor_ref: OnceCell::new(),
140            comm_actor_ref: params.comm_actor_ref,
141            worker_gang_ref: params.worker_gang_ref,
142            history: history::History::new(params.world_size),
143            supervision_query_interval: params.supervision_query_interval,
144            system_supervision_actor_ref: ActorRef::attest(SYSTEM_ACTOR_REF.actor_id().clone()),
145            worker_progress_check_interval: params.worker_progress_check_interval,
146            operation_timeout: params.operation_timeout,
147            operations_per_worker_progress_request: params.operations_per_worker_progress_request,
148            last_controller_request_status: None,
149            fail_on_worker_timeout: params.fail_on_worker_timeout,
150            world_size: params.world_size,
151        })
152    }
153
154    async fn init(&mut self, cx: &hyperactor::Instance<Self>) -> Result<(), anyhow::Error> {
155        self.comm_actor_ref.send(
156            cx,
157            CommActorMode::ImplicitWithWorldId(self.worker_gang_ref.gang_id().world_id().clone()),
158        )?;
159        Ok(())
160    }
161}
162
163impl ControllerActor {
164    /// Bootstrap the controller actor. This will create a new proc, join the system at `bootstrap_addr`
165    /// and spawn the controller actor into the proc. `labels` is an arbitrary set of name/value pairs
166    /// to be attached to the proc in system registry which can be used later to query and find the proc(s)
167    /// using system's snapshot api.
168    pub async fn bootstrap(
169        controller_id: ActorId,
170        listen_addr: ChannelAddr,
171        bootstrap_addr: ChannelAddr,
172        params: ControllerParams,
173        supervision_update_interval: Duration,
174        labels: HashMap<String, String>,
175    ) -> Result<(ActorHandle<ProcActor>, ActorRef<ControllerActor>), anyhow::Error> {
176        let bootstrap = ProcActor::bootstrap(
177            controller_id.proc_id().clone(),
178            controller_id
179                .proc_id()
180                .world_id()
181                .expect("multiprocess supports only ranked procs")
182                .clone(), // REFACTOR(marius): make world_id a parameter of ControllerActor::bootstrap
183            listen_addr,
184            bootstrap_addr.clone(),
185            supervision_update_interval,
186            labels,
187            ProcLifecycleMode::ManagedBySystem,
188        )
189        .await?;
190
191        let mut system = hyperactor_multiprocess::System::new(bootstrap_addr);
192        let client = system.attach().await?;
193
194        let controller_actor_ref = spawn::<ControllerActor>(
195            &client,
196            &bootstrap.proc_actor.bind(),
197            controller_id.clone().name(),
198            &ControllerParams {
199                comm_actor_ref: bootstrap.comm_actor.bind(),
200                ..params
201            },
202        )
203        .await?;
204
205        Ok((bootstrap.proc_actor, controller_actor_ref))
206    }
207
208    fn client(&self) -> Result<ActorRef<ClientActor>, anyhow::Error> {
209        self.client_actor_ref
210            .get()
211            .ok_or_else(|| anyhow::anyhow!("client actor ref not set"))
212            .cloned()
213    }
214
215    // Send a request_status for the seq we expect to complete by our next deadline if it is more than
216    // N ops ahead of our last request_status, or if M seconds passed where:
217    //
218    // N = self.operations_per_worker_progress_request
219    // M = self.worker_progress_check_interval
220    async fn request_status_if_needed(
221        &mut self,
222        cx: &Context<'_, Self>,
223    ) -> Result<(), anyhow::Error> {
224        if let Some((expected_seq, ..)) = self.history.deadline(
225            self.operations_per_worker_progress_request,
226            self.operation_timeout,
227            cx.clock(),
228        ) {
229            if self.last_controller_request_status.is_none_or(
230                |(last_requested_seq, last_requested_time)| {
231                    (expected_seq
232                        >= (u64::from(last_requested_seq)
233                            + self.operations_per_worker_progress_request)
234                            .into()
235                        || last_requested_time.elapsed() > self.worker_progress_check_interval)
236                        && last_requested_seq != expected_seq
237                },
238            ) {
239                // Send to all workers.
240                self.send(
241                    cx,
242                    Ranks::Slice(
243                        ndslice::Slice::new(0, vec![self.history.world_size()], vec![1]).unwrap(),
244                    ),
245                    Serialized::serialize(&WorkerMessage::RequestStatus {
246                        seq: expected_seq.clone(),
247                        controller: true,
248                    })
249                    .unwrap(),
250                )
251                .await?;
252
253                self.last_controller_request_status =
254                    Some((expected_seq.clone(), cx.clock().now()));
255            }
256        }
257
258        Ok(())
259    }
260}
261
262#[derive(Debug)]
263struct CheckWorkerProgress;
264
265#[async_trait]
266impl Handler<CheckWorkerProgress> for ControllerActor {
267    async fn handle(
268        &mut self,
269        cx: &Context<Self>,
270        _check_worker_progress: CheckWorkerProgress,
271    ) -> Result<(), anyhow::Error> {
272        let client = self.client()?;
273
274        if let Some((expected_seq, deadline, reported)) = self.history.deadline(
275            self.operations_per_worker_progress_request,
276            self.operation_timeout,
277            cx.clock(),
278        ) {
279            if !reported
280                && cx.clock().now() > deadline
281                && expected_seq >= self.history.min_incomplete_seq_reported()
282            {
283                let timed_out_ranks = self
284                    .history
285                    .first_incomplete_seqs_controller()
286                    .iter()
287                    .enumerate()
288                    .filter(|(_, seq)| seq <= &&expected_seq)
289                    .map(|(rank, _)| rank)
290                    .collect::<Vec<_>>();
291
292                let failed_rank = timed_out_ranks.first().unwrap().clone();
293
294                let timed_out_ranks_string = timed_out_ranks
295                    .into_iter()
296                    .map(|rank| rank.to_string())
297                    .collect::<Vec<_>>()
298                    .join(", ");
299
300                let message = format!(
301                    "ranks {} have operations that have not completed after {} seconds",
302                    timed_out_ranks_string,
303                    self.operation_timeout.as_secs()
304                );
305                if client
306                    .log(cx, LogLevel::Warn, message.clone())
307                    .await
308                    .is_ok()
309                {
310                    self.history.report_deadline_missed();
311                }
312
313                if self.fail_on_worker_timeout {
314                    client
315                        .result(
316                            cx,
317                            expected_seq,
318                            Some(Err(Exception::Failure(DeviceFailure {
319                                actor_id: self.worker_gang_ref.rank(failed_rank).actor_id().clone(),
320                                address: "unknown".into(),
321                                backtrace: message,
322                            }))),
323                        )
324                        .await?;
325                }
326            }
327            self.request_status_if_needed(cx).await?;
328        }
329
330        cx.self_message_with_delay(CheckWorkerProgress, self.worker_progress_check_interval)?;
331        Ok(())
332    }
333}
334
335/// Hacky translation from a sub-`Slice` to a `Selection.
336fn slice_to_selection(slice: Slice) -> Selection {
337    match (slice.sizes(), slice.strides()) {
338        // Special case exact rank `Selection`.
339        ([], []) => dsl::range(slice.offset()..=slice.offset(), dsl::true_()),
340        // Special case trivial range `Selection`.
341        ([size, rsizes @ ..], [stride, ..]) if rsizes.iter().all(|s| *s == 1) => dsl::range(
342            Range(
343                slice.offset(),
344                Some(slice.offset() + *size * *stride),
345                *stride,
346            ),
347            dsl::true_(),
348        ),
349        // Fallback to more heavy-weight translation for everything else.
350        _ => {
351            let mut selection = Selection::False;
352            let mut selected_ranks = HashSet::new();
353            for rank in slice.iter() {
354                if !selected_ranks.insert(rank) {
355                    continue;
356                }
357                selection = dsl::union(dsl::range(rank..=rank, dsl::true_()), selection);
358            }
359            selection
360        }
361    }
362}
363
364#[async_trait]
365#[hyperactor::forward(ControllerMessage)]
366impl ControllerMessageHandler for ControllerActor {
367    async fn attach(
368        &mut self,
369        cx: &Context<Self>,
370        client_actor: ActorRef<ClientActor>,
371    ) -> Result<(), anyhow::Error> {
372        tracing::debug!("attaching client actor {}", client_actor);
373        self.client_actor_ref
374            .set(client_actor)
375            .map_err(|actor_ref| anyhow::anyhow!("client actor {} already attached", actor_ref))?;
376
377        // Trigger periodical checking of supervision status and worker progress.
378        cx.self_message_with_delay(
379            ControllerMessage::CheckSupervision {},
380            self.supervision_query_interval,
381        )?;
382        cx.self_message_with_delay(CheckWorkerProgress, self.worker_progress_check_interval)?;
383        Ok(())
384    }
385
386    async fn node(
387        &mut self,
388        cx: &Context<Self>,
389        seq: Seq,
390        defs: Vec<Ref>,
391        uses: Vec<Ref>,
392    ) -> Result<(), anyhow::Error> {
393        let failures = self.history.add_invocation(seq, uses, defs);
394        let client = self.client()?;
395        for (seq, failure) in failures {
396            let _ = client.result(cx, seq, failure).await;
397        }
398        self.request_status_if_needed(cx).await?;
399
400        Ok(())
401    }
402
403    async fn drop_refs(
404        &mut self,
405        _cx: &Context<Self>,
406        refs: Vec<Ref>,
407    ) -> Result<(), anyhow::Error> {
408        self.history.delete_invocations_for_refs(refs);
409        Ok(())
410    }
411
412    async fn send(
413        &mut self,
414        cx: &Context<Self>,
415        ranks: Ranks,
416        message: Serialized,
417    ) -> Result<(), anyhow::Error> {
418        let selection = match ranks {
419            Ranks::Slice(slice) => {
420                if slice.len() == self.world_size {
421                    // All ranks are selected.
422                    Selection::True
423                } else {
424                    slice_to_selection(slice)
425                }
426            }
427            Ranks::SliceList(slices) => slices.into_iter().fold(dsl::false_(), |sel, slice| {
428                dsl::union(sel, slice_to_selection(slice))
429            }),
430        };
431
432        let slice = Slice::new(0usize, vec![self.world_size], vec![1])?;
433        // Use a made-up label to create a fake shape. This shape is used by
434        // comm actor to determine the cast rank. Cast rank is not used by
435        // DeviceMesh, but we still need a shape there to make the logic happy.
436        let made_up_shape = Shape::new(vec!["fake_in_controller".to_string()], slice.clone())?
437            .reshape(Limit::from(CASTING_FANOUT_SIZE))
438            .shape;
439
440        let message = CastMessageEnvelope::from_serialized(
441            ActorMeshId::V0(
442                ProcMeshId(self.worker_gang_ref.gang_id().world_id().to_string()),
443                self.worker_gang_ref.gang_id().name().to_string(),
444            ),
445            cx.self_id().clone(),
446            DestinationPort::new::<WorkerActor, WorkerMessage>(
447                // This is awkward, but goes away entirely with meshes.
448                self.worker_gang_ref
449                    .gang_id()
450                    .actor_id(0)
451                    .name()
452                    .to_string(),
453            ),
454            made_up_shape,
455            message,
456        );
457
458        self.comm_actor_ref.send(
459            cx,
460            CastMessage {
461                dest: Uslice {
462                    // TODO: pass both slice and selection from client side
463                    slice,
464                    selection,
465                },
466                message,
467            },
468        )?;
469        Ok(())
470    }
471
472    async fn remote_function_failed(
473        &mut self,
474        cx: &Context<Self>,
475        seq: Seq,
476        error: WorkerError,
477    ) -> Result<(), anyhow::Error> {
478        let rank = error.worker_actor_id.rank();
479        self.history
480            .propagate_exception(seq, Exception::Error(seq, seq, error.clone()));
481        mark_worker_complete_and_propagate_exceptions(cx, self, rank, &seq).await?;
482        Ok(())
483    }
484
485    async fn status(
486        &mut self,
487        cx: &Context<Self>,
488        seq: Seq,
489        worker_actor_id: ActorId,
490        controller: bool,
491    ) -> Result<(), anyhow::Error> {
492        let rank = worker_actor_id.rank();
493
494        if controller {
495            self.history.update_deadline_tracking(rank, seq);
496        } else {
497            mark_worker_complete_and_propagate_exceptions(cx, self, rank, &seq).await?;
498        }
499        Ok(())
500    }
501
502    async fn fetch_result(
503        &mut self,
504        _cx: &Context<Self>,
505        seq: Seq,
506        result: Result<Serialized, WorkerError>,
507    ) -> Result<(), anyhow::Error> {
508        self.history.set_result(seq, result);
509        Ok(())
510    }
511
512    async fn check_supervision(&mut self, cx: &Context<Self>) -> Result<(), anyhow::Error> {
513        let gang_id: GangId = self.worker_gang_ref.clone().into();
514        let world_state = self
515            .system_supervision_actor_ref
516            .state(cx, gang_id.world_id().clone())
517            .await?;
518
519        if let Some(world_state) = world_state {
520            if !world_state.procs.is_empty() {
521                tracing::error!(
522                    "found procs with failures in world {}, state: {:?}",
523                    gang_id.world_id(),
524                    world_state
525                );
526
527                // Randomly pick a failed proc as the failed actor.
528                let (_, failed_state) = world_state.procs.iter().next().unwrap();
529                let (failed_actor, failure_reason) =
530                    failed_state.failed_actors.first().map_or_else(
531                        || {
532                            let proc_id = &failed_state.proc_id;
533                            (
534                                ActorId(proc_id.clone(), "none".into(), 0),
535                                format!(
536                                    "proc is dead due to heartbeat timeout; no backtrace is \
537                                    available; check the log of host {} running proc {} to \
538                                    figure out the root cause",
539                                    failed_state.proc_addr, proc_id
540                                ),
541                            )
542                        },
543                        |(actor, status)| {
544                            (
545                                actor.clone(),
546                                match status {
547                                    ActorStatus::Failed(msg) => msg.clone(),
548                                    _ => format!("unexpected actor status {status}"),
549                                },
550                            )
551                        },
552                    );
553
554                let exc = Exception::Failure(DeviceFailure {
555                    actor_id: failed_actor,
556                    address: failed_state.proc_addr.to_string(),
557                    backtrace: failure_reason,
558                });
559                tracing::error!("Sending failure to client: {exc:?}");
560                // Seq does not matter as the client will raise device error immediately before setting the results.
561                self.client()?
562                    .result(cx, Seq::default(), Some(Err(exc)))
563                    .await?;
564                tracing::error!("Failure successfully sent to client");
565
566                // No need to set history failures as we are directly sending back failure results.
567            }
568        }
569
570        // Schedule the next supervision check.
571        cx.self_message_with_delay(
572            ControllerMessage::CheckSupervision {},
573            self.supervision_query_interval,
574        )?;
575        Ok(())
576    }
577
578    async fn debugger_message(
579        &mut self,
580        cx: &Context<Self>,
581        debugger_actor_id: ActorId,
582        action: DebuggerAction,
583    ) -> Result<(), anyhow::Error> {
584        self.client()?
585            .debugger_message(cx, debugger_actor_id, action)
586            .await
587    }
588
589    #[cfg(test)]
590    async fn get_first_incomplete_seqs_unit_tests_only(
591        &mut self,
592        _cx: &Context<Self>,
593    ) -> Result<Vec<Seq>, anyhow::Error> {
594        Ok(self.history.first_incomplete_seqs().to_vec())
595    }
596
597    #[cfg(not(test))]
598    async fn get_first_incomplete_seqs_unit_tests_only(
599        &mut self,
600        _cx: &Context<Self>,
601    ) -> Result<Vec<Seq>, anyhow::Error> {
602        unimplemented!("get_first_incomplete_seqs_unit_tests_only is only for unit tests")
603    }
604}
605
606async fn mark_worker_complete_and_propagate_exceptions(
607    cx: &impl context::Actor,
608    actor: &mut ControllerActor,
609    rank: usize,
610    seq: &Seq,
611) -> Result<(), anyhow::Error> {
612    let results = actor.history.rank_completed(rank, seq.clone());
613    let client = actor.client()?;
614    // Propagate the failures to the clients.
615    for (seq, result) in results.iter() {
616        let _ = client.result(cx, seq.clone(), result.clone()).await;
617    }
618    Ok(())
619}
620
621#[cfg(test)]
622mod tests {
623    use core::panic;
624    use std::assert_matches::assert_matches;
625    use std::collections::HashMap;
626    use std::collections::HashSet;
627    use std::time::Duration;
628
629    use hyperactor::HandleClient;
630    use hyperactor::Handler;
631    use hyperactor::RefClient;
632    use hyperactor::channel;
633    use hyperactor::channel::ChannelTransport;
634    use hyperactor::clock::Clock;
635    use hyperactor::clock::RealClock;
636    use hyperactor::context::Mailbox as _;
637    use hyperactor::data::Named;
638    use hyperactor::id;
639    use hyperactor::mailbox::BoxedMailboxSender;
640    use hyperactor::mailbox::DialMailboxRouter;
641    use hyperactor::mailbox::Mailbox;
642    use hyperactor::mailbox::MailboxClient;
643    use hyperactor::mailbox::MailboxServer;
644    use hyperactor::mailbox::PortHandle;
645    use hyperactor::mailbox::PortReceiver;
646    use hyperactor::message::IndexedErasedUnbound;
647    use hyperactor::proc::Proc;
648    use hyperactor::reference::GangId;
649    use hyperactor::reference::ProcId;
650    use hyperactor::reference::WorldId;
651    use hyperactor::simnet;
652    use hyperactor_mesh::comm::CommActorParams;
653    use hyperactor_multiprocess::System;
654    use hyperactor_multiprocess::proc_actor::ProcMessage;
655    use hyperactor_multiprocess::supervision::ProcSupervisionMessage;
656    use hyperactor_multiprocess::supervision::ProcSupervisor;
657    use hyperactor_multiprocess::system_actor::SystemMessage;
658    use monarch_messages::client::ClientMessage;
659    use monarch_messages::controller::ControllerMessageClient;
660    use monarch_messages::wire_value::WireValue;
661    use monarch_messages::worker::CallFunctionParams;
662    use monarch_messages::worker::WorkerMessage;
663    use monarch_types::PyTree;
664    use torch_sys::RValue;
665
666    use super::*;
667
668    #[tokio::test]
669    async fn basic_controller() {
670        // TODO: Add a proper multiworker test
671        let proc = Proc::local();
672        let (client, client_ref, mut client_rx) = proc
673            .attach_actor::<ClientActor, ClientMessage>("client")
674            .unwrap();
675        let (worker, worker_ref, mut worker_rx) = proc
676            .attach_actor::<WorkerActor, WorkerMessage>("worker")
677            .unwrap();
678
679        IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(
680            worker_ref.clone(),
681            worker.clone_for_py(),
682            worker.mailbox().clone(),
683        )
684        .unwrap();
685
686        let comm_handle = proc
687            .spawn::<CommActor>("comm", CommActorParams {})
688            .await
689            .unwrap();
690
691        let controller_handle = proc
692            .spawn::<ControllerActor>(
693                "controller",
694                ControllerParams {
695                    world_size: 1,
696                    comm_actor_ref: comm_handle.bind(),
697                    worker_gang_ref: GangId(
698                        WorldId(
699                            proc.proc_id()
700                                .world_name()
701                                .expect("only ranked actors are supported in the controller tests")
702                                .to_string(),
703                        ),
704                        "worker".to_string(),
705                    )
706                    .into(),
707                    supervision_query_interval: Duration::from_secs(1),
708                    worker_progress_check_interval: Duration::from_secs(3),
709                    operation_timeout: Duration::from_secs(30),
710                    operations_per_worker_progress_request: 100,
711                    fail_on_worker_timeout: false,
712                },
713            )
714            .await
715            .unwrap();
716
717        controller_handle.attach(&client, client_ref).await.unwrap();
718
719        controller_handle
720            .node(&client, 0.into(), vec![0.into()], vec![])
721            .await
722            .unwrap();
723        controller_handle
724            .node(&client, 1.into(), vec![1.into(), 2.into()], vec![0.into()])
725            .await
726            .unwrap();
727        controller_handle
728            .node(&client, 20.into(), vec![3.into(), 4.into()], vec![])
729            .await
730            .unwrap();
731
732        ControllerMessageClient::send(
733            &controller_handle,
734            &worker,
735            Ranks::Slice(ndslice::Slice::new(0, vec![1], vec![1]).unwrap()),
736            Serialized::serialize(&WorkerMessage::CallFunction(CallFunctionParams {
737                seq: 1.into(),
738                results: vec![Some(1.into()), Some(2.into())],
739                mutates: vec![],
740                function: "os.path.split".into(),
741                args: vec![WireValue::String("/fbs/fbc/foo/bar".into())],
742                kwargs: HashMap::new(),
743                stream: 1.into(),
744                remote_process_groups: vec![],
745            }))
746            .unwrap(),
747        )
748        .await
749        .unwrap();
750
751        ControllerMessageClient::status(
752            &controller_handle,
753            &worker,
754            0.into(),
755            worker_ref.actor_id().clone(),
756            false,
757        )
758        .await
759        .unwrap();
760        let incomplete_seqs = controller_handle
761            .get_first_incomplete_seqs_unit_tests_only(&worker)
762            .await
763            .unwrap();
764        assert_eq!(incomplete_seqs[0], 0.into());
765
766        controller_handle
767            .remote_function_failed(
768                &worker,
769                1.into(),
770                WorkerError {
771                    backtrace: "some failure happened!".to_string(),
772                    worker_actor_id: worker_ref.actor_id().clone(),
773                },
774            )
775            .await
776            .unwrap();
777        ControllerMessageClient::status(
778            &controller_handle,
779            &worker,
780            2.into(),
781            worker_ref.actor_id().clone(),
782            false,
783        )
784        .await
785        .unwrap();
786
787        let incomplete_seqs = controller_handle
788            .get_first_incomplete_seqs_unit_tests_only(&worker)
789            .await
790            .unwrap();
791        assert_eq!(incomplete_seqs[0], 2.into());
792
793        controller_handle
794            .fetch_result(
795                &worker,
796                20.into(),
797                Ok(Serialized::serialize(&PyTree::from(RValue::Int(42))).unwrap()),
798            )
799            .await
800            .unwrap();
801
802        // Omly a status message can trigger a fetch result to the client.
803        ControllerMessageClient::status(
804            &controller_handle,
805            &worker,
806            21.into(),
807            worker_ref.actor_id().clone(),
808            false,
809        )
810        .await
811        .unwrap();
812
813        let incomplete_seqs = controller_handle
814            .get_first_incomplete_seqs_unit_tests_only(&worker)
815            .await
816            .unwrap();
817        assert_eq!(incomplete_seqs[0], 21.into());
818
819        controller_handle.drain_and_stop().unwrap();
820        controller_handle.await;
821        let worker_messages: Vec<WorkerMessage> = worker_rx.drain();
822        assert_eq!(
823            worker_messages
824                .iter()
825                .filter(|msg| !matches!(msg, WorkerMessage::RequestStatus { .. }))
826                .count(),
827            1
828        );
829        let client_messages = client_rx.drain();
830        assert_eq!(client_messages.len(), 3);
831        let client_message = client_messages[1].clone().into_result().unwrap();
832        assert_eq!(client_message.0, 1.into());
833        assert_eq!(
834            client_message.1,
835            Some(Err(Exception::Error(
836                1.into(),
837                1.into(),
838                WorkerError {
839                    backtrace: "some failure happened!".to_string(),
840                    worker_actor_id: worker_ref.actor_id().clone(),
841                }
842            )))
843        );
844
845        let client_message = client_messages[2].clone().into_result().unwrap();
846        assert_eq!(client_message.0, 20.into());
847        assert_matches!(
848            client_message
849                .1
850                .unwrap()
851                .unwrap()
852                .deserialized::<PyTree<RValue>>()
853                .unwrap()
854                .into_leaf()
855                .unwrap(),
856            RValue::Int(42),
857        );
858    }
859
860    #[tokio::test]
861    async fn worker_timeout() {
862        tokio::time::pause();
863        let timeout_secs = 3;
864        let proc = Proc::local();
865
866        let (client, client_ref, mut client_rx) = proc
867            .attach_actor::<ClientActor, ClientMessage>("client")
868            .unwrap();
869        let (worker, worker_ref, mut worker_rx) = proc
870            .attach_actor::<WorkerActor, WorkerMessage>("worker")
871            .unwrap();
872        IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(
873            worker_ref.clone(),
874            worker.clone_for_py(),
875            worker.mailbox().clone(),
876        )
877        .unwrap();
878
879        let comm_handle = proc
880            .spawn::<CommActor>("comm", CommActorParams {})
881            .await
882            .unwrap();
883
884        let controller_handle = proc
885            .spawn::<ControllerActor>(
886                "controller",
887                ControllerParams {
888                    world_size: 1,
889                    comm_actor_ref: comm_handle.bind(),
890                    worker_gang_ref: GangId(
891                        WorldId(
892                            proc.proc_id()
893                                .world_name()
894                                .expect("only ranked actors are supported in the controller tests")
895                                .to_string(),
896                        ),
897                        "worker".to_string(),
898                    )
899                    .into(),
900                    supervision_query_interval: Duration::from_secs(100000),
901                    worker_progress_check_interval: Duration::from_secs(1),
902                    operation_timeout: Duration::from_secs(timeout_secs),
903                    operations_per_worker_progress_request: 100,
904                    fail_on_worker_timeout: false,
905                },
906            )
907            .await
908            .unwrap();
909
910        controller_handle.attach(&client, client_ref).await.unwrap();
911
912        controller_handle
913            .node(&client, 0.into(), vec![0.into()], vec![])
914            .await
915            .unwrap();
916
917        // Expect that our handler for CheckWorkerProgress will issue RequestWorkerCompletedSeq
918        match worker_rx.recv().await.unwrap().into_request_status().ok() {
919            Some((seq, controller)) if seq == 0.into() && controller => {
920                // Simulate WorkerActor::RequestWorkerCompletedSeq if joining streams takes shorter
921                // than timeout
922                for _ in 0..timeout_secs {
923                    tokio::time::advance(Duration::from_secs(1)).await;
924                }
925
926                ControllerMessageClient::status(
927                    &controller_handle,
928                    &worker,
929                    1.into(),
930                    worker_ref.actor_id().clone(),
931                    true,
932                )
933                .await
934                .unwrap();
935            }
936            _ => panic!("Expected request status message for seq 0"),
937        }
938
939        // Should have no warnings
940        let client_messages = client_rx.drain();
941        assert_eq!(client_messages.len(), 0);
942
943        controller_handle
944            .node(&client, 1.into(), vec![], vec![])
945            .await
946            .unwrap();
947
948        // Expect that our handler for CheckWorkerProgress will issue RequestWorkerCompletedSeq
949        match worker_rx.recv().await.unwrap().into_request_status().ok() {
950            Some((seq, controller)) if seq == 1.into() && controller => {
951                // Simulate WorkerActor::RequestWorkerCompletedSeq if joining streams takes longer
952                // than timeout
953                for _ in 0..timeout_secs * 2 {
954                    tokio::time::advance(Duration::from_secs(1)).await;
955                }
956
957                ControllerMessageClient::status(
958                    &controller_handle,
959                    &worker,
960                    2.into(),
961                    worker_ref.actor_id().clone(),
962                    true,
963                )
964                .await
965                .unwrap();
966            }
967            _ => panic!("Expected request status message for seq 1"),
968        }
969
970        let client_messages = client_rx.drain();
971        assert_eq!(client_messages.len(), 1);
972
973        let (level, message) = client_messages[0].clone().into_log().unwrap();
974        assert_matches!(level, LogLevel::Warn);
975        assert_eq!(
976            message,
977            "ranks 0 have operations that have not completed after 3 seconds"
978        );
979    }
980
981    #[tokio::test]
982    async fn test_failure_on_worker_timeout() {
983        tokio::time::pause();
984        let timeout_secs = 3;
985        let proc = Proc::local();
986
987        let (client, client_ref, mut client_rx) = proc
988            .attach_actor::<ClientActor, ClientMessage>("client")
989            .unwrap();
990
991        let (worker, worker_ref, mut worker_rx) = proc
992            .attach_actor::<WorkerActor, WorkerMessage>("worker")
993            .unwrap();
994        IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(
995            worker_ref.clone(),
996            worker.clone_for_py(),
997            worker.mailbox().clone(),
998        )
999        .unwrap();
1000
1001        let comm_handle = proc
1002            .spawn::<CommActor>("comm", CommActorParams {})
1003            .await
1004            .unwrap();
1005
1006        let world_id = WorldId(
1007            proc.proc_id()
1008                .world_name()
1009                .expect("only ranked actors are supported in the controller tests")
1010                .to_string(),
1011        );
1012        let controller_handle = proc
1013            .spawn::<ControllerActor>(
1014                "controller",
1015                ControllerParams {
1016                    world_size: 1,
1017                    comm_actor_ref: comm_handle.bind(),
1018                    worker_gang_ref: GangId(world_id, "worker".to_string()).into(),
1019                    supervision_query_interval: Duration::from_secs(100000),
1020                    worker_progress_check_interval: Duration::from_secs(1),
1021                    operation_timeout: Duration::from_secs(timeout_secs),
1022                    operations_per_worker_progress_request: 100,
1023                    fail_on_worker_timeout: true,
1024                },
1025            )
1026            .await
1027            .unwrap();
1028
1029        controller_handle.attach(&client, client_ref).await.unwrap();
1030
1031        controller_handle
1032            .node(&client, 0.into(), vec![0.into()], vec![])
1033            .await
1034            .unwrap();
1035
1036        // Expect that our handler for CheckWorkerProgress will issue RequestWorkerCompletedSeq
1037        match worker_rx.recv().await.unwrap().into_request_status().ok() {
1038            Some((seq, controller)) if seq == 0.into() && controller => {
1039                // Simulate WorkerActor::RequestWorkerCompletedSeq if joining streams takes shorter
1040                // than timeout
1041                for _ in 0..timeout_secs {
1042                    tokio::time::advance(Duration::from_secs(1)).await;
1043                }
1044
1045                ControllerMessageClient::status(
1046                    &controller_handle,
1047                    &worker,
1048                    1.into(),
1049                    worker_ref.actor_id().clone(),
1050                    true,
1051                )
1052                .await
1053                .unwrap();
1054            }
1055            _ => panic!("Expected request status message for seq 0"),
1056        }
1057
1058        // Should have no warnings
1059        let client_messages = client_rx.drain();
1060        assert_eq!(client_messages.len(), 0);
1061
1062        controller_handle
1063            .node(&client, 1.into(), vec![], vec![])
1064            .await
1065            .unwrap();
1066
1067        // Expect that our handler for CheckWorkerProgress will issue RequestWorkerCompletedSeq
1068        match worker_rx.recv().await.unwrap().into_request_status().ok() {
1069            Some((seq, controller)) if seq == 1.into() && controller => {
1070                // Simulate WorkerActor::RequestWorkerCompletedSeq if joining streams takes longer
1071                // than timeout
1072                for _ in 0..timeout_secs * 2 {
1073                    tokio::time::advance(Duration::from_secs(1)).await;
1074                }
1075
1076                ControllerMessageClient::status(
1077                    &controller_handle,
1078                    &worker,
1079                    2.into(),
1080                    worker_ref.actor_id().clone(),
1081                    true,
1082                )
1083                .await
1084                .unwrap();
1085            }
1086            _ => panic!("Expected request status message for seq 1"),
1087        }
1088
1089        let client_messages = client_rx.drain();
1090        assert_eq!(client_messages.len(), 2);
1091
1092        let (level, message) = client_messages[0].clone().into_log().unwrap();
1093        assert_matches!(level, LogLevel::Warn);
1094        assert_eq!(
1095            message,
1096            "ranks 0 have operations that have not completed after 3 seconds"
1097        );
1098
1099        let (seq, failure) = client_messages[1].clone().into_result().unwrap();
1100        assert_eq!(seq, 1.into());
1101        let DeviceFailure {
1102            backtrace,
1103            actor_id,
1104            ..
1105        } = failure
1106            .unwrap()
1107            .err()
1108            .unwrap()
1109            .as_failure()
1110            .unwrap()
1111            .clone();
1112        assert_eq!(actor_id, proc.proc_id().actor_id("worker", 0));
1113        assert!(
1114            backtrace.contains("ranks 0 have operations that have not completed after 3 seconds")
1115        );
1116    }
1117
1118    #[tokio::test]
1119    async fn failure_propagation() {
1120        // Serve a system.
1121        let server_handle = System::serve(
1122            ChannelAddr::any(ChannelTransport::Local),
1123            Duration::from_secs(10),
1124            Duration::from_secs(10),
1125        )
1126        .await
1127        .unwrap();
1128        let mut system = System::new(server_handle.local_addr().clone());
1129
1130        // Build a supervisor.
1131        let sup_mail = system.attach().await.unwrap();
1132        let (sup_tx, _sup_rx) = sup_mail.open_port::<ProcSupervisionMessage>();
1133        sup_tx.bind_to(ProcSupervisionMessage::port());
1134        let sup_ref = ActorRef::<ProcSupervisor>::attest(sup_mail.self_id().clone());
1135
1136        // Construct a system sender.
1137        let system_sender = BoxedMailboxSender::new(MailboxClient::new(
1138            channel::dial(server_handle.local_addr().clone()).unwrap(),
1139        ));
1140
1141        // Construct a proc forwarder in terms of the system sender.
1142        let listen_addr = ChannelAddr::any(ChannelTransport::Local);
1143        let proc_forwarder =
1144            BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
1145
1146        // Bootstrap proc 'local[0]', join the system.
1147        let world_id = id!(local);
1148        let proc = Proc::new(world_id.proc_id(0), proc_forwarder.clone());
1149        let proc_actor_0 = ProcActor::bootstrap_for_proc(
1150            proc.clone(),
1151            world_id.clone(),
1152            listen_addr,
1153            server_handle.local_addr().clone(),
1154            sup_ref.clone(),
1155            Duration::from_secs(2),
1156            HashMap::new(),
1157            ProcLifecycleMode::ManagedBySystem,
1158        )
1159        .await
1160        .unwrap();
1161
1162        // Bootstrap proc 'local[1]', join the system.
1163        let proc2 = Proc::new(world_id.proc_id(1), proc_forwarder.clone());
1164        let _proc_actor_1 = ProcActor::bootstrap_for_proc(
1165            proc2.clone(),
1166            world_id.clone(),
1167            ChannelAddr::any(ChannelTransport::Local),
1168            server_handle.local_addr().clone(),
1169            sup_ref.clone(),
1170            Duration::from_secs(2),
1171            HashMap::new(),
1172            ProcLifecycleMode::ManagedBySystem,
1173        )
1174        .await
1175        .unwrap();
1176
1177        // Test
1178        let (client, client_ref, mut client_rx) = proc
1179            .attach_actor::<ClientActor, ClientMessage>("client")
1180            .unwrap();
1181        let (worker1, worker1_ref, _) = proc
1182            .attach_actor::<WorkerActor, WorkerMessage>("worker")
1183            .unwrap();
1184        IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(
1185            worker1_ref.clone(),
1186            worker1.clone_for_py(),
1187            worker1.mailbox().clone(),
1188        )
1189        .unwrap();
1190        let (worker2, worker2_ref, _) = proc2
1191            .attach_actor::<WorkerActor, WorkerMessage>("worker")
1192            .unwrap();
1193        IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(
1194            worker2_ref.clone(),
1195            worker2.clone_for_py(),
1196            worker2.mailbox().clone(),
1197        )
1198        .unwrap();
1199
1200        let controller_handle = proc
1201            .spawn::<ControllerActor>(
1202                "controller",
1203                ControllerParams {
1204                    world_size: 2,
1205                    comm_actor_ref: proc_actor_0.comm_actor.bind(),
1206                    worker_gang_ref: GangId(
1207                        WorldId(world_id.name().to_string()),
1208                        "worker".to_string(),
1209                    )
1210                    .into(),
1211                    supervision_query_interval: Duration::from_secs(1),
1212                    worker_progress_check_interval: Duration::from_secs(3),
1213                    operation_timeout: Duration::from_secs(30),
1214                    operations_per_worker_progress_request: 100,
1215                    fail_on_worker_timeout: false,
1216                },
1217            )
1218            .await
1219            .unwrap();
1220
1221        controller_handle.attach(&client, client_ref).await.unwrap();
1222
1223        controller_handle
1224            .node(&client, 0.into(), vec![1.into(), 2.into()], vec![])
1225            .await
1226            .unwrap();
1227        controller_handle
1228            .node(&client, 1.into(), vec![3.into()], vec![1.into()])
1229            .await
1230            .unwrap();
1231        controller_handle
1232            .node(&client, 2.into(), vec![4.into()], vec![3.into()])
1233            .await
1234            .unwrap();
1235        controller_handle
1236            .node(&client, 3.into(), vec![5.into()], vec![3.into()])
1237            .await
1238            .unwrap();
1239        controller_handle
1240            .node(&client, 4.into(), vec![6.into()], vec![3.into()])
1241            .await
1242            .unwrap();
1243        controller_handle
1244            .node(&client, 5.into(), vec![7.into()], vec![4.into()])
1245            .await
1246            .unwrap();
1247        controller_handle
1248            .node(&client, 6.into(), vec![8.into()], vec![4.into()])
1249            .await
1250            .unwrap();
1251
1252        ControllerMessageClient::status(
1253            &controller_handle,
1254            &worker1,
1255            1.into(),
1256            worker1_ref.actor_id().clone(),
1257            false,
1258        )
1259        .await
1260        .unwrap();
1261        ControllerMessageClient::status(
1262            &controller_handle,
1263            &worker2,
1264            1.into(),
1265            worker2_ref.actor_id().clone(),
1266            false,
1267        )
1268        .await
1269        .unwrap();
1270        controller_handle
1271            .remote_function_failed(
1272                &worker1,
1273                2.into(),
1274                WorkerError {
1275                    backtrace: "some failure happened!".to_string(),
1276                    worker_actor_id: worker1_ref.actor_id().clone(),
1277                },
1278            )
1279            .await
1280            .unwrap();
1281        controller_handle
1282            .remote_function_failed(
1283                &worker2,
1284                2.into(),
1285                WorkerError {
1286                    backtrace: "some failure happened!".to_string(),
1287                    worker_actor_id: worker2_ref.actor_id().clone(),
1288                },
1289            )
1290            .await
1291            .unwrap();
1292        for s in 3..=7 {
1293            ControllerMessageClient::status(
1294                &controller_handle,
1295                &worker1,
1296                s.into(),
1297                worker1_ref.actor_id().clone(),
1298                false,
1299            )
1300            .await
1301            .unwrap();
1302            ControllerMessageClient::status(
1303                &controller_handle,
1304                &worker2,
1305                s.into(),
1306                worker2_ref.actor_id().clone(),
1307                false,
1308            )
1309            .await
1310            .unwrap();
1311        }
1312
1313        controller_handle.drain_and_stop().unwrap();
1314        controller_handle.await;
1315        let mut client_messages = client_rx.drain();
1316        client_messages.sort_by_key(|msg| msg.clone().into_result().unwrap().0);
1317        assert_eq!(client_messages.len(), 7);
1318        let client_message = client_messages[2].clone().into_result().unwrap();
1319        assert_eq!(client_message.0, 2.into());
1320        assert_eq!(
1321            client_message.1,
1322            Some(Err(Exception::Error(
1323                2.into(),
1324                2.into(),
1325                WorkerError {
1326                    backtrace: "some failure happened!".to_string(),
1327                    worker_actor_id: worker1_ref.actor_id().clone(),
1328                }
1329            )))
1330        );
1331
1332        assert_eq!(
1333            client_messages
1334                .into_iter()
1335                .map(|msg| msg.into_result().unwrap().0)
1336                .collect::<HashSet<Seq>>(),
1337            HashSet::from([
1338                0.into(),
1339                3.into(),
1340                1.into(),
1341                4.into(),
1342                2.into(),
1343                5.into(),
1344                6.into()
1345            ])
1346        )
1347    }
1348
1349    #[tokio::test]
1350    async fn test_eager_failure_reporting() {
1351        // Serve a system.
1352        let server_handle = System::serve(
1353            ChannelAddr::any(ChannelTransport::Local),
1354            Duration::from_secs(10),
1355            Duration::from_secs(10),
1356        )
1357        .await
1358        .unwrap();
1359        let mut system = System::new(server_handle.local_addr().clone());
1360
1361        // Build a supervisor.
1362        let sup_mail = system.attach().await.unwrap();
1363        let (sup_tx, _sup_rx) = sup_mail.open_port::<ProcSupervisionMessage>();
1364        sup_tx.bind_to(ProcSupervisionMessage::port());
1365        let sup_ref = ActorRef::<ProcSupervisor>::attest(sup_mail.self_id().clone());
1366
1367        // Construct a system sender.
1368        let system_sender = BoxedMailboxSender::new(MailboxClient::new(
1369            channel::dial(server_handle.local_addr().clone()).unwrap(),
1370        ));
1371
1372        // Construct a proc forwarder in terms of the system sender.
1373        let listen_addr = ChannelAddr::any(ChannelTransport::Local);
1374        let proc_forwarder =
1375            BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
1376
1377        // Bootstrap proc 'local[0]', join the system.
1378        let world_id = id!(local);
1379        let proc = Proc::new(world_id.proc_id(0), proc_forwarder.clone());
1380        let proc_actor_0 = ProcActor::bootstrap_for_proc(
1381            proc.clone(),
1382            world_id.clone(),
1383            listen_addr,
1384            server_handle.local_addr().clone(),
1385            sup_ref.clone(),
1386            Duration::from_secs(2),
1387            HashMap::new(),
1388            ProcLifecycleMode::ManagedBySystem,
1389        )
1390        .await
1391        .unwrap();
1392
1393        // Bootstrap proc 'local[1]', join the system.
1394        let proc2 = Proc::new(world_id.proc_id(1), proc_forwarder.clone());
1395        let _proc_actor_1 = ProcActor::bootstrap_for_proc(
1396            proc2.clone(),
1397            world_id.clone(),
1398            ChannelAddr::any(ChannelTransport::Local),
1399            server_handle.local_addr().clone(),
1400            sup_ref.clone(),
1401            Duration::from_secs(2),
1402            HashMap::new(),
1403            ProcLifecycleMode::ManagedBySystem,
1404        )
1405        .await
1406        .unwrap();
1407
1408        // Test
1409        let (client, client_ref, mut client_rx) = proc
1410            .attach_actor::<ClientActor, ClientMessage>("client")
1411            .unwrap();
1412        let (worker1, worker1_ref, _) = proc
1413            .attach_actor::<WorkerActor, WorkerMessage>("worker")
1414            .unwrap();
1415
1416        let controller_handle = proc
1417            .spawn::<ControllerActor>(
1418                "controller",
1419                ControllerParams {
1420                    world_size: 1,
1421                    comm_actor_ref: proc_actor_0.comm_actor.bind(),
1422                    worker_gang_ref: GangId(
1423                        WorldId(world_id.name().to_string()),
1424                        "worker".to_string(),
1425                    )
1426                    .into(),
1427                    supervision_query_interval: Duration::from_secs(1),
1428                    worker_progress_check_interval: Duration::from_secs(3),
1429                    operation_timeout: Duration::from_secs(30),
1430                    operations_per_worker_progress_request: 100,
1431                    fail_on_worker_timeout: false,
1432                },
1433            )
1434            .await
1435            .unwrap();
1436
1437        controller_handle.attach(&client, client_ref).await.unwrap();
1438
1439        controller_handle
1440            .node(&client, 0.into(), vec![1.into()], vec![])
1441            .await
1442            .unwrap();
1443
1444        controller_handle
1445            .node(&client, 1.into(), vec![2.into()], vec![1.into()])
1446            .await
1447            .unwrap();
1448
1449        controller_handle
1450            .node(&client, 2.into(), vec![3.into()], vec![2.into()])
1451            .await
1452            .unwrap();
1453
1454        controller_handle
1455            .node(&client, 3.into(), vec![], vec![3.into()])
1456            .await
1457            .unwrap();
1458
1459        controller_handle
1460            .node(&client, 4.into(), vec![], vec![])
1461            .await
1462            .unwrap();
1463
1464        controller_handle
1465            .remote_function_failed(
1466                &worker1,
1467                0.into(),
1468                WorkerError {
1469                    backtrace: "some failure happened!".to_string(),
1470                    worker_actor_id: worker1_ref.actor_id().clone(),
1471                },
1472            )
1473            .await
1474            .unwrap();
1475
1476        controller_handle
1477            .remote_function_failed(
1478                &worker1,
1479                3.into(),
1480                WorkerError {
1481                    backtrace: "some failure happened!".to_string(),
1482                    worker_actor_id: worker1_ref.actor_id().clone(),
1483                },
1484            )
1485            .await
1486            .unwrap();
1487
1488        ControllerMessageClient::status(
1489            &controller_handle,
1490            &worker1,
1491            5.into(),
1492            worker1_ref.actor_id().clone(),
1493            false,
1494        )
1495        .await
1496        .unwrap();
1497
1498        controller_handle.drain_and_stop().unwrap();
1499        controller_handle.await;
1500
1501        let client_messages = client_rx.drain();
1502        // no double reported messages
1503        assert_eq!(client_messages.len(), 5);
1504
1505        let (errors, successes) =
1506            client_messages
1507                .into_iter()
1508                .fold((0, 0), |(errors, successes), client_message| {
1509                    let (_, result) = client_message.clone().into_result().unwrap();
1510                    match result {
1511                        Some(Err(Exception::Error(_, _, _))) => (errors + 1, successes),
1512                        None => (errors, successes + 1),
1513                        _ => {
1514                            panic!("should only be exceptions or no result");
1515                        }
1516                    }
1517                });
1518
1519        // Assert that we have 4 error messages and 1 non-error message
1520        assert_eq!(errors, 4);
1521        assert_eq!(successes, 1);
1522    }
1523
1524    #[tokio::test]
1525    async fn test_bootstrap() {
1526        let server_handle = System::serve(
1527            ChannelAddr::any(ChannelTransport::Local),
1528            Duration::from_secs(10),
1529            Duration::from_secs(10),
1530        )
1531        .await
1532        .unwrap();
1533
1534        let controller_id = id!(controller[0].root);
1535        let proc_id = id!(world[0]);
1536        let (proc_handle, actor_ref) = ControllerActor::bootstrap(
1537            controller_id.clone(),
1538            ChannelAddr::any(ChannelTransport::Local),
1539            server_handle.local_addr().clone(),
1540            ControllerParams {
1541                world_size: 1,
1542                comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1543                worker_gang_ref: GangId(
1544                    WorldId(
1545                        proc_id
1546                            .world_name()
1547                            .expect("only ranked actors are supported in the controller tests")
1548                            .to_string(),
1549                    ),
1550                    "worker".to_string(),
1551                )
1552                .into(),
1553                supervision_query_interval: Duration::from_secs(1),
1554                worker_progress_check_interval: Duration::from_secs(3),
1555                operation_timeout: Duration::from_secs(30),
1556                operations_per_worker_progress_request: 100,
1557                fail_on_worker_timeout: false,
1558            },
1559            Duration::from_secs(1),
1560            HashMap::new(),
1561        )
1562        .await
1563        .unwrap();
1564        assert_eq!(*actor_ref.actor_id(), controller_id);
1565
1566        proc_handle.drain_and_stop().unwrap();
1567    }
1568
1569    async fn mock_proc_actor(
1570        idx: usize,
1571        rank: usize,
1572    ) -> (
1573        WorldId,
1574        ProcId,
1575        ChannelAddr,
1576        Mailbox,
1577        PortHandle<ProcMessage>,
1578        PortReceiver<ProcMessage>,
1579    ) {
1580        let world_id = id!(world);
1581        // Set up a local actor.
1582        let local_proc_id = world_id.proc_id(rank);
1583        let (local_proc_addr, local_proc_rx) =
1584            channel::serve(ChannelAddr::any(ChannelTransport::Local)).unwrap();
1585        let local_proc_mbox = Mailbox::new_detached(
1586            local_proc_id.actor_id(format!("test_dummy_proc{}", idx).to_string(), 0),
1587        );
1588        let (local_proc_message_port, local_proc_message_receiver) = local_proc_mbox.open_port();
1589        local_proc_message_port.bind();
1590
1591        let _local_proc_serve_handle = local_proc_mbox.clone().serve(local_proc_rx);
1592        (
1593            world_id,
1594            local_proc_id,
1595            local_proc_addr,
1596            local_proc_mbox,
1597            local_proc_message_port,
1598            local_proc_message_receiver,
1599        )
1600    }
1601
1602    #[tokio::test]
1603    async fn test_sim_supervision_failure() {
1604        // Start system actor.
1605        simnet::start();
1606        simnet::simnet_handle()
1607            .unwrap()
1608            .set_training_script_state(simnet::TrainingScriptState::Waiting);
1609
1610        let system_sim_addr =
1611            ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix)));
1612        // Set very long supervision_update_timeout
1613        let server_handle = System::serve(
1614            system_sim_addr.clone(),
1615            Duration::from_secs(1000),
1616            Duration::from_secs(1000),
1617        )
1618        .await
1619        .unwrap();
1620
1621        let mut system = System::new(server_handle.local_addr().clone());
1622        let client_mailbox = system.attach().await.unwrap();
1623
1624        // Bootstrap the controller
1625        let controller_id = id!(controller[0].root);
1626        let proc_id = id!(world[0]);
1627        let controller_proc_listen_addr =
1628            ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix)));
1629
1630        let (_, actor_ref) = ControllerActor::bootstrap(
1631            controller_id.clone(),
1632            controller_proc_listen_addr,
1633            system_sim_addr,
1634            ControllerParams {
1635                world_size: 1,
1636                comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1637                worker_gang_ref: GangId(
1638                    WorldId(
1639                        proc_id
1640                            .world_name()
1641                            .expect("only ranked actors are supported in the controller tests")
1642                            .to_string(),
1643                    ),
1644                    "worker".to_string(),
1645                )
1646                .into(),
1647                supervision_query_interval: Duration::from_secs(100),
1648                worker_progress_check_interval: Duration::from_secs(100),
1649                operation_timeout: Duration::from_secs(1000),
1650                operations_per_worker_progress_request: 100,
1651                fail_on_worker_timeout: false,
1652            },
1653            Duration::from_secs(100),
1654            HashMap::new(),
1655        )
1656        .await
1657        .unwrap();
1658        assert_eq!(*actor_ref.actor_id(), controller_id);
1659
1660        actor_ref
1661            .attach(
1662                &client_mailbox,
1663                ActorRef::attest(client_mailbox.self_id().clone()),
1664            )
1665            .await
1666            .unwrap();
1667
1668        let (client_supervision_tx, mut client_supervision_rx) =
1669            client_mailbox.open_port::<ClientMessage>();
1670        client_supervision_tx.bind_to(ClientMessage::port());
1671
1672        // mock a proc actor that doesn't update supervision state
1673        let (
1674            world_id,
1675            local_proc_id,
1676            local_proc_addr,
1677            _,
1678            local_proc_message_port,
1679            mut local_proc_message_receiver,
1680        ) = mock_proc_actor(0, 1).await;
1681
1682        // Join the world.
1683        server_handle
1684            .system_actor_handle()
1685            .send(SystemMessage::Join {
1686                proc_id: local_proc_id.clone(),
1687                world_id,
1688                proc_message_port: local_proc_message_port.bind(),
1689                proc_addr: local_proc_addr,
1690                labels: HashMap::new(),
1691                lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
1692            })
1693            .unwrap();
1694
1695        assert_matches!(
1696            local_proc_message_receiver.recv().await.unwrap(),
1697            ProcMessage::Joined()
1698        );
1699
1700        // expect that supervision timeout which takes 1000 real seconds is hit super quickly
1701        // due to simulated time
1702        let result = client_supervision_rx
1703            .recv()
1704            .await
1705            .unwrap()
1706            .into_result()
1707            .unwrap();
1708        assert_eq!(result.0, Seq::default());
1709        assert!(result.1.expect("result").is_err());
1710
1711        let records = simnet::simnet_handle().unwrap().close().await.unwrap();
1712        eprintln!("{}", serde_json::to_string_pretty(&records).unwrap());
1713    }
1714    #[tokio::test]
1715    async fn test_supervision_failure() {
1716        // Start system actor.
1717        let timeout: Duration = Duration::from_secs(6);
1718        let server_handle = System::serve(
1719            ChannelAddr::any(ChannelTransport::Local),
1720            timeout.clone(),
1721            timeout.clone(),
1722        )
1723        .await
1724        .unwrap();
1725
1726        // Client actor.
1727        let mut system = System::new(server_handle.local_addr().clone());
1728        let client_mailbox = system.attach().await.unwrap();
1729        let (client_supervision_tx, mut client_supervision_rx) =
1730            client_mailbox.open_port::<ClientMessage>();
1731        client_supervision_tx.bind_to(ClientMessage::port());
1732
1733        // Bootstrap the controller
1734        let controller_id = id!(controller[0].root);
1735        let proc_id = id!(world[0]);
1736        let (_, actor_ref) = ControllerActor::bootstrap(
1737            controller_id.clone(),
1738            ChannelAddr::any(ChannelTransport::Local),
1739            server_handle.local_addr().clone(),
1740            ControllerParams {
1741                world_size: 1,
1742                comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1743                worker_gang_ref: GangId(
1744                    WorldId(
1745                        proc_id
1746                            .world_name()
1747                            .expect("only ranked actors are supported in the controller tests")
1748                            .to_string(),
1749                    ),
1750                    "worker".to_string(),
1751                )
1752                .into(),
1753                supervision_query_interval: Duration::from_secs(1),
1754                worker_progress_check_interval: Duration::from_secs(3),
1755                operation_timeout: Duration::from_secs(30),
1756                operations_per_worker_progress_request: 100,
1757                fail_on_worker_timeout: false,
1758            },
1759            Duration::from_secs(1),
1760            HashMap::new(),
1761        )
1762        .await
1763        .unwrap();
1764        assert_eq!(*actor_ref.actor_id(), controller_id);
1765
1766        actor_ref
1767            .attach(
1768                &client_mailbox,
1769                ActorRef::attest(client_mailbox.self_id().clone()),
1770            )
1771            .await
1772            .unwrap();
1773
1774        // mock a proc actor that doesn't update supervision state
1775        let (
1776            world_id,
1777            local_proc_id,
1778            local_proc_addr,
1779            _,
1780            local_proc_message_port,
1781            mut local_proc_message_receiver,
1782        ) = mock_proc_actor(0, 1).await;
1783
1784        // Join the world.
1785        server_handle
1786            .system_actor_handle()
1787            .send(SystemMessage::Join {
1788                proc_id: local_proc_id.clone(),
1789                world_id,
1790                proc_message_port: local_proc_message_port.bind(),
1791                proc_addr: local_proc_addr,
1792                labels: HashMap::new(),
1793                lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
1794            })
1795            .unwrap();
1796
1797        assert_matches!(
1798            local_proc_message_receiver.recv().await.unwrap(),
1799            ProcMessage::Joined()
1800        );
1801
1802        // Wait a bit; supervision update should time out.
1803        RealClock.sleep(2 * timeout.clone()).await;
1804
1805        // Should've gotten the supervision message indicating supervision failure
1806        let result = client_supervision_rx
1807            .recv()
1808            .await
1809            .unwrap()
1810            .into_result()
1811            .unwrap();
1812        assert_eq!(result.0, Seq::default());
1813        assert!(result.1.expect("result").is_err());
1814    }
1815
1816    #[derive(
1817        Handler,
1818        HandleClient,
1819        RefClient,
1820        Named,
1821        Debug,
1822        Clone,
1823        Serialize,
1824        Deserialize,
1825        PartialEq
1826    )]
1827    enum PanickingMessage {
1828        Panic(String),
1829    }
1830
1831    #[derive(Debug, Default, Actor)]
1832    #[hyperactor::export(
1833        handlers = [
1834            PanickingMessage,
1835        ],
1836    )]
1837    struct PanickingActor;
1838
1839    #[async_trait]
1840    #[hyperactor::forward(PanickingMessage)]
1841    impl PanickingMessageHandler for PanickingActor {
1842        async fn panic(
1843            &mut self,
1844            _cx: &Context<Self>,
1845            err_msg: String,
1846        ) -> Result<(), anyhow::Error> {
1847            panic!("{}", err_msg);
1848        }
1849    }
1850
1851    hyperactor::remote!(PanickingActor);
1852
1853    #[tokio::test]
1854    async fn test_supervision_fault() {
1855        // Start system actor.
1856        let timeout: Duration = Duration::from_secs(6);
1857        let server_handle = System::serve(
1858            ChannelAddr::any(ChannelTransport::Local),
1859            timeout.clone(),
1860            timeout.clone(),
1861        )
1862        .await
1863        .unwrap();
1864
1865        // Client actor.
1866        let mut system = System::new(server_handle.local_addr().clone());
1867        let client_mailbox = system.attach().await.unwrap();
1868        let (client_supervision_tx, mut client_supervision_rx) =
1869            client_mailbox.open_port::<ClientMessage>();
1870        client_supervision_tx.bind_to(ClientMessage::port());
1871
1872        // Bootstrap the controller
1873        let controller_id = id!(controller[0].root);
1874        let proc_id = id!(world[0]);
1875        let (_, actor_ref) = ControllerActor::bootstrap(
1876            controller_id.clone(),
1877            ChannelAddr::any(ChannelTransport::Local),
1878            server_handle.local_addr().clone(),
1879            ControllerParams {
1880                world_size: 1,
1881                comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1882                worker_gang_ref: GangId(
1883                    WorldId(
1884                        proc_id
1885                            .world_name()
1886                            .expect("only ranked actors are supported in the controller tests")
1887                            .to_string(),
1888                    ),
1889                    "worker".to_string(),
1890                )
1891                .into(),
1892                supervision_query_interval: Duration::from_secs(1),
1893                worker_progress_check_interval: Duration::from_secs(3),
1894                operation_timeout: Duration::from_secs(30),
1895                operations_per_worker_progress_request: 100,
1896                fail_on_worker_timeout: false,
1897            },
1898            Duration::from_secs(1),
1899            HashMap::new(),
1900        )
1901        .await
1902        .unwrap();
1903        assert_eq!(*actor_ref.actor_id(), controller_id);
1904
1905        actor_ref
1906            .attach(
1907                &client_mailbox,
1908                ActorRef::attest(client_mailbox.self_id().clone()),
1909            )
1910            .await
1911            .unwrap();
1912
1913        // bootstreap an actor that panics
1914        let world_id = id!(world);
1915        let panic_proc_id = world_id.proc_id(1);
1916        let bootstrap = ProcActor::bootstrap(
1917            panic_proc_id,
1918            world_id,
1919            ChannelAddr::any(ChannelTransport::Local),
1920            server_handle.local_addr().clone(),
1921            Duration::from_secs(3),
1922            HashMap::new(),
1923            ProcLifecycleMode::ManagedBySystem,
1924        )
1925        .await
1926        .unwrap();
1927        let actor_handle = spawn::<PanickingActor>(
1928            &client_mailbox,
1929            &bootstrap.proc_actor.bind(),
1930            "panicker",
1931            &(),
1932        )
1933        .await
1934        .unwrap();
1935
1936        actor_handle
1937            .panic(&client_mailbox, "some random failure".to_string())
1938            .await
1939            .unwrap();
1940
1941        // Get the supervision message with the panic
1942        let result = client_supervision_rx
1943            .recv()
1944            .await
1945            .unwrap()
1946            .into_result()
1947            .unwrap();
1948        assert_eq!(result.0, Seq::default());
1949        assert!(result.1.is_some() && result.1.as_ref().unwrap().is_err());
1950        let Exception::Failure(err) = result.1.unwrap().unwrap_err() else {
1951            panic!("Expected Failure exception");
1952        };
1953        assert!(err.backtrace.contains("some random failure"));
1954    }
1955}