1#![feature(assert_matches)]
10#![allow(unsafe_op_in_unsafe_fn)]
13
14pub mod bootstrap;
15pub mod history;
16
17use std::collections::HashMap;
18use std::collections::HashSet;
19use std::time::Duration;
20
21use async_trait::async_trait;
22use hyperactor::Actor;
23use hyperactor::ActorId;
24use hyperactor::ActorRef;
25use hyperactor::Context;
26use hyperactor::GangId;
27use hyperactor::GangRef;
28use hyperactor::Handler;
29use hyperactor::Named;
30use hyperactor::actor::ActorHandle;
31use hyperactor::actor::ActorStatus;
32use hyperactor::channel::ChannelAddr;
33use hyperactor::clock::Clock;
34use hyperactor::context;
35use hyperactor::data::Serialized;
36use hyperactor_mesh::comm::CommActor;
37use hyperactor_mesh::comm::CommActorMode;
38use hyperactor_mesh::comm::multicast::CastMessage;
39use hyperactor_mesh::comm::multicast::CastMessageEnvelope;
40use hyperactor_mesh::comm::multicast::DestinationPort;
41use hyperactor_mesh::comm::multicast::Uslice;
42use hyperactor_mesh::reference::ActorMeshId;
43use hyperactor_mesh::reference::ProcMeshId;
44use hyperactor_multiprocess::proc_actor::ProcActor;
45use hyperactor_multiprocess::proc_actor::spawn;
46use hyperactor_multiprocess::supervision::WorldSupervisionMessageClient;
47use hyperactor_multiprocess::supervision::WorldSupervisor;
48use hyperactor_multiprocess::system_actor::ProcLifecycleMode;
49use hyperactor_multiprocess::system_actor::SYSTEM_ACTOR_REF;
50use monarch_messages::client::ClientActor;
51use monarch_messages::client::ClientMessageClient;
52use monarch_messages::client::Exception;
53use monarch_messages::client::LogLevel;
54use monarch_messages::controller::ControllerMessage;
55use monarch_messages::controller::ControllerMessageHandler;
56use monarch_messages::controller::DeviceFailure;
57use monarch_messages::controller::Ranks;
58use monarch_messages::controller::Seq;
59use monarch_messages::controller::WorkerError;
60use monarch_messages::debugger::DebuggerAction;
61use monarch_messages::worker::Ref;
62use monarch_messages::worker::WorkerActor;
63use monarch_messages::worker::WorkerMessage;
64use ndslice::Selection;
65use ndslice::Shape;
66use ndslice::Slice;
67use ndslice::reshape::Limit;
68use ndslice::reshape::ReshapeShapeExt;
69use ndslice::selection::dsl;
70use ndslice::shape::Range;
71use serde::Deserialize;
72use serde::Serialize;
73use tokio::sync::OnceCell;
74
75const CASTING_FANOUT_SIZE: usize = 8;
76
77#[derive(Debug)]
81#[hyperactor::export(
82 spawn = true,
83 handlers = [
84 ControllerMessage,
85 ],
86)]
87pub(crate) struct ControllerActor {
88 client_actor_ref: OnceCell<ActorRef<ClientActor>>,
89 comm_actor_ref: ActorRef<CommActor>,
90 worker_gang_ref: GangRef<WorkerActor>,
91 history: history::History,
92 supervision_query_interval: Duration,
93 system_supervision_actor_ref: ActorRef<WorldSupervisor>,
94 worker_progress_check_interval: Duration,
95 operation_timeout: Duration,
96 operations_per_worker_progress_request: u64,
97 last_controller_request_status: Option<(Seq, tokio::time::Instant)>,
99 fail_on_worker_timeout: bool,
100 world_size: usize,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize, Named)]
104pub(crate) struct ControllerParams {
105 pub(crate) world_size: usize,
107
108 pub(crate) comm_actor_ref: ActorRef<CommActor>,
113
114 pub(crate) worker_gang_ref: GangRef<WorkerActor>,
116
117 pub(crate) supervision_query_interval: Duration,
119
120 pub(crate) worker_progress_check_interval: Duration,
122
123 pub(crate) operation_timeout: Duration,
125
126 pub(crate) operations_per_worker_progress_request: u64,
128
129 pub(crate) fail_on_worker_timeout: bool,
131}
132
133#[async_trait]
134impl Actor for ControllerActor {
135 type Params = ControllerParams;
136
137 async fn new(params: ControllerParams) -> Result<Self, anyhow::Error> {
138 Ok(Self {
139 client_actor_ref: OnceCell::new(),
140 comm_actor_ref: params.comm_actor_ref,
141 worker_gang_ref: params.worker_gang_ref,
142 history: history::History::new(params.world_size),
143 supervision_query_interval: params.supervision_query_interval,
144 system_supervision_actor_ref: ActorRef::attest(SYSTEM_ACTOR_REF.actor_id().clone()),
145 worker_progress_check_interval: params.worker_progress_check_interval,
146 operation_timeout: params.operation_timeout,
147 operations_per_worker_progress_request: params.operations_per_worker_progress_request,
148 last_controller_request_status: None,
149 fail_on_worker_timeout: params.fail_on_worker_timeout,
150 world_size: params.world_size,
151 })
152 }
153
154 async fn init(&mut self, cx: &hyperactor::Instance<Self>) -> Result<(), anyhow::Error> {
155 self.comm_actor_ref.send(
156 cx,
157 CommActorMode::ImplicitWithWorldId(self.worker_gang_ref.gang_id().world_id().clone()),
158 )?;
159 Ok(())
160 }
161}
162
163impl ControllerActor {
164 pub async fn bootstrap(
169 controller_id: ActorId,
170 listen_addr: ChannelAddr,
171 bootstrap_addr: ChannelAddr,
172 params: ControllerParams,
173 supervision_update_interval: Duration,
174 labels: HashMap<String, String>,
175 ) -> Result<(ActorHandle<ProcActor>, ActorRef<ControllerActor>), anyhow::Error> {
176 let bootstrap = ProcActor::bootstrap(
177 controller_id.proc_id().clone(),
178 controller_id
179 .proc_id()
180 .world_id()
181 .expect("multiprocess supports only ranked procs")
182 .clone(), listen_addr,
184 bootstrap_addr.clone(),
185 supervision_update_interval,
186 labels,
187 ProcLifecycleMode::ManagedBySystem,
188 )
189 .await?;
190
191 let mut system = hyperactor_multiprocess::System::new(bootstrap_addr);
192 let client = system.attach().await?;
193
194 let controller_actor_ref = spawn::<ControllerActor>(
195 &client,
196 &bootstrap.proc_actor.bind(),
197 controller_id.clone().name(),
198 &ControllerParams {
199 comm_actor_ref: bootstrap.comm_actor.bind(),
200 ..params
201 },
202 )
203 .await?;
204
205 Ok((bootstrap.proc_actor, controller_actor_ref))
206 }
207
208 fn client(&self) -> Result<ActorRef<ClientActor>, anyhow::Error> {
209 self.client_actor_ref
210 .get()
211 .ok_or_else(|| anyhow::anyhow!("client actor ref not set"))
212 .cloned()
213 }
214
215 async fn request_status_if_needed(
221 &mut self,
222 cx: &Context<'_, Self>,
223 ) -> Result<(), anyhow::Error> {
224 if let Some((expected_seq, ..)) = self.history.deadline(
225 self.operations_per_worker_progress_request,
226 self.operation_timeout,
227 cx.clock(),
228 ) {
229 if self.last_controller_request_status.is_none_or(
230 |(last_requested_seq, last_requested_time)| {
231 (expected_seq
232 >= (u64::from(last_requested_seq)
233 + self.operations_per_worker_progress_request)
234 .into()
235 || last_requested_time.elapsed() > self.worker_progress_check_interval)
236 && last_requested_seq != expected_seq
237 },
238 ) {
239 self.send(
241 cx,
242 Ranks::Slice(
243 ndslice::Slice::new(0, vec![self.history.world_size()], vec![1]).unwrap(),
244 ),
245 Serialized::serialize(&WorkerMessage::RequestStatus {
246 seq: expected_seq.clone(),
247 controller: true,
248 })
249 .unwrap(),
250 )
251 .await?;
252
253 self.last_controller_request_status =
254 Some((expected_seq.clone(), cx.clock().now()));
255 }
256 }
257
258 Ok(())
259 }
260}
261
262#[derive(Debug)]
263struct CheckWorkerProgress;
264
265#[async_trait]
266impl Handler<CheckWorkerProgress> for ControllerActor {
267 async fn handle(
268 &mut self,
269 cx: &Context<Self>,
270 _check_worker_progress: CheckWorkerProgress,
271 ) -> Result<(), anyhow::Error> {
272 let client = self.client()?;
273
274 if let Some((expected_seq, deadline, reported)) = self.history.deadline(
275 self.operations_per_worker_progress_request,
276 self.operation_timeout,
277 cx.clock(),
278 ) {
279 if !reported
280 && cx.clock().now() > deadline
281 && expected_seq >= self.history.min_incomplete_seq_reported()
282 {
283 let timed_out_ranks = self
284 .history
285 .first_incomplete_seqs_controller()
286 .iter()
287 .enumerate()
288 .filter(|(_, seq)| seq <= &&expected_seq)
289 .map(|(rank, _)| rank)
290 .collect::<Vec<_>>();
291
292 let failed_rank = timed_out_ranks.first().unwrap().clone();
293
294 let timed_out_ranks_string = timed_out_ranks
295 .into_iter()
296 .map(|rank| rank.to_string())
297 .collect::<Vec<_>>()
298 .join(", ");
299
300 let message = format!(
301 "ranks {} have operations that have not completed after {} seconds",
302 timed_out_ranks_string,
303 self.operation_timeout.as_secs()
304 );
305 if client
306 .log(cx, LogLevel::Warn, message.clone())
307 .await
308 .is_ok()
309 {
310 self.history.report_deadline_missed();
311 }
312
313 if self.fail_on_worker_timeout {
314 client
315 .result(
316 cx,
317 expected_seq,
318 Some(Err(Exception::Failure(DeviceFailure {
319 actor_id: self.worker_gang_ref.rank(failed_rank).actor_id().clone(),
320 address: "unknown".into(),
321 backtrace: message,
322 }))),
323 )
324 .await?;
325 }
326 }
327 self.request_status_if_needed(cx).await?;
328 }
329
330 cx.self_message_with_delay(CheckWorkerProgress, self.worker_progress_check_interval)?;
331 Ok(())
332 }
333}
334
335fn slice_to_selection(slice: Slice) -> Selection {
337 match (slice.sizes(), slice.strides()) {
338 ([], []) => dsl::range(slice.offset()..=slice.offset(), dsl::true_()),
340 ([size, rsizes @ ..], [stride, ..]) if rsizes.iter().all(|s| *s == 1) => dsl::range(
342 Range(
343 slice.offset(),
344 Some(slice.offset() + *size * *stride),
345 *stride,
346 ),
347 dsl::true_(),
348 ),
349 _ => {
351 let mut selection = Selection::False;
352 let mut selected_ranks = HashSet::new();
353 for rank in slice.iter() {
354 if !selected_ranks.insert(rank) {
355 continue;
356 }
357 selection = dsl::union(dsl::range(rank..=rank, dsl::true_()), selection);
358 }
359 selection
360 }
361 }
362}
363
364#[async_trait]
365#[hyperactor::forward(ControllerMessage)]
366impl ControllerMessageHandler for ControllerActor {
367 async fn attach(
368 &mut self,
369 cx: &Context<Self>,
370 client_actor: ActorRef<ClientActor>,
371 ) -> Result<(), anyhow::Error> {
372 tracing::debug!("attaching client actor {}", client_actor);
373 self.client_actor_ref
374 .set(client_actor)
375 .map_err(|actor_ref| anyhow::anyhow!("client actor {} already attached", actor_ref))?;
376
377 cx.self_message_with_delay(
379 ControllerMessage::CheckSupervision {},
380 self.supervision_query_interval,
381 )?;
382 cx.self_message_with_delay(CheckWorkerProgress, self.worker_progress_check_interval)?;
383 Ok(())
384 }
385
386 async fn node(
387 &mut self,
388 cx: &Context<Self>,
389 seq: Seq,
390 defs: Vec<Ref>,
391 uses: Vec<Ref>,
392 ) -> Result<(), anyhow::Error> {
393 let failures = self.history.add_invocation(seq, uses, defs);
394 let client = self.client()?;
395 for (seq, failure) in failures {
396 let _ = client.result(cx, seq, failure).await;
397 }
398 self.request_status_if_needed(cx).await?;
399
400 Ok(())
401 }
402
403 async fn drop_refs(
404 &mut self,
405 _cx: &Context<Self>,
406 refs: Vec<Ref>,
407 ) -> Result<(), anyhow::Error> {
408 self.history.delete_invocations_for_refs(refs);
409 Ok(())
410 }
411
412 async fn send(
413 &mut self,
414 cx: &Context<Self>,
415 ranks: Ranks,
416 message: Serialized,
417 ) -> Result<(), anyhow::Error> {
418 let selection = match ranks {
419 Ranks::Slice(slice) => {
420 if slice.len() == self.world_size {
421 Selection::True
423 } else {
424 slice_to_selection(slice)
425 }
426 }
427 Ranks::SliceList(slices) => slices.into_iter().fold(dsl::false_(), |sel, slice| {
428 dsl::union(sel, slice_to_selection(slice))
429 }),
430 };
431
432 let slice = Slice::new(0usize, vec![self.world_size], vec![1])?;
433 let made_up_shape = Shape::new(vec!["fake_in_controller".to_string()], slice.clone())?
437 .reshape(Limit::from(CASTING_FANOUT_SIZE))
438 .shape;
439
440 let message = CastMessageEnvelope::from_serialized(
441 ActorMeshId::V0(
442 ProcMeshId(self.worker_gang_ref.gang_id().world_id().to_string()),
443 self.worker_gang_ref.gang_id().name().to_string(),
444 ),
445 cx.self_id().clone(),
446 DestinationPort::new::<WorkerActor, WorkerMessage>(
447 self.worker_gang_ref
449 .gang_id()
450 .actor_id(0)
451 .name()
452 .to_string(),
453 ),
454 made_up_shape,
455 message,
456 );
457
458 self.comm_actor_ref.send(
459 cx,
460 CastMessage {
461 dest: Uslice {
462 slice,
464 selection,
465 },
466 message,
467 },
468 )?;
469 Ok(())
470 }
471
472 async fn remote_function_failed(
473 &mut self,
474 cx: &Context<Self>,
475 seq: Seq,
476 error: WorkerError,
477 ) -> Result<(), anyhow::Error> {
478 let rank = error.worker_actor_id.rank();
479 self.history
480 .propagate_exception(seq, Exception::Error(seq, seq, error.clone()));
481 mark_worker_complete_and_propagate_exceptions(cx, self, rank, &seq).await?;
482 Ok(())
483 }
484
485 async fn status(
486 &mut self,
487 cx: &Context<Self>,
488 seq: Seq,
489 worker_actor_id: ActorId,
490 controller: bool,
491 ) -> Result<(), anyhow::Error> {
492 let rank = worker_actor_id.rank();
493
494 if controller {
495 self.history.update_deadline_tracking(rank, seq);
496 } else {
497 mark_worker_complete_and_propagate_exceptions(cx, self, rank, &seq).await?;
498 }
499 Ok(())
500 }
501
502 async fn fetch_result(
503 &mut self,
504 _cx: &Context<Self>,
505 seq: Seq,
506 result: Result<Serialized, WorkerError>,
507 ) -> Result<(), anyhow::Error> {
508 self.history.set_result(seq, result);
509 Ok(())
510 }
511
512 async fn check_supervision(&mut self, cx: &Context<Self>) -> Result<(), anyhow::Error> {
513 let gang_id: GangId = self.worker_gang_ref.clone().into();
514 let world_state = self
515 .system_supervision_actor_ref
516 .state(cx, gang_id.world_id().clone())
517 .await?;
518
519 if let Some(world_state) = world_state {
520 if !world_state.procs.is_empty() {
521 tracing::error!(
522 "found procs with failures in world {}, state: {:?}",
523 gang_id.world_id(),
524 world_state
525 );
526
527 let (_, failed_state) = world_state.procs.iter().next().unwrap();
529 let (failed_actor, failure_reason) =
530 failed_state.failed_actors.first().map_or_else(
531 || {
532 let proc_id = &failed_state.proc_id;
533 (
534 ActorId(proc_id.clone(), "none".into(), 0),
535 format!(
536 "proc is dead due to heartbeat timeout; no backtrace is \
537 available; check the log of host {} running proc {} to \
538 figure out the root cause",
539 failed_state.proc_addr, proc_id
540 ),
541 )
542 },
543 |(actor, status)| {
544 (
545 actor.clone(),
546 match status {
547 ActorStatus::Failed(msg) => msg.clone(),
548 _ => format!("unexpected actor status {status}"),
549 },
550 )
551 },
552 );
553
554 let exc = Exception::Failure(DeviceFailure {
555 actor_id: failed_actor,
556 address: failed_state.proc_addr.to_string(),
557 backtrace: failure_reason,
558 });
559 tracing::error!("Sending failure to client: {exc:?}");
560 self.client()?
562 .result(cx, Seq::default(), Some(Err(exc)))
563 .await?;
564 tracing::error!("Failure successfully sent to client");
565
566 }
568 }
569
570 cx.self_message_with_delay(
572 ControllerMessage::CheckSupervision {},
573 self.supervision_query_interval,
574 )?;
575 Ok(())
576 }
577
578 async fn debugger_message(
579 &mut self,
580 cx: &Context<Self>,
581 debugger_actor_id: ActorId,
582 action: DebuggerAction,
583 ) -> Result<(), anyhow::Error> {
584 self.client()?
585 .debugger_message(cx, debugger_actor_id, action)
586 .await
587 }
588
589 #[cfg(test)]
590 async fn get_first_incomplete_seqs_unit_tests_only(
591 &mut self,
592 _cx: &Context<Self>,
593 ) -> Result<Vec<Seq>, anyhow::Error> {
594 Ok(self.history.first_incomplete_seqs().to_vec())
595 }
596
597 #[cfg(not(test))]
598 async fn get_first_incomplete_seqs_unit_tests_only(
599 &mut self,
600 _cx: &Context<Self>,
601 ) -> Result<Vec<Seq>, anyhow::Error> {
602 unimplemented!("get_first_incomplete_seqs_unit_tests_only is only for unit tests")
603 }
604}
605
606async fn mark_worker_complete_and_propagate_exceptions(
607 cx: &impl context::Actor,
608 actor: &mut ControllerActor,
609 rank: usize,
610 seq: &Seq,
611) -> Result<(), anyhow::Error> {
612 let results = actor.history.rank_completed(rank, seq.clone());
613 let client = actor.client()?;
614 for (seq, result) in results.iter() {
616 let _ = client.result(cx, seq.clone(), result.clone()).await;
617 }
618 Ok(())
619}
620
621#[cfg(test)]
622mod tests {
623 use core::panic;
624 use std::assert_matches::assert_matches;
625 use std::collections::HashMap;
626 use std::collections::HashSet;
627 use std::time::Duration;
628
629 use hyperactor::HandleClient;
630 use hyperactor::Handler;
631 use hyperactor::RefClient;
632 use hyperactor::channel;
633 use hyperactor::channel::ChannelTransport;
634 use hyperactor::clock::Clock;
635 use hyperactor::clock::RealClock;
636 use hyperactor::context::Mailbox as _;
637 use hyperactor::data::Named;
638 use hyperactor::id;
639 use hyperactor::mailbox::BoxedMailboxSender;
640 use hyperactor::mailbox::DialMailboxRouter;
641 use hyperactor::mailbox::Mailbox;
642 use hyperactor::mailbox::MailboxClient;
643 use hyperactor::mailbox::MailboxServer;
644 use hyperactor::mailbox::PortHandle;
645 use hyperactor::mailbox::PortReceiver;
646 use hyperactor::message::IndexedErasedUnbound;
647 use hyperactor::proc::Proc;
648 use hyperactor::reference::GangId;
649 use hyperactor::reference::ProcId;
650 use hyperactor::reference::WorldId;
651 use hyperactor::simnet;
652 use hyperactor_mesh::comm::CommActorParams;
653 use hyperactor_multiprocess::System;
654 use hyperactor_multiprocess::proc_actor::ProcMessage;
655 use hyperactor_multiprocess::supervision::ProcSupervisionMessage;
656 use hyperactor_multiprocess::supervision::ProcSupervisor;
657 use hyperactor_multiprocess::system_actor::SystemMessage;
658 use monarch_messages::client::ClientMessage;
659 use monarch_messages::controller::ControllerMessageClient;
660 use monarch_messages::wire_value::WireValue;
661 use monarch_messages::worker::CallFunctionParams;
662 use monarch_messages::worker::WorkerMessage;
663 use monarch_types::PyTree;
664 use torch_sys::RValue;
665
666 use super::*;
667
668 #[tokio::test]
669 async fn basic_controller() {
670 let proc = Proc::local();
672 let (client, client_ref, mut client_rx) = proc
673 .attach_actor::<ClientActor, ClientMessage>("client")
674 .unwrap();
675 let (worker, worker_ref, mut worker_rx) = proc
676 .attach_actor::<WorkerActor, WorkerMessage>("worker")
677 .unwrap();
678
679 IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(
680 worker_ref.clone(),
681 worker.clone_for_py(),
682 worker.mailbox().clone(),
683 )
684 .unwrap();
685
686 let comm_handle = proc
687 .spawn::<CommActor>("comm", CommActorParams {})
688 .await
689 .unwrap();
690
691 let controller_handle = proc
692 .spawn::<ControllerActor>(
693 "controller",
694 ControllerParams {
695 world_size: 1,
696 comm_actor_ref: comm_handle.bind(),
697 worker_gang_ref: GangId(
698 WorldId(
699 proc.proc_id()
700 .world_name()
701 .expect("only ranked actors are supported in the controller tests")
702 .to_string(),
703 ),
704 "worker".to_string(),
705 )
706 .into(),
707 supervision_query_interval: Duration::from_secs(1),
708 worker_progress_check_interval: Duration::from_secs(3),
709 operation_timeout: Duration::from_secs(30),
710 operations_per_worker_progress_request: 100,
711 fail_on_worker_timeout: false,
712 },
713 )
714 .await
715 .unwrap();
716
717 controller_handle.attach(&client, client_ref).await.unwrap();
718
719 controller_handle
720 .node(&client, 0.into(), vec![0.into()], vec![])
721 .await
722 .unwrap();
723 controller_handle
724 .node(&client, 1.into(), vec![1.into(), 2.into()], vec![0.into()])
725 .await
726 .unwrap();
727 controller_handle
728 .node(&client, 20.into(), vec![3.into(), 4.into()], vec![])
729 .await
730 .unwrap();
731
732 ControllerMessageClient::send(
733 &controller_handle,
734 &worker,
735 Ranks::Slice(ndslice::Slice::new(0, vec![1], vec![1]).unwrap()),
736 Serialized::serialize(&WorkerMessage::CallFunction(CallFunctionParams {
737 seq: 1.into(),
738 results: vec![Some(1.into()), Some(2.into())],
739 mutates: vec![],
740 function: "os.path.split".into(),
741 args: vec![WireValue::String("/fbs/fbc/foo/bar".into())],
742 kwargs: HashMap::new(),
743 stream: 1.into(),
744 remote_process_groups: vec![],
745 }))
746 .unwrap(),
747 )
748 .await
749 .unwrap();
750
751 ControllerMessageClient::status(
752 &controller_handle,
753 &worker,
754 0.into(),
755 worker_ref.actor_id().clone(),
756 false,
757 )
758 .await
759 .unwrap();
760 let incomplete_seqs = controller_handle
761 .get_first_incomplete_seqs_unit_tests_only(&worker)
762 .await
763 .unwrap();
764 assert_eq!(incomplete_seqs[0], 0.into());
765
766 controller_handle
767 .remote_function_failed(
768 &worker,
769 1.into(),
770 WorkerError {
771 backtrace: "some failure happened!".to_string(),
772 worker_actor_id: worker_ref.actor_id().clone(),
773 },
774 )
775 .await
776 .unwrap();
777 ControllerMessageClient::status(
778 &controller_handle,
779 &worker,
780 2.into(),
781 worker_ref.actor_id().clone(),
782 false,
783 )
784 .await
785 .unwrap();
786
787 let incomplete_seqs = controller_handle
788 .get_first_incomplete_seqs_unit_tests_only(&worker)
789 .await
790 .unwrap();
791 assert_eq!(incomplete_seqs[0], 2.into());
792
793 controller_handle
794 .fetch_result(
795 &worker,
796 20.into(),
797 Ok(Serialized::serialize(&PyTree::from(RValue::Int(42))).unwrap()),
798 )
799 .await
800 .unwrap();
801
802 ControllerMessageClient::status(
804 &controller_handle,
805 &worker,
806 21.into(),
807 worker_ref.actor_id().clone(),
808 false,
809 )
810 .await
811 .unwrap();
812
813 let incomplete_seqs = controller_handle
814 .get_first_incomplete_seqs_unit_tests_only(&worker)
815 .await
816 .unwrap();
817 assert_eq!(incomplete_seqs[0], 21.into());
818
819 controller_handle.drain_and_stop().unwrap();
820 controller_handle.await;
821 let worker_messages: Vec<WorkerMessage> = worker_rx.drain();
822 assert_eq!(
823 worker_messages
824 .iter()
825 .filter(|msg| !matches!(msg, WorkerMessage::RequestStatus { .. }))
826 .count(),
827 1
828 );
829 let client_messages = client_rx.drain();
830 assert_eq!(client_messages.len(), 3);
831 let client_message = client_messages[1].clone().into_result().unwrap();
832 assert_eq!(client_message.0, 1.into());
833 assert_eq!(
834 client_message.1,
835 Some(Err(Exception::Error(
836 1.into(),
837 1.into(),
838 WorkerError {
839 backtrace: "some failure happened!".to_string(),
840 worker_actor_id: worker_ref.actor_id().clone(),
841 }
842 )))
843 );
844
845 let client_message = client_messages[2].clone().into_result().unwrap();
846 assert_eq!(client_message.0, 20.into());
847 assert_matches!(
848 client_message
849 .1
850 .unwrap()
851 .unwrap()
852 .deserialized::<PyTree<RValue>>()
853 .unwrap()
854 .into_leaf()
855 .unwrap(),
856 RValue::Int(42),
857 );
858 }
859
860 #[tokio::test]
861 async fn worker_timeout() {
862 tokio::time::pause();
863 let timeout_secs = 3;
864 let proc = Proc::local();
865
866 let (client, client_ref, mut client_rx) = proc
867 .attach_actor::<ClientActor, ClientMessage>("client")
868 .unwrap();
869 let (worker, worker_ref, mut worker_rx) = proc
870 .attach_actor::<WorkerActor, WorkerMessage>("worker")
871 .unwrap();
872 IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(
873 worker_ref.clone(),
874 worker.clone_for_py(),
875 worker.mailbox().clone(),
876 )
877 .unwrap();
878
879 let comm_handle = proc
880 .spawn::<CommActor>("comm", CommActorParams {})
881 .await
882 .unwrap();
883
884 let controller_handle = proc
885 .spawn::<ControllerActor>(
886 "controller",
887 ControllerParams {
888 world_size: 1,
889 comm_actor_ref: comm_handle.bind(),
890 worker_gang_ref: GangId(
891 WorldId(
892 proc.proc_id()
893 .world_name()
894 .expect("only ranked actors are supported in the controller tests")
895 .to_string(),
896 ),
897 "worker".to_string(),
898 )
899 .into(),
900 supervision_query_interval: Duration::from_secs(100000),
901 worker_progress_check_interval: Duration::from_secs(1),
902 operation_timeout: Duration::from_secs(timeout_secs),
903 operations_per_worker_progress_request: 100,
904 fail_on_worker_timeout: false,
905 },
906 )
907 .await
908 .unwrap();
909
910 controller_handle.attach(&client, client_ref).await.unwrap();
911
912 controller_handle
913 .node(&client, 0.into(), vec![0.into()], vec![])
914 .await
915 .unwrap();
916
917 match worker_rx.recv().await.unwrap().into_request_status().ok() {
919 Some((seq, controller)) if seq == 0.into() && controller => {
920 for _ in 0..timeout_secs {
923 tokio::time::advance(Duration::from_secs(1)).await;
924 }
925
926 ControllerMessageClient::status(
927 &controller_handle,
928 &worker,
929 1.into(),
930 worker_ref.actor_id().clone(),
931 true,
932 )
933 .await
934 .unwrap();
935 }
936 _ => panic!("Expected request status message for seq 0"),
937 }
938
939 let client_messages = client_rx.drain();
941 assert_eq!(client_messages.len(), 0);
942
943 controller_handle
944 .node(&client, 1.into(), vec![], vec![])
945 .await
946 .unwrap();
947
948 match worker_rx.recv().await.unwrap().into_request_status().ok() {
950 Some((seq, controller)) if seq == 1.into() && controller => {
951 for _ in 0..timeout_secs * 2 {
954 tokio::time::advance(Duration::from_secs(1)).await;
955 }
956
957 ControllerMessageClient::status(
958 &controller_handle,
959 &worker,
960 2.into(),
961 worker_ref.actor_id().clone(),
962 true,
963 )
964 .await
965 .unwrap();
966 }
967 _ => panic!("Expected request status message for seq 1"),
968 }
969
970 let client_messages = client_rx.drain();
971 assert_eq!(client_messages.len(), 1);
972
973 let (level, message) = client_messages[0].clone().into_log().unwrap();
974 assert_matches!(level, LogLevel::Warn);
975 assert_eq!(
976 message,
977 "ranks 0 have operations that have not completed after 3 seconds"
978 );
979 }
980
981 #[tokio::test]
982 async fn test_failure_on_worker_timeout() {
983 tokio::time::pause();
984 let timeout_secs = 3;
985 let proc = Proc::local();
986
987 let (client, client_ref, mut client_rx) = proc
988 .attach_actor::<ClientActor, ClientMessage>("client")
989 .unwrap();
990
991 let (worker, worker_ref, mut worker_rx) = proc
992 .attach_actor::<WorkerActor, WorkerMessage>("worker")
993 .unwrap();
994 IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(
995 worker_ref.clone(),
996 worker.clone_for_py(),
997 worker.mailbox().clone(),
998 )
999 .unwrap();
1000
1001 let comm_handle = proc
1002 .spawn::<CommActor>("comm", CommActorParams {})
1003 .await
1004 .unwrap();
1005
1006 let world_id = WorldId(
1007 proc.proc_id()
1008 .world_name()
1009 .expect("only ranked actors are supported in the controller tests")
1010 .to_string(),
1011 );
1012 let controller_handle = proc
1013 .spawn::<ControllerActor>(
1014 "controller",
1015 ControllerParams {
1016 world_size: 1,
1017 comm_actor_ref: comm_handle.bind(),
1018 worker_gang_ref: GangId(world_id, "worker".to_string()).into(),
1019 supervision_query_interval: Duration::from_secs(100000),
1020 worker_progress_check_interval: Duration::from_secs(1),
1021 operation_timeout: Duration::from_secs(timeout_secs),
1022 operations_per_worker_progress_request: 100,
1023 fail_on_worker_timeout: true,
1024 },
1025 )
1026 .await
1027 .unwrap();
1028
1029 controller_handle.attach(&client, client_ref).await.unwrap();
1030
1031 controller_handle
1032 .node(&client, 0.into(), vec![0.into()], vec![])
1033 .await
1034 .unwrap();
1035
1036 match worker_rx.recv().await.unwrap().into_request_status().ok() {
1038 Some((seq, controller)) if seq == 0.into() && controller => {
1039 for _ in 0..timeout_secs {
1042 tokio::time::advance(Duration::from_secs(1)).await;
1043 }
1044
1045 ControllerMessageClient::status(
1046 &controller_handle,
1047 &worker,
1048 1.into(),
1049 worker_ref.actor_id().clone(),
1050 true,
1051 )
1052 .await
1053 .unwrap();
1054 }
1055 _ => panic!("Expected request status message for seq 0"),
1056 }
1057
1058 let client_messages = client_rx.drain();
1060 assert_eq!(client_messages.len(), 0);
1061
1062 controller_handle
1063 .node(&client, 1.into(), vec![], vec![])
1064 .await
1065 .unwrap();
1066
1067 match worker_rx.recv().await.unwrap().into_request_status().ok() {
1069 Some((seq, controller)) if seq == 1.into() && controller => {
1070 for _ in 0..timeout_secs * 2 {
1073 tokio::time::advance(Duration::from_secs(1)).await;
1074 }
1075
1076 ControllerMessageClient::status(
1077 &controller_handle,
1078 &worker,
1079 2.into(),
1080 worker_ref.actor_id().clone(),
1081 true,
1082 )
1083 .await
1084 .unwrap();
1085 }
1086 _ => panic!("Expected request status message for seq 1"),
1087 }
1088
1089 let client_messages = client_rx.drain();
1090 assert_eq!(client_messages.len(), 2);
1091
1092 let (level, message) = client_messages[0].clone().into_log().unwrap();
1093 assert_matches!(level, LogLevel::Warn);
1094 assert_eq!(
1095 message,
1096 "ranks 0 have operations that have not completed after 3 seconds"
1097 );
1098
1099 let (seq, failure) = client_messages[1].clone().into_result().unwrap();
1100 assert_eq!(seq, 1.into());
1101 let DeviceFailure {
1102 backtrace,
1103 actor_id,
1104 ..
1105 } = failure
1106 .unwrap()
1107 .err()
1108 .unwrap()
1109 .as_failure()
1110 .unwrap()
1111 .clone();
1112 assert_eq!(actor_id, proc.proc_id().actor_id("worker", 0));
1113 assert!(
1114 backtrace.contains("ranks 0 have operations that have not completed after 3 seconds")
1115 );
1116 }
1117
1118 #[tokio::test]
1119 async fn failure_propagation() {
1120 let server_handle = System::serve(
1122 ChannelAddr::any(ChannelTransport::Local),
1123 Duration::from_secs(10),
1124 Duration::from_secs(10),
1125 )
1126 .await
1127 .unwrap();
1128 let mut system = System::new(server_handle.local_addr().clone());
1129
1130 let sup_mail = system.attach().await.unwrap();
1132 let (sup_tx, _sup_rx) = sup_mail.open_port::<ProcSupervisionMessage>();
1133 sup_tx.bind_to(ProcSupervisionMessage::port());
1134 let sup_ref = ActorRef::<ProcSupervisor>::attest(sup_mail.self_id().clone());
1135
1136 let system_sender = BoxedMailboxSender::new(MailboxClient::new(
1138 channel::dial(server_handle.local_addr().clone()).unwrap(),
1139 ));
1140
1141 let listen_addr = ChannelAddr::any(ChannelTransport::Local);
1143 let proc_forwarder =
1144 BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
1145
1146 let world_id = id!(local);
1148 let proc = Proc::new(world_id.proc_id(0), proc_forwarder.clone());
1149 let proc_actor_0 = ProcActor::bootstrap_for_proc(
1150 proc.clone(),
1151 world_id.clone(),
1152 listen_addr,
1153 server_handle.local_addr().clone(),
1154 sup_ref.clone(),
1155 Duration::from_secs(2),
1156 HashMap::new(),
1157 ProcLifecycleMode::ManagedBySystem,
1158 )
1159 .await
1160 .unwrap();
1161
1162 let proc2 = Proc::new(world_id.proc_id(1), proc_forwarder.clone());
1164 let _proc_actor_1 = ProcActor::bootstrap_for_proc(
1165 proc2.clone(),
1166 world_id.clone(),
1167 ChannelAddr::any(ChannelTransport::Local),
1168 server_handle.local_addr().clone(),
1169 sup_ref.clone(),
1170 Duration::from_secs(2),
1171 HashMap::new(),
1172 ProcLifecycleMode::ManagedBySystem,
1173 )
1174 .await
1175 .unwrap();
1176
1177 let (client, client_ref, mut client_rx) = proc
1179 .attach_actor::<ClientActor, ClientMessage>("client")
1180 .unwrap();
1181 let (worker1, worker1_ref, _) = proc
1182 .attach_actor::<WorkerActor, WorkerMessage>("worker")
1183 .unwrap();
1184 IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(
1185 worker1_ref.clone(),
1186 worker1.clone_for_py(),
1187 worker1.mailbox().clone(),
1188 )
1189 .unwrap();
1190 let (worker2, worker2_ref, _) = proc2
1191 .attach_actor::<WorkerActor, WorkerMessage>("worker")
1192 .unwrap();
1193 IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(
1194 worker2_ref.clone(),
1195 worker2.clone_for_py(),
1196 worker2.mailbox().clone(),
1197 )
1198 .unwrap();
1199
1200 let controller_handle = proc
1201 .spawn::<ControllerActor>(
1202 "controller",
1203 ControllerParams {
1204 world_size: 2,
1205 comm_actor_ref: proc_actor_0.comm_actor.bind(),
1206 worker_gang_ref: GangId(
1207 WorldId(world_id.name().to_string()),
1208 "worker".to_string(),
1209 )
1210 .into(),
1211 supervision_query_interval: Duration::from_secs(1),
1212 worker_progress_check_interval: Duration::from_secs(3),
1213 operation_timeout: Duration::from_secs(30),
1214 operations_per_worker_progress_request: 100,
1215 fail_on_worker_timeout: false,
1216 },
1217 )
1218 .await
1219 .unwrap();
1220
1221 controller_handle.attach(&client, client_ref).await.unwrap();
1222
1223 controller_handle
1224 .node(&client, 0.into(), vec![1.into(), 2.into()], vec![])
1225 .await
1226 .unwrap();
1227 controller_handle
1228 .node(&client, 1.into(), vec![3.into()], vec![1.into()])
1229 .await
1230 .unwrap();
1231 controller_handle
1232 .node(&client, 2.into(), vec![4.into()], vec![3.into()])
1233 .await
1234 .unwrap();
1235 controller_handle
1236 .node(&client, 3.into(), vec![5.into()], vec![3.into()])
1237 .await
1238 .unwrap();
1239 controller_handle
1240 .node(&client, 4.into(), vec![6.into()], vec![3.into()])
1241 .await
1242 .unwrap();
1243 controller_handle
1244 .node(&client, 5.into(), vec![7.into()], vec![4.into()])
1245 .await
1246 .unwrap();
1247 controller_handle
1248 .node(&client, 6.into(), vec![8.into()], vec![4.into()])
1249 .await
1250 .unwrap();
1251
1252 ControllerMessageClient::status(
1253 &controller_handle,
1254 &worker1,
1255 1.into(),
1256 worker1_ref.actor_id().clone(),
1257 false,
1258 )
1259 .await
1260 .unwrap();
1261 ControllerMessageClient::status(
1262 &controller_handle,
1263 &worker2,
1264 1.into(),
1265 worker2_ref.actor_id().clone(),
1266 false,
1267 )
1268 .await
1269 .unwrap();
1270 controller_handle
1271 .remote_function_failed(
1272 &worker1,
1273 2.into(),
1274 WorkerError {
1275 backtrace: "some failure happened!".to_string(),
1276 worker_actor_id: worker1_ref.actor_id().clone(),
1277 },
1278 )
1279 .await
1280 .unwrap();
1281 controller_handle
1282 .remote_function_failed(
1283 &worker2,
1284 2.into(),
1285 WorkerError {
1286 backtrace: "some failure happened!".to_string(),
1287 worker_actor_id: worker2_ref.actor_id().clone(),
1288 },
1289 )
1290 .await
1291 .unwrap();
1292 for s in 3..=7 {
1293 ControllerMessageClient::status(
1294 &controller_handle,
1295 &worker1,
1296 s.into(),
1297 worker1_ref.actor_id().clone(),
1298 false,
1299 )
1300 .await
1301 .unwrap();
1302 ControllerMessageClient::status(
1303 &controller_handle,
1304 &worker2,
1305 s.into(),
1306 worker2_ref.actor_id().clone(),
1307 false,
1308 )
1309 .await
1310 .unwrap();
1311 }
1312
1313 controller_handle.drain_and_stop().unwrap();
1314 controller_handle.await;
1315 let mut client_messages = client_rx.drain();
1316 client_messages.sort_by_key(|msg| msg.clone().into_result().unwrap().0);
1317 assert_eq!(client_messages.len(), 7);
1318 let client_message = client_messages[2].clone().into_result().unwrap();
1319 assert_eq!(client_message.0, 2.into());
1320 assert_eq!(
1321 client_message.1,
1322 Some(Err(Exception::Error(
1323 2.into(),
1324 2.into(),
1325 WorkerError {
1326 backtrace: "some failure happened!".to_string(),
1327 worker_actor_id: worker1_ref.actor_id().clone(),
1328 }
1329 )))
1330 );
1331
1332 assert_eq!(
1333 client_messages
1334 .into_iter()
1335 .map(|msg| msg.into_result().unwrap().0)
1336 .collect::<HashSet<Seq>>(),
1337 HashSet::from([
1338 0.into(),
1339 3.into(),
1340 1.into(),
1341 4.into(),
1342 2.into(),
1343 5.into(),
1344 6.into()
1345 ])
1346 )
1347 }
1348
1349 #[tokio::test]
1350 async fn test_eager_failure_reporting() {
1351 let server_handle = System::serve(
1353 ChannelAddr::any(ChannelTransport::Local),
1354 Duration::from_secs(10),
1355 Duration::from_secs(10),
1356 )
1357 .await
1358 .unwrap();
1359 let mut system = System::new(server_handle.local_addr().clone());
1360
1361 let sup_mail = system.attach().await.unwrap();
1363 let (sup_tx, _sup_rx) = sup_mail.open_port::<ProcSupervisionMessage>();
1364 sup_tx.bind_to(ProcSupervisionMessage::port());
1365 let sup_ref = ActorRef::<ProcSupervisor>::attest(sup_mail.self_id().clone());
1366
1367 let system_sender = BoxedMailboxSender::new(MailboxClient::new(
1369 channel::dial(server_handle.local_addr().clone()).unwrap(),
1370 ));
1371
1372 let listen_addr = ChannelAddr::any(ChannelTransport::Local);
1374 let proc_forwarder =
1375 BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
1376
1377 let world_id = id!(local);
1379 let proc = Proc::new(world_id.proc_id(0), proc_forwarder.clone());
1380 let proc_actor_0 = ProcActor::bootstrap_for_proc(
1381 proc.clone(),
1382 world_id.clone(),
1383 listen_addr,
1384 server_handle.local_addr().clone(),
1385 sup_ref.clone(),
1386 Duration::from_secs(2),
1387 HashMap::new(),
1388 ProcLifecycleMode::ManagedBySystem,
1389 )
1390 .await
1391 .unwrap();
1392
1393 let proc2 = Proc::new(world_id.proc_id(1), proc_forwarder.clone());
1395 let _proc_actor_1 = ProcActor::bootstrap_for_proc(
1396 proc2.clone(),
1397 world_id.clone(),
1398 ChannelAddr::any(ChannelTransport::Local),
1399 server_handle.local_addr().clone(),
1400 sup_ref.clone(),
1401 Duration::from_secs(2),
1402 HashMap::new(),
1403 ProcLifecycleMode::ManagedBySystem,
1404 )
1405 .await
1406 .unwrap();
1407
1408 let (client, client_ref, mut client_rx) = proc
1410 .attach_actor::<ClientActor, ClientMessage>("client")
1411 .unwrap();
1412 let (worker1, worker1_ref, _) = proc
1413 .attach_actor::<WorkerActor, WorkerMessage>("worker")
1414 .unwrap();
1415
1416 let controller_handle = proc
1417 .spawn::<ControllerActor>(
1418 "controller",
1419 ControllerParams {
1420 world_size: 1,
1421 comm_actor_ref: proc_actor_0.comm_actor.bind(),
1422 worker_gang_ref: GangId(
1423 WorldId(world_id.name().to_string()),
1424 "worker".to_string(),
1425 )
1426 .into(),
1427 supervision_query_interval: Duration::from_secs(1),
1428 worker_progress_check_interval: Duration::from_secs(3),
1429 operation_timeout: Duration::from_secs(30),
1430 operations_per_worker_progress_request: 100,
1431 fail_on_worker_timeout: false,
1432 },
1433 )
1434 .await
1435 .unwrap();
1436
1437 controller_handle.attach(&client, client_ref).await.unwrap();
1438
1439 controller_handle
1440 .node(&client, 0.into(), vec![1.into()], vec![])
1441 .await
1442 .unwrap();
1443
1444 controller_handle
1445 .node(&client, 1.into(), vec![2.into()], vec![1.into()])
1446 .await
1447 .unwrap();
1448
1449 controller_handle
1450 .node(&client, 2.into(), vec![3.into()], vec![2.into()])
1451 .await
1452 .unwrap();
1453
1454 controller_handle
1455 .node(&client, 3.into(), vec![], vec![3.into()])
1456 .await
1457 .unwrap();
1458
1459 controller_handle
1460 .node(&client, 4.into(), vec![], vec![])
1461 .await
1462 .unwrap();
1463
1464 controller_handle
1465 .remote_function_failed(
1466 &worker1,
1467 0.into(),
1468 WorkerError {
1469 backtrace: "some failure happened!".to_string(),
1470 worker_actor_id: worker1_ref.actor_id().clone(),
1471 },
1472 )
1473 .await
1474 .unwrap();
1475
1476 controller_handle
1477 .remote_function_failed(
1478 &worker1,
1479 3.into(),
1480 WorkerError {
1481 backtrace: "some failure happened!".to_string(),
1482 worker_actor_id: worker1_ref.actor_id().clone(),
1483 },
1484 )
1485 .await
1486 .unwrap();
1487
1488 ControllerMessageClient::status(
1489 &controller_handle,
1490 &worker1,
1491 5.into(),
1492 worker1_ref.actor_id().clone(),
1493 false,
1494 )
1495 .await
1496 .unwrap();
1497
1498 controller_handle.drain_and_stop().unwrap();
1499 controller_handle.await;
1500
1501 let client_messages = client_rx.drain();
1502 assert_eq!(client_messages.len(), 5);
1504
1505 let (errors, successes) =
1506 client_messages
1507 .into_iter()
1508 .fold((0, 0), |(errors, successes), client_message| {
1509 let (_, result) = client_message.clone().into_result().unwrap();
1510 match result {
1511 Some(Err(Exception::Error(_, _, _))) => (errors + 1, successes),
1512 None => (errors, successes + 1),
1513 _ => {
1514 panic!("should only be exceptions or no result");
1515 }
1516 }
1517 });
1518
1519 assert_eq!(errors, 4);
1521 assert_eq!(successes, 1);
1522 }
1523
1524 #[tokio::test]
1525 async fn test_bootstrap() {
1526 let server_handle = System::serve(
1527 ChannelAddr::any(ChannelTransport::Local),
1528 Duration::from_secs(10),
1529 Duration::from_secs(10),
1530 )
1531 .await
1532 .unwrap();
1533
1534 let controller_id = id!(controller[0].root);
1535 let proc_id = id!(world[0]);
1536 let (proc_handle, actor_ref) = ControllerActor::bootstrap(
1537 controller_id.clone(),
1538 ChannelAddr::any(ChannelTransport::Local),
1539 server_handle.local_addr().clone(),
1540 ControllerParams {
1541 world_size: 1,
1542 comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1543 worker_gang_ref: GangId(
1544 WorldId(
1545 proc_id
1546 .world_name()
1547 .expect("only ranked actors are supported in the controller tests")
1548 .to_string(),
1549 ),
1550 "worker".to_string(),
1551 )
1552 .into(),
1553 supervision_query_interval: Duration::from_secs(1),
1554 worker_progress_check_interval: Duration::from_secs(3),
1555 operation_timeout: Duration::from_secs(30),
1556 operations_per_worker_progress_request: 100,
1557 fail_on_worker_timeout: false,
1558 },
1559 Duration::from_secs(1),
1560 HashMap::new(),
1561 )
1562 .await
1563 .unwrap();
1564 assert_eq!(*actor_ref.actor_id(), controller_id);
1565
1566 proc_handle.drain_and_stop().unwrap();
1567 }
1568
1569 async fn mock_proc_actor(
1570 idx: usize,
1571 rank: usize,
1572 ) -> (
1573 WorldId,
1574 ProcId,
1575 ChannelAddr,
1576 Mailbox,
1577 PortHandle<ProcMessage>,
1578 PortReceiver<ProcMessage>,
1579 ) {
1580 let world_id = id!(world);
1581 let local_proc_id = world_id.proc_id(rank);
1583 let (local_proc_addr, local_proc_rx) =
1584 channel::serve(ChannelAddr::any(ChannelTransport::Local)).unwrap();
1585 let local_proc_mbox = Mailbox::new_detached(
1586 local_proc_id.actor_id(format!("test_dummy_proc{}", idx).to_string(), 0),
1587 );
1588 let (local_proc_message_port, local_proc_message_receiver) = local_proc_mbox.open_port();
1589 local_proc_message_port.bind();
1590
1591 let _local_proc_serve_handle = local_proc_mbox.clone().serve(local_proc_rx);
1592 (
1593 world_id,
1594 local_proc_id,
1595 local_proc_addr,
1596 local_proc_mbox,
1597 local_proc_message_port,
1598 local_proc_message_receiver,
1599 )
1600 }
1601
1602 #[tokio::test]
1603 async fn test_sim_supervision_failure() {
1604 simnet::start();
1606 simnet::simnet_handle()
1607 .unwrap()
1608 .set_training_script_state(simnet::TrainingScriptState::Waiting);
1609
1610 let system_sim_addr =
1611 ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix)));
1612 let server_handle = System::serve(
1614 system_sim_addr.clone(),
1615 Duration::from_secs(1000),
1616 Duration::from_secs(1000),
1617 )
1618 .await
1619 .unwrap();
1620
1621 let mut system = System::new(server_handle.local_addr().clone());
1622 let client_mailbox = system.attach().await.unwrap();
1623
1624 let controller_id = id!(controller[0].root);
1626 let proc_id = id!(world[0]);
1627 let controller_proc_listen_addr =
1628 ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix)));
1629
1630 let (_, actor_ref) = ControllerActor::bootstrap(
1631 controller_id.clone(),
1632 controller_proc_listen_addr,
1633 system_sim_addr,
1634 ControllerParams {
1635 world_size: 1,
1636 comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1637 worker_gang_ref: GangId(
1638 WorldId(
1639 proc_id
1640 .world_name()
1641 .expect("only ranked actors are supported in the controller tests")
1642 .to_string(),
1643 ),
1644 "worker".to_string(),
1645 )
1646 .into(),
1647 supervision_query_interval: Duration::from_secs(100),
1648 worker_progress_check_interval: Duration::from_secs(100),
1649 operation_timeout: Duration::from_secs(1000),
1650 operations_per_worker_progress_request: 100,
1651 fail_on_worker_timeout: false,
1652 },
1653 Duration::from_secs(100),
1654 HashMap::new(),
1655 )
1656 .await
1657 .unwrap();
1658 assert_eq!(*actor_ref.actor_id(), controller_id);
1659
1660 actor_ref
1661 .attach(
1662 &client_mailbox,
1663 ActorRef::attest(client_mailbox.self_id().clone()),
1664 )
1665 .await
1666 .unwrap();
1667
1668 let (client_supervision_tx, mut client_supervision_rx) =
1669 client_mailbox.open_port::<ClientMessage>();
1670 client_supervision_tx.bind_to(ClientMessage::port());
1671
1672 let (
1674 world_id,
1675 local_proc_id,
1676 local_proc_addr,
1677 _,
1678 local_proc_message_port,
1679 mut local_proc_message_receiver,
1680 ) = mock_proc_actor(0, 1).await;
1681
1682 server_handle
1684 .system_actor_handle()
1685 .send(SystemMessage::Join {
1686 proc_id: local_proc_id.clone(),
1687 world_id,
1688 proc_message_port: local_proc_message_port.bind(),
1689 proc_addr: local_proc_addr,
1690 labels: HashMap::new(),
1691 lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
1692 })
1693 .unwrap();
1694
1695 assert_matches!(
1696 local_proc_message_receiver.recv().await.unwrap(),
1697 ProcMessage::Joined()
1698 );
1699
1700 let result = client_supervision_rx
1703 .recv()
1704 .await
1705 .unwrap()
1706 .into_result()
1707 .unwrap();
1708 assert_eq!(result.0, Seq::default());
1709 assert!(result.1.expect("result").is_err());
1710
1711 let records = simnet::simnet_handle().unwrap().close().await.unwrap();
1712 eprintln!("{}", serde_json::to_string_pretty(&records).unwrap());
1713 }
1714 #[tokio::test]
1715 async fn test_supervision_failure() {
1716 let timeout: Duration = Duration::from_secs(6);
1718 let server_handle = System::serve(
1719 ChannelAddr::any(ChannelTransport::Local),
1720 timeout.clone(),
1721 timeout.clone(),
1722 )
1723 .await
1724 .unwrap();
1725
1726 let mut system = System::new(server_handle.local_addr().clone());
1728 let client_mailbox = system.attach().await.unwrap();
1729 let (client_supervision_tx, mut client_supervision_rx) =
1730 client_mailbox.open_port::<ClientMessage>();
1731 client_supervision_tx.bind_to(ClientMessage::port());
1732
1733 let controller_id = id!(controller[0].root);
1735 let proc_id = id!(world[0]);
1736 let (_, actor_ref) = ControllerActor::bootstrap(
1737 controller_id.clone(),
1738 ChannelAddr::any(ChannelTransport::Local),
1739 server_handle.local_addr().clone(),
1740 ControllerParams {
1741 world_size: 1,
1742 comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1743 worker_gang_ref: GangId(
1744 WorldId(
1745 proc_id
1746 .world_name()
1747 .expect("only ranked actors are supported in the controller tests")
1748 .to_string(),
1749 ),
1750 "worker".to_string(),
1751 )
1752 .into(),
1753 supervision_query_interval: Duration::from_secs(1),
1754 worker_progress_check_interval: Duration::from_secs(3),
1755 operation_timeout: Duration::from_secs(30),
1756 operations_per_worker_progress_request: 100,
1757 fail_on_worker_timeout: false,
1758 },
1759 Duration::from_secs(1),
1760 HashMap::new(),
1761 )
1762 .await
1763 .unwrap();
1764 assert_eq!(*actor_ref.actor_id(), controller_id);
1765
1766 actor_ref
1767 .attach(
1768 &client_mailbox,
1769 ActorRef::attest(client_mailbox.self_id().clone()),
1770 )
1771 .await
1772 .unwrap();
1773
1774 let (
1776 world_id,
1777 local_proc_id,
1778 local_proc_addr,
1779 _,
1780 local_proc_message_port,
1781 mut local_proc_message_receiver,
1782 ) = mock_proc_actor(0, 1).await;
1783
1784 server_handle
1786 .system_actor_handle()
1787 .send(SystemMessage::Join {
1788 proc_id: local_proc_id.clone(),
1789 world_id,
1790 proc_message_port: local_proc_message_port.bind(),
1791 proc_addr: local_proc_addr,
1792 labels: HashMap::new(),
1793 lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
1794 })
1795 .unwrap();
1796
1797 assert_matches!(
1798 local_proc_message_receiver.recv().await.unwrap(),
1799 ProcMessage::Joined()
1800 );
1801
1802 RealClock.sleep(2 * timeout.clone()).await;
1804
1805 let result = client_supervision_rx
1807 .recv()
1808 .await
1809 .unwrap()
1810 .into_result()
1811 .unwrap();
1812 assert_eq!(result.0, Seq::default());
1813 assert!(result.1.expect("result").is_err());
1814 }
1815
1816 #[derive(
1817 Handler,
1818 HandleClient,
1819 RefClient,
1820 Named,
1821 Debug,
1822 Clone,
1823 Serialize,
1824 Deserialize,
1825 PartialEq
1826 )]
1827 enum PanickingMessage {
1828 Panic(String),
1829 }
1830
1831 #[derive(Debug, Default, Actor)]
1832 #[hyperactor::export(
1833 handlers = [
1834 PanickingMessage,
1835 ],
1836 )]
1837 struct PanickingActor;
1838
1839 #[async_trait]
1840 #[hyperactor::forward(PanickingMessage)]
1841 impl PanickingMessageHandler for PanickingActor {
1842 async fn panic(
1843 &mut self,
1844 _cx: &Context<Self>,
1845 err_msg: String,
1846 ) -> Result<(), anyhow::Error> {
1847 panic!("{}", err_msg);
1848 }
1849 }
1850
1851 hyperactor::remote!(PanickingActor);
1852
1853 #[tokio::test]
1854 async fn test_supervision_fault() {
1855 let timeout: Duration = Duration::from_secs(6);
1857 let server_handle = System::serve(
1858 ChannelAddr::any(ChannelTransport::Local),
1859 timeout.clone(),
1860 timeout.clone(),
1861 )
1862 .await
1863 .unwrap();
1864
1865 let mut system = System::new(server_handle.local_addr().clone());
1867 let client_mailbox = system.attach().await.unwrap();
1868 let (client_supervision_tx, mut client_supervision_rx) =
1869 client_mailbox.open_port::<ClientMessage>();
1870 client_supervision_tx.bind_to(ClientMessage::port());
1871
1872 let controller_id = id!(controller[0].root);
1874 let proc_id = id!(world[0]);
1875 let (_, actor_ref) = ControllerActor::bootstrap(
1876 controller_id.clone(),
1877 ChannelAddr::any(ChannelTransport::Local),
1878 server_handle.local_addr().clone(),
1879 ControllerParams {
1880 world_size: 1,
1881 comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1882 worker_gang_ref: GangId(
1883 WorldId(
1884 proc_id
1885 .world_name()
1886 .expect("only ranked actors are supported in the controller tests")
1887 .to_string(),
1888 ),
1889 "worker".to_string(),
1890 )
1891 .into(),
1892 supervision_query_interval: Duration::from_secs(1),
1893 worker_progress_check_interval: Duration::from_secs(3),
1894 operation_timeout: Duration::from_secs(30),
1895 operations_per_worker_progress_request: 100,
1896 fail_on_worker_timeout: false,
1897 },
1898 Duration::from_secs(1),
1899 HashMap::new(),
1900 )
1901 .await
1902 .unwrap();
1903 assert_eq!(*actor_ref.actor_id(), controller_id);
1904
1905 actor_ref
1906 .attach(
1907 &client_mailbox,
1908 ActorRef::attest(client_mailbox.self_id().clone()),
1909 )
1910 .await
1911 .unwrap();
1912
1913 let world_id = id!(world);
1915 let panic_proc_id = world_id.proc_id(1);
1916 let bootstrap = ProcActor::bootstrap(
1917 panic_proc_id,
1918 world_id,
1919 ChannelAddr::any(ChannelTransport::Local),
1920 server_handle.local_addr().clone(),
1921 Duration::from_secs(3),
1922 HashMap::new(),
1923 ProcLifecycleMode::ManagedBySystem,
1924 )
1925 .await
1926 .unwrap();
1927 let actor_handle = spawn::<PanickingActor>(
1928 &client_mailbox,
1929 &bootstrap.proc_actor.bind(),
1930 "panicker",
1931 &(),
1932 )
1933 .await
1934 .unwrap();
1935
1936 actor_handle
1937 .panic(&client_mailbox, "some random failure".to_string())
1938 .await
1939 .unwrap();
1940
1941 let result = client_supervision_rx
1943 .recv()
1944 .await
1945 .unwrap()
1946 .into_result()
1947 .unwrap();
1948 assert_eq!(result.0, Seq::default());
1949 assert!(result.1.is_some() && result.1.as_ref().unwrap().is_err());
1950 let Exception::Failure(err) = result.1.unwrap().unwrap_err() else {
1951 panic!("Expected Failure exception");
1952 };
1953 assert!(err.backtrace.contains("some random failure"));
1954 }
1955}