1use std::collections::HashMap;
15use std::fmt;
16use std::hash::Hash;
17use std::hash::Hasher;
18use std::ops::Deref;
19use std::sync::Arc;
20use std::sync::OnceLock as OnceCell;
21use std::time::Duration;
22
23use hyperactor::ActorLocal;
24use hyperactor::RemoteHandles;
25use hyperactor::RemoteMessage;
26use hyperactor::actor::ActorStatus;
27use hyperactor::actor::Referable;
28use hyperactor::context;
29use hyperactor::mailbox::PortReceiver;
30use hyperactor::message::Castable;
31use hyperactor::message::IndexedErasedUnbound;
32use hyperactor::message::Unbound;
33use hyperactor::reference as hyperactor_reference;
34use hyperactor::supervision::ActorSupervisionEvent;
35use hyperactor_config::CONFIG;
36use hyperactor_config::ConfigAttr;
37use hyperactor_config::Flattrs;
38use hyperactor_config::attrs::declare_attrs;
39use hyperactor_mesh_macros::sel;
40use ndslice::Selection;
41use ndslice::ViewExt as _;
42use ndslice::view;
43use ndslice::view::Region;
44use ndslice::view::View;
45use serde::Deserialize;
46use serde::Deserializer;
47use serde::Serialize;
48use serde::Serializer;
49use tokio::sync::watch;
50
51use crate::CommActor;
52use crate::Error;
53use crate::Name;
54use crate::ProcMeshRef;
55use crate::ValueMesh;
56use crate::casting;
57use crate::comm::multicast;
58use crate::host_mesh::GET_PROC_STATE_MAX_IDLE;
59use crate::host_mesh::mesh_to_rankedvalues_with_default;
60use crate::mesh_controller::ActorMeshController;
61use crate::mesh_controller::SUPERVISION_POLL_FREQUENCY;
62use crate::mesh_controller::Subscribe;
63use crate::mesh_controller::Unsubscribe;
64use crate::proc_agent::ActorState;
65use crate::proc_mesh::GET_ACTOR_STATE_MAX_IDLE;
66use crate::reference::ActorMeshId;
67use crate::resource;
68use crate::supervision::MeshFailure;
69use crate::supervision::Unhealthy;
70
71declare_attrs! {
72 @meta(CONFIG = ConfigAttr::new(
81 Some("HYPERACTOR_MESH_SUPERVISION_WATCHDOG_TIMEOUT".to_string()),
82 Some("supervision_watchdog_timeout".to_string()),
83 ))
84 pub attr SUPERVISION_WATCHDOG_TIMEOUT: Duration = Duration::from_mins(2);
85}
86
87#[derive(Debug)]
92pub struct ActorMesh<A: Referable> {
93 proc_mesh: ProcMeshRef,
94 name: Name,
95 current_ref: ActorMeshRef<A>,
96 controller: Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
102}
103
104impl<A: Referable> ActorMesh<A> {
107 pub(crate) fn new(
108 proc_mesh: ProcMeshRef,
109 name: Name,
110 controller: Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
111 ) -> Self {
112 let current_ref = ActorMeshRef::with_page_size(
113 name.clone(),
114 proc_mesh.clone(),
115 DEFAULT_PAGE,
116 controller.clone(),
117 );
118
119 Self {
120 proc_mesh,
121 name,
122 current_ref,
123 controller,
124 }
125 }
126
127 pub fn name(&self) -> &Name {
128 &self.name
129 }
130
131 pub(crate) fn set_controller(
132 &mut self,
133 controller: Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
134 ) {
135 self.controller = controller.clone();
136 self.current_ref.set_controller(controller);
137 }
138
139 pub async fn stop(&mut self, cx: &impl context::Actor, reason: String) -> crate::Result<()> {
141 if let Some(controller) = self.controller.take() {
147 controller
149 .send(
150 cx,
151 resource::Stop {
152 name: self.name.clone(),
153 reason,
154 },
155 )
156 .map_err(|e| {
157 crate::Error::SendingError(controller.actor_id().clone(), Box::new(e))
158 })?;
159 let region = ndslice::view::Ranked::region(&self.current_ref);
160 let num_ranks = region.num_ranks();
161 let (port, mut rx) = cx.mailbox().open_port();
163
164 controller
165 .send(
166 cx,
167 resource::GetState::<resource::mesh::State<()>> {
168 name: self.name.clone(),
169 reply: port.bind(),
170 },
171 )
172 .map_err(|e| {
173 crate::Error::SendingError(controller.actor_id().clone(), Box::new(e))
174 })?;
175
176 let statuses = rx.recv().await?;
177 if let Some(state) = &statuses.state {
178 let all_stopped = state.statuses.values().all(|s| s.is_terminating());
182 if all_stopped {
183 Ok(())
184 } else {
185 let legacy = mesh_to_rankedvalues_with_default(
186 &state.statuses,
187 resource::Status::NotExist,
188 resource::Status::is_not_exist,
189 num_ranks,
190 );
191 Err(Error::ActorStopError { statuses: legacy })
192 }
193 } else {
194 Err(Error::Other(anyhow::anyhow!(
195 "non-existent state in GetState reply from controller: {}",
196 controller.actor_id()
197 )))
198 }?;
199 let mut entry = self.health_state.entry(cx).or_default();
201 let health_state = entry.get_mut();
202 health_state.unhealthy_event = Some(Unhealthy::StreamClosed(MeshFailure {
203 actor_mesh_name: Some(self.name().to_string()),
204 rank: None,
205 event: ActorSupervisionEvent::new(
206 ndslice::view::Ranked::get(&self.current_ref, 0)
208 .unwrap()
209 .actor_id()
210 .clone(),
211 None,
212 ActorStatus::Stopped("mesh stopped".to_string()),
213 None,
214 ),
215 }));
216 }
217 self.current_ref.controller.take();
220 Ok(())
221 }
222}
223
224impl<A: Referable> fmt::Display for ActorMesh<A> {
225 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226 write!(f, "{}", self.current_ref)
227 }
228}
229
230impl<A: Referable> Deref for ActorMesh<A> {
231 type Target = ActorMeshRef<A>;
232
233 fn deref(&self) -> &Self::Target {
234 &self.current_ref
235 }
236}
237
238impl<A: Referable> Clone for ActorMesh<A> {
241 fn clone(&self) -> Self {
242 Self {
243 proc_mesh: self.proc_mesh.clone(),
244 name: self.name.clone(),
245 current_ref: self.current_ref.clone(),
246 controller: self.controller.clone(),
247 }
248 }
249}
250
251impl<A: Referable> Drop for ActorMesh<A> {
252 fn drop(&mut self) {
253 tracing::info!(
254 name = "ActorMeshStatus",
255 actor_name = %self.name,
256 status = "Dropped",
257 );
258 }
259}
260
261const DEFAULT_PAGE: usize = 1024;
265
266struct Page<A: Referable> {
268 slots: Box<[OnceCell<hyperactor_reference::ActorRef<A>>]>,
269}
270
271impl<A: Referable> Page<A> {
272 fn new(len: usize) -> Self {
273 let mut v = Vec::with_capacity(len);
274 for _ in 0..len {
275 v.push(OnceCell::new());
276 }
277 Self {
278 slots: v.into_boxed_slice(),
279 }
280 }
281}
282
283#[derive(Default)]
284struct HealthState {
285 unhealthy_event: Option<Unhealthy>,
286 crashed_ranks: HashMap<usize, ActorSupervisionEvent>,
287}
288
289impl std::fmt::Debug for HealthState {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 f.debug_struct("HealthState")
292 .field("unhealthy_event", &self.unhealthy_event)
293 .field("crashed_ranks", &self.crashed_ranks)
294 .finish()
295 }
296}
297
298#[derive(Clone)]
299enum MessageOrFailure<M: Send + Sync + Clone + Default + 'static> {
300 Message(M),
301 Failure(String),
304 Timeout,
305}
306
307impl<M: Send + Sync + Clone + Default + 'static> Default for MessageOrFailure<M> {
308 fn default() -> Self {
309 Self::Message(M::default())
310 }
311}
312
313fn into_watch<M: Send + Sync + Clone + Default + 'static>(
317 mut rx: PortReceiver<M>,
318) -> watch::Receiver<MessageOrFailure<M>> {
319 let (sender, receiver) = watch::channel(MessageOrFailure::<M>::default());
320 let timeout = hyperactor_config::global::get(SUPERVISION_WATCHDOG_TIMEOUT);
328 let poll_frequency = hyperactor_config::global::get(SUPERVISION_POLL_FREQUENCY);
329 let get_actor_state_max_idle = hyperactor_config::global::get(GET_ACTOR_STATE_MAX_IDLE);
330 let get_proc_state_max_idle = hyperactor_config::global::get(GET_PROC_STATE_MAX_IDLE);
331 let total_time = poll_frequency + get_actor_state_max_idle + get_proc_state_max_idle;
332 if timeout < total_time {
333 tracing::warn!(
334 "HYPERACTOR_MESH_SUPERVISION_WATCHDOG_TIMEOUT={} is too short. It should be >= {} (SUPERVISION_POLL_FREQUENCY={} + GET_ACTOR_STATE_MAX_IDLE={} + GET_PROC_STATE_MAX_IDLE={})",
335 humantime::format_duration(timeout),
336 humantime::format_duration(total_time),
337 humantime::format_duration(poll_frequency),
338 humantime::format_duration(get_actor_state_max_idle),
339 humantime::format_duration(get_proc_state_max_idle),
340 );
341 }
342 tokio::spawn(async move {
343 loop {
344 let message = match tokio::time::timeout(timeout, rx.recv()).await {
345 Ok(Ok(msg)) => MessageOrFailure::Message(msg),
346 Ok(Err(e)) => MessageOrFailure::Failure(e.to_string()),
347 Err(_) => MessageOrFailure::Timeout,
348 };
349 let is_failure = matches!(
350 message,
351 MessageOrFailure::Failure(_) | MessageOrFailure::Timeout
352 );
353 if sender.send(message).is_err() {
354 break;
356 }
357 if is_failure {
358 break;
360 }
361 }
362 });
363 receiver
364}
365
366pub struct ActorMeshRef<A: Referable> {
368 proc_mesh: ProcMeshRef,
369 name: Name,
370 controller: Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
377
378 health_state: ActorLocal<HealthState>,
382 receiver: ActorLocal<
387 Arc<
388 tokio::sync::Mutex<(
389 hyperactor_reference::PortRef<Option<MeshFailure>>,
390 watch::Receiver<MessageOrFailure<Option<MeshFailure>>>,
391 )>,
392 >,
393 >,
394 pages: OnceCell<Vec<OnceCell<Box<Page<A>>>>>,
404 page_size: usize,
406}
407
408impl<A: Referable> ActorMeshRef<A> {
409 #[allow(clippy::result_large_err)]
411 pub fn cast<M>(&self, cx: &impl context::Actor, message: M) -> crate::Result<()>
412 where
413 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
414 M: Castable + RemoteMessage + Clone, {
416 self.cast_with_selection(cx, sel!(*), message)
417 }
418
419 #[allow(clippy::result_large_err)]
424 pub fn cast_for_tensor_engine_only_do_not_use<M>(
425 &self,
426 cx: &impl context::Actor,
427 sel: Selection,
428 message: M,
429 ) -> crate::Result<()>
430 where
431 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
432 M: Castable + RemoteMessage + Clone, {
434 self.cast_with_selection(cx, sel, message)
435 }
436
437 #[allow(clippy::result_large_err)]
438 fn cast_with_selection<M>(
439 &self,
440 cx: &impl context::Actor,
441 sel: Selection,
442 message: M,
443 ) -> crate::Result<()>
444 where
445 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
446 M: Castable + RemoteMessage + Clone, {
448 {
451 let health_state = self.health_state.entry(cx).or_default();
452 let health_state = health_state.get();
453 match &health_state.unhealthy_event {
454 Some(Unhealthy::StreamClosed(failure)) => {
455 return Err(crate::Error::Supervision(Box::new(failure.clone())));
456 }
457 Some(Unhealthy::Crashed(failure)) => {
458 return Err(crate::Error::Supervision(Box::new(failure.clone())));
459 }
460 None => {
461 assert!(health_state.crashed_ranks.is_empty());
464 }
465 }
466 }
467
468 hyperactor_telemetry::notify_sent_message(hyperactor_telemetry::SentMessageEvent {
469 timestamp: std::time::SystemTime::now(),
470 sender_actor_id: hyperactor_telemetry::hash_to_u64(cx.mailbox().actor_id()),
471 actor_mesh_id: hyperactor_telemetry::hash_to_u64(&self.name.to_string()),
472 view_json: serde_json::to_string(view::Ranked::region(self)).unwrap_or_default(),
473 shape_json: {
474 let shape: ndslice::Shape = view::Ranked::region(self).into();
475 serde_json::to_string(&shape).unwrap_or_default()
476 },
477 });
478
479 if let Some(root_comm_actor) = self.proc_mesh.root_comm_actor() {
481 self.cast_v0(cx, message, sel, root_comm_actor)
482 } else {
483 for (point, actor) in self.iter() {
484 let create_rank = point.rank();
485 let mut headers = Flattrs::new();
486 multicast::set_cast_info_on_headers(
487 &mut headers,
488 point,
489 cx.instance().self_id().clone(),
490 );
491
492 let mut unbound = Unbound::try_from_message(message.clone())
495 .map_err(|e| Error::CastingError(self.name.clone(), e))?;
496 unbound
497 .visit_mut::<resource::Rank>(|resource::Rank(rank)| {
498 *rank = Some(create_rank);
499 Ok(())
500 })
501 .map_err(|e| Error::CastingError(self.name.clone(), e))?;
502 let rebound_message = unbound
503 .bind()
504 .map_err(|e| Error::CastingError(self.name.clone(), e))?;
505 actor
506 .send_with_headers(cx, headers, rebound_message)
507 .map_err(|e| Error::SendingError(actor.actor_id().clone(), Box::new(e)))?;
508 }
509 Ok(())
510 }
511 }
512
513 #[allow(clippy::result_large_err)]
514 fn cast_v0<M>(
515 &self,
516 cx: &impl context::Actor,
517 message: M,
518 sel: Selection,
519 root_comm_actor: &hyperactor_reference::ActorRef<CommActor>,
520 ) -> crate::Result<()>
521 where
522 A: RemoteHandles<IndexedErasedUnbound<M>>,
523 M: Castable + RemoteMessage + Clone, {
525 let cast_mesh_shape = view::Ranked::region(self).into();
526 let actor_mesh_id = ActorMeshId(self.name.clone());
527 match &self.proc_mesh.root_region {
528 Some(root_region) => {
529 let root_mesh_shape = root_region.into();
530 casting::cast_to_sliced_mesh::<A, M>(
531 cx,
532 actor_mesh_id,
533 root_comm_actor,
534 &sel,
535 message,
536 &cast_mesh_shape,
537 &root_mesh_shape,
538 )
539 .map_err(|e| Error::CastingError(self.name.clone(), e.into()))
540 }
541 None => casting::actor_mesh_cast::<A, M>(
542 cx,
543 actor_mesh_id,
544 root_comm_actor,
545 sel,
546 &cast_mesh_shape,
547 &cast_mesh_shape,
548 message,
549 )
550 .map_err(|e| Error::CastingError(self.name.clone(), e.into())),
551 }
552 }
553
554 #[allow(clippy::result_large_err)]
560 pub async fn actor_states(
561 &self,
562 cx: &impl context::Actor,
563 ) -> crate::Result<ValueMesh<resource::State<ActorState>>> {
564 self.actor_states_with_keepalive(cx, None).await
565 }
566
567 #[allow(clippy::result_large_err)]
568 pub(crate) async fn actor_states_with_keepalive(
569 &self,
570 cx: &impl context::Actor,
571 keepalive: Option<std::time::SystemTime>,
572 ) -> crate::Result<ValueMesh<resource::State<ActorState>>> {
573 self.proc_mesh
574 .actor_states_with_keepalive(cx, self.name.clone(), keepalive)
575 .await
576 }
577
578 pub(crate) fn new(
579 name: Name,
580 proc_mesh: ProcMeshRef,
581 controller: Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
582 ) -> Self {
583 Self::with_page_size(name, proc_mesh, DEFAULT_PAGE, controller)
584 }
585
586 pub fn name(&self) -> &Name {
587 &self.name
588 }
589
590 pub(crate) fn with_page_size(
591 name: Name,
592 proc_mesh: ProcMeshRef,
593 page_size: usize,
594 controller: Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
595 ) -> Self {
596 Self {
597 proc_mesh,
598 name,
599 controller,
600 health_state: ActorLocal::new(),
601 receiver: ActorLocal::new(),
602 pages: OnceCell::new(),
603 page_size: page_size.max(1),
604 }
605 }
606
607 pub fn proc_mesh(&self) -> &ProcMeshRef {
608 &self.proc_mesh
609 }
610
611 #[inline]
612 fn len(&self) -> usize {
613 view::Ranked::region(&self.proc_mesh).num_ranks()
614 }
615
616 pub fn controller(&self) -> &Option<hyperactor_reference::ActorRef<ActorMeshController<A>>> {
617 &self.controller
618 }
619
620 fn set_controller(
621 &mut self,
622 controller: Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
623 ) {
624 self.controller = controller;
625 }
626
627 fn ensure_pages(&self) -> &Vec<OnceCell<Box<Page<A>>>> {
628 let n = self.len().div_ceil(self.page_size); self.pages
630 .get_or_init(|| (0..n).map(|_| OnceCell::new()).collect())
631 }
632
633 fn materialize(&self, rank: usize) -> Option<&hyperactor_reference::ActorRef<A>> {
634 let len = self.len();
635 if rank >= len {
636 return None;
637 }
638 let p = self.page_size;
639 let page_ix = rank / p;
640 let local_ix = rank % p;
641
642 let pages = self.ensure_pages();
643 let page = pages[page_ix].get_or_init(|| {
644 let base = page_ix * p;
646 let remaining = len - base;
647 let page_len = remaining.min(p);
648 Box::new(Page::<A>::new(page_len))
649 });
650
651 Some(page.slots[local_ix].get_or_init(|| {
652 debug_assert!(rank < self.len(), "rank must be within [0, len)");
659 debug_assert!(
660 ndslice::view::Ranked::get(&self.proc_mesh, rank).is_some(),
661 "proc_mesh must be dense/aligned with this view"
662 );
663 let proc_ref =
664 ndslice::view::Ranked::get(&self.proc_mesh, rank).expect("rank in-bounds");
665 proc_ref.attest(&self.name)
666 }))
667 }
668
669 fn init_supervision_receiver(
670 controller: &hyperactor_reference::ActorRef<ActorMeshController<A>>,
671 cx: &impl context::Actor,
672 ) -> (
673 hyperactor_reference::PortRef<Option<MeshFailure>>,
674 watch::Receiver<MessageOrFailure<Option<MeshFailure>>>,
675 ) {
676 let (tx, rx) = cx.mailbox().open_port();
677 let tx = tx.bind();
678 controller
679 .send(cx, Subscribe(tx.clone()))
680 .expect("failed to send Subscribe");
681 (tx, into_watch(rx))
682 }
683
684 pub async fn next_supervision_event(
691 &self,
692 cx: &impl context::Actor,
693 ) -> Result<MeshFailure, anyhow::Error> {
694 let controller = if let Some(c) = self.controller() {
695 c
696 } else {
697 let health_state = self.health_state.entry(cx).or_default();
698 let health_state = health_state.get();
699 return match &health_state.unhealthy_event {
700 Some(Unhealthy::StreamClosed(f)) => Ok(f.clone()),
701 Some(Unhealthy::Crashed(f)) => Ok(f.clone()),
702 None => Err(anyhow::anyhow!(
703 "unexpected healthy state while controller is gone"
704 )),
705 };
706 };
707 let rx = {
708 let entry = self.receiver.entry(cx).or_insert_with(|| {
710 Arc::new(tokio::sync::Mutex::new(Self::init_supervision_receiver(
711 controller, cx,
712 )))
713 });
714 Arc::clone(entry.get())
717 };
718 let message = {
719 let mut rx = rx.lock().await;
720 let subscriber_port = rx.0.clone();
721 let message =
722 rx.1.wait_for(|message| {
723 if let MessageOrFailure::Message(message) = message {
727 if let Some(message) = &message {
728 if let Some(rank) = &message.rank {
729 ndslice::view::Ranked::region(self).slice().contains(*rank)
730 } else {
731 true
733 }
734 } else {
735 false
739 }
740 } else {
741 true
743 }
744 })
745 .await?;
746 let message = message.clone();
747 let is_failure = matches!(
748 message,
749 MessageOrFailure::Failure(_) | MessageOrFailure::Timeout
750 );
751 if is_failure {
752 let mut port = controller.port();
757 port.return_undeliverable(false);
759 let _ = port.send(cx, Unsubscribe(subscriber_port));
760 }
761 match message {
765 MessageOrFailure::Message(message) => Ok::<MeshFailure, anyhow::Error>(
766 message.expect("filter excludes any None messages"),
767 ),
768 MessageOrFailure::Failure(failure) => Err(anyhow::anyhow!("{}", failure)),
769 MessageOrFailure::Timeout => {
770 Ok(MeshFailure {
773 actor_mesh_name: Some(self.name().to_string()),
774 rank: None,
775 event: ActorSupervisionEvent::new(
776 controller.actor_id().clone(),
777 None,
778 ActorStatus::generic_failure(format!(
779 "timed out reaching controller {} for mesh {}. Assuming controller's proc is dead",
780 controller.actor_id(),
781 self.name()
782 )),
783 None,
784 ),
785 })
786 }
787 }?
788 };
789 let rank = message.rank.unwrap_or_default();
791 let event = &message.event;
792 let mut entry = self.health_state.entry(cx).or_default();
794 let health_state = entry.get_mut();
795 if let ActorStatus::Failed(_) = event.actor_status {
796 health_state.crashed_ranks.insert(rank, event.clone());
797 }
798 health_state.unhealthy_event = match &event.actor_status {
799 ActorStatus::Failed(_) => Some(Unhealthy::Crashed(message.clone())),
800 ActorStatus::Stopped(_) => Some(Unhealthy::StreamClosed(message.clone())),
801 _ => None,
802 };
803 Ok(message)
804 }
805
806 pub fn clone_with_supervision_receiver(&self) -> Self {
810 Self {
811 proc_mesh: self.proc_mesh.clone(),
812 name: self.name.clone(),
813 controller: self.controller.clone(),
814 health_state: self.health_state.clone(),
815 receiver: self.receiver.clone(),
816 pages: OnceCell::new(),
818 page_size: self.page_size,
819 }
820 }
821}
822
823impl<A: Referable> Clone for ActorMeshRef<A> {
824 fn clone(&self) -> Self {
825 Self {
826 proc_mesh: self.proc_mesh.clone(),
827 name: self.name.clone(),
828 controller: self.controller.clone(),
829 health_state: ActorLocal::new(),
832 receiver: ActorLocal::new(),
833 pages: OnceCell::new(), page_size: self.page_size,
835 }
836 }
837}
838
839impl<A: Referable> fmt::Display for ActorMeshRef<A> {
840 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
841 write!(f, "{}:{}@{}", self.name, A::typename(), self.proc_mesh)
842 }
843}
844
845impl<A: Referable> PartialEq for ActorMeshRef<A> {
846 fn eq(&self, other: &Self) -> bool {
847 self.proc_mesh == other.proc_mesh && self.name == other.name
848 }
849}
850impl<A: Referable> Eq for ActorMeshRef<A> {}
851
852impl<A: Referable> Hash for ActorMeshRef<A> {
853 fn hash<H: Hasher>(&self, state: &mut H) {
854 self.proc_mesh.hash(state);
855 self.name.hash(state);
856 }
857}
858
859impl<A: Referable> fmt::Debug for ActorMeshRef<A> {
860 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
861 f.debug_struct("ActorMeshRef")
862 .field("proc_mesh", &self.proc_mesh)
863 .field("name", &self.name)
864 .field("page_size", &self.page_size)
865 .finish_non_exhaustive() }
867}
868
869impl<A: Referable> Serialize for ActorMeshRef<A> {
871 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
872 where
873 S: Serializer,
874 {
875 (&self.proc_mesh, &self.name, &self.controller).serialize(serializer)
877 }
878}
879
880impl<'de, A: Referable> Deserialize<'de> for ActorMeshRef<A> {
882 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
883 where
884 D: Deserializer<'de>,
885 {
886 let (proc_mesh, name, controller) = <(
887 ProcMeshRef,
888 Name,
889 Option<hyperactor_reference::ActorRef<ActorMeshController<A>>>,
890 )>::deserialize(deserializer)?;
891 Ok(ActorMeshRef::with_page_size(
892 name,
893 proc_mesh,
894 DEFAULT_PAGE,
895 controller,
896 ))
897 }
898}
899
900impl<A: Referable> view::Ranked for ActorMeshRef<A> {
901 type Item = hyperactor_reference::ActorRef<A>;
902
903 #[inline]
904 fn region(&self) -> &Region {
905 view::Ranked::region(&self.proc_mesh)
906 }
907
908 #[inline]
909 fn get(&self, rank: usize) -> Option<&Self::Item> {
910 self.materialize(rank)
911 }
912}
913
914impl<A: Referable> view::RankedSliceable for ActorMeshRef<A> {
915 fn sliced(&self, region: Region) -> Self {
916 debug_assert!(region.is_subset(view::Ranked::region(self)));
919 let proc_mesh = self.proc_mesh.subset(region).unwrap();
920 Self::with_page_size(
921 self.name.clone(),
922 proc_mesh,
923 self.page_size,
924 self.controller.clone(),
925 )
926 }
927}
928
929#[cfg(test)]
930mod tests {
931
932 use std::collections::HashSet;
933 use std::ops::Deref;
934
935 use hyperactor::actor::ActorErrorKind;
936 use hyperactor::actor::ActorStatus;
937 use hyperactor::context::Mailbox as _;
938 use hyperactor::mailbox;
939 use ndslice::Extent;
940 use ndslice::ViewExt;
941 use ndslice::extent;
942 use ndslice::view::Ranked;
943 use timed_test::async_timed_test;
944 use tokio::time::Duration;
945
946 use super::ActorMesh;
947 use crate::ActorMeshRef;
948 use crate::Name;
949 use crate::ProcMesh;
950 use crate::proc_mesh::ACTOR_SPAWN_MAX_IDLE;
951 use crate::proc_mesh::GET_ACTOR_STATE_MAX_IDLE;
952 use crate::supervision::MeshFailure;
953 use crate::testactor;
954 use crate::testing;
955
956 #[test]
957 fn test_actor_mesh_ref_is_send_and_sync() {
958 fn assert_send_sync<T: Send + Sync>() {}
959 assert_send_sync::<ActorMeshRef<()>>();
960 }
961
962 #[tokio::test]
963 #[cfg(fbcode_build)]
964 async fn test_actor_mesh_ref_lazy_materialization() {
965 let instance = testing::instance();
967 let mut hm = testing::host_mesh(3).await;
970 let pm: ProcMesh = hm.spawn(instance, "test", extent!(gpus = 2)).await.unwrap();
971 let am: ActorMesh<testactor::TestActor> = pm.spawn(instance, "test", &()).await.unwrap();
972
973 let page_size = 2;
977 let amr: ActorMeshRef<testactor::TestActor> =
978 ActorMeshRef::with_page_size(am.name.clone(), pm.clone(), page_size, None);
979 assert_eq!(amr.extent(), extent!(hosts = 3, gpus = 2));
980 assert_eq!(amr.region().num_ranks(), 6);
981
982 let p0_a = amr.get(0).expect("rank 0 exists") as *const _;
984 let p0_b = amr.get(0).expect("rank 0 exists") as *const _;
985 assert_eq!(p0_a, p0_b, "same rank should return same cached pointer");
986
987 let p1_a = amr.get(1).expect("rank 1 exists") as *const _;
989 let p1_b = amr.get(1).expect("rank 1 exists") as *const _;
990 assert_eq!(p1_a, p1_b, "same rank should return same cached pointer");
991 assert_ne!(p0_a, p1_a, "different ranks have different cache slots");
994
995 let p2_a = amr.get(2).expect("rank 2 exists") as *const _;
997 let p2_b = amr.get(2).expect("rank 2 exists") as *const _;
998 assert_eq!(p2_a, p2_b, "same rank should return same cached pointer");
999 assert_ne!(p0_a, p2_a, "different pages have different cache slots");
1000
1001 let amr_clone = amr.clone();
1003 let orig_id_0 = amr.get(0).unwrap().actor_id().clone();
1004 let clone_id_0 = amr_clone.get(0).unwrap().actor_id().clone();
1005 assert_eq!(orig_id_0, clone_id_0, "clone preserves identity");
1006 let p0_clone = amr_clone.get(0).unwrap() as *const _;
1007 assert_ne!(
1008 p0_a, p0_clone,
1009 "cloned ActorMeshRef has a fresh cache (different pointer)"
1010 );
1011
1012 let sliced = amr.range("hosts", 1..).expect("slice should be valid"); assert_eq!(sliced.region().num_ranks(), 4);
1016 let sp0_a = sliced.get(0).unwrap() as *const _;
1018 let sp0_b = sliced.get(0).unwrap() as *const _;
1019 assert_eq!(sp0_a, sp0_b, "sliced view has its own cache slot per rank");
1020 let sp2 = sliced.get(2).unwrap() as *const _;
1023 assert_ne!(sp0_a, sp2, "sliced view crosses its own page boundary");
1024
1025 let mut set = HashSet::new();
1028 set.insert(amr.clone());
1029 set.insert(amr.clone());
1030 assert_eq!(set.len(), 1, "cache state must not affect Hash/Eq");
1031
1032 let (port, mut rx) = mailbox::open_port(instance);
1035 amr.get(0)
1038 .expect("rank 0 exists")
1039 .send(instance, testactor::GetActorId(port.bind()))
1040 .expect("send to rank 0 should succeed");
1041 amr.get(3)
1042 .expect("rank 3 exists")
1043 .send(instance, testactor::GetActorId(port.bind()))
1044 .expect("send to rank 3 should succeed");
1045 let id_a = tokio::time::timeout(Duration::from_secs(3), rx.recv())
1046 .await
1047 .expect("timed out waiting for first reply")
1048 .expect("channel closed before first reply");
1049 let id_b = tokio::time::timeout(Duration::from_secs(3), rx.recv())
1050 .await
1051 .expect("timed out waiting for second reply")
1052 .expect("channel closed before second reply");
1053 assert_ne!(id_a, id_b, "two different ranks responded");
1054
1055 let _ = hm.shutdown(instance).await;
1056 }
1057
1058 #[async_timed_test(timeout_secs = 30)]
1059 #[cfg(fbcode_build)]
1060 async fn test_actor_states_with_panic() {
1061 hyperactor_telemetry::initialize_logging_for_test();
1062
1063 let instance = testing::instance();
1064 let (supervision_port, mut supervision_receiver) = instance.open_port::<MeshFailure>();
1066 let supervisor = supervision_port.bind();
1067 let num_replicas = 4;
1068 let mut hm = testing::host_mesh(num_replicas).await;
1069 let proc_mesh = hm.spawn(instance, "test", Extent::unity()).await.unwrap();
1070 let child_name = Name::new("child").unwrap();
1071
1072 let actor_mesh: ActorMesh<testactor::WrapperActor> = proc_mesh
1075 .spawn(
1076 instance,
1077 "wrapper",
1078 &(proc_mesh.deref().clone(), supervisor, child_name.clone()),
1079 )
1080 .await
1081 .unwrap();
1082
1083 actor_mesh
1085 .cast(
1086 instance,
1087 testactor::CauseSupervisionEvent {
1088 kind: testactor::SupervisionEventType::Panic,
1089 send_to_children: true,
1090 },
1091 )
1092 .unwrap();
1093
1094 let (failure_port, mut failure_receiver) = instance.open_port::<Option<MeshFailure>>();
1104 actor_mesh
1105 .cast(
1106 instance,
1107 testactor::NextSupervisionFailure(failure_port.bind()),
1108 )
1109 .unwrap();
1110 let failure = failure_receiver
1111 .recv()
1112 .await
1113 .unwrap()
1114 .expect("no supervision event found on ref from wrapper actor");
1115 let check_failure = move |failure: MeshFailure| {
1116 assert_eq!(failure.actor_mesh_name, Some(child_name.to_string()));
1117 assert_eq!(
1118 failure.event.actor_id.name(),
1119 child_name.clone().to_string()
1120 );
1121 if let ActorStatus::Failed(ActorErrorKind::Generic(msg)) = &failure.event.actor_status {
1122 assert!(msg.contains("panic"), "{}", msg);
1123 assert!(msg.contains("for testing"), "{}", msg);
1124 } else {
1125 panic!("actor status is not failed: {}", failure.event.actor_status);
1126 }
1127 };
1128 check_failure(failure);
1129
1130 for _ in 0..num_replicas {
1134 let failure =
1135 tokio::time::timeout(Duration::from_secs(20), supervision_receiver.recv())
1136 .await
1137 .expect("timeout")
1138 .unwrap();
1139 check_failure(failure);
1140 }
1141
1142 let _ = hm.shutdown(instance).await;
1143 }
1144
1145 #[async_timed_test(timeout_secs = 30)]
1146 #[cfg(fbcode_build)]
1147 async fn test_actor_states_with_process_exit() {
1148 hyperactor_telemetry::initialize_logging_for_test();
1149
1150 let config = hyperactor_config::global::lock();
1151 let _guard = config.override_key(GET_ACTOR_STATE_MAX_IDLE, Duration::from_secs(1));
1152
1153 let instance = testing::instance();
1154 let (supervision_port, mut supervision_receiver) = instance.open_port::<MeshFailure>();
1156 let supervisor = supervision_port.bind();
1157 let num_replicas = 4;
1158 let mut hm = testing::host_mesh(num_replicas).await;
1159 let proc_mesh = hm.spawn(instance, "test", Extent::unity()).await.unwrap();
1160 let mut second_hm = testing::host_mesh(num_replicas).await;
1161 let second_proc_mesh = second_hm
1162 .spawn(instance, "test2", Extent::unity())
1163 .await
1164 .unwrap();
1165 let child_name = Name::new("child").unwrap();
1166
1167 let actor_mesh: ActorMesh<testactor::WrapperActor> = proc_mesh
1170 .spawn(
1171 instance,
1172 "wrapper",
1173 &(
1174 second_proc_mesh.deref().clone(),
1177 supervisor,
1178 child_name.clone(),
1179 ),
1180 )
1181 .await
1182 .unwrap();
1183
1184 actor_mesh
1185 .cast(
1186 instance,
1187 testactor::CauseSupervisionEvent {
1188 kind: testactor::SupervisionEventType::ProcessExit(1),
1189 send_to_children: true,
1190 },
1191 )
1192 .unwrap();
1193
1194 let (failure_port, mut failure_receiver) = instance.open_port::<Option<MeshFailure>>();
1196 actor_mesh
1197 .cast(
1198 instance,
1199 testactor::NextSupervisionFailure(failure_port.bind()),
1200 )
1201 .unwrap();
1202 let failure = failure_receiver
1203 .recv()
1204 .await
1205 .unwrap()
1206 .expect("no supervision event found on ref from wrapper actor");
1207
1208 let check_failure = move |failure: MeshFailure| {
1209 assert_eq!(failure.actor_mesh_name, Some(child_name.to_string()));
1210 assert_eq!(failure.event.actor_id.name(), child_name.to_string());
1211 if let ActorStatus::Failed(ActorErrorKind::Generic(msg)) = &failure.event.actor_status {
1212 assert!(
1213 msg.contains("process exited with non-zero code 1"),
1214 "{}",
1215 msg
1216 );
1217 } else {
1218 panic!("actor status is not failed: {}", failure.event.actor_status);
1219 }
1220 };
1221 check_failure(failure);
1222
1223 for _ in 0..num_replicas {
1225 let failure =
1226 tokio::time::timeout(Duration::from_secs(20), supervision_receiver.recv())
1227 .await
1228 .expect("timeout")
1229 .unwrap();
1230 check_failure(failure);
1231 }
1232
1233 let _ = second_hm.shutdown(instance).await;
1234 let _ = hm.shutdown(instance).await;
1235 }
1236
1237 #[async_timed_test(timeout_secs = 30)]
1238 #[cfg(fbcode_build)]
1239 async fn test_actor_states_on_sliced_mesh() {
1240 hyperactor_telemetry::initialize_logging_for_test();
1241
1242 let instance = testing::instance();
1243 let (supervision_port, mut supervision_receiver) = instance.open_port::<MeshFailure>();
1245 let supervisor = supervision_port.bind();
1246 let num_replicas = 4;
1247 let mut hm = testing::host_mesh(num_replicas).await;
1248 let proc_mesh = hm.spawn(instance, "test", Extent::unity()).await.unwrap();
1249 let child_name = Name::new("child").unwrap();
1250
1251 let actor_mesh: ActorMesh<testactor::WrapperActor> = proc_mesh
1254 .spawn(
1255 instance,
1256 "wrapper",
1257 &(proc_mesh.deref().clone(), supervisor, child_name.clone()),
1258 )
1259 .await
1260 .unwrap();
1261 let sliced = actor_mesh
1262 .range("hosts", 1..3)
1263 .expect("slice should be valid");
1264 let sliced_replicas = sliced.len();
1265
1266 sliced
1268 .cast(
1269 instance,
1270 testactor::CauseSupervisionEvent {
1271 kind: testactor::SupervisionEventType::Panic,
1272 send_to_children: true,
1273 },
1274 )
1275 .unwrap();
1276
1277 for _ in 0..sliced_replicas {
1278 let supervision_message =
1279 tokio::time::timeout(Duration::from_secs(20), supervision_receiver.recv())
1280 .await
1281 .expect("timeout")
1282 .unwrap();
1283 let event = supervision_message.event;
1284 assert_eq!(event.actor_id.name(), format!("{}", child_name.clone()));
1285 if let ActorStatus::Failed(ActorErrorKind::Generic(msg)) = &event.actor_status {
1286 assert!(msg.contains("panic"));
1287 assert!(msg.contains("for testing"));
1288 } else {
1289 panic!("actor status is not failed: {}", event.actor_status);
1290 }
1291 }
1292
1293 let _ = hm.shutdown(instance).await;
1294 }
1295
1296 #[async_timed_test(timeout_secs = 30)]
1297 #[cfg(fbcode_build)]
1298 async fn test_cast() {
1299 let config = hyperactor_config::global::lock();
1300 let _guard = config.override_key(crate::bootstrap::MESH_BOOTSTRAP_ENABLE_PDEATHSIG, false);
1301
1302 let instance = testing::instance();
1303 let mut host_mesh = testing::host_mesh(4).await;
1304 let proc_mesh = host_mesh
1305 .spawn(instance, "test", Extent::unity())
1306 .await
1307 .unwrap();
1308 let actor_mesh: ActorMesh<testactor::TestActor> =
1309 proc_mesh.spawn(instance, "test", &()).await.unwrap();
1310
1311 let (cast_info, mut cast_info_rx) = instance.mailbox().open_port();
1312 actor_mesh
1313 .cast(
1314 instance,
1315 testactor::GetCastInfo {
1316 cast_info: cast_info.bind(),
1317 },
1318 )
1319 .unwrap();
1320
1321 let mut point_to_actor: HashSet<_> = actor_mesh.iter().collect();
1322 while !point_to_actor.is_empty() {
1323 let (point, origin_actor_ref, sender_actor_id) = cast_info_rx.recv().await.unwrap();
1324 let key = (point, origin_actor_ref);
1325 assert!(
1326 point_to_actor.remove(&key),
1327 "key {:?} not present or removed twice",
1328 key
1329 );
1330 assert_eq!(&sender_actor_id, instance.self_id());
1331 }
1332
1333 let _ = host_mesh.shutdown(instance).await;
1334 }
1335
1336 #[async_timed_test(timeout_secs = 60)]
1342 #[cfg(fbcode_build)]
1343 async fn test_undeliverable_message_return() {
1344 use hyperactor::mailbox::MessageEnvelope;
1345 use hyperactor::mailbox::Undeliverable;
1346 use hyperactor::testing::pingpong::PingPongActor;
1347 use hyperactor::testing::pingpong::PingPongMessage;
1348
1349 hyperactor_telemetry::initialize_logging_for_test();
1350
1351 let instance = testing::instance();
1352
1353 let mut hm = testing::host_mesh(2).await;
1355 let proc_mesh = hm.spawn(instance, "test", Extent::unity()).await.unwrap();
1356
1357 let (undeliverable_port, mut undeliverable_rx) =
1359 instance.open_port::<Undeliverable<MessageEnvelope>>();
1360
1361 let ping_proc_mesh = proc_mesh.range("hosts", 0..1).unwrap();
1364 let pong_proc_mesh = proc_mesh.range("hosts", 1..2).unwrap();
1365
1366 let ping_mesh: ActorMesh<PingPongActor> = ping_proc_mesh
1367 .spawn(
1368 instance,
1369 "ping",
1370 &(Some(undeliverable_port.bind()), None, None),
1371 )
1372 .await
1373 .unwrap();
1374
1375 let mut pong_mesh: ActorMesh<PingPongActor> = pong_proc_mesh
1376 .spawn(instance, "pong", &(None, None, None))
1377 .await
1378 .unwrap();
1379
1380 let ping_handle = ping_mesh.values().next().unwrap();
1382 let pong_handle = pong_mesh.values().next().unwrap();
1383
1384 let (done_tx, done_rx) = instance.open_once_port();
1386 ping_handle
1387 .send(
1388 instance,
1389 PingPongMessage(2, pong_handle.clone(), done_tx.bind()),
1390 )
1391 .unwrap();
1392 assert!(
1393 done_rx.recv().await.unwrap(),
1394 "Initial ping-pong should work"
1395 );
1396
1397 pong_mesh
1399 .stop(instance, "test stop".to_string())
1400 .await
1401 .unwrap();
1402
1403 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1405
1406 let config = hyperactor_config::global::lock();
1408 let _guard = config.override_key(
1409 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1410 std::time::Duration::from_secs(5),
1411 );
1412
1413 let n = 100usize;
1415 for i in 1..=n {
1416 let ttl = 66 + i as u64; let (once_tx, _once_rx) = instance.open_once_port();
1418 ping_handle
1419 .send(
1420 instance,
1421 PingPongMessage(ttl, pong_handle.clone(), once_tx.bind()),
1422 )
1423 .unwrap();
1424 }
1425
1426 let mut count = 0;
1430 let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(10);
1431 while count < n && tokio::time::Instant::now() < deadline {
1432 match tokio::time::timeout(std::time::Duration::from_secs(1), undeliverable_rx.recv())
1433 .await
1434 {
1435 Ok(Ok(Undeliverable(envelope))) => {
1436 let _: PingPongMessage = envelope.deserialized().unwrap();
1437 count += 1;
1438 }
1439 Ok(Err(_)) => break, Err(_) => break, }
1442 }
1443
1444 assert_eq!(
1445 count, n,
1446 "Expected {} undeliverable messages, got {}",
1447 n, count
1448 );
1449
1450 let _ = hm.shutdown(instance).await;
1451 }
1452
1453 #[async_timed_test(timeout_secs = 30)]
1457 #[cfg(fbcode_build)]
1458 async fn test_actor_mesh_stop_timeout() {
1459 hyperactor_telemetry::initialize_logging_for_test();
1460
1461 let config = hyperactor_config::global::lock();
1471 let _guard = config.override_key(ACTOR_SPAWN_MAX_IDLE, std::time::Duration::from_secs(1));
1472
1473 let instance = testing::instance();
1474
1475 let mut hm = testing::host_mesh(2).await;
1477 let proc_mesh = hm.spawn(instance, "test", Extent::unity()).await.unwrap();
1478
1479 let mut sleep_mesh: ActorMesh<testactor::SleepActor> =
1482 proc_mesh.spawn(instance, "sleepers", &()).await.unwrap();
1483
1484 for actor_ref in sleep_mesh.values() {
1487 actor_ref
1488 .send(instance, std::time::Duration::from_secs(5))
1489 .unwrap();
1490 }
1491
1492 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1494
1495 let expected_actors = sleep_mesh.values().count();
1497
1498 let stop_start = tokio::time::Instant::now();
1501 let result = sleep_mesh.stop(instance, "test stop".to_string()).await;
1502 let stop_duration = tokio::time::Instant::now().duration_since(stop_start);
1503
1504 match result {
1508 Ok(_) => {
1509 tracing::warn!("Actors stopped gracefully (unexpected but ok)");
1512 }
1513 Err(ref e) => {
1514 let err_str = format!("{:?}", e);
1516 assert!(
1517 err_str.contains("Timeout"),
1518 "Expected Timeout error, got: {:?}",
1519 e
1520 );
1521 tracing::info!(
1522 "Stop timed out as expected for {} actors, they were aborted",
1523 expected_actors
1524 );
1525 }
1526 }
1527
1528 assert!(
1533 stop_duration < std::time::Duration::from_secs(3),
1534 "Stop took {:?}, expected < 3s (actors should have been aborted, not waited for)",
1535 stop_duration
1536 );
1537 assert!(
1538 stop_duration >= std::time::Duration::from_millis(900),
1539 "Stop took {:?}, expected >= 900ms (should have waited for timeout)",
1540 stop_duration
1541 );
1542
1543 let _ = hm.shutdown(instance).await;
1544 }
1545
1546 #[async_timed_test(timeout_secs = 30)]
1552 #[cfg(fbcode_build)]
1553 async fn test_actor_mesh_stop_graceful() {
1554 hyperactor_telemetry::initialize_logging_for_test();
1555
1556 let instance = testing::instance();
1557
1558 let mut hm = testing::host_mesh(2).await;
1560 let proc_mesh = hm.spawn(instance, "test", Extent::unity()).await.unwrap();
1561
1562 let mut actor_mesh: ActorMesh<testactor::TestActor> =
1565 proc_mesh.spawn(instance, "test_actors", &()).await.unwrap();
1566
1567 let mesh_ref = actor_mesh.deref().clone();
1570
1571 let expected_actors = actor_mesh.values().count();
1572 assert!(expected_actors > 0, "Should have spawned some actors");
1573
1574 let stop_start = tokio::time::Instant::now();
1576 let result = actor_mesh.stop(instance, "test stop".to_string()).await;
1577 let stop_duration = tokio::time::Instant::now().duration_since(stop_start);
1578
1579 assert!(
1581 result.is_ok(),
1582 "Stop should succeed for responsive actors, got: {:?}",
1583 result.err()
1584 );
1585
1586 assert!(
1590 stop_duration < std::time::Duration::from_secs(2),
1591 "Graceful stop took {:?}, expected < 2s (actors should stop quickly)",
1592 stop_duration
1593 );
1594
1595 tracing::info!(
1596 "Successfully stopped {} actors in {:?}",
1597 expected_actors,
1598 stop_duration
1599 );
1600
1601 let next_event = actor_mesh.next_supervision_event(instance).await.unwrap();
1607 assert_eq!(
1608 next_event.actor_mesh_name,
1609 Some(mesh_ref.name().to_string())
1610 );
1611 assert!(matches!(
1612 next_event.event.actor_status,
1613 ActorStatus::Stopped(_)
1614 ));
1615 let next_event = mesh_ref.next_supervision_event(instance).await.unwrap();
1618 assert_eq!(
1619 next_event.actor_mesh_name,
1620 Some(mesh_ref.name().to_string())
1621 );
1622 assert!(matches!(
1623 next_event.event.actor_status,
1624 ActorStatus::Stopped(_)
1625 ));
1626
1627 let _ = hm.shutdown(instance).await;
1628 }
1629}