1use std::collections::HashMap;
15use std::fmt;
16use std::hash::Hash;
17use std::hash::Hasher;
18use std::ops::Deref;
19use std::sync::Arc;
20use std::sync::OnceLock as OnceCell;
21use std::time::Duration;
22
23use hyperactor::ActorLocal;
24use hyperactor::ActorRef;
25use hyperactor::Endpoint as _;
26use hyperactor::PortRef;
27use hyperactor::RemoteEndpoint as _;
28use hyperactor::RemoteHandles;
29use hyperactor::RemoteMessage;
30use hyperactor::UnboundPort;
31use hyperactor::UnboundPortKind;
32use hyperactor::accum::ReducerMode;
33use hyperactor::actor::ActorStatus;
34use hyperactor::actor::Referable;
35use hyperactor::context;
36use hyperactor::mailbox::PortReceiver;
37use hyperactor::message::Castable;
38use hyperactor::message::ErasedUnbound;
39use hyperactor::message::IndexedErasedUnbound;
40use hyperactor::message::Unbound;
41use hyperactor::port::Port;
42use hyperactor::supervision::ActorSupervisionEvent;
43use hyperactor_config::CONFIG;
44use hyperactor_config::ConfigAttr;
45use hyperactor_config::Flattrs;
46use hyperactor_config::attrs::declare_attrs;
47use hyperactor_mesh_macros::sel;
48use ndslice::Selection;
49use ndslice::ViewExt as _;
50use ndslice::view;
51use ndslice::view::MapIntoExt;
52use ndslice::view::Region;
53use ndslice::view::View;
54use serde::Deserialize;
55use serde::Deserializer;
56use serde::Serialize;
57use serde::Serializer;
58use tokio::sync::watch;
59
60use crate::CommActor;
61use crate::Error;
62use crate::ProcMeshRef;
63use crate::ValueMesh;
64use crate::casting;
65use crate::comm::multicast;
66use crate::comm::multicast::CastMessageV1;
67use crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD;
68use crate::host_mesh::GET_PROC_STATE_MAX_IDLE;
69use crate::host_mesh::mesh_to_rankedvalues_with_default;
70use crate::mesh_controller::ActorMeshController;
71use crate::mesh_controller::SUPERVISION_POLL_FREQUENCY;
72use crate::mesh_controller::Subscribe;
73use crate::mesh_controller::Unsubscribe;
74use crate::mesh_id::ActorMeshId;
75use crate::metrics;
76use crate::proc_agent::ActorState;
77use crate::proc_mesh::GET_ACTOR_STATE_MAX_IDLE;
78use crate::resource;
79use crate::supervision::MeshFailure;
80use crate::supervision::Unhealthy;
81
82declare_attrs! {
83 @meta(CONFIG = ConfigAttr::new(
92 Some("HYPERACTOR_MESH_SUPERVISION_WATCHDOG_TIMEOUT".to_string()),
93 Some("supervision_watchdog_timeout".to_string()),
94 ))
95 pub attr SUPERVISION_WATCHDOG_TIMEOUT: Duration = Duration::from_mins(2);
96}
97
98#[derive(Debug)]
103pub struct ActorMesh<A: Referable> {
104 proc_mesh: ProcMeshRef,
105 id: ActorMeshId,
106 current_ref: ActorMeshRef<A>,
107 controller: Option<ActorRef<ActorMeshController<A>>>,
113}
114
115impl<A: Referable> ActorMesh<A> {
117 pub(crate) fn new(
118 proc_mesh: ProcMeshRef,
119 id: ActorMeshId,
120 controller: Option<ActorRef<ActorMeshController<A>>>,
121 ) -> Self {
122 let current_ref = ActorMeshRef::with_page_size(
123 id.clone(),
124 proc_mesh.clone(),
125 DEFAULT_PAGE,
126 controller.clone(),
127 );
128
129 Self {
130 proc_mesh,
131 id,
132 current_ref,
133 controller,
134 }
135 }
136
137 pub fn id(&self) -> &ActorMeshId {
138 &self.id
139 }
140
141 pub(crate) fn set_controller(&mut self, controller: Option<ActorRef<ActorMeshController<A>>>) {
142 self.controller = controller.clone();
143 self.current_ref.set_controller(controller);
144 }
145
146 pub async fn stop(&mut self, cx: &impl context::Actor, reason: String) -> crate::Result<()> {
148 if let Some(controller) = self.controller.take() {
154 let id = self.id.resource_id().clone();
162 let num_ranks = self.current_ref.region().num_ranks();
163 let result: crate::Result<()> = async {
164 controller.post(
165 cx,
166 resource::Stop {
167 id: id.clone(),
168 reason,
169 },
170 );
171 let (port, mut rx) = cx.mailbox().open_port();
181 controller.post(
182 cx,
183 resource::GetState::<resource::mesh::State<()>> {
184 id: id.clone(),
185 reply: port.bind(),
186 },
187 );
188 let statuses = rx.recv().await?;
189 let Some(state) = &statuses.state else {
190 return Err(Error::Other(anyhow::anyhow!(
191 "non-existent state in GetState reply from controller: {}",
192 controller.actor_addr()
193 )));
194 };
195 let all_terminating = state.statuses.values().all(|s| s.is_terminating());
202 if !all_terminating {
203 let legacy = mesh_to_rankedvalues_with_default(
204 &state.statuses,
205 resource::Status::NotExist,
206 resource::Status::is_not_exist,
207 num_ranks,
208 );
209 return Err(Error::ActorStopError { statuses: legacy });
210 }
211 Ok(())
212 }
213 .await;
214
215 let status = match &result {
221 Ok(()) => ActorStatus::Stopped("mesh stopped".to_string()),
222 Err(e) => ActorStatus::Stopped(format!("mesh stop failed: {e}")),
223 };
224 let mut entry = self.health_state.entry(cx).or_default();
225 let health_state = entry.get_mut();
226 health_state.unhealthy_event = Some(Unhealthy::StreamClosed(MeshFailure {
227 actor_mesh_name: Some(self.id().to_string()),
228 event: ActorSupervisionEvent::new(
229 ndslice::view::Ranked::get(&self.current_ref, 0)
231 .unwrap()
232 .actor_addr()
233 .clone(),
234 None,
235 status,
236 None,
237 ),
238 crashed_ranks: vec![],
239 }));
240
241 result?;
242 }
243 self.current_ref.controller.take();
246 Ok(())
247 }
248}
249
250impl<A: Referable> fmt::Display for ActorMesh<A> {
251 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252 write!(f, "{}", self.current_ref)
253 }
254}
255
256impl<A: Referable> Deref for ActorMesh<A> {
257 type Target = ActorMeshRef<A>;
258
259 fn deref(&self) -> &Self::Target {
260 &self.current_ref
261 }
262}
263
264impl<A: Referable> Clone for ActorMesh<A> {
267 fn clone(&self) -> Self {
268 Self {
269 proc_mesh: self.proc_mesh.clone(),
270 id: self.id.clone(),
271 current_ref: self.current_ref.clone(),
272 controller: self.controller.clone(),
273 }
274 }
275}
276
277impl<A: Referable> Drop for ActorMesh<A> {
278 fn drop(&mut self) {
279 tracing::info!(
280 name = "ActorMeshStatus",
281 actor_name = %self.id,
282 status = "Dropped",
283 );
284 }
285}
286
287const DEFAULT_PAGE: usize = 1024;
291
292struct Page<A: Referable> {
294 slots: Box<[OnceCell<ActorRef<A>>]>,
295}
296
297impl<A: Referable> Page<A> {
298 fn new(len: usize) -> Self {
299 let mut v = Vec::with_capacity(len);
300 for _ in 0..len {
301 v.push(OnceCell::new());
302 }
303 Self {
304 slots: v.into_boxed_slice(),
305 }
306 }
307}
308
309#[derive(Default)]
310struct HealthState {
311 unhealthy_event: Option<Unhealthy>,
312 crashed_ranks: HashMap<usize, ActorSupervisionEvent>,
313}
314
315impl HealthState {
316 fn failure_for_region(&self, region: &Region) -> Option<MeshFailure> {
317 let unhealthy = self.unhealthy_event.as_ref()?;
318 let mut failure = match unhealthy {
319 Unhealthy::StreamClosed(failure) | Unhealthy::Crashed(failure) => failure.clone(),
320 };
321 if failure.crashed_ranks.is_empty() {
322 return Some(failure);
323 }
324 let mut crashed_ranks = self
325 .crashed_ranks
326 .keys()
327 .copied()
328 .filter(|rank| region.slice().contains(*rank))
329 .collect::<Vec<_>>();
330 crashed_ranks.sort_unstable();
331 if crashed_ranks.is_empty() {
332 return None;
333 }
334 failure.crashed_ranks = crashed_ranks;
335 Some(failure)
336 }
337}
338
339impl std::fmt::Debug for HealthState {
340 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341 f.debug_struct("HealthState")
342 .field("unhealthy_event", &self.unhealthy_event)
343 .field("crashed_ranks", &self.crashed_ranks)
344 .finish()
345 }
346}
347
348#[derive(Clone)]
349enum MessageOrFailure<M: Send + Sync + Clone + Default + 'static> {
350 Message(M),
351 Failure(String),
354 Timeout,
355}
356
357impl<M: Send + Sync + Clone + Default + 'static> Default for MessageOrFailure<M> {
358 fn default() -> Self {
359 Self::Message(M::default())
360 }
361}
362
363fn into_watch<M: Send + Sync + Clone + Default + 'static>(
367 mut rx: PortReceiver<M>,
368) -> watch::Receiver<MessageOrFailure<M>> {
369 let (sender, receiver) = watch::channel(MessageOrFailure::<M>::default());
370 let timeout = hyperactor_config::global::get(SUPERVISION_WATCHDOG_TIMEOUT);
378 let poll_frequency = hyperactor_config::global::get(SUPERVISION_POLL_FREQUENCY);
379 let get_actor_state_max_idle = hyperactor_config::global::get(GET_ACTOR_STATE_MAX_IDLE);
380 let get_proc_state_max_idle = hyperactor_config::global::get(GET_PROC_STATE_MAX_IDLE);
381 let total_time = poll_frequency + get_actor_state_max_idle + get_proc_state_max_idle;
382 if timeout < total_time {
383 tracing::warn!(
384 "HYPERACTOR_MESH_SUPERVISION_WATCHDOG_TIMEOUT={} is too short. It should be >= {} (SUPERVISION_POLL_FREQUENCY={} + GET_ACTOR_STATE_MAX_IDLE={} + GET_PROC_STATE_MAX_IDLE={})",
385 humantime::format_duration(timeout),
386 humantime::format_duration(total_time),
387 humantime::format_duration(poll_frequency),
388 humantime::format_duration(get_actor_state_max_idle),
389 humantime::format_duration(get_proc_state_max_idle),
390 );
391 }
392 tokio::spawn(async move {
393 loop {
394 let message = match tokio::time::timeout(timeout, rx.recv()).await {
395 Ok(Ok(msg)) => MessageOrFailure::Message(msg),
396 Ok(Err(e)) => MessageOrFailure::Failure(e.to_string()),
397 Err(_) => MessageOrFailure::Timeout,
398 };
399 let is_failure = matches!(
400 message,
401 MessageOrFailure::Failure(_) | MessageOrFailure::Timeout
402 );
403 if sender.send(message).is_err() {
404 break;
406 }
407 if is_failure {
408 break;
410 }
411 }
412 });
413 receiver
414}
415
416#[derive(typeuri::Named)]
418pub struct ActorMeshRef<A: Referable> {
419 proc_mesh: ProcMeshRef,
420 id: ActorMeshId,
421 controller: Option<ActorRef<ActorMeshController<A>>>,
428
429 health_state: ActorLocal<HealthState>,
433 receiver: ActorLocal<
438 Arc<
439 tokio::sync::Mutex<(
440 PortRef<Option<MeshFailure>>,
441 watch::Receiver<MessageOrFailure<Option<MeshFailure>>>,
442 )>,
443 >,
444 >,
445 pages: OnceCell<Vec<OnceCell<Box<Page<A>>>>>,
455 page_size: usize,
457}
458
459impl<A: Referable> ActorMeshRef<A> {
460 fn cached_failure(&self, cx: &impl context::Actor) -> Option<MeshFailure> {
461 let health_state = self.health_state.entry(cx).or_default();
462 health_state
463 .get()
464 .failure_for_region(ndslice::view::Ranked::region(self))
465 }
466
467 #[allow(clippy::result_large_err)]
469 pub fn cast<M>(&self, cx: &impl context::Actor, message: M) -> crate::Result<()>
470 where
471 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
472 M: Castable + RemoteMessage + Clone, {
474 self.cast_with_selection(cx, sel!(*), message, &Flattrs::new())
475 }
476
477 #[allow(clippy::result_large_err)]
484 pub fn cast_with_headers<M>(
485 &self,
486 cx: &impl context::Actor,
487 caller_headers: &Flattrs,
488 message: M,
489 ) -> crate::Result<()>
490 where
491 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
492 M: Castable + RemoteMessage + Clone,
493 {
494 self.cast_with_selection(cx, sel!(*), message, caller_headers)
495 }
496
497 #[allow(clippy::result_large_err)]
502 pub fn cast_for_tensor_engine_only_do_not_use<M>(
503 &self,
504 cx: &impl context::Actor,
505 sel: Selection,
506 message: M,
507 ) -> crate::Result<()>
508 where
509 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
510 M: Castable + RemoteMessage + Clone, {
512 self.cast_with_selection(cx, sel, message, &Flattrs::new())
513 }
514
515 #[allow(clippy::result_large_err)]
516 fn cast_with_selection<M>(
517 &self,
518 cx: &impl context::Actor,
519 sel: Selection,
520 message: M,
521 caller_headers: &Flattrs,
522 ) -> crate::Result<()>
523 where
524 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
525 M: Castable + RemoteMessage + Clone, {
527 if let Some(failure) = self.cached_failure(cx) {
530 tracing::debug!(
531 actor_mesh = %self.id,
532 crashed_ranks = ?failure.crashed_ranks,
533 "rejecting cast due to cached supervision failure"
534 );
535 return Err(crate::Error::Supervision(Box::new(failure)));
536 }
537
538 hyperactor_telemetry::notify_sent_message(hyperactor_telemetry::SentMessageEvent {
539 timestamp: std::time::SystemTime::now(),
540 sender_actor_id: hyperactor_telemetry::hash_to_u64(cx.mailbox().actor_addr()),
541 actor_mesh_id: hyperactor_telemetry::hash_to_u64(&self.id.to_string()),
542 view_json: serde_json::to_string(view::Ranked::region(self)).unwrap_or_default(),
543 shape_json: {
544 let shape: ndslice::Shape = view::Ranked::region(self).into();
545 serde_json::to_string(&shape).unwrap_or_default()
546 },
547 });
548
549 if let Some(root_comm_actor) = self.proc_mesh.root_comm_actor() {
551 if casting::v1_casting_enabled() {
552 if Selection::is_equivalent_to_true(&sel) {
553 self.cast_v1(cx, message, root_comm_actor, caller_headers);
554 return Ok(());
555 }
556 } else {
559 return self.cast_v0(cx, message, sel, root_comm_actor, caller_headers);
560 }
561 }
562
563 let selected_ranks: std::collections::HashSet<usize> = sel
564 .eval(
565 &ndslice::selection::EvalOpts::lenient(),
566 view::Ranked::region(self).slice(),
567 )
568 .map_err(|e| Error::CastingError(self.id.clone(), e.into()))?
569 .collect();
570
571 for (point, actor) in self.iter() {
572 if !selected_ranks.contains(&point.rank()) {
573 continue;
574 }
575 let create_rank = point.rank();
576 let mut headers = caller_headers.clone();
577 multicast::set_cast_info_on_headers(
578 &mut headers,
579 point,
580 cx.instance().self_addr().clone(),
581 );
582
583 let mut unbound = Unbound::try_from_message(message.clone())
586 .map_err(|e| Error::CastingError(self.id.clone(), e))?;
587 unbound
588 .visit_mut::<resource::Rank>(|resource::Rank(rank)| {
589 *rank = Some(create_rank);
590 Ok(())
591 })
592 .map_err(|e| Error::CastingError(self.id.clone(), e))?;
593 let rebound_message = unbound
594 .bind()
595 .map_err(|e| Error::CastingError(self.id.clone(), e))?;
596 actor.post_with_headers(cx, headers, rebound_message);
597 }
598 Ok(())
599 }
600
601 #[allow(clippy::result_large_err)]
602 fn cast_v0<M>(
603 &self,
604 cx: &impl context::Actor,
605 message: M,
606 sel: Selection,
607 root_comm_actor: &ActorRef<CommActor>,
608 caller_headers: &Flattrs,
609 ) -> crate::Result<()>
610 where
611 A: RemoteHandles<IndexedErasedUnbound<M>>,
612 M: Castable + RemoteMessage + Clone, {
614 let cast_mesh_shape = view::Ranked::region(self).into();
615 let actor_mesh_id = self.id.clone();
616 match &self.proc_mesh.root_region {
617 Some(root_region) => {
618 let root_mesh_shape = root_region.into();
619 casting::cast_to_sliced_mesh::<A, M>(
620 cx,
621 actor_mesh_id,
622 root_comm_actor,
623 &sel,
624 message,
625 &cast_mesh_shape,
626 &root_mesh_shape,
627 caller_headers,
628 )
629 .map_err(|e| Error::CastingError(self.id.clone(), e.into()))
630 }
631 None => casting::actor_mesh_cast::<A, M>(
632 cx,
633 actor_mesh_id,
634 root_comm_actor,
635 sel,
636 &cast_mesh_shape,
637 &cast_mesh_shape,
638 message,
639 caller_headers,
640 )
641 .map_err(|e| Error::CastingError(self.id.clone(), e.into())),
642 }
643 }
644
645 fn cast_v1<M>(
646 &self,
647 cx: &impl context::Actor,
648 message: M,
649 root_comm_actor: &ActorRef<CommActor>,
650 caller_headers: &Flattrs,
651 ) where
652 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
653 M: Castable + RemoteMessage,
654 {
655 let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
656 "message_type" => <M as typeuri::Named>::typename(),
657 "message_variant" => message.arm().unwrap_or_default(),
658 ));
659
660 let actor_ids: ValueMesh<_> = self.proc_mesh.map_into(|proc| proc.actor_addr(&self.id));
661
662 let mut headers = caller_headers.clone();
663 headers.set(
664 multicast::CAST_ORIGINATING_SENDER,
665 cx.instance().self_addr().clone(),
666 );
667 headers.set(casting::CAST_ACTOR_MESH_ID, self.id.clone());
671
672 let region = view::Ranked::region(self).clone();
673 let num_ranks = region.num_ranks();
674 let threshold = hyperactor_config::global::get(V1_CAST_POINT_TO_POINT_THRESHOLD);
675
676 if threshold > 0 && num_ranks < threshold {
677 let sender = cx.instance().self_addr().clone();
681 let dest_port = <IndexedErasedUnbound<M> as typeuri::Named>::port();
682
683 let mut data = ErasedUnbound::try_from_message(message)
684 .expect("cast message serialization should not fail");
685
686 data.visit_mut::<UnboundPort>(
689 |UnboundPort(port_id, reducer_spec, return_undeliverable, kind, unsplit)| {
690 if *unsplit {
691 return Ok(());
692 }
693 let reducer_mode = match kind {
694 UnboundPortKind::Streaming(opts) => {
695 ReducerMode::Streaming(opts.clone().unwrap_or_default())
696 }
697 UnboundPortKind::Once if reducer_spec.is_none() => {
698 return Ok(());
701 }
702 UnboundPortKind::Once => ReducerMode::Once(num_ranks),
703 };
704 let split = port_id.split(
705 cx,
706 reducer_spec.clone(),
707 reducer_mode,
708 *return_undeliverable,
709 )?;
710 *port_id = split;
711 Ok(())
712 },
713 )
714 .expect("port splitting should not fail");
715
716 for rank in 0..num_ranks {
717 let mut rank_data = data.clone();
718
719 let cast_point = region
720 .point_of_base_rank(rank)
721 .expect("rank should be valid in region");
722
723 rank_data
724 .visit_mut::<resource::Rank>(|resource::Rank(r)| {
725 *r = Some(cast_point.rank());
726 Ok(())
727 })
728 .expect("rank replacement should not fail");
729
730 let mut rank_headers = headers.clone();
731 multicast::set_cast_info_on_headers(&mut rank_headers, cast_point, sender.clone());
732
733 let port_id = actor_ids
734 .get(rank)
735 .expect("mismatched actor_ids and dest_region")
736 .port_addr(Port::from(dest_port));
737
738 cx.instance().post(
739 port_id,
740 rank_headers,
741 wirevalue::Any::serialize(&rank_data)
742 .expect("cast message serialization should not fail"),
743 );
744 }
745 } else {
746 let sequencer = cx.instance().sequencer();
750 let seqs: ValueMesh<u64> = actor_ids.map_into(|actor_id| {
751 let hyperactor::ordering::SeqInfo::Session { seq, .. } = sequencer
752 .assign_seq(&actor_id.port_addr(Port::from(<M as typeuri::Named>::port())))
753 else {
754 unreachable!("assign_seq always returns SeqInfo::Session")
755 };
756 seq
757 });
758
759 let mut headers = caller_headers.clone();
760 headers.set(
761 multicast::CAST_ORIGINATING_SENDER,
762 cx.instance().self_addr().clone(),
763 );
764 headers.set(casting::CAST_ACTOR_MESH_ID, self.id.clone());
768 let cast_message = CastMessageV1::new::<A, M>(
769 cx.instance().self_addr().clone(),
770 &self.id,
771 region,
772 headers.clone(),
773 message,
774 sequencer.session_id(),
775 seqs,
776 )
777 .expect("infallible because CastMessage should not fail for serialization");
778
779 root_comm_actor.post_with_headers(cx, headers, cast_message);
781 }
782 }
783 #[allow(clippy::result_large_err)]
789 pub async fn actor_states(
790 &self,
791 cx: &impl context::Actor,
792 ) -> crate::Result<ValueMesh<resource::State<ActorState>>> {
793 self.actor_states_with_keepalive(cx, None).await
794 }
795
796 #[allow(clippy::result_large_err)]
797 pub(crate) async fn actor_states_with_keepalive(
798 &self,
799 cx: &impl context::Actor,
800 keepalive: Option<std::time::SystemTime>,
801 ) -> crate::Result<ValueMesh<resource::State<ActorState>>> {
802 self.proc_mesh
803 .actor_states_with_keepalive(cx, self.id.clone(), keepalive)
804 .await
805 }
806
807 pub(crate) fn new(
808 id: ActorMeshId,
809 proc_mesh: ProcMeshRef,
810 controller: Option<ActorRef<ActorMeshController<A>>>,
811 ) -> Self {
812 Self::with_page_size(id, proc_mesh, DEFAULT_PAGE, controller)
813 }
814
815 pub fn id(&self) -> &ActorMeshId {
816 &self.id
817 }
818
819 pub(crate) fn with_page_size(
820 id: ActorMeshId,
821 proc_mesh: ProcMeshRef,
822 page_size: usize,
823 controller: Option<ActorRef<ActorMeshController<A>>>,
824 ) -> Self {
825 Self {
826 proc_mesh,
827 id,
828 controller,
829 health_state: ActorLocal::new(),
830 receiver: ActorLocal::new(),
831 pages: OnceCell::new(),
832 page_size: page_size.max(1),
833 }
834 }
835
836 pub fn proc_mesh(&self) -> &ProcMeshRef {
837 &self.proc_mesh
838 }
839
840 #[inline]
841 fn len(&self) -> usize {
842 view::Ranked::region(&self.proc_mesh).num_ranks()
843 }
844
845 pub fn controller(&self) -> &Option<ActorRef<ActorMeshController<A>>> {
846 &self.controller
847 }
848
849 fn set_controller(&mut self, controller: Option<ActorRef<ActorMeshController<A>>>) {
850 self.controller = controller;
851 }
852
853 fn ensure_pages(&self) -> &Vec<OnceCell<Box<Page<A>>>> {
854 let n = self.len().div_ceil(self.page_size); self.pages
856 .get_or_init(|| (0..n).map(|_| OnceCell::new()).collect())
857 }
858
859 fn materialize(&self, rank: usize) -> Option<&ActorRef<A>> {
860 let len = self.len();
861 if rank >= len {
862 return None;
863 }
864 let p = self.page_size;
865 let page_ix = rank / p;
866 let local_ix = rank % p;
867
868 let pages = self.ensure_pages();
869 let page = pages[page_ix].get_or_init(|| {
870 let base = page_ix * p;
872 let remaining = len - base;
873 let page_len = remaining.min(p);
874 Box::new(Page::<A>::new(page_len))
875 });
876
877 Some(page.slots[local_ix].get_or_init(|| {
878 debug_assert!(rank < self.len(), "rank must be within [0, len)");
885 debug_assert!(
886 ndslice::view::Ranked::get(&self.proc_mesh, rank).is_some(),
887 "proc_mesh must be dense/aligned with this view"
888 );
889 let proc_ref =
890 ndslice::view::Ranked::get(&self.proc_mesh, rank).expect("rank in-bounds");
891 proc_ref.attest(&self.id)
892 }))
893 }
894
895 fn init_supervision_receiver(
896 controller: &ActorRef<ActorMeshController<A>>,
897 cx: &impl context::Actor,
898 ) -> (
899 PortRef<Option<MeshFailure>>,
900 watch::Receiver<MessageOrFailure<Option<MeshFailure>>>,
901 ) {
902 let (tx, rx) = cx.mailbox().open_port();
903 let tx = tx.bind();
904 controller.post(cx, Subscribe(tx.clone()));
905 (tx, into_watch(rx))
906 }
907
908 pub async fn next_supervision_event(
915 &self,
916 cx: &impl context::Actor,
917 ) -> Result<MeshFailure, anyhow::Error> {
918 if let Some(failure) = self.cached_failure(cx) {
919 tracing::debug!(
920 actor_mesh = %self.id,
921 crashed_ranks = ?failure.crashed_ranks,
922 "returning cached supervision failure"
923 );
924 return Ok(failure);
925 }
926 let controller = if let Some(c) = self.controller() {
927 c
928 } else {
929 return Err(anyhow::anyhow!(
930 "unexpected healthy state while controller is gone"
931 ));
932 };
933 let rx = {
934 let entry = self.receiver.entry(cx).or_insert_with(|| {
936 Arc::new(tokio::sync::Mutex::new(Self::init_supervision_receiver(
937 controller, cx,
938 )))
939 });
940 Arc::clone(entry.get())
943 };
944 let message = {
945 let mut rx = rx.lock().await;
946 let subscriber_port = rx.0.clone();
947 let message =
948 rx.1.wait_for(|message| {
949 if let MessageOrFailure::Message(message) = message {
953 if let Some(message) = &message {
954 let region = ndslice::view::Ranked::region(self).slice();
955 if message.crashed_ranks.is_empty() {
956 true
958 } else {
959 message.crashed_ranks.iter().any(|r| region.contains(*r))
962 }
963 } else {
964 false
968 }
969 } else {
970 true
972 }
973 })
974 .await?;
975 let message = message.clone();
976 let is_failure = matches!(
977 message,
978 MessageOrFailure::Failure(_) | MessageOrFailure::Timeout
979 );
980 if is_failure {
981 let mut port = controller.port();
986 port.return_undeliverable(false);
988 let _ = port.post(cx, Unsubscribe(subscriber_port));
989 }
990 match message {
994 MessageOrFailure::Message(message) => Ok::<MeshFailure, anyhow::Error>(
995 message.expect("filter excludes any None messages"),
996 ),
997 MessageOrFailure::Failure(failure) => Err(anyhow::anyhow!("{}", failure)),
998 MessageOrFailure::Timeout => {
999 Ok(MeshFailure {
1002 actor_mesh_name: Some(self.id().to_string()),
1003 event: ActorSupervisionEvent::new(
1004 controller.actor_addr().clone(),
1005 None,
1006 ActorStatus::generic_failure(format!(
1007 "timed out reaching controller {} for mesh {}. Assuming controller's proc is dead",
1008 controller.actor_addr(),
1009 self.id()
1010 )),
1011 None,
1012 ),
1013 crashed_ranks: vec![],
1014 })
1015 }
1016 }?
1017 };
1018 let event = &message.event;
1020 let mut entry = self.health_state.entry(cx).or_default();
1022 let health_state = entry.get_mut();
1023 if let ActorStatus::Failed(_) = event.actor_status {
1024 for &rank in &message.crashed_ranks {
1025 health_state.crashed_ranks.insert(rank, event.clone());
1026 }
1027 }
1028 health_state.unhealthy_event = match &event.actor_status {
1029 ActorStatus::Failed(_) => Some(Unhealthy::Crashed(message.clone())),
1030 ActorStatus::Stopped(_) => Some(Unhealthy::StreamClosed(message.clone())),
1031 _ => None,
1032 };
1033 Ok(message)
1034 }
1035
1036 pub fn clone_with_supervision_receiver(&self) -> Self {
1040 Self {
1041 proc_mesh: self.proc_mesh.clone(),
1042 id: self.id.clone(),
1043 controller: self.controller.clone(),
1044 health_state: self.health_state.clone(),
1045 receiver: self.receiver.clone(),
1046 pages: OnceCell::new(),
1048 page_size: self.page_size,
1049 }
1050 }
1051}
1052
1053impl<A: Referable> Clone for ActorMeshRef<A> {
1054 fn clone(&self) -> Self {
1055 Self {
1056 proc_mesh: self.proc_mesh.clone(),
1057 id: self.id.clone(),
1058 controller: self.controller.clone(),
1059 health_state: ActorLocal::new(),
1062 receiver: ActorLocal::new(),
1063 pages: OnceCell::new(), page_size: self.page_size,
1065 }
1066 }
1067}
1068
1069impl<A: Referable> fmt::Display for ActorMeshRef<A> {
1070 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1071 write!(f, "{}:{}@{}", self.id, A::typename(), self.proc_mesh)
1072 }
1073}
1074
1075impl<A: Referable> PartialEq for ActorMeshRef<A> {
1076 fn eq(&self, other: &Self) -> bool {
1077 self.proc_mesh == other.proc_mesh && self.id == other.id
1078 }
1079}
1080impl<A: Referable> Eq for ActorMeshRef<A> {}
1081
1082impl<A: Referable> Hash for ActorMeshRef<A> {
1083 fn hash<H: Hasher>(&self, state: &mut H) {
1084 self.proc_mesh.hash(state);
1085 self.id.hash(state);
1086 }
1087}
1088
1089impl<A: Referable> fmt::Debug for ActorMeshRef<A> {
1090 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1091 f.debug_struct("ActorMeshRef")
1092 .field("proc_mesh", &self.proc_mesh)
1093 .field("id", &self.id)
1094 .field("page_size", &self.page_size)
1095 .finish_non_exhaustive() }
1097}
1098
1099impl<A: Referable> Serialize for ActorMeshRef<A> {
1101 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1102 where
1103 S: Serializer,
1104 {
1105 (&self.proc_mesh, &self.id, &self.controller).serialize(serializer)
1106 }
1107}
1108
1109impl<'de, A: Referable> Deserialize<'de> for ActorMeshRef<A> {
1111 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1112 where
1113 D: Deserializer<'de>,
1114 {
1115 let (proc_mesh, id, controller) = <(
1116 ProcMeshRef,
1117 ActorMeshId,
1118 Option<ActorRef<ActorMeshController<A>>>,
1119 )>::deserialize(deserializer)?;
1120 Ok(ActorMeshRef::with_page_size(
1121 id,
1122 proc_mesh,
1123 DEFAULT_PAGE,
1124 controller,
1125 ))
1126 }
1127}
1128
1129impl<A: Referable> view::Ranked for ActorMeshRef<A> {
1130 type Item = ActorRef<A>;
1131
1132 #[inline]
1133 fn region(&self) -> &Region {
1134 view::Ranked::region(&self.proc_mesh)
1135 }
1136
1137 #[inline]
1138 fn get(&self, rank: usize) -> Option<&Self::Item> {
1139 self.materialize(rank)
1140 }
1141}
1142
1143impl<A: Referable> view::RankedSliceable for ActorMeshRef<A> {
1144 fn sliced(&self, region: Region) -> Self {
1145 debug_assert!(region.is_subset(view::Ranked::region(self)));
1150 let proc_mesh = self.proc_mesh.subset(region).unwrap();
1151 Self {
1152 proc_mesh,
1153 id: self.id.clone(),
1154 controller: self.controller.clone(),
1155 health_state: self.health_state.clone(),
1156 receiver: ActorLocal::new(),
1157 pages: OnceCell::new(),
1158 page_size: self.page_size,
1159 }
1160 }
1161}
1162
1163#[cfg(all(test, fbcode_build))]
1164mod tests {
1165
1166 use std::collections::HashSet;
1167 use std::ops::Deref;
1168
1169 use hyperactor::Endpoint as _;
1170 use hyperactor::actor::ActorErrorKind;
1171 use hyperactor::actor::ActorStatus;
1172 use hyperactor::context::Mailbox as _;
1173 use hyperactor::id::Label;
1174 use hyperactor::mailbox;
1175 use ndslice::Extent;
1176 use ndslice::ViewExt;
1177 use ndslice::extent;
1178 use ndslice::view::Ranked;
1179 use timed_test::assert_no_process_leak;
1180 use timed_test::async_timed_test;
1181 use tokio::time::Duration;
1182
1183 use super::ActorMesh;
1184 use crate::ActorMeshRef;
1185 use crate::ProcMesh;
1186 use crate::host_mesh::GET_PROC_STATE_MAX_IDLE;
1187 use crate::host_mesh::PROC_SPAWN_MAX_IDLE;
1188 use crate::mesh_controller::SUPERVISION_POLL_FREQUENCY;
1189 use crate::mesh_id::ActorMeshId;
1190 use crate::proc_mesh::ACTOR_SPAWN_MAX_IDLE;
1191 use crate::proc_mesh::GET_ACTOR_STATE_MAX_IDLE;
1192 use crate::supervision::MeshFailure;
1193 use crate::testactor;
1194 use crate::testing;
1195
1196 #[test]
1197 fn test_actor_mesh_ref_is_send_and_sync() {
1198 fn assert_send_sync<T: Send + Sync>() {}
1199 assert_send_sync::<ActorMeshRef<()>>();
1200 }
1201
1202 #[tokio::test]
1203 async fn test_actor_mesh_ref_lazy_materialization() {
1204 let instance = testing::instance();
1206 let mut hm = testing::host_mesh(2).await;
1209 let pm: ProcMesh = hm
1210 .spawn(instance, "test", extent!(gpus = 2), None, None)
1211 .await
1212 .unwrap();
1213 let am: ActorMesh<testactor::TestActor> = pm.spawn(instance, "test", &()).await.unwrap();
1214
1215 let page_size = 2;
1219 let amr: ActorMeshRef<testactor::TestActor> =
1220 ActorMeshRef::with_page_size(am.id.clone(), pm.clone(), page_size, None);
1221 assert_eq!(amr.extent(), extent!(hosts = 2, gpus = 2));
1222 assert_eq!(amr.region().num_ranks(), 4);
1223
1224 let p0_a = amr.get(0).expect("rank 0 exists") as *const _;
1226 let p0_b = amr.get(0).expect("rank 0 exists") as *const _;
1227 assert_eq!(p0_a, p0_b, "same rank should return same cached pointer");
1228
1229 let p1_a = amr.get(1).expect("rank 1 exists") as *const _;
1231 let p1_b = amr.get(1).expect("rank 1 exists") as *const _;
1232 assert_eq!(p1_a, p1_b, "same rank should return same cached pointer");
1233 assert_ne!(p0_a, p1_a, "different ranks have different cache slots");
1236
1237 let p2_a = amr.get(2).expect("rank 2 exists") as *const _;
1239 let p2_b = amr.get(2).expect("rank 2 exists") as *const _;
1240 assert_eq!(p2_a, p2_b, "same rank should return same cached pointer");
1241 assert_ne!(p0_a, p2_a, "different pages have different cache slots");
1242
1243 let amr_clone = amr.clone();
1245 let orig_id_0 = amr.get(0).unwrap().actor_addr().clone();
1246 let clone_id_0 = amr_clone.get(0).unwrap().actor_addr().clone();
1247 assert_eq!(orig_id_0, clone_id_0, "clone preserves identity");
1248 let p0_clone = amr_clone.get(0).unwrap() as *const _;
1249 assert_ne!(
1250 p0_a, p0_clone,
1251 "cloned ActorMeshRef has a fresh cache (different pointer)"
1252 );
1253
1254 let sliced = amr.range("hosts", 0..2).expect("slice should be valid"); assert_eq!(sliced.region().num_ranks(), 4);
1258 let sp0_a = sliced.get(0).unwrap() as *const _;
1260 let sp0_b = sliced.get(0).unwrap() as *const _;
1261 assert_eq!(sp0_a, sp0_b, "sliced view has its own cache slot per rank");
1262 let sp2 = sliced.get(2).unwrap() as *const _;
1265 assert_ne!(sp0_a, sp2, "sliced view crosses its own page boundary");
1266
1267 let mut set = HashSet::new();
1270 set.insert(amr.clone());
1271 set.insert(amr.clone());
1272 assert_eq!(set.len(), 1, "cache state must not affect Hash/Eq");
1273
1274 let (port, mut rx) = mailbox::open_port(instance);
1277 amr.get(0)
1280 .expect("rank 0 exists")
1281 .post(instance, testactor::GetActorId(port.bind()));
1282 amr.get(3)
1283 .expect("rank 3 exists")
1284 .post(instance, testactor::GetActorId(port.bind()));
1285 let id_a = tokio::time::timeout(Duration::from_secs(3), rx.recv())
1286 .await
1287 .expect("timed out waiting for first reply")
1288 .expect("channel closed before first reply");
1289 let id_b = tokio::time::timeout(Duration::from_secs(3), rx.recv())
1290 .await
1291 .expect("timed out waiting for second reply")
1292 .expect("channel closed before second reply");
1293 assert_ne!(id_a, id_b, "two different ranks responded");
1294
1295 let _ = hm.shutdown(instance).await;
1296 }
1297
1298 #[async_timed_test(timeout_secs = 300)]
1299 async fn test_actor_states_with_panic() {
1300 hyperactor_telemetry::initialize_logging_for_test();
1301
1302 let instance = testing::instance();
1303 let config = hyperactor_config::global::lock();
1304 let _proc_spawn = config.override_key(PROC_SPAWN_MAX_IDLE, Duration::from_secs(120));
1305 let _actor_spawn = config.override_key(ACTOR_SPAWN_MAX_IDLE, Duration::from_secs(120));
1306 let _host_spawn = config.override_key(
1307 hyperactor::config::HOST_SPAWN_READY_TIMEOUT,
1308 Duration::from_secs(120),
1309 );
1310
1311 let (supervision_port, mut supervision_receiver) = instance.open_port::<MeshFailure>();
1313 let supervisor = supervision_port.bind();
1314 let num_replicas = 1;
1315 let mut hm = testing::host_mesh(num_replicas).await;
1316 let proc_mesh = hm
1317 .spawn(instance, "test", Extent::unity(), None, None)
1318 .await
1319 .unwrap();
1320 let child_name = ActorMeshId::instance(Label::new("child").unwrap());
1321
1322 let actor_mesh: ActorMesh<testactor::WrapperActor> = proc_mesh
1325 .spawn(
1326 instance,
1327 "wrapper",
1328 &(proc_mesh.deref().clone(), supervisor, child_name.clone()),
1329 )
1330 .await
1331 .unwrap();
1332
1333 actor_mesh
1335 .cast(
1336 instance,
1337 testactor::CauseSupervisionEvent {
1338 kind: testactor::SupervisionEventType::Panic,
1339 send_to_children: true,
1340 },
1341 )
1342 .unwrap();
1343
1344 let (failure_port, mut failure_receiver) = instance.open_port::<Option<MeshFailure>>();
1354 actor_mesh
1355 .cast(
1356 instance,
1357 testactor::NextSupervisionFailure(failure_port.bind()),
1358 )
1359 .unwrap();
1360 let failure = failure_receiver
1361 .recv()
1362 .await
1363 .unwrap()
1364 .expect("no supervision event found on ref from wrapper actor");
1365 let check_failure = move |failure: MeshFailure| {
1366 assert_eq!(failure.actor_mesh_name, Some(child_name.to_string()));
1367 assert!(
1368 failure
1369 .event
1370 .actor_id
1371 .label()
1372 .unwrap()
1373 .as_str()
1374 .starts_with(child_name.label().unwrap().as_str())
1375 );
1376 if let ActorStatus::Failed(ActorErrorKind::Generic(msg)) = &failure.event.actor_status {
1377 assert!(msg.contains("panic"), "{}", msg);
1378 assert!(msg.contains("for testing"), "{}", msg);
1379 } else {
1380 panic!("actor status is not failed: {}", failure.event.actor_status);
1381 }
1382 };
1383 check_failure(failure);
1384
1385 for _ in 0..num_replicas {
1389 let failure =
1390 tokio::time::timeout(Duration::from_secs(20), supervision_receiver.recv())
1391 .await
1392 .expect("timeout")
1393 .unwrap();
1394 check_failure(failure);
1395 }
1396
1397 let _ = hm.shutdown(instance).await;
1398 }
1399
1400 #[assert_no_process_leak]
1401 #[async_timed_test(timeout_secs = 300)]
1402 async fn test_actor_states_with_process_exit() {
1403 hyperactor_telemetry::initialize_logging_for_test();
1404
1405 let config = hyperactor_config::global::lock();
1406 let _poll = config.override_key(SUPERVISION_POLL_FREQUENCY, Duration::from_secs(1));
1407 let _guard = config.override_key(GET_ACTOR_STATE_MAX_IDLE, Duration::from_secs(1));
1408 let _proc_guard = config.override_key(GET_PROC_STATE_MAX_IDLE, Duration::from_secs(1));
1409 let _proc_spawn = config.override_key(PROC_SPAWN_MAX_IDLE, Duration::from_secs(120));
1410 let _host_spawn = config.override_key(
1411 hyperactor::config::HOST_SPAWN_READY_TIMEOUT,
1412 Duration::from_secs(120),
1413 );
1414
1415 let instance = testing::instance();
1416 let (supervision_port, mut supervision_receiver) = instance.open_port::<MeshFailure>();
1418 let supervisor = supervision_port.bind();
1419 let num_replicas = 1;
1420 let mut hm = testing::host_mesh(num_replicas).await;
1421 let proc_mesh = hm
1422 .spawn(instance, "test", Extent::unity(), None, None)
1423 .await
1424 .unwrap();
1425 let mut second_hm = testing::host_mesh(num_replicas).await;
1426 let second_proc_mesh = second_hm
1427 .spawn(instance, "test2", Extent::unity(), None, None)
1428 .await
1429 .unwrap();
1430 let child_name = ActorMeshId::instance(Label::new("child").unwrap());
1431
1432 let actor_mesh: ActorMesh<testactor::WrapperActor> = proc_mesh
1435 .spawn(
1436 instance,
1437 "wrapper",
1438 &(
1439 second_proc_mesh.deref().clone(),
1442 supervisor,
1443 child_name.clone(),
1444 ),
1445 )
1446 .await
1447 .unwrap();
1448
1449 actor_mesh
1450 .cast(
1451 instance,
1452 testactor::CauseSupervisionEvent {
1453 kind: testactor::SupervisionEventType::ProcessExit(1),
1454 send_to_children: true,
1455 },
1456 )
1457 .unwrap();
1458
1459 let (failure_port, mut failure_receiver) = instance.open_port::<Option<MeshFailure>>();
1461 actor_mesh
1462 .cast(
1463 instance,
1464 testactor::NextSupervisionFailure(failure_port.bind()),
1465 )
1466 .unwrap();
1467 let failure = failure_receiver
1468 .recv()
1469 .await
1470 .unwrap()
1471 .expect("no supervision event found on ref from wrapper actor");
1472
1473 let check_failure = move |failure: MeshFailure| {
1474 assert_eq!(failure.actor_mesh_name, Some(child_name.to_string()));
1475 assert!(
1476 failure
1477 .event
1478 .actor_id
1479 .label()
1480 .unwrap()
1481 .as_str()
1482 .starts_with(child_name.label().unwrap().as_str())
1483 );
1484 if let ActorStatus::Failed(ActorErrorKind::Generic(msg)) = &failure.event.actor_status {
1485 assert!(msg.contains("exited with non-zero code 1"), "{}", msg);
1486 } else {
1487 panic!("actor status is not failed: {}", failure.event.actor_status);
1488 }
1489 };
1490 check_failure(failure);
1491
1492 for _ in 0..num_replicas {
1494 let failure =
1495 tokio::time::timeout(Duration::from_secs(20), supervision_receiver.recv())
1496 .await
1497 .expect("timeout")
1498 .unwrap();
1499 check_failure(failure);
1500 }
1501
1502 let _ = second_hm.shutdown(instance).await;
1503 let _ = hm.shutdown(instance).await;
1504 }
1505
1506 #[async_timed_test(timeout_secs = 300)]
1507 async fn test_actor_states_on_sliced_mesh() {
1508 hyperactor_telemetry::initialize_logging_for_test();
1509
1510 let instance = testing::instance();
1511 let (supervision_port, mut supervision_receiver) = instance.open_port::<MeshFailure>();
1513 let supervisor = supervision_port.bind();
1514 let (mut hm, _actor_mesh, sliced, sliced_replicas, child_name) = {
1515 let config = hyperactor_config::global::lock();
1516 let _proc_spawn = config.override_key(PROC_SPAWN_MAX_IDLE, Duration::from_secs(120));
1517 let _actor_spawn = config.override_key(ACTOR_SPAWN_MAX_IDLE, Duration::from_secs(120));
1518 let _host_spawn = config.override_key(
1519 hyperactor::config::HOST_SPAWN_READY_TIMEOUT,
1520 Duration::from_secs(120),
1521 );
1522 let num_replicas = 2;
1523 let hm = testing::host_mesh(num_replicas).await;
1524 let proc_mesh = hm
1525 .spawn(instance, "test", Extent::unity(), None, None)
1526 .await
1527 .unwrap();
1528 let child_name = ActorMeshId::instance(Label::new("child").unwrap());
1529
1530 let actor_mesh: ActorMesh<testactor::WrapperActor> = proc_mesh
1533 .spawn(
1534 instance,
1535 "wrapper",
1536 &(proc_mesh.deref().clone(), supervisor, child_name.clone()),
1537 )
1538 .await
1539 .unwrap();
1540 let sliced = actor_mesh
1541 .range("hosts", 1..2)
1542 .expect("slice should be valid");
1543 let sliced_replicas = sliced.len();
1544 (hm, actor_mesh, sliced, sliced_replicas, child_name)
1545 };
1546
1547 sliced
1549 .cast(
1550 instance,
1551 testactor::CauseSupervisionEvent {
1552 kind: testactor::SupervisionEventType::Panic,
1553 send_to_children: true,
1554 },
1555 )
1556 .unwrap();
1557
1558 for _ in 0..sliced_replicas {
1559 let supervision_message =
1560 tokio::time::timeout(Duration::from_secs(20), supervision_receiver.recv())
1561 .await
1562 .expect("timeout")
1563 .unwrap();
1564 let event = supervision_message.event;
1565 assert!(
1566 event
1567 .actor_id
1568 .label()
1569 .unwrap()
1570 .as_str()
1571 .starts_with(child_name.label().unwrap().as_str())
1572 );
1573 if let ActorStatus::Failed(ActorErrorKind::Generic(msg)) = &event.actor_status {
1574 assert!(msg.contains("panic"));
1575 assert!(msg.contains("for testing"));
1576 } else {
1577 panic!("actor status is not failed: {}", event.actor_status);
1578 }
1579 }
1580
1581 let _ = hm.shutdown(instance).await;
1582 }
1583
1584 async fn execute_cast(config: &hyperactor_config::global::ConfigLock) {
1585 let _guard = config.override_key(crate::bootstrap::MESH_BOOTSTRAP_ENABLE_PDEATHSIG, false);
1586 let _proc_spawn = config.override_key(PROC_SPAWN_MAX_IDLE, Duration::from_secs(60));
1587 let _host_spawn = config.override_key(
1588 hyperactor::config::HOST_SPAWN_READY_TIMEOUT,
1589 Duration::from_secs(60),
1590 );
1591
1592 let instance = testing::instance();
1593 let mut host_mesh = testing::host_mesh(2).await;
1594 let proc_mesh = host_mesh
1595 .spawn(instance, "test", Extent::unity(), None, None)
1596 .await
1597 .unwrap();
1598 let actor_mesh: ActorMesh<testactor::TestActor> =
1599 proc_mesh.spawn(instance, "test", &()).await.unwrap();
1600
1601 let (cast_info, mut cast_info_rx) = instance.mailbox().open_port();
1602 actor_mesh
1603 .cast(
1604 instance,
1605 testactor::GetCastInfo {
1606 cast_info: cast_info.bind(),
1607 },
1608 )
1609 .unwrap();
1610
1611 let mut point_to_actor: HashSet<_> = actor_mesh.iter().collect();
1612 while !point_to_actor.is_empty() {
1613 let (point, origin_actor_ref, sender_actor_id) = cast_info_rx.recv().await.unwrap();
1614 let key = (point, origin_actor_ref);
1615 assert!(
1616 point_to_actor.remove(&key),
1617 "key {:?} not present or removed twice",
1618 key
1619 );
1620 assert_eq!(&sender_actor_id, instance.self_addr());
1621 }
1622
1623 let _ = host_mesh.shutdown(instance).await;
1624 }
1625
1626 #[async_timed_test(timeout_secs = 60)]
1627 async fn test_cast_with_selection_v1_fallback() {
1628 use hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER;
1629 use hyperactor_mesh_macros::sel;
1630 use ndslice::Selection;
1631
1632 let config = hyperactor_config::global::lock();
1633 let _guard = config.override_key(crate::bootstrap::MESH_BOOTSTRAP_ENABLE_PDEATHSIG, false);
1634 let _v1 = config.override_key(crate::comm::ENABLE_NATIVE_V1_CASTING, true);
1635 let _reorder = config.override_key(ENABLE_DEST_ACTOR_REORDERING_BUFFER, true);
1636 let _proc_spawn = config.override_key(PROC_SPAWN_MAX_IDLE, Duration::from_secs(60));
1637 let _host_spawn = config.override_key(
1638 hyperactor::config::HOST_SPAWN_READY_TIMEOUT,
1639 Duration::from_secs(60),
1640 );
1641
1642 let instance = testing::instance();
1643 let mut host_mesh = testing::host_mesh(2).await;
1644 let proc_mesh = host_mesh
1645 .spawn(instance, "test", Extent::unity(), None, None)
1646 .await
1647 .unwrap();
1648 let actor_mesh: ActorMesh<testactor::TestActor> =
1649 proc_mesh.spawn(instance, "test", &()).await.unwrap();
1650
1651 let (cast_info, mut cast_info_rx) = instance.mailbox().open_port();
1653 actor_mesh
1654 .cast_for_tensor_engine_only_do_not_use(
1655 instance,
1656 sel!(0:1),
1657 testactor::GetCastInfo {
1658 cast_info: cast_info.bind(),
1659 },
1660 )
1661 .unwrap();
1662
1663 let (point, _actor_ref, _sender) = cast_info_rx.recv().await.unwrap();
1664 let received_ranks = HashSet::from([point.rank()]);
1665 assert_eq!(received_ranks, HashSet::from([0]));
1666
1667 let (cast_info2, mut cast_info_rx2) = instance.mailbox().open_port();
1669 actor_mesh
1670 .cast(
1671 instance,
1672 testactor::GetCastInfo {
1673 cast_info: cast_info2.bind(),
1674 },
1675 )
1676 .unwrap();
1677
1678 let mut all_ranks: HashSet<usize> = HashSet::new();
1679 for _ in 0..2 {
1680 let (point, _actor_ref, _sender) = cast_info_rx2.recv().await.unwrap();
1681 all_ranks.insert(point.rank());
1682 }
1683 assert_eq!(all_ranks, HashSet::from([0, 1]));
1684
1685 let _ = host_mesh.shutdown(instance).await;
1686 }
1687
1688 #[async_timed_test(timeout_secs = 30)]
1689 async fn test_cast() {
1690 let config = hyperactor_config::global::lock();
1691 execute_cast(&config).await;
1692 }
1693
1694 #[async_timed_test(timeout_secs = 30)]
1695 async fn test_cast_p2p() {
1696 let config = hyperactor_config::global::lock();
1697 let _guard = config.override_key(crate::comm::ENABLE_NATIVE_V1_CASTING, true);
1698 let _guard2 = config.override_key(
1699 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1700 true,
1701 );
1702 let _guard3 = config.override_key(crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD, 1024);
1703 execute_cast(&config).await;
1704 }
1705 #[assert_no_process_leak]
1711 #[async_timed_test(timeout_secs = 60)]
1712 async fn test_undeliverable_message_return() {
1713 use hyperactor::mailbox::MessageEnvelope;
1714 use hyperactor::mailbox::Undeliverable;
1715 use hyperactor::testing::pingpong::PingPongActor;
1716 use hyperactor::testing::pingpong::PingPongMessage;
1717
1718 hyperactor_telemetry::initialize_logging_for_test();
1719
1720 let instance = testing::instance();
1721
1722 let (mut hm, proc_mesh) = {
1724 let config = hyperactor_config::global::lock();
1725 let _proc_spawn_guard =
1726 config.override_key(PROC_SPAWN_MAX_IDLE, Duration::from_secs(60));
1727 let _host_spawn_guard = config.override_key(
1728 hyperactor::config::HOST_SPAWN_READY_TIMEOUT,
1729 Duration::from_secs(60),
1730 );
1731 let hm = testing::host_mesh(2).await;
1732 let proc_mesh = hm
1733 .spawn(instance, "test", Extent::unity(), None, None)
1734 .await
1735 .unwrap();
1736 (hm, proc_mesh)
1737 };
1738
1739 let (undeliverable_port, mut undeliverable_rx) =
1741 instance.open_port::<Undeliverable<MessageEnvelope>>();
1742
1743 let ping_proc_mesh = proc_mesh.range("hosts", 0..1).unwrap();
1746 let pong_proc_mesh = proc_mesh.range("hosts", 1..2).unwrap();
1747
1748 let ping_mesh: ActorMesh<PingPongActor> = ping_proc_mesh
1749 .spawn(
1750 instance,
1751 "ping",
1752 &(Some(undeliverable_port.bind()), None, None),
1753 )
1754 .await
1755 .unwrap();
1756
1757 let mut pong_mesh: ActorMesh<PingPongActor> = pong_proc_mesh
1758 .spawn(instance, "pong", &(None, None, None))
1759 .await
1760 .unwrap();
1761
1762 let ping_handle = ping_mesh.values().next().unwrap();
1764 let pong_handle = pong_mesh.values().next().unwrap();
1765
1766 let (done_tx, done_rx) = instance.open_once_port();
1768 ping_handle.post(
1769 instance,
1770 PingPongMessage(2, pong_handle.clone(), done_tx.bind()),
1771 );
1772 assert!(
1773 done_rx.recv().await.unwrap(),
1774 "Initial ping-pong should work"
1775 );
1776
1777 pong_mesh
1779 .stop(instance, "test stop".to_string())
1780 .await
1781 .unwrap();
1782
1783 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1785
1786 let config = hyperactor_config::global::lock();
1788 let _guard = config.override_key(
1789 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1790 std::time::Duration::from_secs(5),
1791 );
1792
1793 let n = 100usize;
1795 for i in 1..=n {
1796 let ttl = 66 + i as u64; let (once_tx, _once_rx) = instance.open_once_port();
1798 ping_handle.post(
1799 instance,
1800 PingPongMessage(ttl, pong_handle.clone(), once_tx.bind()),
1801 );
1802 }
1803
1804 let mut count = 0;
1808 let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(10);
1809 while count < n && tokio::time::Instant::now() < deadline {
1810 match tokio::time::timeout(std::time::Duration::from_secs(1), undeliverable_rx.recv())
1811 .await
1812 {
1813 Ok(Ok(Undeliverable::Message(envelope))) => {
1814 let _: PingPongMessage = envelope.deserialized().unwrap();
1815 count += 1;
1816 }
1817 Ok(Ok(Undeliverable::Lost(_))) => break,
1818 Ok(Err(_)) => break, Err(_) => break, }
1821 }
1822
1823 assert_eq!(
1824 count, n,
1825 "Expected {} undeliverable messages, got {}",
1826 n, count
1827 );
1828
1829 let _ = hm.shutdown(instance).await;
1830 }
1831
1832 #[async_timed_test(timeout_secs = 30)]
1842 async fn test_actor_mesh_stop_timeout() {
1843 hyperactor_telemetry::initialize_logging_for_test();
1844
1845 let config = hyperactor_config::global::lock();
1849 let _proc_spawn = config.override_key(PROC_SPAWN_MAX_IDLE, Duration::from_secs(60));
1850 let _host_spawn = config.override_key(
1851 hyperactor::config::HOST_SPAWN_READY_TIMEOUT,
1852 Duration::from_secs(60),
1853 );
1854
1855 let instance = testing::instance();
1856
1857 let mut hm = testing::host_mesh(2).await;
1859 let proc_mesh = hm
1860 .spawn(instance, "test", Extent::unity(), None, None)
1861 .await
1862 .unwrap();
1863
1864 let mut sleep_mesh: ActorMesh<testactor::SleepActor> =
1867 proc_mesh.spawn(instance, "sleepers", &()).await.unwrap();
1868 let _guard = config.override_key(ACTOR_SPAWN_MAX_IDLE, std::time::Duration::from_secs(1));
1869
1870 for actor_ref in sleep_mesh.values() {
1875 actor_ref.post(instance, std::time::Duration::from_secs(5));
1876 }
1877
1878 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1880
1881 let expected_actors = sleep_mesh.values().count();
1883
1884 let stop_start = tokio::time::Instant::now();
1889 let result = sleep_mesh.stop(instance, "test stop".to_string()).await;
1890 let stop_duration = tokio::time::Instant::now().duration_since(stop_start);
1891
1892 match result {
1896 Ok(_) => {
1897 tracing::info!(
1898 "stop returned Ok for {} actors; their tokio tasks \
1899 may still be running until their handler yields",
1900 expected_actors
1901 );
1902 }
1903 Err(ref e) => {
1904 let err_str = format!("{:?}", e);
1905 assert!(
1906 err_str.contains("Timeout"),
1907 "Expected Timeout error, got: {:?}",
1908 e
1909 );
1910 }
1911 }
1912
1913 assert!(
1918 stop_duration < std::time::Duration::from_millis(4500),
1919 "Stop took {:?}, expected < 4.5s (controller should have given up waiting for Stopped)",
1920 stop_duration
1921 );
1922 assert!(
1923 stop_duration >= std::time::Duration::from_millis(900),
1924 "Stop took {:?}, expected >= 900ms (should have waited for the 1s idle timeout)",
1925 stop_duration
1926 );
1927
1928 let _ = hm.shutdown(instance).await;
1929 }
1930
1931 #[async_timed_test(timeout_secs = 60)]
1937 async fn test_actor_mesh_stop_graceful() {
1938 hyperactor_telemetry::initialize_logging_for_test();
1939
1940 let config = hyperactor_config::global::lock();
1941 let _proc_spawn = config.override_key(PROC_SPAWN_MAX_IDLE, Duration::from_secs(60));
1942 let _host_spawn = config.override_key(
1943 hyperactor::config::HOST_SPAWN_READY_TIMEOUT,
1944 Duration::from_secs(60),
1945 );
1946
1947 let instance = testing::instance();
1948
1949 let mut hm = testing::host_mesh(2).await;
1951 let proc_mesh = hm
1952 .spawn(instance, "test", Extent::unity(), None, None)
1953 .await
1954 .unwrap();
1955
1956 let mut actor_mesh: ActorMesh<testactor::TestActor> =
1959 proc_mesh.spawn(instance, "test_actors", &()).await.unwrap();
1960
1961 let mesh_ref = actor_mesh.deref().clone();
1964
1965 let expected_actors = actor_mesh.values().count();
1966 assert!(expected_actors > 0, "Should have spawned some actors");
1967
1968 let stop_start = tokio::time::Instant::now();
1970 let result = actor_mesh.stop(instance, "test stop".to_string()).await;
1971 let stop_duration = tokio::time::Instant::now().duration_since(stop_start);
1972
1973 assert!(
1975 result.is_ok(),
1976 "Stop should succeed for responsive actors, got: {:?}",
1977 result.err()
1978 );
1979
1980 assert!(
1984 stop_duration < std::time::Duration::from_secs(5),
1985 "Graceful stop took {:?}, expected < 5s (actors should stop quickly)",
1986 stop_duration
1987 );
1988
1989 tracing::info!(
1990 "Successfully stopped {} actors in {:?}",
1991 expected_actors,
1992 stop_duration
1993 );
1994
1995 let next_event = actor_mesh.next_supervision_event(instance).await.unwrap();
2001 assert_eq!(next_event.actor_mesh_name, Some(mesh_ref.id().to_string()));
2002 assert!(matches!(
2003 next_event.event.actor_status,
2004 ActorStatus::Stopped(_)
2005 ));
2006 let next_event = mesh_ref.next_supervision_event(instance).await.unwrap();
2009 assert_eq!(next_event.actor_mesh_name, Some(mesh_ref.id().to_string()));
2010 assert!(matches!(
2011 next_event.event.actor_status,
2012 ActorStatus::Stopped(_)
2013 ));
2014
2015 let _ = hm.shutdown(instance).await;
2016 }
2017}