1use std::collections::HashMap;
10use std::fmt;
11use std::hash::Hash;
12use std::hash::Hasher;
13use std::ops::Deref;
14use std::sync::Arc;
15use std::sync::OnceLock as OnceCell;
16use std::time::Duration;
17
18use hyperactor::ActorRef;
19use hyperactor::RemoteHandles;
20use hyperactor::RemoteMessage;
21use hyperactor::actor::ActorStatus;
22use hyperactor::actor::Referable;
23use hyperactor::clock::Clock;
24use hyperactor::clock::RealClock;
25use hyperactor::context;
26use hyperactor::mailbox::PortReceiver;
27use hyperactor::message::Castable;
28use hyperactor::message::IndexedErasedUnbound;
29use hyperactor::message::Unbound;
30use hyperactor::supervision::ActorSupervisionEvent;
31use hyperactor_config::CONFIG;
32use hyperactor_config::ConfigAttr;
33use hyperactor_config::attrs::Attrs;
34use hyperactor_config::attrs::declare_attrs;
35use hyperactor_mesh_macros::sel;
36use ndslice::Selection;
37use ndslice::ViewExt as _;
38use ndslice::view;
39use ndslice::view::Ranked;
40use ndslice::view::Region;
41use ndslice::view::View;
42use serde::Deserialize;
43use serde::Deserializer;
44use serde::Serialize;
45use serde::Serializer;
46use tokio::sync::watch;
47
48use crate::CommActor;
49use crate::actor_mesh as v0_actor_mesh;
50use crate::comm::multicast;
51use crate::proc_mesh::mesh_agent::ActorState;
52use crate::reference::ActorMeshId;
53use crate::resource;
54use crate::supervision::MeshFailure;
55use crate::supervision::Unhealthy;
56use crate::v1;
57use crate::v1::Error;
58use crate::v1::Name;
59use crate::v1::ProcMeshRef;
60use crate::v1::ValueMesh;
61use crate::v1::host_mesh::mesh_to_rankedvalues_with_default;
62use crate::v1::mesh_controller::ActorMeshController;
63use crate::v1::mesh_controller::Subscribe;
64
65declare_attrs! {
66 @meta(CONFIG = ConfigAttr {
72 env_name: Some("HYPERACTOR_MESH_SUPERVISION_LIVENESS_TIMEOUT".to_string()),
73 py_name: Some("supervision_liveness_timeout".to_string()),
74 })
75 pub attr SUPERVISION_LIVENESS_TIMEOUT: Duration = Duration::from_secs(30);
76}
77
78#[derive(Debug)]
83pub struct ActorMesh<A: Referable> {
84 proc_mesh: ProcMeshRef,
85 name: Name,
86 current_ref: ActorMeshRef<A>,
87 controller: Option<ActorRef<ActorMeshController<A>>>,
93}
94
95impl<A: Referable> ActorMesh<A> {
98 pub(crate) fn new(
99 proc_mesh: ProcMeshRef,
100 name: Name,
101 controller: Option<ActorRef<ActorMeshController<A>>>,
102 ) -> Self {
103 let current_ref = ActorMeshRef::with_page_size(
104 name.clone(),
105 proc_mesh.clone(),
106 DEFAULT_PAGE,
107 controller.clone(),
108 );
109
110 Self {
111 proc_mesh,
112 name,
113 current_ref,
114 controller,
115 }
116 }
117
118 pub fn name(&self) -> &Name {
119 &self.name
120 }
121
122 pub(crate) fn detach(self) -> ActorMeshRef<A> {
124 self.current_ref.clone()
125 }
126
127 pub(crate) fn set_controller(&mut self, controller: Option<ActorRef<ActorMeshController<A>>>) {
128 self.controller = controller.clone();
129 self.current_ref.set_controller(controller);
130 }
131
132 pub async fn stop(&mut self, cx: &impl context::Actor) -> v1::Result<()> {
134 if let Some(controller) = self.controller.take() {
140 controller
142 .send(
143 cx,
144 resource::Stop {
145 name: self.name.clone(),
146 },
147 )
148 .map_err(|e| v1::Error::SendingError(controller.actor_id().clone(), Box::new(e)))?;
149 let region = ndslice::view::Ranked::region(&self.current_ref);
150 let num_ranks = region.num_ranks();
151 let (port, mut rx) = cx.mailbox().open_port();
153
154 controller
155 .send(
156 cx,
157 resource::GetState::<resource::mesh::State<()>> {
158 name: self.name.clone(),
159 reply: port.bind(),
160 },
161 )
162 .map_err(|e| v1::Error::SendingError(controller.actor_id().clone(), Box::new(e)))?;
163
164 let statuses = rx.recv().await?;
165 if let Some(state) = &statuses.state {
166 let all_stopped = state.statuses.values().all(|s| s.is_terminating());
170 if all_stopped {
171 Ok(())
172 } else {
173 let legacy = mesh_to_rankedvalues_with_default(
174 &state.statuses,
175 resource::Status::NotExist,
176 resource::Status::is_not_exist,
177 num_ranks,
178 );
179 Err(Error::ActorStopError { statuses: legacy })
180 }
181 } else {
182 Err(Error::Other(anyhow::anyhow!(
183 "non-existent state in GetState reply from controller: {}",
184 controller.actor_id()
185 )))
186 }?;
187 let mut health_state = self.health_state.lock().expect("lock poisoned");
189 health_state.unhealthy_event = Some(Unhealthy::StreamClosed(MeshFailure {
190 actor_mesh_name: Some(self.name().to_string()),
191 rank: None,
192 event: ActorSupervisionEvent::new(
193 ndslice::view::Ranked::get(&self.current_ref, 0)
195 .unwrap()
196 .actor_id()
197 .clone(),
198 None,
199 ActorStatus::Stopped,
200 None,
201 ),
202 }));
203 }
204 self.current_ref.controller.take();
207 Ok(())
208 }
209}
210
211impl<A: Referable> fmt::Display for ActorMesh<A> {
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 write!(f, "{}", self.current_ref)
214 }
215}
216
217impl<A: Referable> Deref for ActorMesh<A> {
218 type Target = ActorMeshRef<A>;
219
220 fn deref(&self) -> &Self::Target {
221 &self.current_ref
222 }
223}
224
225impl<A: Referable> Clone for ActorMesh<A> {
228 fn clone(&self) -> Self {
229 Self {
230 proc_mesh: self.proc_mesh.clone(),
231 name: self.name.clone(),
232 current_ref: self.current_ref.clone(),
233 controller: self.controller.clone(),
234 }
235 }
236}
237
238impl<A: Referable> Drop for ActorMesh<A> {
239 fn drop(&mut self) {
240 tracing::info!(
241 name = "ActorMeshStatus",
242 actor_name = %self.name,
243 status = "Dropped",
244 );
245 }
246}
247
248const DEFAULT_PAGE: usize = 1024;
252
253struct Page<A: Referable> {
255 slots: Box<[OnceCell<ActorRef<A>>]>,
256}
257
258impl<A: Referable> Page<A> {
259 fn new(len: usize) -> Self {
260 let mut v = Vec::with_capacity(len);
261 for _ in 0..len {
262 v.push(OnceCell::new());
263 }
264 Self {
265 slots: v.into_boxed_slice(),
266 }
267 }
268}
269
270#[derive(Default)]
271struct HealthState {
272 unhealthy_event: Option<Unhealthy>,
273 crashed_ranks: HashMap<usize, ActorSupervisionEvent>,
274}
275
276impl std::fmt::Debug for HealthState {
277 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278 f.debug_struct("HealthState")
279 .field("unhealthy_event", &self.unhealthy_event)
280 .field("crashed_ranks", &self.crashed_ranks)
281 .finish()
282 }
283}
284
285#[derive(Clone)]
286enum MessageOrFailure<M: Send + Sync + Clone + Default + 'static> {
287 Message(M),
288 Failure(String),
291 Timeout,
292}
293
294impl<M: Send + Sync + Clone + Default + 'static> Default for MessageOrFailure<M> {
295 fn default() -> Self {
296 Self::Message(M::default())
297 }
298}
299
300fn into_watch<M: Send + Sync + Clone + Default + 'static>(
304 mut rx: PortReceiver<M>,
305) -> watch::Receiver<MessageOrFailure<M>> {
306 let (sender, receiver) = watch::channel(MessageOrFailure::<M>::default());
307 let timeout = hyperactor_config::global::get(SUPERVISION_LIVENESS_TIMEOUT);
315 tokio::spawn(async move {
316 loop {
317 let message = match RealClock.timeout(timeout, rx.recv()).await {
318 Ok(Ok(msg)) => MessageOrFailure::Message(msg),
319 Ok(Err(e)) => MessageOrFailure::Failure(e.to_string()),
320 Err(_) => MessageOrFailure::Timeout,
321 };
322 let is_failure = matches!(
323 message,
324 MessageOrFailure::Failure(_) | MessageOrFailure::Timeout
325 );
326 if sender.send(message).is_err() {
327 break;
329 }
330 if is_failure {
331 break;
333 }
334 }
335 });
336 receiver
337}
338
339pub struct ActorMeshRef<A: Referable> {
341 proc_mesh: ProcMeshRef,
342 name: Name,
343 controller: Option<ActorRef<ActorMeshController<A>>>,
350
351 health_state: Arc<std::sync::Mutex<HealthState>>,
355 receiver:
359 Arc<tokio::sync::Mutex<Option<watch::Receiver<MessageOrFailure<Option<MeshFailure>>>>>>,
360 pages: OnceCell<Vec<OnceCell<Box<Page<A>>>>>,
370 page_size: usize,
372}
373
374impl<A: Referable> ActorMeshRef<A> {
375 #[allow(clippy::result_large_err)]
377 pub fn cast<M>(&self, cx: &impl context::Actor, message: M) -> v1::Result<()>
378 where
379 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
380 M: Castable + RemoteMessage + Clone, {
382 self.cast_with_selection(cx, sel!(*), message)
383 }
384
385 #[allow(clippy::result_large_err)]
390 pub(crate) fn cast_for_tensor_engine_only_do_not_use<M>(
391 &self,
392 cx: &impl context::Actor,
393 sel: Selection,
394 message: M,
395 ) -> v1::Result<()>
396 where
397 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
398 M: Castable + RemoteMessage + Clone, {
400 self.cast_with_selection(cx, sel, message)
401 }
402
403 #[allow(clippy::result_large_err)]
404 fn cast_with_selection<M>(
405 &self,
406 cx: &impl context::Actor,
407 sel: Selection,
408 message: M,
409 ) -> v1::Result<()>
410 where
411 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
412 M: Castable + RemoteMessage + Clone, {
414 let health_state = self
417 .health_state()
418 .lock()
419 .unwrap_or_else(|e| e.into_inner());
420 let region = Ranked::region(self);
421 match &health_state.unhealthy_event {
422 Some(Unhealthy::StreamClosed(_)) => {
423 return Err(v1::Error::Other(anyhow::anyhow!(
424 "actor mesh is stopped due to proc mesh shutdown",
425 )));
426 }
427 Some(Unhealthy::Crashed(failure)) => {
428 return Err(v1::Error::Other(anyhow::anyhow!(
429 "Actor {} is unhealthy with reason: {}",
430 failure.event.actor_id,
431 failure.event.actor_status
432 )));
433 }
434 None => {
435 if let Some(event) = region
438 .slice()
439 .iter()
440 .find_map(|rank| health_state.crashed_ranks.get(&rank).clone())
441 {
442 return Err(v1::Error::Other(anyhow::anyhow!(
443 "Actor {} is unhealthy with reason: {}",
444 event.actor_id,
445 event.actor_status
446 )));
447 }
448 }
449 }
450 drop(health_state);
451
452 if let Some(root_comm_actor) = self.proc_mesh.root_comm_actor() {
454 self.cast_v0(cx, message, sel, root_comm_actor)
455 } else {
456 for (point, actor) in self.iter() {
457 let create_rank = point.rank();
458 let mut headers = Attrs::new();
459 headers.set(
460 multicast::CAST_ORIGINATING_SENDER,
461 cx.instance().self_id().clone(),
462 );
463 headers.set(multicast::CAST_POINT, point);
464
465 let mut unbound = Unbound::try_from_message(message.clone())
468 .map_err(|e| Error::CastingError(self.name.clone(), e))?;
469 unbound
470 .visit_mut::<resource::Rank>(|resource::Rank(rank)| {
471 *rank = Some(create_rank);
472 Ok(())
473 })
474 .map_err(|e| Error::CastingError(self.name.clone(), e))?;
475 let rebound_message = unbound
476 .bind()
477 .map_err(|e| Error::CastingError(self.name.clone(), e))?;
478 actor
479 .send_with_headers(cx, headers, rebound_message)
480 .map_err(|e| Error::SendingError(actor.actor_id().clone(), Box::new(e)))?;
481 }
482 Ok(())
483 }
484 }
485
486 #[allow(clippy::result_large_err)]
487 fn cast_v0<M>(
488 &self,
489 cx: &impl context::Actor,
490 message: M,
491 sel: Selection,
492 root_comm_actor: &ActorRef<CommActor>,
493 ) -> v1::Result<()>
494 where
495 A: RemoteHandles<IndexedErasedUnbound<M>>,
496 M: Castable + RemoteMessage + Clone, {
498 let cast_mesh_shape = view::Ranked::region(self).into();
499 let actor_mesh_id = ActorMeshId::V1(self.name.clone());
500 match &self.proc_mesh.root_region {
501 Some(root_region) => {
502 let root_mesh_shape = root_region.into();
503 v0_actor_mesh::cast_to_sliced_mesh::<A, M>(
504 cx,
505 actor_mesh_id,
506 root_comm_actor,
507 &sel,
508 message,
509 &cast_mesh_shape,
510 &root_mesh_shape,
511 )
512 .map_err(|e| Error::CastingError(self.name.clone(), e.into()))
513 }
514 None => v0_actor_mesh::actor_mesh_cast::<A, M>(
515 cx,
516 actor_mesh_id,
517 root_comm_actor,
518 sel,
519 &cast_mesh_shape,
520 &cast_mesh_shape,
521 message,
522 )
523 .map_err(|e| Error::CastingError(self.name.clone(), e.into())),
524 }
525 }
526
527 #[allow(clippy::result_large_err)]
528 pub async fn actor_states(
529 &self,
530 cx: &impl context::Actor,
531 ) -> v1::Result<ValueMesh<resource::State<ActorState>>> {
532 self.proc_mesh.actor_states(cx, self.name.clone()).await
533 }
534
535 pub(crate) fn new(
536 name: Name,
537 proc_mesh: ProcMeshRef,
538 controller: Option<ActorRef<ActorMeshController<A>>>,
539 ) -> Self {
540 Self::with_page_size(name, proc_mesh, DEFAULT_PAGE, controller)
541 }
542
543 pub fn name(&self) -> &Name {
544 &self.name
545 }
546
547 pub(crate) fn with_page_size(
548 name: Name,
549 proc_mesh: ProcMeshRef,
550 page_size: usize,
551 controller: Option<ActorRef<ActorMeshController<A>>>,
552 ) -> Self {
553 Self {
554 proc_mesh,
555 name,
556 controller,
557 health_state: Arc::new(std::sync::Mutex::new(HealthState::default())),
558 receiver: Arc::new(tokio::sync::Mutex::new(None)),
559 pages: OnceCell::new(),
560 page_size: page_size.max(1),
561 }
562 }
563
564 pub fn proc_mesh(&self) -> &ProcMeshRef {
565 &self.proc_mesh
566 }
567
568 #[inline]
569 fn len(&self) -> usize {
570 view::Ranked::region(&self.proc_mesh).num_ranks()
571 }
572
573 pub fn controller(&self) -> &Option<ActorRef<ActorMeshController<A>>> {
574 &self.controller
575 }
576
577 fn set_controller(&mut self, controller: Option<ActorRef<ActorMeshController<A>>>) {
578 self.controller = controller;
579 }
580
581 fn ensure_pages(&self) -> &Vec<OnceCell<Box<Page<A>>>> {
582 let n = self.len().div_ceil(self.page_size); self.pages
584 .get_or_init(|| (0..n).map(|_| OnceCell::new()).collect())
585 }
586
587 fn materialize(&self, rank: usize) -> Option<&ActorRef<A>> {
588 let len = self.len();
589 if rank >= len {
590 return None;
591 }
592 let p = self.page_size;
593 let page_ix = rank / p;
594 let local_ix = rank % p;
595
596 let pages = self.ensure_pages();
597 let page = pages[page_ix].get_or_init(|| {
598 let base = page_ix * p;
600 let remaining = len - base;
601 let page_len = remaining.min(p);
602 Box::new(Page::<A>::new(page_len))
603 });
604
605 Some(page.slots[local_ix].get_or_init(|| {
606 debug_assert!(rank < self.len(), "rank must be within [0, len)");
614 debug_assert!(
615 ndslice::view::Ranked::get(&self.proc_mesh, rank).is_some(),
616 "proc_mesh must be dense/aligned with this view"
617 );
618 let proc_ref =
619 ndslice::view::Ranked::get(&self.proc_mesh, rank).expect("rank in-bounds");
620 proc_ref.attest(&self.name)
621 }))
622 }
623
624 fn health_state(&self) -> &Arc<std::sync::Mutex<HealthState>> {
625 &self.health_state
626 }
627
628 fn init_supervision_receiver(
629 controller: &ActorRef<ActorMeshController<A>>,
630 cx: &impl context::Actor,
631 ) -> watch::Receiver<MessageOrFailure<Option<MeshFailure>>> {
632 let (tx, rx) = cx.mailbox().open_port();
633 controller
634 .send(cx, Subscribe(tx.bind()))
635 .expect("failed to send Subscribe");
636 into_watch(rx)
637 }
638
639 pub async fn next_supervision_event(
646 &self,
647 cx: &impl context::Actor,
648 ) -> Result<MeshFailure, anyhow::Error> {
649 let controller = if let Some(c) = self.controller() {
650 c
651 } else {
652 let health_state = self.health_state.lock().expect("lock poisoned");
653 return match &health_state.unhealthy_event {
654 Some(Unhealthy::StreamClosed(f)) => Ok(f.clone()),
655 Some(Unhealthy::Crashed(f)) => Ok(f.clone()),
656 None => Err(anyhow::anyhow!(
657 "unexpected healthy state while controller is gone"
658 )),
659 };
660 };
661 let message = {
662 let mut receiver = self.receiver.lock().await;
664 let rx =
665 receiver.get_or_insert_with(|| Self::init_supervision_receiver(controller, cx));
666 let message = rx
667 .wait_for(|message| {
668 if let MessageOrFailure::Message(message) = message {
672 if let Some(message) = &message {
673 if let Some(rank) = &message.rank {
674 ndslice::view::Ranked::region(self).slice().contains(*rank)
675 } else {
676 true
678 }
679 } else {
680 false
684 }
685 } else {
686 true
688 }
689 })
690 .await?;
691 let message = message.clone();
692 match message {
693 MessageOrFailure::Message(message) => Ok::<MeshFailure, anyhow::Error>(
694 message.expect("filter excludes any None messages"),
695 ),
696 MessageOrFailure::Failure(failure) => Err(anyhow::anyhow!("{}", failure)),
697 MessageOrFailure::Timeout => {
698 Ok(MeshFailure {
701 actor_mesh_name: Some(self.name().to_string()),
702 rank: None,
703 event: ActorSupervisionEvent::new(
704 controller.actor_id().clone(),
705 None,
706 ActorStatus::generic_failure(format!(
707 "timed out reaching controller {} for mesh {}. Assuming controller's proc is dead",
708 controller.actor_id(),
709 self.name()
710 )),
711 None,
712 ),
713 })
714 }
715 }?
716 };
717 let rank = message.rank.unwrap_or_default();
719 let event = &message.event;
720 let mut health_state = self.health_state.lock().expect("lock poisoned");
722 if let ActorStatus::Failed(_) = event.actor_status {
723 health_state.crashed_ranks.insert(rank, event.clone());
724 }
725 health_state.unhealthy_event = match &event.actor_status {
726 ActorStatus::Failed(_) => Some(Unhealthy::Crashed(message.clone())),
727 ActorStatus::Stopped => Some(Unhealthy::StreamClosed(message.clone())),
728 _ => None,
729 };
730 Ok(message)
731 }
732}
733
734impl<A: Referable> Clone for ActorMeshRef<A> {
735 fn clone(&self) -> Self {
736 Self {
737 proc_mesh: self.proc_mesh.clone(),
738 name: self.name.clone(),
739 controller: self.controller.clone(),
740 health_state: Arc::new(std::sync::Mutex::new(HealthState::default())),
743 receiver: Arc::new(tokio::sync::Mutex::new(None)),
744 pages: OnceCell::new(), page_size: self.page_size,
746 }
747 }
748}
749
750impl<A: Referable> fmt::Display for ActorMeshRef<A> {
751 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
752 write!(f, "{}:{}@{}", self.name, A::typename(), self.proc_mesh)
753 }
754}
755
756impl<A: Referable> PartialEq for ActorMeshRef<A> {
757 fn eq(&self, other: &Self) -> bool {
758 self.proc_mesh == other.proc_mesh && self.name == other.name
759 }
760}
761impl<A: Referable> Eq for ActorMeshRef<A> {}
762
763impl<A: Referable> Hash for ActorMeshRef<A> {
764 fn hash<H: Hasher>(&self, state: &mut H) {
765 self.proc_mesh.hash(state);
766 self.name.hash(state);
767 }
768}
769
770impl<A: Referable> fmt::Debug for ActorMeshRef<A> {
771 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
772 f.debug_struct("ActorMeshRef")
773 .field("proc_mesh", &self.proc_mesh)
774 .field("name", &self.name)
775 .field("page_size", &self.page_size)
776 .finish_non_exhaustive() }
778}
779
780impl<A: Referable> Serialize for ActorMeshRef<A> {
782 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
783 where
784 S: Serializer,
785 {
786 (&self.proc_mesh, &self.name, &self.controller).serialize(serializer)
788 }
789}
790
791impl<'de, A: Referable> Deserialize<'de> for ActorMeshRef<A> {
793 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
794 where
795 D: Deserializer<'de>,
796 {
797 let (proc_mesh, name, controller) =
798 <(ProcMeshRef, Name, Option<ActorRef<ActorMeshController<A>>>)>::deserialize(
799 deserializer,
800 )?;
801 Ok(ActorMeshRef::with_page_size(
802 name,
803 proc_mesh,
804 DEFAULT_PAGE,
805 controller,
806 ))
807 }
808}
809
810impl<A: Referable> view::Ranked for ActorMeshRef<A> {
811 type Item = ActorRef<A>;
812
813 #[inline]
814 fn region(&self) -> &Region {
815 view::Ranked::region(&self.proc_mesh)
816 }
817
818 #[inline]
819 fn get(&self, rank: usize) -> Option<&Self::Item> {
820 self.materialize(rank)
821 }
822}
823
824impl<A: Referable> view::RankedSliceable for ActorMeshRef<A> {
825 fn sliced(&self, region: Region) -> Self {
826 debug_assert!(region.is_subset(view::Ranked::region(self)));
829 let proc_mesh = self.proc_mesh.subset(region).unwrap();
830 Self::with_page_size(
831 self.name.clone(),
832 proc_mesh,
833 self.page_size,
834 self.controller.clone(),
835 )
836 }
837}
838
839#[cfg(test)]
840mod tests {
841
842 use std::collections::HashSet;
843 use std::ops::Deref;
844
845 use hyperactor::actor::ActorErrorKind;
846 use hyperactor::actor::ActorStatus;
847 use hyperactor::clock::Clock;
848 use hyperactor::clock::RealClock;
849 use hyperactor::context::Mailbox as _;
850 use hyperactor::mailbox;
851 use ndslice::Extent;
852 use ndslice::ViewExt;
853 use ndslice::extent;
854 use ndslice::view::Ranked;
855 use timed_test::async_timed_test;
856 use tokio::time::Duration;
857
858 use super::ActorMesh;
859 use crate::supervision::MeshFailure;
860 use crate::v1::ActorMeshRef;
861 use crate::v1::Name;
862 use crate::v1::ProcMesh;
863 use crate::v1::proc_mesh::ACTOR_SPAWN_MAX_IDLE;
864 use crate::v1::proc_mesh::GET_ACTOR_STATE_MAX_IDLE;
865 use crate::v1::testactor;
866 use crate::v1::testing;
867
868 #[test]
869 fn test_actor_mesh_ref_is_send_and_sync() {
870 fn assert_send_sync<T: Send + Sync>() {}
871 assert_send_sync::<ActorMeshRef<()>>();
872 }
873
874 #[tokio::test]
875 #[cfg(fbcode_build)]
876 async fn test_actor_mesh_ref_lazy_materialization() {
877 let instance = testing::instance();
879 let extent = extent!(replicas = 3, hosts = 2); let pm: ProcMesh = testing::proc_meshes(instance, extent.clone())
883 .await
884 .into_iter()
885 .next()
886 .expect("at least one proc mesh");
887 let am: ActorMesh<testactor::TestActor> = pm.spawn(instance, "test", &()).await.unwrap();
888
889 let page_size = 2;
893 let amr: ActorMeshRef<testactor::TestActor> =
894 ActorMeshRef::with_page_size(am.name.clone(), pm.clone(), page_size, None);
895 assert_eq!(amr.extent(), extent);
896 assert_eq!(amr.region().num_ranks(), 6);
897
898 let p0_a = amr.get(0).expect("rank 0 exists") as *const _;
900 let p0_b = amr.get(0).expect("rank 0 exists") as *const _;
901 assert_eq!(p0_a, p0_b, "same rank should return same cached pointer");
902
903 let p1_a = amr.get(1).expect("rank 1 exists") as *const _;
905 let p1_b = amr.get(1).expect("rank 1 exists") as *const _;
906 assert_eq!(p1_a, p1_b, "same rank should return same cached pointer");
907 assert_ne!(p0_a, p1_a, "different ranks have different cache slots");
910
911 let p2_a = amr.get(2).expect("rank 2 exists") as *const _;
913 let p2_b = amr.get(2).expect("rank 2 exists") as *const _;
914 assert_eq!(p2_a, p2_b, "same rank should return same cached pointer");
915 assert_ne!(p0_a, p2_a, "different pages have different cache slots");
916
917 let amr_clone = amr.clone();
919 let orig_id_0 = amr.get(0).unwrap().actor_id().clone();
920 let clone_id_0 = amr_clone.get(0).unwrap().actor_id().clone();
921 assert_eq!(orig_id_0, clone_id_0, "clone preserves identity");
922 let p0_clone = amr_clone.get(0).unwrap() as *const _;
923 assert_ne!(
924 p0_a, p0_clone,
925 "cloned ActorMeshRef has a fresh cache (different pointer)"
926 );
927
928 let sliced = amr.range("replicas", 1..).expect("slice should be valid"); assert_eq!(sliced.region().num_ranks(), 4);
932 let sp0_a = sliced.get(0).unwrap() as *const _;
934 let sp0_b = sliced.get(0).unwrap() as *const _;
935 assert_eq!(sp0_a, sp0_b, "sliced view has its own cache slot per rank");
936 let sp2 = sliced.get(2).unwrap() as *const _;
939 assert_ne!(sp0_a, sp2, "sliced view crosses its own page boundary");
940
941 let mut set = HashSet::new();
944 set.insert(amr.clone());
945 set.insert(amr.clone());
946 assert_eq!(set.len(), 1, "cache state must not affect Hash/Eq");
947
948 let (port, mut rx) = mailbox::open_port(instance);
951 amr.get(0)
954 .expect("rank 0 exists")
955 .send(instance, testactor::GetActorId(port.bind()))
956 .expect("send to rank 0 should succeed");
957 amr.get(3)
958 .expect("rank 3 exists")
959 .send(instance, testactor::GetActorId(port.bind()))
960 .expect("send to rank 3 should succeed");
961 let id_a = RealClock
962 .timeout(Duration::from_secs(3), rx.recv())
963 .await
964 .expect("timed out waiting for first reply")
965 .expect("channel closed before first reply");
966 let id_b = RealClock
967 .timeout(Duration::from_secs(3), rx.recv())
968 .await
969 .expect("timed out waiting for second reply")
970 .expect("channel closed before second reply");
971 assert_ne!(id_a, id_b, "two different ranks responded");
972 }
973
974 #[async_timed_test(timeout_secs = 30)]
975 #[cfg(fbcode_build)]
976 async fn test_actor_states_with_panic() {
977 hyperactor_telemetry::initialize_logging_for_test();
978
979 let instance = testing::instance();
980 let (supervision_port, mut supervision_receiver) = instance.open_port::<MeshFailure>();
982 let supervisor = supervision_port.bind();
983 let num_replicas = 4;
984 let meshes = testing::proc_meshes(instance, extent!(replicas = num_replicas)).await;
985 let proc_mesh = &meshes[1];
986 let child_name = Name::new("child").unwrap();
987
988 let actor_mesh: ActorMesh<testactor::WrapperActor> = proc_mesh
991 .spawn(
992 instance,
993 "wrapper",
994 &(proc_mesh.deref().clone(), supervisor, child_name.clone()),
995 )
996 .await
997 .unwrap();
998
999 actor_mesh
1001 .cast(
1002 instance,
1003 testactor::CauseSupervisionEvent {
1004 kind: testactor::SupervisionEventType::Panic,
1005 send_to_children: true,
1006 },
1007 )
1008 .unwrap();
1009
1010 let (failure_port, mut failure_receiver) = instance.open_port::<Option<MeshFailure>>();
1020 actor_mesh
1021 .cast(
1022 instance,
1023 testactor::NextSupervisionFailure(failure_port.bind()),
1024 )
1025 .unwrap();
1026 let failure = failure_receiver
1027 .recv()
1028 .await
1029 .unwrap()
1030 .expect("no supervision event found on ref from wrapper actor");
1031 let check_failure = move |failure: MeshFailure| {
1032 assert_eq!(failure.actor_mesh_name, Some(child_name.to_string()));
1033 assert_eq!(
1034 failure.event.actor_id.name(),
1035 child_name.clone().to_string()
1036 );
1037 if let ActorStatus::Failed(ActorErrorKind::Generic(msg)) = &failure.event.actor_status {
1038 assert!(msg.contains("panic"), "{}", msg);
1039 assert!(msg.contains("for testing"), "{}", msg);
1040 } else {
1041 panic!("actor status is not failed: {}", failure.event.actor_status);
1042 }
1043 };
1044 check_failure(failure);
1045
1046 for _ in 0..num_replicas {
1050 let failure = RealClock
1051 .timeout(Duration::from_secs(10), supervision_receiver.recv())
1052 .await
1053 .expect("timeout")
1054 .unwrap();
1055 check_failure(failure);
1056 }
1057 }
1058
1059 #[async_timed_test(timeout_secs = 30)]
1060 #[cfg(fbcode_build)]
1061 async fn test_actor_states_with_process_exit() {
1062 hyperactor_telemetry::initialize_logging_for_test();
1063
1064 let config = hyperactor_config::global::lock();
1065 let _guard = config.override_key(GET_ACTOR_STATE_MAX_IDLE, Duration::from_secs(1));
1066
1067 let instance = testing::instance();
1068 let (supervision_port, mut supervision_receiver) = instance.open_port::<MeshFailure>();
1070 let supervisor = supervision_port.bind();
1071 let num_replicas = 4;
1072 let meshes = testing::proc_meshes(instance, extent!(replicas = num_replicas)).await;
1073 let second_meshes = testing::proc_meshes(instance, extent!(replicas = num_replicas)).await;
1074 let proc_mesh = &meshes[1];
1075 let second_proc_mesh = &second_meshes[1];
1076 let child_name = Name::new("child").unwrap();
1077
1078 let actor_mesh: ActorMesh<testactor::WrapperActor> = proc_mesh
1081 .spawn(
1082 instance,
1083 "wrapper",
1084 &(
1085 second_proc_mesh.deref().clone(),
1088 supervisor,
1089 child_name.clone(),
1090 ),
1091 )
1092 .await
1093 .unwrap();
1094
1095 actor_mesh
1096 .cast(
1097 instance,
1098 testactor::CauseSupervisionEvent {
1099 kind: testactor::SupervisionEventType::ProcessExit(1),
1100 send_to_children: true,
1101 },
1102 )
1103 .unwrap();
1104
1105 let (failure_port, mut failure_receiver) = instance.open_port::<Option<MeshFailure>>();
1107 actor_mesh
1108 .cast(
1109 instance,
1110 testactor::NextSupervisionFailure(failure_port.bind()),
1111 )
1112 .unwrap();
1113 let failure = failure_receiver
1114 .recv()
1115 .await
1116 .unwrap()
1117 .expect("no supervision event found on ref from wrapper actor");
1118
1119 let check_failure = move |failure: MeshFailure| {
1120 assert_eq!(failure.actor_mesh_name, Some(child_name.to_string()));
1122 assert_eq!(failure.event.actor_id.name(), "mesh");
1123 if let ActorStatus::Failed(ActorErrorKind::Generic(msg)) = &failure.event.actor_status {
1124 assert!(
1125 msg.contains("timeout waiting for message from proc mesh agent"),
1126 "{}",
1127 msg
1128 );
1129 } else {
1130 panic!("actor status is not failed: {}", failure.event.actor_status);
1131 }
1132 };
1133 check_failure(failure);
1134
1135 for _ in 0..num_replicas {
1137 let failure = RealClock
1138 .timeout(Duration::from_secs(10), supervision_receiver.recv())
1139 .await
1140 .expect("timeout")
1141 .unwrap();
1142 check_failure(failure);
1143 }
1144 }
1145
1146 #[async_timed_test(timeout_secs = 30)]
1147 #[cfg(fbcode_build)]
1148 async fn test_actor_states_on_sliced_mesh() {
1149 hyperactor_telemetry::initialize_logging_for_test();
1150
1151 let instance = testing::instance();
1152 let (supervision_port, mut supervision_receiver) = instance.open_port::<MeshFailure>();
1154 let supervisor = supervision_port.bind();
1155 let num_replicas = 4;
1156 let meshes = testing::proc_meshes(instance, extent!(replicas = num_replicas)).await;
1157 let proc_mesh = &meshes[1];
1158 let child_name = Name::new("child").unwrap();
1159
1160 let actor_mesh: ActorMesh<testactor::WrapperActor> = proc_mesh
1163 .spawn(
1164 instance,
1165 "wrapper",
1166 &(proc_mesh.deref().clone(), supervisor, child_name.clone()),
1167 )
1168 .await
1169 .unwrap();
1170 let sliced = actor_mesh
1171 .range("replicas", 1..3)
1172 .expect("slice should be valid");
1173 let sliced_replicas = sliced.len();
1174
1175 sliced
1177 .cast(
1178 instance,
1179 testactor::CauseSupervisionEvent {
1180 kind: testactor::SupervisionEventType::Panic,
1181 send_to_children: true,
1182 },
1183 )
1184 .unwrap();
1185
1186 for _ in 0..sliced_replicas {
1187 let supervision_message = RealClock
1188 .timeout(Duration::from_secs(10), supervision_receiver.recv())
1189 .await
1190 .expect("timeout")
1191 .unwrap();
1192 let event = supervision_message.event;
1193 assert_eq!(event.actor_id.name(), format!("{}", child_name.clone()));
1194 if let ActorStatus::Failed(ActorErrorKind::Generic(msg)) = &event.actor_status {
1195 assert!(msg.contains("panic"));
1196 assert!(msg.contains("for testing"));
1197 } else {
1198 panic!("actor status is not failed: {}", event.actor_status);
1199 }
1200 }
1201 }
1202
1203 #[async_timed_test(timeout_secs = 30)]
1204 #[cfg(fbcode_build)]
1205 async fn test_cast() {
1206 let config = hyperactor_config::global::lock();
1207 let _guard = config.override_key(crate::bootstrap::MESH_BOOTSTRAP_ENABLE_PDEATHSIG, false);
1208
1209 let instance = testing::instance();
1210 let mut host_mesh = testing::host_mesh(extent!(host = 4)).await;
1211 let proc_mesh = host_mesh
1212 .spawn(instance, "test", Extent::unity())
1213 .await
1214 .unwrap();
1215 let actor_mesh: ActorMesh<testactor::TestActor> =
1216 proc_mesh.spawn(instance, "test", &()).await.unwrap();
1217
1218 let (cast_info, mut cast_info_rx) = instance.mailbox().open_port();
1219 actor_mesh
1220 .cast(
1221 instance,
1222 testactor::GetCastInfo {
1223 cast_info: cast_info.bind(),
1224 },
1225 )
1226 .unwrap();
1227
1228 let mut point_to_actor: HashSet<_> = actor_mesh.iter().collect();
1229 while !point_to_actor.is_empty() {
1230 let (point, origin_actor_ref, sender_actor_id) = cast_info_rx.recv().await.unwrap();
1231 let key = (point, origin_actor_ref);
1232 assert!(
1233 point_to_actor.remove(&key),
1234 "key {:?} not present or removed twice",
1235 key
1236 );
1237 assert_eq!(&sender_actor_id, instance.self_id());
1238 }
1239
1240 let _ = host_mesh.shutdown(&instance).await;
1241 }
1242
1243 #[async_timed_test(timeout_secs = 30)]
1249 #[cfg(fbcode_build)]
1250 async fn test_undeliverable_message_return() {
1251 use hyperactor::mailbox::MessageEnvelope;
1252 use hyperactor::mailbox::Undeliverable;
1253 use hyperactor::test_utils::pingpong::PingPongActor;
1254 use hyperactor::test_utils::pingpong::PingPongMessage;
1255
1256 hyperactor_telemetry::initialize_logging_for_test();
1257
1258 let config = hyperactor_config::global::lock();
1260 let _guard = config.override_key(
1261 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1262 std::time::Duration::from_secs(1),
1263 );
1264
1265 let instance = testing::instance();
1266
1267 let meshes = testing::proc_meshes(instance, extent!(replicas = 2)).await;
1269 let proc_mesh = &meshes[1]; let (undeliverable_port, mut undeliverable_rx) =
1273 instance.open_port::<Undeliverable<MessageEnvelope>>();
1274
1275 let ping_proc_mesh = proc_mesh.range("replicas", 0..1).unwrap();
1278 let pong_proc_mesh = proc_mesh.range("replicas", 1..2).unwrap();
1279
1280 let ping_mesh: ActorMesh<PingPongActor> = ping_proc_mesh
1281 .spawn(
1282 instance,
1283 "ping",
1284 &(Some(undeliverable_port.bind()), None, None),
1285 )
1286 .await
1287 .unwrap();
1288
1289 let mut pong_mesh: ActorMesh<PingPongActor> = pong_proc_mesh
1290 .spawn(instance, "pong", &(None, None, None))
1291 .await
1292 .unwrap();
1293
1294 let ping_handle = ping_mesh.values().next().unwrap();
1296 let pong_handle = pong_mesh.values().next().unwrap();
1297
1298 let (done_tx, done_rx) = instance.open_once_port();
1300 ping_handle
1301 .send(
1302 instance,
1303 PingPongMessage(2, pong_handle.clone(), done_tx.bind()),
1304 )
1305 .unwrap();
1306 assert!(
1307 done_rx.recv().await.unwrap(),
1308 "Initial ping-pong should work"
1309 );
1310
1311 pong_mesh.stop(instance).await.unwrap();
1313
1314 RealClock.sleep(std::time::Duration::from_millis(200)).await;
1316
1317 let n = 100usize;
1319 for i in 1..=n {
1320 let ttl = 66 + i as u64; let (once_tx, _once_rx) = instance.open_once_port();
1322 ping_handle
1323 .send(
1324 instance,
1325 PingPongMessage(ttl, pong_handle.clone(), once_tx.bind()),
1326 )
1327 .unwrap();
1328 }
1329
1330 let mut count = 0;
1334 let deadline = RealClock.now() + std::time::Duration::from_secs(5);
1335 while count < n && RealClock.now() < deadline {
1336 match RealClock
1337 .timeout(std::time::Duration::from_secs(1), undeliverable_rx.recv())
1338 .await
1339 {
1340 Ok(Ok(Undeliverable(envelope))) => {
1341 let _: PingPongMessage = envelope.deserialized().unwrap();
1342 count += 1;
1343 }
1344 Ok(Err(_)) => break, Err(_) => break, }
1347 }
1348
1349 assert_eq!(
1350 count, n,
1351 "Expected {} undeliverable messages, got {}",
1352 n, count
1353 );
1354 }
1355
1356 #[async_timed_test(timeout_secs = 30)]
1360 #[cfg(fbcode_build)]
1361 async fn test_actor_mesh_stop_timeout() {
1362 hyperactor_telemetry::initialize_logging_for_test();
1363
1364 let config = hyperactor_config::global::lock();
1374 let _guard = config.override_key(ACTOR_SPAWN_MAX_IDLE, std::time::Duration::from_secs(1));
1375
1376 let instance = testing::instance();
1377
1378 let meshes = testing::proc_meshes(instance, extent!(replicas = 2)).await;
1380 let proc_mesh = &meshes[1]; let mut sleep_mesh: ActorMesh<testactor::SleepActor> =
1385 proc_mesh.spawn(instance, "sleepers", &()).await.unwrap();
1386
1387 for actor_ref in sleep_mesh.values() {
1390 actor_ref
1391 .send(instance, std::time::Duration::from_secs(5))
1392 .unwrap();
1393 }
1394
1395 RealClock.sleep(std::time::Duration::from_millis(200)).await;
1397
1398 let expected_actors = sleep_mesh.values().count();
1400
1401 let stop_start = RealClock.now();
1404 let result = sleep_mesh.stop(instance).await;
1405 let stop_duration = RealClock.now().duration_since(stop_start);
1406
1407 match result {
1411 Ok(_) => {
1412 tracing::warn!("Actors stopped gracefully (unexpected but ok)");
1415 }
1416 Err(ref e) => {
1417 let err_str = format!("{:?}", e);
1419 assert!(
1420 err_str.contains("Timeout"),
1421 "Expected Timeout error, got: {:?}",
1422 e
1423 );
1424 tracing::info!(
1425 "Stop timed out as expected for {} actors, they were aborted",
1426 expected_actors
1427 );
1428 }
1429 }
1430
1431 assert!(
1436 stop_duration < std::time::Duration::from_secs(3),
1437 "Stop took {:?}, expected < 3s (actors should have been aborted, not waited for)",
1438 stop_duration
1439 );
1440 assert!(
1441 stop_duration >= std::time::Duration::from_millis(900),
1442 "Stop took {:?}, expected >= 900ms (should have waited for timeout)",
1443 stop_duration
1444 );
1445 }
1446
1447 #[async_timed_test(timeout_secs = 30)]
1453 #[cfg(fbcode_build)]
1454 async fn test_actor_mesh_stop_graceful() {
1455 hyperactor_telemetry::initialize_logging_for_test();
1456
1457 let instance = testing::instance();
1458
1459 let meshes = testing::proc_meshes(instance, extent!(replicas = 2)).await;
1461 let proc_mesh = &meshes[1];
1462
1463 let mut actor_mesh: ActorMesh<testactor::TestActor> =
1466 proc_mesh.spawn(instance, "test_actors", &()).await.unwrap();
1467
1468 let mesh_ref = actor_mesh.deref().clone();
1471
1472 let expected_actors = actor_mesh.values().count();
1473 assert!(expected_actors > 0, "Should have spawned some actors");
1474
1475 let stop_start = RealClock.now();
1477 let result = actor_mesh.stop(instance).await;
1478 let stop_duration = RealClock.now().duration_since(stop_start);
1479
1480 assert!(
1482 result.is_ok(),
1483 "Stop should succeed for responsive actors, got: {:?}",
1484 result.err()
1485 );
1486
1487 assert!(
1491 stop_duration < std::time::Duration::from_secs(2),
1492 "Graceful stop took {:?}, expected < 2s (actors should stop quickly)",
1493 stop_duration
1494 );
1495
1496 tracing::info!(
1497 "Successfully stopped {} actors in {:?}",
1498 expected_actors,
1499 stop_duration
1500 );
1501
1502 let next_event = actor_mesh.next_supervision_event(instance).await.unwrap();
1508 assert_eq!(
1509 next_event.actor_mesh_name,
1510 Some(mesh_ref.name().to_string())
1511 );
1512 assert_eq!(next_event.event.actor_status, ActorStatus::Stopped);
1513 let next_event = mesh_ref.next_supervision_event(instance).await.unwrap();
1516 assert_eq!(
1517 next_event.actor_mesh_name,
1518 Some(mesh_ref.name().to_string())
1519 );
1520 assert_eq!(next_event.event.actor_status, ActorStatus::Stopped);
1521 }
1522}