1use std::collections::HashMap;
15use std::fmt;
16use std::hash::Hash;
17use std::hash::Hasher;
18use std::ops::Deref;
19use std::sync::Arc;
20use std::sync::OnceLock as OnceCell;
21use std::time::Duration;
22
23use hyperactor::ActorLocal;
24use hyperactor::RemoteHandles;
25use hyperactor::RemoteMessage;
26use hyperactor::actor::ActorStatus;
27use hyperactor::actor::Referable;
28use hyperactor::context;
29use hyperactor::mailbox::PortReceiver;
30use hyperactor::message::Castable;
31use hyperactor::message::IndexedErasedUnbound;
32use hyperactor::message::Unbound;
33use hyperactor::reference as hyperactor_reference;
34use hyperactor::supervision::ActorSupervisionEvent;
35use hyperactor_config::CONFIG;
36use hyperactor_config::ConfigAttr;
37use hyperactor_config::Flattrs;
38use hyperactor_config::attrs::declare_attrs;
39use hyperactor_mesh_macros::sel;
40use ndslice::Selection;
41use ndslice::ViewExt as _;
42use ndslice::view;
43use ndslice::view::Region;
44use ndslice::view::View;
45use serde::Deserialize;
46use serde::Deserializer;
47use serde::Serialize;
48use serde::Serializer;
49use tokio::sync::watch;
50
51use crate::CommActor;
52use crate::Error;
53use crate::Name;
54use crate::ProcMeshRef;
55use crate::ValueMesh;
56use crate::casting;
57use crate::comm::multicast;
58use crate::host_mesh::GET_PROC_STATE_MAX_IDLE;
59use crate::host_mesh::mesh_to_rankedvalues_with_default;
60use crate::mesh_controller::ActorMeshController;
61use crate::mesh_controller::SUPERVISION_POLL_FREQUENCY;
62use crate::mesh_controller::Subscribe;
63use crate::mesh_controller::Unsubscribe;
64use crate::proc_agent::ActorState;
65use crate::proc_mesh::GET_ACTOR_STATE_MAX_IDLE;
66use crate::reference::ActorMeshId;
67use crate::resource;
68use crate::supervision::MeshFailure;
69use crate::supervision::Unhealthy;
70
71declare_attrs! {
72 @meta(CONFIG = ConfigAttr::new(
81 Some("HYPERACTOR_MESH_SUPERVISION_WATCHDOG_TIMEOUT".to_string()),
82 Some("supervision_watchdog_timeout".to_string()),
83 ))
84 pub attr SUPERVISION_WATCHDOG_TIMEOUT: Duration = Duration::from_mins(2);
85}
86
87#[derive(Debug)]
92pub struct ActorMesh<A: Referable> {
93 proc_mesh: ProcMeshRef,
94 name: Name,
95 current_ref: ActorMeshRef<A>,
96 controller: Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
102}
103
104impl<A: Referable> ActorMesh<A> {
107 pub(crate) fn new(
108 proc_mesh: ProcMeshRef,
109 name: Name,
110 controller: Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
111 ) -> Self {
112 let current_ref = ActorMeshRef::with_page_size(
113 name.clone(),
114 proc_mesh.clone(),
115 DEFAULT_PAGE,
116 controller.clone(),
117 );
118
119 Self {
120 proc_mesh,
121 name,
122 current_ref,
123 controller,
124 }
125 }
126
127 pub fn name(&self) -> &Name {
128 &self.name
129 }
130
131 pub(crate) fn set_controller(
132 &mut self,
133 controller: Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
134 ) {
135 self.controller = controller.clone();
136 self.current_ref.set_controller(controller);
137 }
138
139 pub async fn stop(&mut self, cx: &impl context::Actor, reason: String) -> crate::Result<()> {
141 if let Some(controller) = self.controller.take() {
147 controller
149 .send(
150 cx,
151 resource::Stop {
152 name: self.name.clone(),
153 reason,
154 },
155 )
156 .map_err(|e| {
157 crate::Error::SendingError(controller.actor_id().clone(), Box::new(e))
158 })?;
159 let region = ndslice::view::Ranked::region(&self.current_ref);
160 let num_ranks = region.num_ranks();
161 let (port, mut rx) = cx.mailbox().open_port();
163
164 controller
165 .send(
166 cx,
167 resource::GetState::<resource::mesh::State<()>> {
168 name: self.name.clone(),
169 reply: port.bind(),
170 },
171 )
172 .map_err(|e| {
173 crate::Error::SendingError(controller.actor_id().clone(), Box::new(e))
174 })?;
175
176 let statuses = rx.recv().await?;
177 if let Some(state) = &statuses.state {
178 let all_stopped = state.statuses.values().all(|s| s.is_terminating());
182 if all_stopped {
183 Ok(())
184 } else {
185 let legacy = mesh_to_rankedvalues_with_default(
186 &state.statuses,
187 resource::Status::NotExist,
188 resource::Status::is_not_exist,
189 num_ranks,
190 );
191 Err(Error::ActorStopError { statuses: legacy })
192 }
193 } else {
194 Err(Error::Other(anyhow::anyhow!(
195 "non-existent state in GetState reply from controller: {}",
196 controller.actor_id()
197 )))
198 }?;
199 let mut entry = self.health_state.entry(cx).or_default();
201 let health_state = entry.get_mut();
202 health_state.unhealthy_event = Some(Unhealthy::StreamClosed(MeshFailure {
203 actor_mesh_name: Some(self.name().to_string()),
204 event: ActorSupervisionEvent::new(
205 ndslice::view::Ranked::get(&self.current_ref, 0)
207 .unwrap()
208 .actor_id()
209 .clone(),
210 None,
211 ActorStatus::Stopped("mesh stopped".to_string()),
212 None,
213 ),
214 crashed_ranks: vec![],
215 }));
216 }
217 self.current_ref.controller.take();
220 Ok(())
221 }
222}
223
224impl<A: Referable> fmt::Display for ActorMesh<A> {
225 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226 write!(f, "{}", self.current_ref)
227 }
228}
229
230impl<A: Referable> Deref for ActorMesh<A> {
231 type Target = ActorMeshRef<A>;
232
233 fn deref(&self) -> &Self::Target {
234 &self.current_ref
235 }
236}
237
238impl<A: Referable> Clone for ActorMesh<A> {
241 fn clone(&self) -> Self {
242 Self {
243 proc_mesh: self.proc_mesh.clone(),
244 name: self.name.clone(),
245 current_ref: self.current_ref.clone(),
246 controller: self.controller.clone(),
247 }
248 }
249}
250
251impl<A: Referable> Drop for ActorMesh<A> {
252 fn drop(&mut self) {
253 tracing::info!(
254 name = "ActorMeshStatus",
255 actor_name = %self.name,
256 status = "Dropped",
257 );
258 }
259}
260
261const DEFAULT_PAGE: usize = 1024;
265
266struct Page<A: Referable> {
268 slots: Box<[OnceCell<hyperactor_reference::ActorRef<A>>]>,
269}
270
271impl<A: Referable> Page<A> {
272 fn new(len: usize) -> Self {
273 let mut v = Vec::with_capacity(len);
274 for _ in 0..len {
275 v.push(OnceCell::new());
276 }
277 Self {
278 slots: v.into_boxed_slice(),
279 }
280 }
281}
282
283#[derive(Default)]
284struct HealthState {
285 unhealthy_event: Option<Unhealthy>,
286 crashed_ranks: HashMap<usize, ActorSupervisionEvent>,
287}
288
289impl std::fmt::Debug for HealthState {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 f.debug_struct("HealthState")
292 .field("unhealthy_event", &self.unhealthy_event)
293 .field("crashed_ranks", &self.crashed_ranks)
294 .finish()
295 }
296}
297
298#[derive(Clone)]
299enum MessageOrFailure<M: Send + Sync + Clone + Default + 'static> {
300 Message(M),
301 Failure(String),
304 Timeout,
305}
306
307impl<M: Send + Sync + Clone + Default + 'static> Default for MessageOrFailure<M> {
308 fn default() -> Self {
309 Self::Message(M::default())
310 }
311}
312
313fn into_watch<M: Send + Sync + Clone + Default + 'static>(
317 mut rx: PortReceiver<M>,
318) -> watch::Receiver<MessageOrFailure<M>> {
319 let (sender, receiver) = watch::channel(MessageOrFailure::<M>::default());
320 let timeout = hyperactor_config::global::get(SUPERVISION_WATCHDOG_TIMEOUT);
328 let poll_frequency = hyperactor_config::global::get(SUPERVISION_POLL_FREQUENCY);
329 let get_actor_state_max_idle = hyperactor_config::global::get(GET_ACTOR_STATE_MAX_IDLE);
330 let get_proc_state_max_idle = hyperactor_config::global::get(GET_PROC_STATE_MAX_IDLE);
331 let total_time = poll_frequency + get_actor_state_max_idle + get_proc_state_max_idle;
332 if timeout < total_time {
333 tracing::warn!(
334 "HYPERACTOR_MESH_SUPERVISION_WATCHDOG_TIMEOUT={} is too short. It should be >= {} (SUPERVISION_POLL_FREQUENCY={} + GET_ACTOR_STATE_MAX_IDLE={} + GET_PROC_STATE_MAX_IDLE={})",
335 humantime::format_duration(timeout),
336 humantime::format_duration(total_time),
337 humantime::format_duration(poll_frequency),
338 humantime::format_duration(get_actor_state_max_idle),
339 humantime::format_duration(get_proc_state_max_idle),
340 );
341 }
342 tokio::spawn(async move {
343 loop {
344 let message = match tokio::time::timeout(timeout, rx.recv()).await {
345 Ok(Ok(msg)) => MessageOrFailure::Message(msg),
346 Ok(Err(e)) => MessageOrFailure::Failure(e.to_string()),
347 Err(_) => MessageOrFailure::Timeout,
348 };
349 let is_failure = matches!(
350 message,
351 MessageOrFailure::Failure(_) | MessageOrFailure::Timeout
352 );
353 if sender.send(message).is_err() {
354 break;
356 }
357 if is_failure {
358 break;
360 }
361 }
362 });
363 receiver
364}
365
366pub struct ActorMeshRef<A: Referable> {
368 proc_mesh: ProcMeshRef,
369 name: Name,
370 controller: Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
377
378 health_state: ActorLocal<HealthState>,
382 receiver: ActorLocal<
387 Arc<
388 tokio::sync::Mutex<(
389 hyperactor_reference::PortRef<Option<MeshFailure>>,
390 watch::Receiver<MessageOrFailure<Option<MeshFailure>>>,
391 )>,
392 >,
393 >,
394 pages: OnceCell<Vec<OnceCell<Box<Page<A>>>>>,
404 page_size: usize,
406}
407
408impl<A: Referable> ActorMeshRef<A> {
409 #[allow(clippy::result_large_err)]
411 pub fn cast<M>(&self, cx: &impl context::Actor, message: M) -> crate::Result<()>
412 where
413 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
414 M: Castable + RemoteMessage + Clone, {
416 self.cast_with_selection(cx, sel!(*), message)
417 }
418
419 #[allow(clippy::result_large_err)]
424 pub fn cast_for_tensor_engine_only_do_not_use<M>(
425 &self,
426 cx: &impl context::Actor,
427 sel: Selection,
428 message: M,
429 ) -> crate::Result<()>
430 where
431 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
432 M: Castable + RemoteMessage + Clone, {
434 self.cast_with_selection(cx, sel, message)
435 }
436
437 #[allow(clippy::result_large_err)]
438 fn cast_with_selection<M>(
439 &self,
440 cx: &impl context::Actor,
441 sel: Selection,
442 message: M,
443 ) -> crate::Result<()>
444 where
445 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
446 M: Castable + RemoteMessage + Clone, {
448 {
451 let health_state = self.health_state.entry(cx).or_default();
452 let health_state = health_state.get();
453 match &health_state.unhealthy_event {
454 Some(Unhealthy::StreamClosed(failure)) => {
455 return Err(crate::Error::Supervision(Box::new(failure.clone())));
456 }
457 Some(Unhealthy::Crashed(failure)) => {
458 return Err(crate::Error::Supervision(Box::new(failure.clone())));
459 }
460 None => {
461 assert!(health_state.crashed_ranks.is_empty());
464 }
465 }
466 }
467
468 hyperactor_telemetry::notify_sent_message(hyperactor_telemetry::SentMessageEvent {
469 timestamp: std::time::SystemTime::now(),
470 sender_actor_id: hyperactor_telemetry::hash_to_u64(cx.mailbox().actor_id()),
471 actor_mesh_id: hyperactor_telemetry::hash_to_u64(&self.name.to_string()),
472 view_json: serde_json::to_string(view::Ranked::region(self)).unwrap_or_default(),
473 shape_json: {
474 let shape: ndslice::Shape = view::Ranked::region(self).into();
475 serde_json::to_string(&shape).unwrap_or_default()
476 },
477 });
478
479 if let Some(root_comm_actor) = self.proc_mesh.root_comm_actor() {
481 self.cast_v0(cx, message, sel, root_comm_actor)
482 } else {
483 for (point, actor) in self.iter() {
484 let create_rank = point.rank();
485 let mut headers = Flattrs::new();
486 multicast::set_cast_info_on_headers(
487 &mut headers,
488 point,
489 cx.instance().self_id().clone(),
490 );
491
492 let mut unbound = Unbound::try_from_message(message.clone())
495 .map_err(|e| Error::CastingError(self.name.clone(), e))?;
496 unbound
497 .visit_mut::<resource::Rank>(|resource::Rank(rank)| {
498 *rank = Some(create_rank);
499 Ok(())
500 })
501 .map_err(|e| Error::CastingError(self.name.clone(), e))?;
502 let rebound_message = unbound
503 .bind()
504 .map_err(|e| Error::CastingError(self.name.clone(), e))?;
505 actor
506 .send_with_headers(cx, headers, rebound_message)
507 .map_err(|e| Error::SendingError(actor.actor_id().clone(), Box::new(e)))?;
508 }
509 Ok(())
510 }
511 }
512
513 #[allow(clippy::result_large_err)]
514 fn cast_v0<M>(
515 &self,
516 cx: &impl context::Actor,
517 message: M,
518 sel: Selection,
519 root_comm_actor: &hyperactor_reference::ActorRef<CommActor>,
520 ) -> crate::Result<()>
521 where
522 A: RemoteHandles<IndexedErasedUnbound<M>>,
523 M: Castable + RemoteMessage + Clone, {
525 let cast_mesh_shape = view::Ranked::region(self).into();
526 let actor_mesh_id = ActorMeshId(self.name.clone());
527 match &self.proc_mesh.root_region {
528 Some(root_region) => {
529 let root_mesh_shape = root_region.into();
530 casting::cast_to_sliced_mesh::<A, M>(
531 cx,
532 actor_mesh_id,
533 root_comm_actor,
534 &sel,
535 message,
536 &cast_mesh_shape,
537 &root_mesh_shape,
538 )
539 .map_err(|e| Error::CastingError(self.name.clone(), e.into()))
540 }
541 None => casting::actor_mesh_cast::<A, M>(
542 cx,
543 actor_mesh_id,
544 root_comm_actor,
545 sel,
546 &cast_mesh_shape,
547 &cast_mesh_shape,
548 message,
549 )
550 .map_err(|e| Error::CastingError(self.name.clone(), e.into())),
551 }
552 }
553
554 #[allow(clippy::result_large_err)]
560 pub async fn actor_states(
561 &self,
562 cx: &impl context::Actor,
563 ) -> crate::Result<ValueMesh<resource::State<ActorState>>> {
564 self.actor_states_with_keepalive(cx, None).await
565 }
566
567 #[allow(clippy::result_large_err)]
568 pub(crate) async fn actor_states_with_keepalive(
569 &self,
570 cx: &impl context::Actor,
571 keepalive: Option<std::time::SystemTime>,
572 ) -> crate::Result<ValueMesh<resource::State<ActorState>>> {
573 self.proc_mesh
574 .actor_states_with_keepalive(cx, self.name.clone(), keepalive)
575 .await
576 }
577
578 pub(crate) fn new(
579 name: Name,
580 proc_mesh: ProcMeshRef,
581 controller: Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
582 ) -> Self {
583 Self::with_page_size(name, proc_mesh, DEFAULT_PAGE, controller)
584 }
585
586 pub fn name(&self) -> &Name {
587 &self.name
588 }
589
590 pub(crate) fn with_page_size(
591 name: Name,
592 proc_mesh: ProcMeshRef,
593 page_size: usize,
594 controller: Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
595 ) -> Self {
596 Self {
597 proc_mesh,
598 name,
599 controller,
600 health_state: ActorLocal::new(),
601 receiver: ActorLocal::new(),
602 pages: OnceCell::new(),
603 page_size: page_size.max(1),
604 }
605 }
606
607 pub fn proc_mesh(&self) -> &ProcMeshRef {
608 &self.proc_mesh
609 }
610
611 #[inline]
612 fn len(&self) -> usize {
613 view::Ranked::region(&self.proc_mesh).num_ranks()
614 }
615
616 pub fn controller(&self) -> &Option<hyperactor_reference::ActorRef<ActorMeshController<A>>> {
617 &self.controller
618 }
619
620 fn set_controller(
621 &mut self,
622 controller: Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
623 ) {
624 self.controller = controller;
625 }
626
627 fn ensure_pages(&self) -> &Vec<OnceCell<Box<Page<A>>>> {
628 let n = self.len().div_ceil(self.page_size); self.pages
630 .get_or_init(|| (0..n).map(|_| OnceCell::new()).collect())
631 }
632
633 fn materialize(&self, rank: usize) -> Option<&hyperactor_reference::ActorRef<A>> {
634 let len = self.len();
635 if rank >= len {
636 return None;
637 }
638 let p = self.page_size;
639 let page_ix = rank / p;
640 let local_ix = rank % p;
641
642 let pages = self.ensure_pages();
643 let page = pages[page_ix].get_or_init(|| {
644 let base = page_ix * p;
646 let remaining = len - base;
647 let page_len = remaining.min(p);
648 Box::new(Page::<A>::new(page_len))
649 });
650
651 Some(page.slots[local_ix].get_or_init(|| {
652 debug_assert!(rank < self.len(), "rank must be within [0, len)");
659 debug_assert!(
660 ndslice::view::Ranked::get(&self.proc_mesh, rank).is_some(),
661 "proc_mesh must be dense/aligned with this view"
662 );
663 let proc_ref =
664 ndslice::view::Ranked::get(&self.proc_mesh, rank).expect("rank in-bounds");
665 proc_ref.attest(&self.name)
666 }))
667 }
668
669 fn init_supervision_receiver(
670 controller: &hyperactor_reference::ActorRef<ActorMeshController<A>>,
671 cx: &impl context::Actor,
672 ) -> (
673 hyperactor_reference::PortRef<Option<MeshFailure>>,
674 watch::Receiver<MessageOrFailure<Option<MeshFailure>>>,
675 ) {
676 let (tx, rx) = cx.mailbox().open_port();
677 let tx = tx.bind();
678 controller
679 .send(cx, Subscribe(tx.clone()))
680 .expect("failed to send Subscribe");
681 (tx, into_watch(rx))
682 }
683
684 pub async fn next_supervision_event(
691 &self,
692 cx: &impl context::Actor,
693 ) -> Result<MeshFailure, anyhow::Error> {
694 let controller = if let Some(c) = self.controller() {
695 c
696 } else {
697 let health_state = self.health_state.entry(cx).or_default();
698 let health_state = health_state.get();
699 return match &health_state.unhealthy_event {
700 Some(Unhealthy::StreamClosed(f)) => Ok(f.clone()),
701 Some(Unhealthy::Crashed(f)) => Ok(f.clone()),
702 None => Err(anyhow::anyhow!(
703 "unexpected healthy state while controller is gone"
704 )),
705 };
706 };
707 let rx = {
708 let entry = self.receiver.entry(cx).or_insert_with(|| {
710 Arc::new(tokio::sync::Mutex::new(Self::init_supervision_receiver(
711 controller, cx,
712 )))
713 });
714 Arc::clone(entry.get())
717 };
718 let message = {
719 let mut rx = rx.lock().await;
720 let subscriber_port = rx.0.clone();
721 let message =
722 rx.1.wait_for(|message| {
723 if let MessageOrFailure::Message(message) = message {
727 if let Some(message) = &message {
728 let region = ndslice::view::Ranked::region(self).slice();
729 if message.crashed_ranks.is_empty() {
730 true
732 } else {
733 message.crashed_ranks.iter().any(|r| region.contains(*r))
736 }
737 } else {
738 false
742 }
743 } else {
744 true
746 }
747 })
748 .await?;
749 let message = message.clone();
750 let is_failure = matches!(
751 message,
752 MessageOrFailure::Failure(_) | MessageOrFailure::Timeout
753 );
754 if is_failure {
755 let mut port = controller.port();
760 port.return_undeliverable(false);
762 let _ = port.send(cx, Unsubscribe(subscriber_port));
763 }
764 match message {
768 MessageOrFailure::Message(message) => Ok::<MeshFailure, anyhow::Error>(
769 message.expect("filter excludes any None messages"),
770 ),
771 MessageOrFailure::Failure(failure) => Err(anyhow::anyhow!("{}", failure)),
772 MessageOrFailure::Timeout => {
773 Ok(MeshFailure {
776 actor_mesh_name: Some(self.name().to_string()),
777 event: ActorSupervisionEvent::new(
778 controller.actor_id().clone(),
779 None,
780 ActorStatus::generic_failure(format!(
781 "timed out reaching controller {} for mesh {}. Assuming controller's proc is dead",
782 controller.actor_id(),
783 self.name()
784 )),
785 None,
786 ),
787 crashed_ranks: vec![],
788 })
789 }
790 }?
791 };
792 let event = &message.event;
794 let mut entry = self.health_state.entry(cx).or_default();
796 let health_state = entry.get_mut();
797 if let ActorStatus::Failed(_) = event.actor_status {
798 for &rank in &message.crashed_ranks {
799 health_state.crashed_ranks.insert(rank, event.clone());
800 }
801 }
802 health_state.unhealthy_event = match &event.actor_status {
803 ActorStatus::Failed(_) => Some(Unhealthy::Crashed(message.clone())),
804 ActorStatus::Stopped(_) => Some(Unhealthy::StreamClosed(message.clone())),
805 _ => None,
806 };
807 Ok(message)
808 }
809
810 pub fn clone_with_supervision_receiver(&self) -> Self {
814 Self {
815 proc_mesh: self.proc_mesh.clone(),
816 name: self.name.clone(),
817 controller: self.controller.clone(),
818 health_state: self.health_state.clone(),
819 receiver: self.receiver.clone(),
820 pages: OnceCell::new(),
822 page_size: self.page_size,
823 }
824 }
825}
826
827impl<A: Referable> Clone for ActorMeshRef<A> {
828 fn clone(&self) -> Self {
829 Self {
830 proc_mesh: self.proc_mesh.clone(),
831 name: self.name.clone(),
832 controller: self.controller.clone(),
833 health_state: ActorLocal::new(),
836 receiver: ActorLocal::new(),
837 pages: OnceCell::new(), page_size: self.page_size,
839 }
840 }
841}
842
843impl<A: Referable> fmt::Display for ActorMeshRef<A> {
844 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
845 write!(f, "{}:{}@{}", self.name, A::typename(), self.proc_mesh)
846 }
847}
848
849impl<A: Referable> PartialEq for ActorMeshRef<A> {
850 fn eq(&self, other: &Self) -> bool {
851 self.proc_mesh == other.proc_mesh && self.name == other.name
852 }
853}
854impl<A: Referable> Eq for ActorMeshRef<A> {}
855
856impl<A: Referable> Hash for ActorMeshRef<A> {
857 fn hash<H: Hasher>(&self, state: &mut H) {
858 self.proc_mesh.hash(state);
859 self.name.hash(state);
860 }
861}
862
863impl<A: Referable> fmt::Debug for ActorMeshRef<A> {
864 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
865 f.debug_struct("ActorMeshRef")
866 .field("proc_mesh", &self.proc_mesh)
867 .field("name", &self.name)
868 .field("page_size", &self.page_size)
869 .finish_non_exhaustive() }
871}
872
873impl<A: Referable> Serialize for ActorMeshRef<A> {
875 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
876 where
877 S: Serializer,
878 {
879 (&self.proc_mesh, &self.name, &self.controller).serialize(serializer)
881 }
882}
883
884impl<'de, A: Referable> Deserialize<'de> for ActorMeshRef<A> {
886 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
887 where
888 D: Deserializer<'de>,
889 {
890 let (proc_mesh, name, controller) = <(
891 ProcMeshRef,
892 Name,
893 Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
894 )>::deserialize(deserializer)?;
895 Ok(ActorMeshRef::with_page_size(
896 name,
897 proc_mesh,
898 DEFAULT_PAGE,
899 controller,
900 ))
901 }
902}
903
904impl<A: Referable> view::Ranked for ActorMeshRef<A> {
905 type Item = hyperactor_reference::ActorRef<A>;
906
907 #[inline]
908 fn region(&self) -> &Region {
909 view::Ranked::region(&self.proc_mesh)
910 }
911
912 #[inline]
913 fn get(&self, rank: usize) -> Option<&Self::Item> {
914 self.materialize(rank)
915 }
916}
917
918impl<A: Referable> view::RankedSliceable for ActorMeshRef<A> {
919 fn sliced(&self, region: Region) -> Self {
920 debug_assert!(region.is_subset(view::Ranked::region(self)));
923 let proc_mesh = self.proc_mesh.subset(region).unwrap();
924 Self::with_page_size(
925 self.name.clone(),
926 proc_mesh,
927 self.page_size,
928 self.controller.clone(),
929 )
930 }
931}
932
933#[cfg(test)]
934mod tests {
935
936 use std::collections::HashSet;
937 use std::ops::Deref;
938
939 use hyperactor::actor::ActorErrorKind;
940 use hyperactor::actor::ActorStatus;
941 use hyperactor::context::Mailbox as _;
942 use hyperactor::mailbox;
943 use ndslice::Extent;
944 use ndslice::ViewExt;
945 use ndslice::extent;
946 use ndslice::view::Ranked;
947 use timed_test::async_timed_test;
948 use tokio::time::Duration;
949
950 use super::ActorMesh;
951 use crate::ActorMeshRef;
952 use crate::Name;
953 use crate::ProcMesh;
954 use crate::proc_mesh::ACTOR_SPAWN_MAX_IDLE;
955 use crate::proc_mesh::GET_ACTOR_STATE_MAX_IDLE;
956 use crate::supervision::MeshFailure;
957 use crate::testactor;
958 use crate::testing;
959
960 #[test]
961 fn test_actor_mesh_ref_is_send_and_sync() {
962 fn assert_send_sync<T: Send + Sync>() {}
963 assert_send_sync::<ActorMeshRef<()>>();
964 }
965
966 #[tokio::test]
967 #[cfg(fbcode_build)]
968 async fn test_actor_mesh_ref_lazy_materialization() {
969 let instance = testing::instance();
971 let mut hm = testing::host_mesh(3).await;
974 let pm: ProcMesh = hm
975 .spawn(instance, "test", extent!(gpus = 2), None)
976 .await
977 .unwrap();
978 let am: ActorMesh<testactor::TestActor> = pm.spawn(instance, "test", &()).await.unwrap();
979
980 let page_size = 2;
984 let amr: ActorMeshRef<testactor::TestActor> =
985 ActorMeshRef::with_page_size(am.name.clone(), pm.clone(), page_size, None);
986 assert_eq!(amr.extent(), extent!(hosts = 3, gpus = 2));
987 assert_eq!(amr.region().num_ranks(), 6);
988
989 let p0_a = amr.get(0).expect("rank 0 exists") as *const _;
991 let p0_b = amr.get(0).expect("rank 0 exists") as *const _;
992 assert_eq!(p0_a, p0_b, "same rank should return same cached pointer");
993
994 let p1_a = amr.get(1).expect("rank 1 exists") as *const _;
996 let p1_b = amr.get(1).expect("rank 1 exists") as *const _;
997 assert_eq!(p1_a, p1_b, "same rank should return same cached pointer");
998 assert_ne!(p0_a, p1_a, "different ranks have different cache slots");
1001
1002 let p2_a = amr.get(2).expect("rank 2 exists") as *const _;
1004 let p2_b = amr.get(2).expect("rank 2 exists") as *const _;
1005 assert_eq!(p2_a, p2_b, "same rank should return same cached pointer");
1006 assert_ne!(p0_a, p2_a, "different pages have different cache slots");
1007
1008 let amr_clone = amr.clone();
1010 let orig_id_0 = amr.get(0).unwrap().actor_id().clone();
1011 let clone_id_0 = amr_clone.get(0).unwrap().actor_id().clone();
1012 assert_eq!(orig_id_0, clone_id_0, "clone preserves identity");
1013 let p0_clone = amr_clone.get(0).unwrap() as *const _;
1014 assert_ne!(
1015 p0_a, p0_clone,
1016 "cloned ActorMeshRef has a fresh cache (different pointer)"
1017 );
1018
1019 let sliced = amr.range("hosts", 1..).expect("slice should be valid"); assert_eq!(sliced.region().num_ranks(), 4);
1023 let sp0_a = sliced.get(0).unwrap() as *const _;
1025 let sp0_b = sliced.get(0).unwrap() as *const _;
1026 assert_eq!(sp0_a, sp0_b, "sliced view has its own cache slot per rank");
1027 let sp2 = sliced.get(2).unwrap() as *const _;
1030 assert_ne!(sp0_a, sp2, "sliced view crosses its own page boundary");
1031
1032 let mut set = HashSet::new();
1035 set.insert(amr.clone());
1036 set.insert(amr.clone());
1037 assert_eq!(set.len(), 1, "cache state must not affect Hash/Eq");
1038
1039 let (port, mut rx) = mailbox::open_port(instance);
1042 amr.get(0)
1045 .expect("rank 0 exists")
1046 .send(instance, testactor::GetActorId(port.bind()))
1047 .expect("send to rank 0 should succeed");
1048 amr.get(3)
1049 .expect("rank 3 exists")
1050 .send(instance, testactor::GetActorId(port.bind()))
1051 .expect("send to rank 3 should succeed");
1052 let id_a = tokio::time::timeout(Duration::from_secs(3), rx.recv())
1053 .await
1054 .expect("timed out waiting for first reply")
1055 .expect("channel closed before first reply");
1056 let id_b = tokio::time::timeout(Duration::from_secs(3), rx.recv())
1057 .await
1058 .expect("timed out waiting for second reply")
1059 .expect("channel closed before second reply");
1060 assert_ne!(id_a, id_b, "two different ranks responded");
1061
1062 let _ = hm.shutdown(instance).await;
1063 }
1064
1065 #[async_timed_test(timeout_secs = 30)]
1066 #[cfg(fbcode_build)]
1067 async fn test_actor_states_with_panic() {
1068 hyperactor_telemetry::initialize_logging_for_test();
1069
1070 let instance = testing::instance();
1071 let (supervision_port, mut supervision_receiver) = instance.open_port::<MeshFailure>();
1073 let supervisor = supervision_port.bind();
1074 let num_replicas = 4;
1075 let mut hm = testing::host_mesh(num_replicas).await;
1076 let proc_mesh = hm
1077 .spawn(instance, "test", Extent::unity(), None)
1078 .await
1079 .unwrap();
1080 let child_name = Name::new("child").unwrap();
1081
1082 let actor_mesh: ActorMesh<testactor::WrapperActor> = proc_mesh
1085 .spawn(
1086 instance,
1087 "wrapper",
1088 &(proc_mesh.deref().clone(), supervisor, child_name.clone()),
1089 )
1090 .await
1091 .unwrap();
1092
1093 actor_mesh
1095 .cast(
1096 instance,
1097 testactor::CauseSupervisionEvent {
1098 kind: testactor::SupervisionEventType::Panic,
1099 send_to_children: true,
1100 },
1101 )
1102 .unwrap();
1103
1104 let (failure_port, mut failure_receiver) = instance.open_port::<Option<MeshFailure>>();
1114 actor_mesh
1115 .cast(
1116 instance,
1117 testactor::NextSupervisionFailure(failure_port.bind()),
1118 )
1119 .unwrap();
1120 let failure = failure_receiver
1121 .recv()
1122 .await
1123 .unwrap()
1124 .expect("no supervision event found on ref from wrapper actor");
1125 let check_failure = move |failure: MeshFailure| {
1126 assert_eq!(failure.actor_mesh_name, Some(child_name.to_string()));
1127 assert_eq!(
1128 failure.event.actor_id.name(),
1129 child_name.clone().to_string()
1130 );
1131 if let ActorStatus::Failed(ActorErrorKind::Generic(msg)) = &failure.event.actor_status {
1132 assert!(msg.contains("panic"), "{}", msg);
1133 assert!(msg.contains("for testing"), "{}", msg);
1134 } else {
1135 panic!("actor status is not failed: {}", failure.event.actor_status);
1136 }
1137 };
1138 check_failure(failure);
1139
1140 for _ in 0..num_replicas {
1144 let failure =
1145 tokio::time::timeout(Duration::from_secs(20), supervision_receiver.recv())
1146 .await
1147 .expect("timeout")
1148 .unwrap();
1149 check_failure(failure);
1150 }
1151
1152 let _ = hm.shutdown(instance).await;
1153 }
1154
1155 #[async_timed_test(timeout_secs = 30)]
1156 #[cfg(fbcode_build)]
1157 async fn test_actor_states_with_process_exit() {
1158 hyperactor_telemetry::initialize_logging_for_test();
1159
1160 let config = hyperactor_config::global::lock();
1161 let _guard = config.override_key(GET_ACTOR_STATE_MAX_IDLE, Duration::from_secs(1));
1162
1163 let instance = testing::instance();
1164 let (supervision_port, mut supervision_receiver) = instance.open_port::<MeshFailure>();
1166 let supervisor = supervision_port.bind();
1167 let num_replicas = 4;
1168 let mut hm = testing::host_mesh(num_replicas).await;
1169 let proc_mesh = hm
1170 .spawn(instance, "test", Extent::unity(), None)
1171 .await
1172 .unwrap();
1173 let mut second_hm = testing::host_mesh(num_replicas).await;
1174 let second_proc_mesh = second_hm
1175 .spawn(instance, "test2", Extent::unity(), None)
1176 .await
1177 .unwrap();
1178 let child_name = Name::new("child").unwrap();
1179
1180 let actor_mesh: ActorMesh<testactor::WrapperActor> = proc_mesh
1183 .spawn(
1184 instance,
1185 "wrapper",
1186 &(
1187 second_proc_mesh.deref().clone(),
1190 supervisor,
1191 child_name.clone(),
1192 ),
1193 )
1194 .await
1195 .unwrap();
1196
1197 actor_mesh
1198 .cast(
1199 instance,
1200 testactor::CauseSupervisionEvent {
1201 kind: testactor::SupervisionEventType::ProcessExit(1),
1202 send_to_children: true,
1203 },
1204 )
1205 .unwrap();
1206
1207 let (failure_port, mut failure_receiver) = instance.open_port::<Option<MeshFailure>>();
1209 actor_mesh
1210 .cast(
1211 instance,
1212 testactor::NextSupervisionFailure(failure_port.bind()),
1213 )
1214 .unwrap();
1215 let failure = failure_receiver
1216 .recv()
1217 .await
1218 .unwrap()
1219 .expect("no supervision event found on ref from wrapper actor");
1220
1221 let check_failure = move |failure: MeshFailure| {
1222 assert_eq!(failure.actor_mesh_name, Some(child_name.to_string()));
1223 assert_eq!(failure.event.actor_id.name(), child_name.to_string());
1224 if let ActorStatus::Failed(ActorErrorKind::Generic(msg)) = &failure.event.actor_status {
1225 assert!(msg.contains("exited with non-zero code 1"), "{}", msg);
1226 } else {
1227 panic!("actor status is not failed: {}", failure.event.actor_status);
1228 }
1229 };
1230 check_failure(failure);
1231
1232 for _ in 0..num_replicas {
1234 let failure =
1235 tokio::time::timeout(Duration::from_secs(20), supervision_receiver.recv())
1236 .await
1237 .expect("timeout")
1238 .unwrap();
1239 check_failure(failure);
1240 }
1241
1242 let _ = second_hm.shutdown(instance).await;
1243 let _ = hm.shutdown(instance).await;
1244 }
1245
1246 #[async_timed_test(timeout_secs = 30)]
1247 #[cfg(fbcode_build)]
1248 async fn test_actor_states_on_sliced_mesh() {
1249 hyperactor_telemetry::initialize_logging_for_test();
1250
1251 let instance = testing::instance();
1252 let (supervision_port, mut supervision_receiver) = instance.open_port::<MeshFailure>();
1254 let supervisor = supervision_port.bind();
1255 let num_replicas = 4;
1256 let mut hm = testing::host_mesh(num_replicas).await;
1257 let proc_mesh = hm
1258 .spawn(instance, "test", Extent::unity(), None)
1259 .await
1260 .unwrap();
1261 let child_name = Name::new("child").unwrap();
1262
1263 let actor_mesh: ActorMesh<testactor::WrapperActor> = proc_mesh
1266 .spawn(
1267 instance,
1268 "wrapper",
1269 &(proc_mesh.deref().clone(), supervisor, child_name.clone()),
1270 )
1271 .await
1272 .unwrap();
1273 let sliced = actor_mesh
1274 .range("hosts", 1..3)
1275 .expect("slice should be valid");
1276 let sliced_replicas = sliced.len();
1277
1278 sliced
1280 .cast(
1281 instance,
1282 testactor::CauseSupervisionEvent {
1283 kind: testactor::SupervisionEventType::Panic,
1284 send_to_children: true,
1285 },
1286 )
1287 .unwrap();
1288
1289 for _ in 0..sliced_replicas {
1290 let supervision_message =
1291 tokio::time::timeout(Duration::from_secs(20), supervision_receiver.recv())
1292 .await
1293 .expect("timeout")
1294 .unwrap();
1295 let event = supervision_message.event;
1296 assert_eq!(event.actor_id.name(), format!("{}", child_name.clone()));
1297 if let ActorStatus::Failed(ActorErrorKind::Generic(msg)) = &event.actor_status {
1298 assert!(msg.contains("panic"));
1299 assert!(msg.contains("for testing"));
1300 } else {
1301 panic!("actor status is not failed: {}", event.actor_status);
1302 }
1303 }
1304
1305 let _ = hm.shutdown(instance).await;
1306 }
1307
1308 #[async_timed_test(timeout_secs = 30)]
1309 #[cfg(fbcode_build)]
1310 async fn test_cast() {
1311 let config = hyperactor_config::global::lock();
1312 let _guard = config.override_key(crate::bootstrap::MESH_BOOTSTRAP_ENABLE_PDEATHSIG, false);
1313
1314 let instance = testing::instance();
1315 let mut host_mesh = testing::host_mesh(4).await;
1316 let proc_mesh = host_mesh
1317 .spawn(instance, "test", Extent::unity(), None)
1318 .await
1319 .unwrap();
1320 let actor_mesh: ActorMesh<testactor::TestActor> =
1321 proc_mesh.spawn(instance, "test", &()).await.unwrap();
1322
1323 let (cast_info, mut cast_info_rx) = instance.mailbox().open_port();
1324 actor_mesh
1325 .cast(
1326 instance,
1327 testactor::GetCastInfo {
1328 cast_info: cast_info.bind(),
1329 },
1330 )
1331 .unwrap();
1332
1333 let mut point_to_actor: HashSet<_> = actor_mesh.iter().collect();
1334 while !point_to_actor.is_empty() {
1335 let (point, origin_actor_ref, sender_actor_id) = cast_info_rx.recv().await.unwrap();
1336 let key = (point, origin_actor_ref);
1337 assert!(
1338 point_to_actor.remove(&key),
1339 "key {:?} not present or removed twice",
1340 key
1341 );
1342 assert_eq!(&sender_actor_id, instance.self_id());
1343 }
1344
1345 let _ = host_mesh.shutdown(instance).await;
1346 }
1347
1348 #[async_timed_test(timeout_secs = 60)]
1354 #[cfg(fbcode_build)]
1355 async fn test_undeliverable_message_return() {
1356 use hyperactor::mailbox::MessageEnvelope;
1357 use hyperactor::mailbox::Undeliverable;
1358 use hyperactor::testing::pingpong::PingPongActor;
1359 use hyperactor::testing::pingpong::PingPongMessage;
1360
1361 hyperactor_telemetry::initialize_logging_for_test();
1362
1363 let instance = testing::instance();
1364
1365 let mut hm = testing::host_mesh(2).await;
1367 let proc_mesh = hm
1368 .spawn(instance, "test", Extent::unity(), None)
1369 .await
1370 .unwrap();
1371
1372 let (undeliverable_port, mut undeliverable_rx) =
1374 instance.open_port::<Undeliverable<MessageEnvelope>>();
1375
1376 let ping_proc_mesh = proc_mesh.range("hosts", 0..1).unwrap();
1379 let pong_proc_mesh = proc_mesh.range("hosts", 1..2).unwrap();
1380
1381 let ping_mesh: ActorMesh<PingPongActor> = ping_proc_mesh
1382 .spawn(
1383 instance,
1384 "ping",
1385 &(Some(undeliverable_port.bind()), None, None),
1386 )
1387 .await
1388 .unwrap();
1389
1390 let mut pong_mesh: ActorMesh<PingPongActor> = pong_proc_mesh
1391 .spawn(instance, "pong", &(None, None, None))
1392 .await
1393 .unwrap();
1394
1395 let ping_handle = ping_mesh.values().next().unwrap();
1397 let pong_handle = pong_mesh.values().next().unwrap();
1398
1399 let (done_tx, done_rx) = instance.open_once_port();
1401 ping_handle
1402 .send(
1403 instance,
1404 PingPongMessage(2, pong_handle.clone(), done_tx.bind()),
1405 )
1406 .unwrap();
1407 assert!(
1408 done_rx.recv().await.unwrap(),
1409 "Initial ping-pong should work"
1410 );
1411
1412 pong_mesh
1414 .stop(instance, "test stop".to_string())
1415 .await
1416 .unwrap();
1417
1418 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1420
1421 let config = hyperactor_config::global::lock();
1423 let _guard = config.override_key(
1424 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1425 std::time::Duration::from_secs(5),
1426 );
1427
1428 let n = 100usize;
1430 for i in 1..=n {
1431 let ttl = 66 + i as u64; let (once_tx, _once_rx) = instance.open_once_port();
1433 ping_handle
1434 .send(
1435 instance,
1436 PingPongMessage(ttl, pong_handle.clone(), once_tx.bind()),
1437 )
1438 .unwrap();
1439 }
1440
1441 let mut count = 0;
1445 let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(10);
1446 while count < n && tokio::time::Instant::now() < deadline {
1447 match tokio::time::timeout(std::time::Duration::from_secs(1), undeliverable_rx.recv())
1448 .await
1449 {
1450 Ok(Ok(Undeliverable(envelope))) => {
1451 let _: PingPongMessage = envelope.deserialized().unwrap();
1452 count += 1;
1453 }
1454 Ok(Err(_)) => break, Err(_) => break, }
1457 }
1458
1459 assert_eq!(
1460 count, n,
1461 "Expected {} undeliverable messages, got {}",
1462 n, count
1463 );
1464
1465 let _ = hm.shutdown(instance).await;
1466 }
1467
1468 #[async_timed_test(timeout_secs = 30)]
1472 #[cfg(fbcode_build)]
1473 async fn test_actor_mesh_stop_timeout() {
1474 hyperactor_telemetry::initialize_logging_for_test();
1475
1476 let config = hyperactor_config::global::lock();
1486 let _guard = config.override_key(ACTOR_SPAWN_MAX_IDLE, std::time::Duration::from_secs(1));
1487
1488 let instance = testing::instance();
1489
1490 let mut hm = testing::host_mesh(2).await;
1492 let proc_mesh = hm
1493 .spawn(instance, "test", Extent::unity(), None)
1494 .await
1495 .unwrap();
1496
1497 let mut sleep_mesh: ActorMesh<testactor::SleepActor> =
1500 proc_mesh.spawn(instance, "sleepers", &()).await.unwrap();
1501
1502 for actor_ref in sleep_mesh.values() {
1505 actor_ref
1506 .send(instance, std::time::Duration::from_secs(5))
1507 .unwrap();
1508 }
1509
1510 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1512
1513 let expected_actors = sleep_mesh.values().count();
1515
1516 let stop_start = tokio::time::Instant::now();
1519 let result = sleep_mesh.stop(instance, "test stop".to_string()).await;
1520 let stop_duration = tokio::time::Instant::now().duration_since(stop_start);
1521
1522 match result {
1526 Ok(_) => {
1527 tracing::warn!("Actors stopped gracefully (unexpected but ok)");
1530 }
1531 Err(ref e) => {
1532 let err_str = format!("{:?}", e);
1534 assert!(
1535 err_str.contains("Timeout"),
1536 "Expected Timeout error, got: {:?}",
1537 e
1538 );
1539 tracing::info!(
1540 "Stop timed out as expected for {} actors, they were aborted",
1541 expected_actors
1542 );
1543 }
1544 }
1545
1546 assert!(
1551 stop_duration < std::time::Duration::from_secs(3),
1552 "Stop took {:?}, expected < 3s (actors should have been aborted, not waited for)",
1553 stop_duration
1554 );
1555 assert!(
1556 stop_duration >= std::time::Duration::from_millis(900),
1557 "Stop took {:?}, expected >= 900ms (should have waited for timeout)",
1558 stop_duration
1559 );
1560
1561 let _ = hm.shutdown(instance).await;
1562 }
1563
1564 #[async_timed_test(timeout_secs = 30)]
1570 #[cfg(fbcode_build)]
1571 async fn test_actor_mesh_stop_graceful() {
1572 hyperactor_telemetry::initialize_logging_for_test();
1573
1574 let instance = testing::instance();
1575
1576 let mut hm = testing::host_mesh(2).await;
1578 let proc_mesh = hm
1579 .spawn(instance, "test", Extent::unity(), None)
1580 .await
1581 .unwrap();
1582
1583 let mut actor_mesh: ActorMesh<testactor::TestActor> =
1586 proc_mesh.spawn(instance, "test_actors", &()).await.unwrap();
1587
1588 let mesh_ref = actor_mesh.deref().clone();
1591
1592 let expected_actors = actor_mesh.values().count();
1593 assert!(expected_actors > 0, "Should have spawned some actors");
1594
1595 let stop_start = tokio::time::Instant::now();
1597 let result = actor_mesh.stop(instance, "test stop".to_string()).await;
1598 let stop_duration = tokio::time::Instant::now().duration_since(stop_start);
1599
1600 assert!(
1602 result.is_ok(),
1603 "Stop should succeed for responsive actors, got: {:?}",
1604 result.err()
1605 );
1606
1607 assert!(
1611 stop_duration < std::time::Duration::from_secs(2),
1612 "Graceful stop took {:?}, expected < 2s (actors should stop quickly)",
1613 stop_duration
1614 );
1615
1616 tracing::info!(
1617 "Successfully stopped {} actors in {:?}",
1618 expected_actors,
1619 stop_duration
1620 );
1621
1622 let next_event = actor_mesh.next_supervision_event(instance).await.unwrap();
1628 assert_eq!(
1629 next_event.actor_mesh_name,
1630 Some(mesh_ref.name().to_string())
1631 );
1632 assert!(matches!(
1633 next_event.event.actor_status,
1634 ActorStatus::Stopped(_)
1635 ));
1636 let next_event = mesh_ref.next_supervision_event(instance).await.unwrap();
1639 assert_eq!(
1640 next_event.actor_mesh_name,
1641 Some(mesh_ref.name().to_string())
1642 );
1643 assert!(matches!(
1644 next_event.event.actor_status,
1645 ActorStatus::Stopped(_)
1646 ));
1647
1648 let _ = hm.shutdown(instance).await;
1649 }
1650}