1#![feature(assert_matches)]
10#![allow(unsafe_op_in_unsafe_fn)]
13
14pub mod bootstrap;
15pub mod history;
16
17use std::collections::HashMap;
18use std::collections::HashSet;
19use std::time::Duration;
20
21use async_trait::async_trait;
22use hyperactor::Actor;
23use hyperactor::ActorId;
24use hyperactor::ActorRef;
25use hyperactor::Context;
26use hyperactor::GangId;
27use hyperactor::GangRef;
28use hyperactor::Handler;
29use hyperactor::Named;
30use hyperactor::actor::ActorHandle;
31use hyperactor::actor::ActorStatus;
32use hyperactor::cap;
33use hyperactor::channel::ChannelAddr;
34use hyperactor::clock::Clock;
35use hyperactor::data::Serialized;
36use hyperactor_mesh::comm::CommActor;
37use hyperactor_mesh::comm::CommActorMode;
38use hyperactor_mesh::comm::multicast::CastMessage;
39use hyperactor_mesh::comm::multicast::CastMessageEnvelope;
40use hyperactor_mesh::comm::multicast::DestinationPort;
41use hyperactor_mesh::comm::multicast::Uslice;
42use hyperactor_mesh::reference::ActorMeshId;
43use hyperactor_mesh::reference::ProcMeshId;
44use hyperactor_multiprocess::proc_actor::ProcActor;
45use hyperactor_multiprocess::proc_actor::spawn;
46use hyperactor_multiprocess::supervision::WorldSupervisionMessageClient;
47use hyperactor_multiprocess::supervision::WorldSupervisor;
48use hyperactor_multiprocess::system_actor::ProcLifecycleMode;
49use hyperactor_multiprocess::system_actor::SYSTEM_ACTOR_REF;
50use monarch_messages::client::ClientActor;
51use monarch_messages::client::ClientMessageClient;
52use monarch_messages::client::Exception;
53use monarch_messages::client::LogLevel;
54use monarch_messages::controller::ControllerMessage;
55use monarch_messages::controller::ControllerMessageHandler;
56use monarch_messages::controller::DeviceFailure;
57use monarch_messages::controller::Ranks;
58use monarch_messages::controller::Seq;
59use monarch_messages::controller::WorkerError;
60use monarch_messages::debugger::DebuggerAction;
61use monarch_messages::worker::Ref;
62use monarch_messages::worker::WorkerActor;
63use monarch_messages::worker::WorkerMessage;
64use ndslice::Selection;
65use ndslice::Shape;
66use ndslice::Slice;
67use ndslice::reshape::Limit;
68use ndslice::reshape::ReshapeShapeExt;
69use ndslice::selection::dsl;
70use ndslice::shape::Range;
71use serde::Deserialize;
72use serde::Serialize;
73use tokio::sync::OnceCell;
74
75const CASTING_FANOUT_SIZE: usize = 8;
76
77#[derive(Debug)]
81#[hyperactor::export(
82 spawn = true,
83 handlers = [
84 ControllerMessage,
85 ],
86)]
87pub(crate) struct ControllerActor {
88 client_actor_ref: OnceCell<ActorRef<ClientActor>>,
89 comm_actor_ref: ActorRef<CommActor>,
90 worker_gang_ref: GangRef<WorkerActor>,
91 history: history::History,
92 supervision_query_interval: Duration,
93 system_supervision_actor_ref: ActorRef<WorldSupervisor>,
94 worker_progress_check_interval: Duration,
95 operation_timeout: Duration,
96 operations_per_worker_progress_request: u64,
97 last_controller_request_status: Option<(Seq, tokio::time::Instant)>,
99 fail_on_worker_timeout: bool,
100 world_size: usize,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize, Named)]
104pub(crate) struct ControllerParams {
105 pub(crate) world_size: usize,
107
108 pub(crate) comm_actor_ref: ActorRef<CommActor>,
113
114 pub(crate) worker_gang_ref: GangRef<WorkerActor>,
116
117 pub(crate) supervision_query_interval: Duration,
119
120 pub(crate) worker_progress_check_interval: Duration,
122
123 pub(crate) operation_timeout: Duration,
125
126 pub(crate) operations_per_worker_progress_request: u64,
128
129 pub(crate) fail_on_worker_timeout: bool,
131}
132
133#[async_trait]
134impl Actor for ControllerActor {
135 type Params = ControllerParams;
136
137 async fn new(params: ControllerParams) -> Result<Self, anyhow::Error> {
138 Ok(Self {
139 client_actor_ref: OnceCell::new(),
140 comm_actor_ref: params.comm_actor_ref,
141 worker_gang_ref: params.worker_gang_ref,
142 history: history::History::new(params.world_size),
143 supervision_query_interval: params.supervision_query_interval,
144 system_supervision_actor_ref: ActorRef::attest(SYSTEM_ACTOR_REF.actor_id().clone()),
145 worker_progress_check_interval: params.worker_progress_check_interval,
146 operation_timeout: params.operation_timeout,
147 operations_per_worker_progress_request: params.operations_per_worker_progress_request,
148 last_controller_request_status: None,
149 fail_on_worker_timeout: params.fail_on_worker_timeout,
150 world_size: params.world_size,
151 })
152 }
153
154 async fn init(&mut self, cx: &hyperactor::Instance<Self>) -> Result<(), anyhow::Error> {
155 self.comm_actor_ref.send(
156 cx,
157 CommActorMode::ImplicitWithWorldId(self.worker_gang_ref.gang_id().world_id().clone()),
158 )?;
159 Ok(())
160 }
161}
162
163impl ControllerActor {
164 pub async fn bootstrap(
169 controller_id: ActorId,
170 listen_addr: ChannelAddr,
171 bootstrap_addr: ChannelAddr,
172 params: ControllerParams,
173 supervision_update_interval: Duration,
174 labels: HashMap<String, String>,
175 ) -> Result<(ActorHandle<ProcActor>, ActorRef<ControllerActor>), anyhow::Error> {
176 let bootstrap = ProcActor::bootstrap(
177 controller_id.proc_id().clone(),
178 controller_id
179 .proc_id()
180 .world_id()
181 .expect("multiprocess supports only ranked procs")
182 .clone(), listen_addr,
184 bootstrap_addr.clone(),
185 supervision_update_interval,
186 labels,
187 ProcLifecycleMode::ManagedBySystem,
188 )
189 .await?;
190
191 let mut system = hyperactor_multiprocess::System::new(bootstrap_addr);
192 let client = system.attach().await?;
193
194 let controller_actor_ref = spawn::<ControllerActor>(
195 &client,
196 &bootstrap.proc_actor.bind(),
197 controller_id.clone().name(),
198 &ControllerParams {
199 comm_actor_ref: bootstrap.comm_actor.bind(),
200 ..params
201 },
202 )
203 .await?;
204
205 Ok((bootstrap.proc_actor, controller_actor_ref))
206 }
207
208 fn client(&self) -> Result<ActorRef<ClientActor>, anyhow::Error> {
209 self.client_actor_ref
210 .get()
211 .ok_or_else(|| anyhow::anyhow!("client actor ref not set"))
212 .cloned()
213 }
214
215 async fn request_status_if_needed(
221 &mut self,
222 cx: &Context<'_, Self>,
223 ) -> Result<(), anyhow::Error> {
224 if let Some((expected_seq, ..)) = self.history.deadline(
225 self.operations_per_worker_progress_request,
226 self.operation_timeout,
227 cx.clock(),
228 ) {
229 if self.last_controller_request_status.is_none_or(
230 |(last_requested_seq, last_requested_time)| {
231 (expected_seq
232 >= (u64::from(last_requested_seq)
233 + self.operations_per_worker_progress_request)
234 .into()
235 || last_requested_time.elapsed() > self.worker_progress_check_interval)
236 && last_requested_seq != expected_seq
237 },
238 ) {
239 self.send(
241 cx,
242 Ranks::Slice(
243 ndslice::Slice::new(0, vec![self.history.world_size()], vec![1]).unwrap(),
244 ),
245 Serialized::serialize(&WorkerMessage::RequestStatus {
246 seq: expected_seq.clone(),
247 controller: true,
248 })
249 .unwrap(),
250 )
251 .await?;
252
253 self.last_controller_request_status =
254 Some((expected_seq.clone(), cx.clock().now()));
255 }
256 }
257
258 Ok(())
259 }
260}
261
262#[derive(Debug)]
263struct CheckWorkerProgress;
264
265#[async_trait]
266impl Handler<CheckWorkerProgress> for ControllerActor {
267 async fn handle(
268 &mut self,
269 cx: &Context<Self>,
270 _check_worker_progress: CheckWorkerProgress,
271 ) -> Result<(), anyhow::Error> {
272 let client = self.client()?;
273
274 if let Some((expected_seq, deadline, reported)) = self.history.deadline(
275 self.operations_per_worker_progress_request,
276 self.operation_timeout,
277 cx.clock(),
278 ) {
279 if !reported
280 && cx.clock().now() > deadline
281 && expected_seq >= self.history.min_incomplete_seq_reported()
282 {
283 let timed_out_ranks = self
284 .history
285 .first_incomplete_seqs_controller()
286 .iter()
287 .enumerate()
288 .filter(|(_, seq)| seq <= &&expected_seq)
289 .map(|(rank, _)| rank)
290 .collect::<Vec<_>>();
291
292 let failed_rank = timed_out_ranks.first().unwrap().clone();
293
294 let timed_out_ranks_string = timed_out_ranks
295 .into_iter()
296 .map(|rank| rank.to_string())
297 .collect::<Vec<_>>()
298 .join(", ");
299
300 let message = format!(
301 "ranks {} have operations that have not completed after {} seconds",
302 timed_out_ranks_string,
303 self.operation_timeout.as_secs()
304 );
305 if client
306 .log(cx, LogLevel::Warn, message.clone())
307 .await
308 .is_ok()
309 {
310 self.history.report_deadline_missed();
311 }
312
313 if self.fail_on_worker_timeout {
314 client
315 .result(
316 cx,
317 expected_seq,
318 Some(Err(Exception::Failure(DeviceFailure {
319 actor_id: self.worker_gang_ref.rank(failed_rank).actor_id().clone(),
320 address: "unknown".into(),
321 backtrace: message,
322 }))),
323 )
324 .await?;
325 }
326 }
327 self.request_status_if_needed(cx).await?;
328 }
329
330 cx.self_message_with_delay(CheckWorkerProgress, self.worker_progress_check_interval)?;
331 Ok(())
332 }
333}
334
335fn slice_to_selection(slice: Slice) -> Selection {
337 match (slice.sizes(), slice.strides()) {
338 ([], []) => dsl::range(slice.offset()..=slice.offset(), dsl::true_()),
340 ([size, rsizes @ ..], [stride, ..]) if rsizes.iter().all(|s| *s == 1) => dsl::range(
342 Range(
343 slice.offset(),
344 Some(slice.offset() + *size * *stride),
345 *stride,
346 ),
347 dsl::true_(),
348 ),
349 _ => {
351 let mut selection = Selection::False;
352 let mut selected_ranks = HashSet::new();
353 for rank in slice.iter() {
354 if !selected_ranks.insert(rank) {
355 continue;
356 }
357 selection = dsl::union(dsl::range(rank..=rank, dsl::true_()), selection);
358 }
359 selection
360 }
361 }
362}
363
364#[async_trait]
365#[hyperactor::forward(ControllerMessage)]
366impl ControllerMessageHandler for ControllerActor {
367 async fn attach(
368 &mut self,
369 cx: &Context<Self>,
370 client_actor: ActorRef<ClientActor>,
371 ) -> Result<(), anyhow::Error> {
372 tracing::debug!("attaching client actor {}", client_actor);
373 self.client_actor_ref
374 .set(client_actor)
375 .map_err(|actor_ref| anyhow::anyhow!("client actor {} already attached", actor_ref))?;
376
377 cx.self_message_with_delay(
379 ControllerMessage::CheckSupervision {},
380 self.supervision_query_interval,
381 )?;
382 cx.self_message_with_delay(CheckWorkerProgress, self.worker_progress_check_interval)?;
383 Ok(())
384 }
385
386 async fn node(
387 &mut self,
388 cx: &Context<Self>,
389 seq: Seq,
390 defs: Vec<Ref>,
391 uses: Vec<Ref>,
392 ) -> Result<(), anyhow::Error> {
393 let failures = self.history.add_invocation(seq, uses, defs);
394 let client = self.client()?;
395 for (seq, failure) in failures {
396 let _ = client.result(cx, seq, failure).await;
397 }
398 self.request_status_if_needed(cx).await?;
399
400 Ok(())
401 }
402
403 async fn drop_refs(
404 &mut self,
405 _cx: &Context<Self>,
406 refs: Vec<Ref>,
407 ) -> Result<(), anyhow::Error> {
408 self.history.delete_invocations_for_refs(refs);
409 Ok(())
410 }
411
412 async fn send(
413 &mut self,
414 cx: &Context<Self>,
415 ranks: Ranks,
416 message: Serialized,
417 ) -> Result<(), anyhow::Error> {
418 let selection = match ranks {
419 Ranks::Slice(slice) => {
420 if slice.len() == self.world_size {
421 Selection::True
423 } else {
424 slice_to_selection(slice)
425 }
426 }
427 Ranks::SliceList(slices) => slices.into_iter().fold(dsl::false_(), |sel, slice| {
428 dsl::union(sel, slice_to_selection(slice))
429 }),
430 };
431
432 let slice = Slice::new(0usize, vec![self.world_size], vec![1])?;
433 let made_up_shape = Shape::new(vec!["fake_in_controller".to_string()], slice.clone())?
437 .reshape(Limit::from(CASTING_FANOUT_SIZE))
438 .shape;
439
440 let message = CastMessageEnvelope::from_serialized(
441 ActorMeshId(
442 ProcMeshId(self.worker_gang_ref.gang_id().world_id().to_string()),
443 self.worker_gang_ref.gang_id().name().to_string(),
444 ),
445 cx.self_id().clone(),
446 DestinationPort::new::<WorkerActor, WorkerMessage>(
447 self.worker_gang_ref
449 .gang_id()
450 .actor_id(0)
451 .name()
452 .to_string(),
453 ),
454 made_up_shape,
455 message,
456 );
457
458 self.comm_actor_ref.send(
459 cx,
460 CastMessage {
461 dest: Uslice {
462 slice,
464 selection,
465 },
466 message,
467 },
468 )?;
469 Ok(())
470 }
471
472 async fn remote_function_failed(
473 &mut self,
474 cx: &Context<Self>,
475 seq: Seq,
476 error: WorkerError,
477 ) -> Result<(), anyhow::Error> {
478 let rank = error.worker_actor_id.rank();
479 self.history
480 .propagate_exception(seq, Exception::Error(seq, seq, error.clone()));
481 mark_worker_complete_and_propagate_exceptions(self, cx, rank, &seq).await?;
482 Ok(())
483 }
484
485 async fn status(
486 &mut self,
487 cx: &Context<Self>,
488 seq: Seq,
489 worker_actor_id: ActorId,
490 controller: bool,
491 ) -> Result<(), anyhow::Error> {
492 let rank = worker_actor_id.rank();
493
494 if controller {
495 self.history.update_deadline_tracking(rank, seq);
496 } else {
497 mark_worker_complete_and_propagate_exceptions(self, cx, rank, &seq).await?;
498 }
499 Ok(())
500 }
501
502 async fn fetch_result(
503 &mut self,
504 _cx: &Context<Self>,
505 seq: Seq,
506 result: Result<Serialized, WorkerError>,
507 ) -> Result<(), anyhow::Error> {
508 self.history.set_result(seq, result);
509 Ok(())
510 }
511
512 async fn check_supervision(&mut self, cx: &Context<Self>) -> Result<(), anyhow::Error> {
513 let gang_id: GangId = self.worker_gang_ref.clone().into();
514 let world_state = self
515 .system_supervision_actor_ref
516 .state(cx, gang_id.world_id().clone())
517 .await?;
518
519 if let Some(world_state) = world_state {
520 if !world_state.procs.is_empty() {
521 tracing::error!(
522 "found procs with failures in world {}, state: {:?}",
523 gang_id.world_id(),
524 world_state
525 );
526
527 let (_, failed_state) = world_state.procs.iter().next().unwrap();
529 let (failed_actor, failure_reason) =
530 failed_state.failed_actors.first().map_or_else(
531 || {
532 let proc_id = &failed_state.proc_id;
533 (
534 ActorId(proc_id.clone(), "none".into(), 0),
535 format!(
536 "proc is dead due to heartbeat timeout; no backtrace is \
537 available; check the log of host {} running proc {} to \
538 figure out the root cause",
539 failed_state.proc_addr, proc_id
540 ),
541 )
542 },
543 |(actor, status)| {
544 (
545 actor.clone(),
546 match status {
547 ActorStatus::Failed(msg) => msg.clone(),
548 _ => format!("unexpected actor status {status}"),
549 },
550 )
551 },
552 );
553
554 let exc = Exception::Failure(DeviceFailure {
555 actor_id: failed_actor,
556 address: failed_state.proc_addr.to_string(),
557 backtrace: failure_reason,
558 });
559 tracing::error!("Sending failure to client: {exc:?}");
560 self.client()?
562 .result(cx, Seq::default(), Some(Err(exc)))
563 .await?;
564 tracing::error!("Failure successfully sent to client");
565
566 }
568 }
569
570 cx.self_message_with_delay(
572 ControllerMessage::CheckSupervision {},
573 self.supervision_query_interval,
574 )?;
575 Ok(())
576 }
577
578 async fn debugger_message(
579 &mut self,
580 cx: &Context<Self>,
581 debugger_actor_id: ActorId,
582 action: DebuggerAction,
583 ) -> Result<(), anyhow::Error> {
584 self.client()?
585 .debugger_message(cx, debugger_actor_id, action)
586 .await
587 }
588
589 #[cfg(test)]
590 async fn get_first_incomplete_seqs_unit_tests_only(
591 &mut self,
592 _cx: &Context<Self>,
593 ) -> Result<Vec<Seq>, anyhow::Error> {
594 Ok(self.history.first_incomplete_seqs().to_vec())
595 }
596
597 #[cfg(not(test))]
598 async fn get_first_incomplete_seqs_unit_tests_only(
599 &mut self,
600 _cx: &Context<Self>,
601 ) -> Result<Vec<Seq>, anyhow::Error> {
602 unimplemented!("get_first_incomplete_seqs_unit_tests_only is only for unit tests")
603 }
604}
605
606async fn mark_worker_complete_and_propagate_exceptions(
607 actor: &mut ControllerActor,
608 instance: &(impl cap::CanSend + cap::CanOpenPort),
609 rank: usize,
610 seq: &Seq,
611) -> Result<(), anyhow::Error> {
612 let results = actor.history.rank_completed(rank, seq.clone());
613 let client = actor.client()?;
614 for (seq, result) in results.iter() {
616 let _ = client.result(instance, seq.clone(), result.clone()).await;
617 }
618 Ok(())
619}
620
621#[cfg(test)]
622mod tests {
623 use core::panic;
624 use std::assert_matches::assert_matches;
625 use std::collections::HashMap;
626 use std::collections::HashSet;
627 use std::time::Duration;
628
629 use hyperactor::HandleClient;
630 use hyperactor::Handler;
631 use hyperactor::RefClient;
632 use hyperactor::channel;
633 use hyperactor::channel::ChannelTransport;
634 use hyperactor::clock::Clock;
635 use hyperactor::clock::RealClock;
636 use hyperactor::data::Named;
637 use hyperactor::id;
638 use hyperactor::mailbox::BoxedMailboxSender;
639 use hyperactor::mailbox::DialMailboxRouter;
640 use hyperactor::mailbox::Mailbox;
641 use hyperactor::mailbox::MailboxClient;
642 use hyperactor::mailbox::MailboxServer;
643 use hyperactor::mailbox::PortHandle;
644 use hyperactor::mailbox::PortReceiver;
645 use hyperactor::message::IndexedErasedUnbound;
646 use hyperactor::proc::Proc;
647 use hyperactor::reference::GangId;
648 use hyperactor::reference::ProcId;
649 use hyperactor::reference::WorldId;
650 use hyperactor::simnet;
651 use hyperactor_mesh::comm::CommActorParams;
652 use hyperactor_multiprocess::System;
653 use hyperactor_multiprocess::proc_actor::ProcMessage;
654 use hyperactor_multiprocess::supervision::ProcSupervisionMessage;
655 use hyperactor_multiprocess::supervision::ProcSupervisor;
656 use hyperactor_multiprocess::system_actor::SystemMessage;
657 use monarch_messages::client::ClientMessage;
658 use monarch_messages::controller::ControllerMessageClient;
659 use monarch_messages::wire_value::WireValue;
660 use monarch_messages::worker::CallFunctionParams;
661 use monarch_messages::worker::WorkerMessage;
662 use monarch_types::PyTree;
663 use torch_sys::RValue;
664
665 use super::*;
666
667 #[tokio::test]
668 async fn basic_controller() {
669 let proc = Proc::local();
671 let (client, client_ref, mut client_rx) = proc
672 .attach_actor::<ClientActor, ClientMessage>("client")
673 .unwrap();
674 let (worker, worker_ref, mut worker_rx) = proc
675 .attach_actor::<WorkerActor, WorkerMessage>("worker")
676 .unwrap();
677 IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(worker_ref.clone(), &worker)
678 .unwrap();
679
680 let comm_handle = proc
681 .spawn::<CommActor>("comm", CommActorParams {})
682 .await
683 .unwrap();
684
685 let controller_handle = proc
686 .spawn::<ControllerActor>(
687 "controller",
688 ControllerParams {
689 world_size: 1,
690 comm_actor_ref: comm_handle.bind(),
691 worker_gang_ref: GangId(
692 WorldId(
693 proc.proc_id()
694 .world_name()
695 .expect("only ranked actors are supported in the controller tests")
696 .to_string(),
697 ),
698 "worker".to_string(),
699 )
700 .into(),
701 supervision_query_interval: Duration::from_secs(1),
702 worker_progress_check_interval: Duration::from_secs(3),
703 operation_timeout: Duration::from_secs(30),
704 operations_per_worker_progress_request: 100,
705 fail_on_worker_timeout: false,
706 },
707 )
708 .await
709 .unwrap();
710
711 controller_handle.attach(&client, client_ref).await.unwrap();
712
713 controller_handle
714 .node(&client, 0.into(), vec![0.into()], vec![])
715 .await
716 .unwrap();
717 controller_handle
718 .node(&client, 1.into(), vec![1.into(), 2.into()], vec![0.into()])
719 .await
720 .unwrap();
721 controller_handle
722 .node(&client, 20.into(), vec![3.into(), 4.into()], vec![])
723 .await
724 .unwrap();
725
726 ControllerMessageClient::send(
727 &controller_handle,
728 &worker,
729 Ranks::Slice(ndslice::Slice::new(0, vec![1], vec![1]).unwrap()),
730 Serialized::serialize(&WorkerMessage::CallFunction(CallFunctionParams {
731 seq: 1.into(),
732 results: vec![Some(1.into()), Some(2.into())],
733 mutates: vec![],
734 function: "os.path.split".into(),
735 args: vec![WireValue::String("/fbs/fbc/foo/bar".into())],
736 kwargs: HashMap::new(),
737 stream: 1.into(),
738 remote_process_groups: vec![],
739 }))
740 .unwrap(),
741 )
742 .await
743 .unwrap();
744
745 ControllerMessageClient::status(
746 &controller_handle,
747 &worker,
748 0.into(),
749 worker_ref.actor_id().clone(),
750 false,
751 )
752 .await
753 .unwrap();
754 let incomplete_seqs = controller_handle
755 .get_first_incomplete_seqs_unit_tests_only(&worker)
756 .await
757 .unwrap();
758 assert_eq!(incomplete_seqs[0], 0.into());
759
760 controller_handle
761 .remote_function_failed(
762 &worker,
763 1.into(),
764 WorkerError {
765 backtrace: "some failure happened!".to_string(),
766 worker_actor_id: worker_ref.actor_id().clone(),
767 },
768 )
769 .await
770 .unwrap();
771 ControllerMessageClient::status(
772 &controller_handle,
773 &worker,
774 2.into(),
775 worker_ref.actor_id().clone(),
776 false,
777 )
778 .await
779 .unwrap();
780
781 let incomplete_seqs = controller_handle
782 .get_first_incomplete_seqs_unit_tests_only(&worker)
783 .await
784 .unwrap();
785 assert_eq!(incomplete_seqs[0], 2.into());
786
787 controller_handle
788 .fetch_result(
789 &worker,
790 20.into(),
791 Ok(Serialized::serialize_anon(&PyTree::from(RValue::Int(42))).unwrap()),
792 )
793 .await
794 .unwrap();
795
796 ControllerMessageClient::status(
798 &controller_handle,
799 &worker,
800 21.into(),
801 worker_ref.actor_id().clone(),
802 false,
803 )
804 .await
805 .unwrap();
806
807 let incomplete_seqs = controller_handle
808 .get_first_incomplete_seqs_unit_tests_only(&worker)
809 .await
810 .unwrap();
811 assert_eq!(incomplete_seqs[0], 21.into());
812
813 controller_handle.drain_and_stop().unwrap();
814 controller_handle.await;
815 let worker_messages: Vec<WorkerMessage> = worker_rx.drain();
816 assert_eq!(
817 worker_messages
818 .iter()
819 .filter(|msg| !matches!(msg, WorkerMessage::RequestStatus { .. }))
820 .count(),
821 1
822 );
823 let client_messages = client_rx.drain();
824 assert_eq!(client_messages.len(), 3);
825 let client_message = client_messages[1].clone().into_result().unwrap();
826 assert_eq!(client_message.0, 1.into());
827 assert_eq!(
828 client_message.1,
829 Some(Err(Exception::Error(
830 1.into(),
831 1.into(),
832 WorkerError {
833 backtrace: "some failure happened!".to_string(),
834 worker_actor_id: worker_ref.actor_id().clone(),
835 }
836 )))
837 );
838
839 let client_message = client_messages[2].clone().into_result().unwrap();
840 assert_eq!(client_message.0, 20.into());
841 assert_matches!(
842 client_message
843 .1
844 .unwrap()
845 .unwrap()
846 .deserialized::<PyTree<RValue>>()
847 .unwrap()
848 .into_leaf()
849 .unwrap(),
850 RValue::Int(42),
851 );
852 }
853
854 #[tokio::test]
855 async fn worker_timeout() {
856 tokio::time::pause();
857 let timeout_secs = 3;
858 let proc = Proc::local();
859
860 let (client, client_ref, mut client_rx) = proc
861 .attach_actor::<ClientActor, ClientMessage>("client")
862 .unwrap();
863 let (worker, worker_ref, mut worker_rx) = proc
864 .attach_actor::<WorkerActor, WorkerMessage>("worker")
865 .unwrap();
866 IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(worker_ref.clone(), &worker)
867 .unwrap();
868
869 let comm_handle = proc
870 .spawn::<CommActor>("comm", CommActorParams {})
871 .await
872 .unwrap();
873
874 let controller_handle = proc
875 .spawn::<ControllerActor>(
876 "controller",
877 ControllerParams {
878 world_size: 1,
879 comm_actor_ref: comm_handle.bind(),
880 worker_gang_ref: GangId(
881 WorldId(
882 proc.proc_id()
883 .world_name()
884 .expect("only ranked actors are supported in the controller tests")
885 .to_string(),
886 ),
887 "worker".to_string(),
888 )
889 .into(),
890 supervision_query_interval: Duration::from_secs(100000),
891 worker_progress_check_interval: Duration::from_secs(1),
892 operation_timeout: Duration::from_secs(timeout_secs),
893 operations_per_worker_progress_request: 100,
894 fail_on_worker_timeout: false,
895 },
896 )
897 .await
898 .unwrap();
899
900 controller_handle.attach(&client, client_ref).await.unwrap();
901
902 controller_handle
903 .node(&client, 0.into(), vec![0.into()], vec![])
904 .await
905 .unwrap();
906
907 match worker_rx.recv().await.unwrap().into_request_status().ok() {
909 Some((seq, controller)) if seq == 0.into() && controller => {
910 for _ in 0..timeout_secs {
913 tokio::time::advance(Duration::from_secs(1)).await;
914 }
915
916 ControllerMessageClient::status(
917 &controller_handle,
918 &worker,
919 1.into(),
920 worker_ref.actor_id().clone(),
921 true,
922 )
923 .await
924 .unwrap();
925 }
926 _ => panic!("Expected request status message for seq 0"),
927 }
928
929 let client_messages = client_rx.drain();
931 assert_eq!(client_messages.len(), 0);
932
933 controller_handle
934 .node(&client, 1.into(), vec![], vec![])
935 .await
936 .unwrap();
937
938 match worker_rx.recv().await.unwrap().into_request_status().ok() {
940 Some((seq, controller)) if seq == 1.into() && controller => {
941 for _ in 0..timeout_secs * 2 {
944 tokio::time::advance(Duration::from_secs(1)).await;
945 }
946
947 ControllerMessageClient::status(
948 &controller_handle,
949 &worker,
950 2.into(),
951 worker_ref.actor_id().clone(),
952 true,
953 )
954 .await
955 .unwrap();
956 }
957 _ => panic!("Expected request status message for seq 1"),
958 }
959
960 let client_messages = client_rx.drain();
961 assert_eq!(client_messages.len(), 1);
962
963 let (level, message) = client_messages[0].clone().into_log().unwrap();
964 assert_matches!(level, LogLevel::Warn);
965 assert_eq!(
966 message,
967 "ranks 0 have operations that have not completed after 3 seconds"
968 );
969 }
970
971 #[tokio::test]
972 async fn test_failure_on_worker_timeout() {
973 tokio::time::pause();
974 let timeout_secs = 3;
975 let proc = Proc::local();
976
977 let (client, client_ref, mut client_rx) = proc
978 .attach_actor::<ClientActor, ClientMessage>("client")
979 .unwrap();
980
981 let (worker, worker_ref, mut worker_rx) = proc
982 .attach_actor::<WorkerActor, WorkerMessage>("worker")
983 .unwrap();
984 IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(worker_ref.clone(), &worker)
985 .unwrap();
986
987 let comm_handle = proc
988 .spawn::<CommActor>("comm", CommActorParams {})
989 .await
990 .unwrap();
991
992 let world_id = WorldId(
993 proc.proc_id()
994 .world_name()
995 .expect("only ranked actors are supported in the controller tests")
996 .to_string(),
997 );
998 let controller_handle = proc
999 .spawn::<ControllerActor>(
1000 "controller",
1001 ControllerParams {
1002 world_size: 1,
1003 comm_actor_ref: comm_handle.bind(),
1004 worker_gang_ref: GangId(world_id, "worker".to_string()).into(),
1005 supervision_query_interval: Duration::from_secs(100000),
1006 worker_progress_check_interval: Duration::from_secs(1),
1007 operation_timeout: Duration::from_secs(timeout_secs),
1008 operations_per_worker_progress_request: 100,
1009 fail_on_worker_timeout: true,
1010 },
1011 )
1012 .await
1013 .unwrap();
1014
1015 controller_handle.attach(&client, client_ref).await.unwrap();
1016
1017 controller_handle
1018 .node(&client, 0.into(), vec![0.into()], vec![])
1019 .await
1020 .unwrap();
1021
1022 match worker_rx.recv().await.unwrap().into_request_status().ok() {
1024 Some((seq, controller)) if seq == 0.into() && controller => {
1025 for _ in 0..timeout_secs {
1028 tokio::time::advance(Duration::from_secs(1)).await;
1029 }
1030
1031 ControllerMessageClient::status(
1032 &controller_handle,
1033 &worker,
1034 1.into(),
1035 worker_ref.actor_id().clone(),
1036 true,
1037 )
1038 .await
1039 .unwrap();
1040 }
1041 _ => panic!("Expected request status message for seq 0"),
1042 }
1043
1044 let client_messages = client_rx.drain();
1046 assert_eq!(client_messages.len(), 0);
1047
1048 controller_handle
1049 .node(&client, 1.into(), vec![], vec![])
1050 .await
1051 .unwrap();
1052
1053 match worker_rx.recv().await.unwrap().into_request_status().ok() {
1055 Some((seq, controller)) if seq == 1.into() && controller => {
1056 for _ in 0..timeout_secs * 2 {
1059 tokio::time::advance(Duration::from_secs(1)).await;
1060 }
1061
1062 ControllerMessageClient::status(
1063 &controller_handle,
1064 &worker,
1065 2.into(),
1066 worker_ref.actor_id().clone(),
1067 true,
1068 )
1069 .await
1070 .unwrap();
1071 }
1072 _ => panic!("Expected request status message for seq 1"),
1073 }
1074
1075 let client_messages = client_rx.drain();
1076 assert_eq!(client_messages.len(), 2);
1077
1078 let (level, message) = client_messages[0].clone().into_log().unwrap();
1079 assert_matches!(level, LogLevel::Warn);
1080 assert_eq!(
1081 message,
1082 "ranks 0 have operations that have not completed after 3 seconds"
1083 );
1084
1085 let (seq, failure) = client_messages[1].clone().into_result().unwrap();
1086 assert_eq!(seq, 1.into());
1087 let DeviceFailure {
1088 backtrace,
1089 actor_id,
1090 ..
1091 } = failure
1092 .unwrap()
1093 .err()
1094 .unwrap()
1095 .as_failure()
1096 .unwrap()
1097 .clone();
1098 assert_eq!(actor_id, proc.proc_id().actor_id("worker", 0));
1099 assert!(
1100 backtrace.contains("ranks 0 have operations that have not completed after 3 seconds")
1101 );
1102 }
1103
1104 #[tokio::test]
1105 async fn failure_propagation() {
1106 let server_handle = System::serve(
1108 ChannelAddr::any(ChannelTransport::Local),
1109 Duration::from_secs(10),
1110 Duration::from_secs(10),
1111 )
1112 .await
1113 .unwrap();
1114 let mut system = System::new(server_handle.local_addr().clone());
1115
1116 let sup_mail = system.attach().await.unwrap();
1118 let (sup_tx, _sup_rx) = sup_mail.open_port::<ProcSupervisionMessage>();
1119 sup_tx.bind_to(ProcSupervisionMessage::port());
1120 let sup_ref = ActorRef::<ProcSupervisor>::attest(sup_mail.actor_id().clone());
1121
1122 let system_sender = BoxedMailboxSender::new(MailboxClient::new(
1124 channel::dial(server_handle.local_addr().clone()).unwrap(),
1125 ));
1126
1127 let listen_addr = ChannelAddr::any(ChannelTransport::Local);
1129 let proc_forwarder =
1130 BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
1131
1132 let world_id = id!(local);
1134 let proc = Proc::new(world_id.proc_id(0), proc_forwarder.clone());
1135 let proc_actor_0 = ProcActor::bootstrap_for_proc(
1136 proc.clone(),
1137 world_id.clone(),
1138 listen_addr,
1139 server_handle.local_addr().clone(),
1140 sup_ref.clone(),
1141 Duration::from_secs(2),
1142 HashMap::new(),
1143 ProcLifecycleMode::ManagedBySystem,
1144 )
1145 .await
1146 .unwrap();
1147
1148 let proc2 = Proc::new(world_id.proc_id(1), proc_forwarder.clone());
1150 let _proc_actor_1 = ProcActor::bootstrap_for_proc(
1151 proc2.clone(),
1152 world_id.clone(),
1153 ChannelAddr::any(ChannelTransport::Local),
1154 server_handle.local_addr().clone(),
1155 sup_ref.clone(),
1156 Duration::from_secs(2),
1157 HashMap::new(),
1158 ProcLifecycleMode::ManagedBySystem,
1159 )
1160 .await
1161 .unwrap();
1162
1163 let (client, client_ref, mut client_rx) = proc
1165 .attach_actor::<ClientActor, ClientMessage>("client")
1166 .unwrap();
1167 let (worker1, worker1_ref, _) = proc
1168 .attach_actor::<WorkerActor, WorkerMessage>("worker")
1169 .unwrap();
1170 IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(worker1_ref.clone(), &worker1)
1171 .unwrap();
1172 let (worker2, worker2_ref, _) = proc2
1173 .attach_actor::<WorkerActor, WorkerMessage>("worker")
1174 .unwrap();
1175 IndexedErasedUnbound::<WorkerMessage>::bind_for_test_only(worker2_ref.clone(), &worker2)
1176 .unwrap();
1177
1178 let controller_handle = proc
1179 .spawn::<ControllerActor>(
1180 "controller",
1181 ControllerParams {
1182 world_size: 2,
1183 comm_actor_ref: proc_actor_0.comm_actor.bind(),
1184 worker_gang_ref: GangId(
1185 WorldId(world_id.name().to_string()),
1186 "worker".to_string(),
1187 )
1188 .into(),
1189 supervision_query_interval: Duration::from_secs(1),
1190 worker_progress_check_interval: Duration::from_secs(3),
1191 operation_timeout: Duration::from_secs(30),
1192 operations_per_worker_progress_request: 100,
1193 fail_on_worker_timeout: false,
1194 },
1195 )
1196 .await
1197 .unwrap();
1198
1199 controller_handle.attach(&client, client_ref).await.unwrap();
1200
1201 controller_handle
1202 .node(&client, 0.into(), vec![1.into(), 2.into()], vec![])
1203 .await
1204 .unwrap();
1205 controller_handle
1206 .node(&client, 1.into(), vec![3.into()], vec![1.into()])
1207 .await
1208 .unwrap();
1209 controller_handle
1210 .node(&client, 2.into(), vec![4.into()], vec![3.into()])
1211 .await
1212 .unwrap();
1213 controller_handle
1214 .node(&client, 3.into(), vec![5.into()], vec![3.into()])
1215 .await
1216 .unwrap();
1217 controller_handle
1218 .node(&client, 4.into(), vec![6.into()], vec![3.into()])
1219 .await
1220 .unwrap();
1221 controller_handle
1222 .node(&client, 5.into(), vec![7.into()], vec![4.into()])
1223 .await
1224 .unwrap();
1225 controller_handle
1226 .node(&client, 6.into(), vec![8.into()], vec![4.into()])
1227 .await
1228 .unwrap();
1229
1230 ControllerMessageClient::status(
1231 &controller_handle,
1232 &worker1,
1233 1.into(),
1234 worker1_ref.actor_id().clone(),
1235 false,
1236 )
1237 .await
1238 .unwrap();
1239 ControllerMessageClient::status(
1240 &controller_handle,
1241 &worker2,
1242 1.into(),
1243 worker2_ref.actor_id().clone(),
1244 false,
1245 )
1246 .await
1247 .unwrap();
1248 controller_handle
1249 .remote_function_failed(
1250 &worker1,
1251 2.into(),
1252 WorkerError {
1253 backtrace: "some failure happened!".to_string(),
1254 worker_actor_id: worker1_ref.actor_id().clone(),
1255 },
1256 )
1257 .await
1258 .unwrap();
1259 controller_handle
1260 .remote_function_failed(
1261 &worker2,
1262 2.into(),
1263 WorkerError {
1264 backtrace: "some failure happened!".to_string(),
1265 worker_actor_id: worker2_ref.actor_id().clone(),
1266 },
1267 )
1268 .await
1269 .unwrap();
1270 for s in 3..=7 {
1271 ControllerMessageClient::status(
1272 &controller_handle,
1273 &worker1,
1274 s.into(),
1275 worker1_ref.actor_id().clone(),
1276 false,
1277 )
1278 .await
1279 .unwrap();
1280 ControllerMessageClient::status(
1281 &controller_handle,
1282 &worker2,
1283 s.into(),
1284 worker2_ref.actor_id().clone(),
1285 false,
1286 )
1287 .await
1288 .unwrap();
1289 }
1290
1291 controller_handle.drain_and_stop().unwrap();
1292 controller_handle.await;
1293 let mut client_messages = client_rx.drain();
1294 client_messages.sort_by_key(|msg| msg.clone().into_result().unwrap().0);
1295 assert_eq!(client_messages.len(), 7);
1296 let client_message = client_messages[2].clone().into_result().unwrap();
1297 assert_eq!(client_message.0, 2.into());
1298 assert_eq!(
1299 client_message.1,
1300 Some(Err(Exception::Error(
1301 2.into(),
1302 2.into(),
1303 WorkerError {
1304 backtrace: "some failure happened!".to_string(),
1305 worker_actor_id: worker1_ref.actor_id().clone(),
1306 }
1307 )))
1308 );
1309
1310 assert_eq!(
1311 client_messages
1312 .into_iter()
1313 .map(|msg| msg.into_result().unwrap().0)
1314 .collect::<HashSet<Seq>>(),
1315 HashSet::from([
1316 0.into(),
1317 3.into(),
1318 1.into(),
1319 4.into(),
1320 2.into(),
1321 5.into(),
1322 6.into()
1323 ])
1324 )
1325 }
1326
1327 #[tokio::test]
1328 async fn test_eager_failure_reporting() {
1329 let server_handle = System::serve(
1331 ChannelAddr::any(ChannelTransport::Local),
1332 Duration::from_secs(10),
1333 Duration::from_secs(10),
1334 )
1335 .await
1336 .unwrap();
1337 let mut system = System::new(server_handle.local_addr().clone());
1338
1339 let sup_mail = system.attach().await.unwrap();
1341 let (sup_tx, _sup_rx) = sup_mail.open_port::<ProcSupervisionMessage>();
1342 sup_tx.bind_to(ProcSupervisionMessage::port());
1343 let sup_ref = ActorRef::<ProcSupervisor>::attest(sup_mail.actor_id().clone());
1344
1345 let system_sender = BoxedMailboxSender::new(MailboxClient::new(
1347 channel::dial(server_handle.local_addr().clone()).unwrap(),
1348 ));
1349
1350 let listen_addr = ChannelAddr::any(ChannelTransport::Local);
1352 let proc_forwarder =
1353 BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
1354
1355 let world_id = id!(local);
1357 let proc = Proc::new(world_id.proc_id(0), proc_forwarder.clone());
1358 let proc_actor_0 = ProcActor::bootstrap_for_proc(
1359 proc.clone(),
1360 world_id.clone(),
1361 listen_addr,
1362 server_handle.local_addr().clone(),
1363 sup_ref.clone(),
1364 Duration::from_secs(2),
1365 HashMap::new(),
1366 ProcLifecycleMode::ManagedBySystem,
1367 )
1368 .await
1369 .unwrap();
1370
1371 let proc2 = Proc::new(world_id.proc_id(1), proc_forwarder.clone());
1373 let _proc_actor_1 = ProcActor::bootstrap_for_proc(
1374 proc2.clone(),
1375 world_id.clone(),
1376 ChannelAddr::any(ChannelTransport::Local),
1377 server_handle.local_addr().clone(),
1378 sup_ref.clone(),
1379 Duration::from_secs(2),
1380 HashMap::new(),
1381 ProcLifecycleMode::ManagedBySystem,
1382 )
1383 .await
1384 .unwrap();
1385
1386 let (client, client_ref, mut client_rx) = proc
1388 .attach_actor::<ClientActor, ClientMessage>("client")
1389 .unwrap();
1390 let (worker1, worker1_ref, _) = proc
1391 .attach_actor::<WorkerActor, WorkerMessage>("worker")
1392 .unwrap();
1393
1394 let controller_handle = proc
1395 .spawn::<ControllerActor>(
1396 "controller",
1397 ControllerParams {
1398 world_size: 1,
1399 comm_actor_ref: proc_actor_0.comm_actor.bind(),
1400 worker_gang_ref: GangId(
1401 WorldId(world_id.name().to_string()),
1402 "worker".to_string(),
1403 )
1404 .into(),
1405 supervision_query_interval: Duration::from_secs(1),
1406 worker_progress_check_interval: Duration::from_secs(3),
1407 operation_timeout: Duration::from_secs(30),
1408 operations_per_worker_progress_request: 100,
1409 fail_on_worker_timeout: false,
1410 },
1411 )
1412 .await
1413 .unwrap();
1414
1415 controller_handle.attach(&client, client_ref).await.unwrap();
1416
1417 controller_handle
1418 .node(&client, 0.into(), vec![1.into()], vec![])
1419 .await
1420 .unwrap();
1421
1422 controller_handle
1423 .node(&client, 1.into(), vec![2.into()], vec![1.into()])
1424 .await
1425 .unwrap();
1426
1427 controller_handle
1428 .node(&client, 2.into(), vec![3.into()], vec![2.into()])
1429 .await
1430 .unwrap();
1431
1432 controller_handle
1433 .node(&client, 3.into(), vec![], vec![3.into()])
1434 .await
1435 .unwrap();
1436
1437 controller_handle
1438 .node(&client, 4.into(), vec![], vec![])
1439 .await
1440 .unwrap();
1441
1442 controller_handle
1443 .remote_function_failed(
1444 &worker1,
1445 0.into(),
1446 WorkerError {
1447 backtrace: "some failure happened!".to_string(),
1448 worker_actor_id: worker1_ref.actor_id().clone(),
1449 },
1450 )
1451 .await
1452 .unwrap();
1453
1454 controller_handle
1455 .remote_function_failed(
1456 &worker1,
1457 3.into(),
1458 WorkerError {
1459 backtrace: "some failure happened!".to_string(),
1460 worker_actor_id: worker1_ref.actor_id().clone(),
1461 },
1462 )
1463 .await
1464 .unwrap();
1465
1466 ControllerMessageClient::status(
1467 &controller_handle,
1468 &worker1,
1469 5.into(),
1470 worker1_ref.actor_id().clone(),
1471 false,
1472 )
1473 .await
1474 .unwrap();
1475
1476 controller_handle.drain_and_stop().unwrap();
1477 controller_handle.await;
1478
1479 let client_messages = client_rx.drain();
1480 assert_eq!(client_messages.len(), 5);
1482
1483 let (errors, successes) =
1484 client_messages
1485 .into_iter()
1486 .fold((0, 0), |(errors, successes), client_message| {
1487 let (_, result) = client_message.clone().into_result().unwrap();
1488 match result {
1489 Some(Err(Exception::Error(_, _, _))) => (errors + 1, successes),
1490 None => (errors, successes + 1),
1491 _ => {
1492 panic!("should only be exceptions or no result");
1493 }
1494 }
1495 });
1496
1497 assert_eq!(errors, 4);
1499 assert_eq!(successes, 1);
1500 }
1501
1502 #[tokio::test]
1503 async fn test_bootstrap() {
1504 let server_handle = System::serve(
1505 ChannelAddr::any(ChannelTransport::Local),
1506 Duration::from_secs(10),
1507 Duration::from_secs(10),
1508 )
1509 .await
1510 .unwrap();
1511
1512 let controller_id = id!(controller[0].root);
1513 let proc_id = id!(world[0]);
1514 let (proc_handle, actor_ref) = ControllerActor::bootstrap(
1515 controller_id.clone(),
1516 ChannelAddr::any(ChannelTransport::Local),
1517 server_handle.local_addr().clone(),
1518 ControllerParams {
1519 world_size: 1,
1520 comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1521 worker_gang_ref: GangId(
1522 WorldId(
1523 proc_id
1524 .world_name()
1525 .expect("only ranked actors are supported in the controller tests")
1526 .to_string(),
1527 ),
1528 "worker".to_string(),
1529 )
1530 .into(),
1531 supervision_query_interval: Duration::from_secs(1),
1532 worker_progress_check_interval: Duration::from_secs(3),
1533 operation_timeout: Duration::from_secs(30),
1534 operations_per_worker_progress_request: 100,
1535 fail_on_worker_timeout: false,
1536 },
1537 Duration::from_secs(1),
1538 HashMap::new(),
1539 )
1540 .await
1541 .unwrap();
1542 assert_eq!(*actor_ref.actor_id(), controller_id);
1543
1544 proc_handle.drain_and_stop().unwrap();
1545 }
1546
1547 async fn mock_proc_actor(
1548 idx: usize,
1549 rank: usize,
1550 ) -> (
1551 WorldId,
1552 ProcId,
1553 ChannelAddr,
1554 Mailbox,
1555 PortHandle<ProcMessage>,
1556 PortReceiver<ProcMessage>,
1557 ) {
1558 let world_id = id!(world);
1559 let local_proc_id = world_id.proc_id(rank);
1561 let (local_proc_addr, local_proc_rx) =
1562 channel::serve(ChannelAddr::any(ChannelTransport::Local))
1563 .await
1564 .unwrap();
1565 let local_proc_mbox = Mailbox::new_detached(
1566 local_proc_id.actor_id(format!("test_dummy_proc{}", idx).to_string(), 0),
1567 );
1568 let (local_proc_message_port, local_proc_message_receiver) = local_proc_mbox.open_port();
1569 local_proc_message_port.bind();
1570
1571 let _local_proc_serve_handle = local_proc_mbox.clone().serve(local_proc_rx);
1572 (
1573 world_id,
1574 local_proc_id,
1575 local_proc_addr,
1576 local_proc_mbox,
1577 local_proc_message_port,
1578 local_proc_message_receiver,
1579 )
1580 }
1581
1582 #[tokio::test]
1583 async fn test_sim_supervision_failure() {
1584 simnet::start();
1586 simnet::simnet_handle()
1587 .unwrap()
1588 .set_training_script_state(simnet::TrainingScriptState::Waiting);
1589
1590 let system_sim_addr =
1591 ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix)));
1592 let server_handle = System::serve(
1594 system_sim_addr.clone(),
1595 Duration::from_secs(1000),
1596 Duration::from_secs(1000),
1597 )
1598 .await
1599 .unwrap();
1600
1601 let mut system = System::new(server_handle.local_addr().clone());
1602 let client_mailbox = system.attach().await.unwrap();
1603
1604 let controller_id = id!(controller[0].root);
1606 let proc_id = id!(world[0]);
1607 let controller_proc_listen_addr =
1608 ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix)));
1609
1610 let (_, actor_ref) = ControllerActor::bootstrap(
1611 controller_id.clone(),
1612 controller_proc_listen_addr,
1613 system_sim_addr,
1614 ControllerParams {
1615 world_size: 1,
1616 comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1617 worker_gang_ref: GangId(
1618 WorldId(
1619 proc_id
1620 .world_name()
1621 .expect("only ranked actors are supported in the controller tests")
1622 .to_string(),
1623 ),
1624 "worker".to_string(),
1625 )
1626 .into(),
1627 supervision_query_interval: Duration::from_secs(100),
1628 worker_progress_check_interval: Duration::from_secs(100),
1629 operation_timeout: Duration::from_secs(1000),
1630 operations_per_worker_progress_request: 100,
1631 fail_on_worker_timeout: false,
1632 },
1633 Duration::from_secs(100),
1634 HashMap::new(),
1635 )
1636 .await
1637 .unwrap();
1638 assert_eq!(*actor_ref.actor_id(), controller_id);
1639
1640 actor_ref
1641 .attach(
1642 &client_mailbox,
1643 ActorRef::attest(client_mailbox.actor_id().clone()),
1644 )
1645 .await
1646 .unwrap();
1647
1648 let (client_supervision_tx, mut client_supervision_rx) =
1649 client_mailbox.open_port::<ClientMessage>();
1650 client_supervision_tx.bind_to(ClientMessage::port());
1651
1652 let (
1654 world_id,
1655 local_proc_id,
1656 local_proc_addr,
1657 _,
1658 local_proc_message_port,
1659 mut local_proc_message_receiver,
1660 ) = mock_proc_actor(0, 1).await;
1661
1662 server_handle
1664 .system_actor_handle()
1665 .send(SystemMessage::Join {
1666 proc_id: local_proc_id.clone(),
1667 world_id,
1668 proc_message_port: local_proc_message_port.bind(),
1669 proc_addr: local_proc_addr,
1670 labels: HashMap::new(),
1671 lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
1672 })
1673 .unwrap();
1674
1675 assert_matches!(
1676 local_proc_message_receiver.recv().await.unwrap(),
1677 ProcMessage::Joined()
1678 );
1679
1680 let result = client_supervision_rx
1683 .recv()
1684 .await
1685 .unwrap()
1686 .into_result()
1687 .unwrap();
1688 assert_eq!(result.0, Seq::default());
1689 assert!(result.1.expect("result").is_err());
1690
1691 let records = simnet::simnet_handle().unwrap().close().await.unwrap();
1692 eprintln!("{}", serde_json::to_string_pretty(&records).unwrap());
1693 }
1694 #[tokio::test]
1695 async fn test_supervision_failure() {
1696 let timeout: Duration = Duration::from_secs(6);
1698 let server_handle = System::serve(
1699 ChannelAddr::any(ChannelTransport::Local),
1700 timeout.clone(),
1701 timeout.clone(),
1702 )
1703 .await
1704 .unwrap();
1705
1706 let mut system = System::new(server_handle.local_addr().clone());
1708 let client_mailbox = system.attach().await.unwrap();
1709 let (client_supervision_tx, mut client_supervision_rx) =
1710 client_mailbox.open_port::<ClientMessage>();
1711 client_supervision_tx.bind_to(ClientMessage::port());
1712
1713 let controller_id = id!(controller[0].root);
1715 let proc_id = id!(world[0]);
1716 let (_, actor_ref) = ControllerActor::bootstrap(
1717 controller_id.clone(),
1718 ChannelAddr::any(ChannelTransport::Local),
1719 server_handle.local_addr().clone(),
1720 ControllerParams {
1721 world_size: 1,
1722 comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1723 worker_gang_ref: GangId(
1724 WorldId(
1725 proc_id
1726 .world_name()
1727 .expect("only ranked actors are supported in the controller tests")
1728 .to_string(),
1729 ),
1730 "worker".to_string(),
1731 )
1732 .into(),
1733 supervision_query_interval: Duration::from_secs(1),
1734 worker_progress_check_interval: Duration::from_secs(3),
1735 operation_timeout: Duration::from_secs(30),
1736 operations_per_worker_progress_request: 100,
1737 fail_on_worker_timeout: false,
1738 },
1739 Duration::from_secs(1),
1740 HashMap::new(),
1741 )
1742 .await
1743 .unwrap();
1744 assert_eq!(*actor_ref.actor_id(), controller_id);
1745
1746 actor_ref
1747 .attach(
1748 &client_mailbox,
1749 ActorRef::attest(client_mailbox.actor_id().clone()),
1750 )
1751 .await
1752 .unwrap();
1753
1754 let (
1756 world_id,
1757 local_proc_id,
1758 local_proc_addr,
1759 _,
1760 local_proc_message_port,
1761 mut local_proc_message_receiver,
1762 ) = mock_proc_actor(0, 1).await;
1763
1764 server_handle
1766 .system_actor_handle()
1767 .send(SystemMessage::Join {
1768 proc_id: local_proc_id.clone(),
1769 world_id,
1770 proc_message_port: local_proc_message_port.bind(),
1771 proc_addr: local_proc_addr,
1772 labels: HashMap::new(),
1773 lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
1774 })
1775 .unwrap();
1776
1777 assert_matches!(
1778 local_proc_message_receiver.recv().await.unwrap(),
1779 ProcMessage::Joined()
1780 );
1781
1782 RealClock.sleep(2 * timeout.clone()).await;
1784
1785 let result = client_supervision_rx
1787 .recv()
1788 .await
1789 .unwrap()
1790 .into_result()
1791 .unwrap();
1792 assert_eq!(result.0, Seq::default());
1793 assert!(result.1.expect("result").is_err());
1794 }
1795
1796 #[derive(
1797 Handler,
1798 HandleClient,
1799 RefClient,
1800 Named,
1801 Debug,
1802 Clone,
1803 Serialize,
1804 Deserialize,
1805 PartialEq
1806 )]
1807 enum PanickingMessage {
1808 Panic(String),
1809 }
1810
1811 #[derive(Debug, Default, Actor)]
1812 #[hyperactor::export(
1813 handlers = [
1814 PanickingMessage,
1815 ],
1816 )]
1817 struct PanickingActor;
1818
1819 #[async_trait]
1820 #[hyperactor::forward(PanickingMessage)]
1821 impl PanickingMessageHandler for PanickingActor {
1822 async fn panic(
1823 &mut self,
1824 _cx: &Context<Self>,
1825 err_msg: String,
1826 ) -> Result<(), anyhow::Error> {
1827 panic!("{}", err_msg);
1828 }
1829 }
1830
1831 hyperactor::remote!(PanickingActor);
1832
1833 #[tokio::test]
1834 async fn test_supervision_fault() {
1835 let timeout: Duration = Duration::from_secs(6);
1837 let server_handle = System::serve(
1838 ChannelAddr::any(ChannelTransport::Local),
1839 timeout.clone(),
1840 timeout.clone(),
1841 )
1842 .await
1843 .unwrap();
1844
1845 let mut system = System::new(server_handle.local_addr().clone());
1847 let client_mailbox = system.attach().await.unwrap();
1848 let (client_supervision_tx, mut client_supervision_rx) =
1849 client_mailbox.open_port::<ClientMessage>();
1850 client_supervision_tx.bind_to(ClientMessage::port());
1851
1852 let controller_id = id!(controller[0].root);
1854 let proc_id = id!(world[0]);
1855 let (_, actor_ref) = ControllerActor::bootstrap(
1856 controller_id.clone(),
1857 ChannelAddr::any(ChannelTransport::Local),
1858 server_handle.local_addr().clone(),
1859 ControllerParams {
1860 world_size: 1,
1861 comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)),
1862 worker_gang_ref: GangId(
1863 WorldId(
1864 proc_id
1865 .world_name()
1866 .expect("only ranked actors are supported in the controller tests")
1867 .to_string(),
1868 ),
1869 "worker".to_string(),
1870 )
1871 .into(),
1872 supervision_query_interval: Duration::from_secs(1),
1873 worker_progress_check_interval: Duration::from_secs(3),
1874 operation_timeout: Duration::from_secs(30),
1875 operations_per_worker_progress_request: 100,
1876 fail_on_worker_timeout: false,
1877 },
1878 Duration::from_secs(1),
1879 HashMap::new(),
1880 )
1881 .await
1882 .unwrap();
1883 assert_eq!(*actor_ref.actor_id(), controller_id);
1884
1885 actor_ref
1886 .attach(
1887 &client_mailbox,
1888 ActorRef::attest(client_mailbox.actor_id().clone()),
1889 )
1890 .await
1891 .unwrap();
1892
1893 let world_id = id!(world);
1895 let panic_proc_id = world_id.proc_id(1);
1896 let bootstrap = ProcActor::bootstrap(
1897 panic_proc_id,
1898 world_id,
1899 ChannelAddr::any(ChannelTransport::Local),
1900 server_handle.local_addr().clone(),
1901 Duration::from_secs(3),
1902 HashMap::new(),
1903 ProcLifecycleMode::ManagedBySystem,
1904 )
1905 .await
1906 .unwrap();
1907 let actor_handle = spawn::<PanickingActor>(
1908 &client_mailbox,
1909 &bootstrap.proc_actor.bind(),
1910 "panicker",
1911 &(),
1912 )
1913 .await
1914 .unwrap();
1915
1916 actor_handle
1917 .panic(&client_mailbox, "some random failure".to_string())
1918 .await
1919 .unwrap();
1920
1921 let result = client_supervision_rx
1923 .recv()
1924 .await
1925 .unwrap()
1926 .into_result()
1927 .unwrap();
1928 assert_eq!(result.0, Seq::default());
1929 assert!(result.1.is_some() && result.1.as_ref().unwrap().is_err());
1930 let Exception::Failure(err) = result.1.unwrap().unwrap_err() else {
1931 panic!("Expected Failure exception");
1932 };
1933 assert!(err.backtrace.contains("some random failure"));
1934 }
1935}