1use std::collections::HashMap;
10use std::fmt;
11use std::ops::Deref;
12use std::panic::Location;
13use std::sync::Arc;
14use std::sync::atomic::AtomicUsize;
15use std::sync::atomic::Ordering;
16
17use async_trait::async_trait;
18use dashmap::DashMap;
19use futures::future::join_all;
20use hyperactor::Actor;
21use hyperactor::ActorHandle;
22use hyperactor::ActorId;
23use hyperactor::ActorRef;
24use hyperactor::Instance;
25use hyperactor::RemoteMessage;
26use hyperactor::WorldId;
27use hyperactor::actor::ActorStatus;
28use hyperactor::actor::Referable;
29use hyperactor::actor::remote::Remote;
30use hyperactor::channel;
31use hyperactor::channel::ChannelAddr;
32use hyperactor::channel::ChannelTransport;
33use hyperactor::config;
34use hyperactor::config::CONFIG;
35use hyperactor::config::ConfigAttr;
36use hyperactor::context;
37use hyperactor::declare_attrs;
38use hyperactor::mailbox;
39use hyperactor::mailbox::BoxableMailboxSender;
40use hyperactor::mailbox::BoxedMailboxSender;
41use hyperactor::mailbox::DialMailboxRouter;
42use hyperactor::mailbox::MailboxServer;
43use hyperactor::mailbox::MessageEnvelope;
44use hyperactor::mailbox::PortHandle;
45use hyperactor::mailbox::PortReceiver;
46use hyperactor::mailbox::Undeliverable;
47use hyperactor::metrics;
48use hyperactor::proc::Proc;
49use hyperactor::reference::ProcId;
50use hyperactor::supervision::ActorSupervisionEvent;
51use ndslice::Range;
52use ndslice::Shape;
53use ndslice::ShapeError;
54use ndslice::View;
55use ndslice::ViewExt;
56use strum::AsRefStr;
57use tokio::sync::mpsc;
58use tracing::Instrument;
59use tracing::Level;
60use tracing::span;
61
62use crate::CommActor;
63use crate::Mesh;
64use crate::actor_mesh::RootActorMesh;
65use crate::alloc::Alloc;
66use crate::alloc::AllocExt;
67use crate::alloc::AllocatedProc;
68use crate::alloc::AllocatorError;
69use crate::alloc::ProcState;
70use crate::alloc::ProcStopReason;
71use crate::alloc::serve_with_config;
72use crate::assign::Ranks;
73use crate::comm::CommActorMode;
74use crate::proc_mesh::mesh_agent::GspawnResult;
75use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
76use crate::proc_mesh::mesh_agent::ProcMeshAgent;
77use crate::proc_mesh::mesh_agent::StopActorResult;
78use crate::proc_mesh::mesh_agent::update_event_actor_id;
79use crate::reference::ProcMeshId;
80use crate::router;
81use crate::shortuuid::ShortUuid;
82use crate::v1;
83use crate::v1::Name;
84
85pub mod mesh_agent;
86
87use std::sync::OnceLock;
88use std::sync::RwLock;
89
90declare_attrs! {
91 @meta(CONFIG = ConfigAttr {
93 env_name: Some("HYPERACTOR_MESH_DEFAULT_TRANSPORT".to_string()),
94 py_name: Some("default_transport".to_string()),
95 })
96 pub attr DEFAULT_TRANSPORT: ChannelTransport = ChannelTransport::Unix;
97}
98
99pub fn default_transport() -> ChannelTransport {
101 config::global::get_cloned(DEFAULT_TRANSPORT)
102}
103
104static GLOBAL_SUPERVISION_SINK: OnceLock<RwLock<Option<PortHandle<ActorSupervisionEvent>>>> =
111 OnceLock::new();
112
113fn sink_cell() -> &'static RwLock<Option<PortHandle<ActorSupervisionEvent>>> {
119 GLOBAL_SUPERVISION_SINK.get_or_init(|| RwLock::new(None))
120}
121
122pub(crate) fn set_global_supervision_sink(
136 sink: PortHandle<ActorSupervisionEvent>,
137) -> Option<PortHandle<ActorSupervisionEvent>> {
138 let cell = sink_cell();
139 let mut guard = cell.write().unwrap();
140 let prev = guard.take();
141 *guard = Some(sink);
142 prev
143}
144
145pub(crate) fn get_global_supervision_sink() -> Option<PortHandle<ActorSupervisionEvent>> {
157 sink_cell().read().unwrap().clone()
158}
159
160pub fn global_root_client() -> &'static Instance<()> {
164 static GLOBAL_INSTANCE: OnceLock<(Instance<()>, ActorHandle<()>)> = OnceLock::new();
165 &GLOBAL_INSTANCE.get_or_init(|| {
166 let client_proc = Proc::direct_with_default(
167 ChannelAddr::any(default_transport()),
168 "mesh_root_client_proc".into(),
169 router::global().clone().boxed(),
170 )
171 .unwrap();
172
173 router::global().bind(client_proc.proc_id().clone().into(), client_proc.clone());
176
177 let (client, handle) = client_proc
178 .instance("client")
179 .expect("root instance create");
180
181 let (_undeliverable_tx, undeliverable_rx) =
196 client.bind_actor_port::<Undeliverable<MessageEnvelope>>();
197 hyperactor::mailbox::supervise_undeliverable_messages_with(
198 undeliverable_rx,
199 crate::proc_mesh::get_global_supervision_sink,
200 |env| {
201 let sink_present = crate::proc_mesh::get_global_supervision_sink().is_some();
202 tracing::info!(
203 actor_id = %env.dest().actor_id(),
204 "global root client undeliverable observed with headers {:?} {}", env.headers(), sink_present
205 );
206 },
207 );
208
209 (client, handle)
210 }).0
211}
212
213type ActorEventRouter = Arc<DashMap<ActorMeshName, mpsc::UnboundedSender<ActorSupervisionEvent>>>;
214
215pub struct ProcMesh {
218 inner: ProcMeshKind,
219 shape: OnceLock<Shape>,
220}
221
222enum ProcMeshKind {
223 V0 {
224 event_state: Option<EventState>,
227 actor_event_router: ActorEventRouter,
228 shape: Shape,
229 ranks: Vec<(ShortUuid, ProcId, ChannelAddr, ActorRef<ProcMeshAgent>)>,
230 #[allow(dead_code)] client_proc: Proc,
232 client: Instance<()>,
233 comm_actors: Vec<ActorRef<CommActor>>,
234 world_id: WorldId,
235 },
236
237 V1(v1::ProcMeshRef),
238}
239
240struct EventState {
241 alloc: Box<dyn Alloc + Send + Sync>,
242 supervision_events: PortReceiver<ActorSupervisionEvent>,
243}
244
245impl From<v1::ProcMeshRef> for ProcMesh {
246 fn from(proc_mesh: v1::ProcMeshRef) -> Self {
247 ProcMesh {
248 inner: ProcMeshKind::V1(proc_mesh),
249 shape: OnceLock::new(),
250 }
251 }
252}
253
254impl ProcMesh {
255 #[hyperactor::instrument(fields(name = "ProcMesh::allocate"))]
256 pub async fn allocate(
257 alloc: impl Alloc + Send + Sync + 'static,
258 ) -> Result<Self, AllocatorError> {
259 ProcMesh::allocate_boxed(Box::new(alloc)).await
260 }
261
262 #[track_caller]
266 pub fn allocate_boxed(
267 alloc: Box<dyn Alloc + Send + Sync>,
268 ) -> impl std::future::Future<Output = Result<Self, AllocatorError>> {
269 Self::allocate_boxed_inner(alloc, Location::caller())
270 }
271
272 fn alloc_counter() -> &'static AtomicUsize {
273 static C: OnceLock<AtomicUsize> = OnceLock::new();
274 C.get_or_init(|| AtomicUsize::new(0))
275 }
276
277 #[tracing::instrument(skip_all)]
278 #[hyperactor::observe_result("ProcMesh")]
279 async fn allocate_boxed_inner(
280 mut alloc: Box<dyn Alloc + Send + Sync>,
281 loc: &'static Location<'static>,
282 ) -> Result<Self, AllocatorError> {
283 let alloc_id = Self::alloc_counter().fetch_add(1, Ordering::Relaxed) + 1;
284 let world = alloc.world_id().name().to_string();
285 tracing::info!(
286 name = "ProcMesh::Allocate::Attempt",
287 %world,
288 alloc_id,
289 caller = %format!("{}:{}", loc.file(), loc.line()),
290 shape = ?alloc.shape(),
291 "allocating proc mesh"
292 );
293
294 let running = alloc
296 .initialize()
297 .instrument(span!(
298 Level::INFO,
299 "ProcMesh::Allocate::Initialize",
300 alloc_id
301 ))
302 .await?;
303
304 let router = DialMailboxRouter::new_with_default(router::global().boxed());
307 for AllocatedProc { proc_id, addr, .. } in running.iter() {
308 if proc_id.is_direct() {
309 continue;
310 }
311 router.bind(proc_id.clone().into(), addr.clone());
312 }
313
314 let client_proc_id =
318 ProcId::Ranked(WorldId(format!("{}_client", alloc.world_id().name())), 0);
319 let (client_proc_addr, client_rx) = channel::serve(ChannelAddr::any(alloc.transport()))
320 .map_err(|err| AllocatorError::Other(err.into()))?;
321 tracing::info!(
322 name = "ProcMesh::Allocate::ChannelServe",
323 alloc_id = alloc_id,
324 "client proc started listening on addr: {client_proc_addr}"
325 );
326 let client_proc = Proc::new(
327 client_proc_id.clone(),
328 BoxedMailboxSender::new(router.clone()),
329 );
330 client_proc.clone().serve(client_rx);
331 router.bind(client_proc_id.clone().into(), client_proc_addr.clone());
332
333 router::global().bind_dial_router(&router);
336
337 let (supervisor, _supervisor_handle) = client_proc.instance("supervisor")?;
338 let (supervision_port, supervision_events) =
339 supervisor.open_port::<ActorSupervisionEvent>();
340
341 let _prev = set_global_supervision_sink(supervision_port.clone());
359
360 let (client, _handle) = client_proc.instance("client")?;
371 let (_undeliverable_messages, client_undeliverable_receiver) =
373 client.bind_actor_port::<Undeliverable<MessageEnvelope>>();
374 hyperactor::mailbox::supervise_undeliverable_messages(
375 supervision_port.clone(),
376 client_undeliverable_receiver,
377 |env| {
378 tracing::info!(actor=%env.dest().actor_id(), "per-mesh client undeliverable observed");
379 },
380 );
381
382 let (router_channel_addr, router_rx) =
384 serve_with_config(alloc.client_router_addr()).map_err(AllocatorError::Other)?;
385 router.serve(router_rx);
386 tracing::info!("router channel started listening on addr: {router_channel_addr}");
387
388 let address_book: HashMap<_, _> = running
391 .iter()
392 .map(
393 |AllocatedProc {
394 addr, mesh_agent, ..
395 }| { (mesh_agent.actor_id().proc_id().clone(), addr.clone()) },
396 )
397 .collect();
398
399 let (config_handle, mut config_receiver) = client.open_port();
400 for (rank, AllocatedProc { mesh_agent, .. }) in running.iter().enumerate() {
401 mesh_agent
402 .configure(
403 &client,
404 rank,
405 router_channel_addr.clone(),
406 Some(supervision_port.bind()),
407 address_book.clone(),
408 config_handle.bind(),
409 false,
410 )
411 .await?;
412 }
413 let mut completed = Ranks::new(running.len());
414 while !completed.is_full() {
415 let rank = config_receiver
416 .recv()
417 .await
418 .map_err(|err| AllocatorError::Other(err.into()))?;
419 if completed.insert(rank, rank).is_some() {
420 tracing::warn!("multiple completions received for rank {}", rank);
421 }
422 }
423
424 fn project_mesh_agent_ref(allocated_proc: &AllocatedProc) -> ActorRef<ProcMeshAgent> {
438 allocated_proc.mesh_agent.clone()
439 }
440
441 let comm_actors = Self::spawn_on_procs::<CommActor>(
446 &client,
447 running.iter().map(project_mesh_agent_ref),
448 "comm",
449 &Default::default(),
450 )
451 .await?;
452 let address_book: HashMap<_, _> = comm_actors.iter().cloned().enumerate().collect();
453 for (rank, comm_actor) in comm_actors.iter().enumerate() {
456 comm_actor
457 .send(&client, CommActorMode::Mesh(rank, address_book.clone()))
458 .map_err(anyhow::Error::from)?;
459 }
460
461 let shape = alloc.shape().clone();
462 let world_id = alloc.world_id().clone();
463 metrics::PROC_MESH_ALLOCATION.add(
464 running.len() as u64,
465 hyperactor_telemetry::kv_pairs!("alloc_id" => alloc_id.to_string()),
466 );
467
468 Ok(Self {
469 inner: ProcMeshKind::V0 {
470 event_state: Some(EventState {
471 alloc,
472 supervision_events,
473 }),
474 actor_event_router: Arc::new(DashMap::new()),
475 shape,
476 ranks: running
477 .into_iter()
478 .map(
479 |AllocatedProc {
480 create_key,
481 proc_id,
482 addr,
483 mesh_agent,
484 }| (create_key, proc_id, addr, mesh_agent),
485 )
486 .collect(),
487 client_proc,
488 client,
489 comm_actors,
490 world_id,
491 },
492 shape: OnceLock::new(),
493 })
494 }
495
496 async fn spawn_on_procs<A: Actor + Referable>(
505 cx: &impl context::Actor,
506 agents: impl IntoIterator<Item = ActorRef<ProcMeshAgent>> + '_,
507 actor_name: &str,
508 params: &A::Params,
509 ) -> Result<Vec<ActorRef<A>>, anyhow::Error>
510 where
511 A::Params: RemoteMessage,
512 {
513 let remote = Remote::collect();
514 let actor_type = remote
515 .name_of::<A>()
516 .ok_or(anyhow::anyhow!("actor not registered"))?
517 .to_string();
518
519 let (completed_handle, mut completed_receiver) = mailbox::open_port(cx);
520 let mut n = 0;
521 for agent in agents {
522 agent
523 .gspawn(
524 cx,
525 actor_type.clone(),
526 actor_name.to_string(),
527 bincode::serialize(params)?,
528 completed_handle.bind(),
529 )
530 .await?;
531 n += 1;
532 }
533 let mut completed = Ranks::new(n);
534 while !completed.is_full() {
535 let result = completed_receiver.recv().await?;
536 match result {
537 GspawnResult::Success { rank, actor_id } => {
538 if completed.insert(rank, actor_id).is_some() {
539 tracing::warn!("multiple completions received for rank {}", rank);
540 }
541 }
542 GspawnResult::Error(error_msg) => {
543 metrics::PROC_MESH_ACTOR_FAILURES.add(
544 1,
545 hyperactor_telemetry::kv_pairs!(
546 "actor_name" => actor_name.to_string(),
547 "error" => error_msg.clone(),
548 ),
549 );
550
551 anyhow::bail!("gspawn failed: {}", error_msg);
552 }
553 }
554 }
555
556 Ok(completed
559 .into_iter()
560 .map(Option::unwrap)
561 .map(ActorRef::attest)
562 .collect())
563 }
564
565 fn agents(&self) -> Box<dyn Iterator<Item = ActorRef<ProcMeshAgent>> + '_ + Send> {
566 match &self.inner {
567 ProcMeshKind::V0 { ranks, .. } => {
568 Box::new(ranks.iter().map(|(_, _, _, agent)| agent.clone()))
569 }
570 ProcMeshKind::V1(proc_mesh) => Box::new(
571 proc_mesh
572 .agent_mesh()
573 .iter()
574 .map(|(_point, agent)| agent.clone())
575 .collect::<Vec<_>>()
583 .into_iter(),
584 ),
585 }
586 }
587
588 pub(crate) fn comm_actor(&self) -> &ActorRef<CommActor> {
590 match &self.inner {
591 ProcMeshKind::V0 { comm_actors, .. } => &comm_actors[0],
592 ProcMeshKind::V1(proc_mesh) => proc_mesh.root_comm_actor().unwrap(),
593 }
594 }
595
596 pub async fn spawn<A: Actor + Referable>(
612 &self,
613 cx: &impl context::Actor,
614 actor_name: &str,
615 params: &A::Params,
616 ) -> Result<RootActorMesh<'_, A>, anyhow::Error>
617 where
618 A::Params: RemoteMessage,
619 {
620 match &self.inner {
621 ProcMeshKind::V0 {
622 actor_event_router,
623 client,
624 ..
625 } => {
626 let (tx, rx) = mpsc::unbounded_channel::<ActorSupervisionEvent>();
627 {
628 actor_event_router.insert(actor_name.to_string(), tx);
630 tracing::info!(
631 name = "router_insert",
632 actor_name = %actor_name,
633 "the length of the router is {}", actor_event_router.len(),
634 );
635 }
636 let root_mesh = RootActorMesh::new(
637 self,
638 actor_name.to_string(),
639 rx,
640 Self::spawn_on_procs::<A>(client, self.agents(), actor_name, params).await?,
641 );
642 Ok(root_mesh)
643 }
644 ProcMeshKind::V1(proc_mesh) => {
645 let actor_mesh = proc_mesh.spawn(cx, actor_name, params).await?;
646 Ok(RootActorMesh::new_v1(actor_mesh.detach()))
647 }
648 }
649 }
650
651 pub fn client(&self) -> &Instance<()> {
653 match &self.inner {
654 ProcMeshKind::V0 { client, .. } => client,
655 ProcMeshKind::V1(_proc_mesh) => unimplemented!("no client for v1::ProcMesh"),
656 }
657 }
658
659 pub fn client_proc(&self) -> &Proc {
660 match &self.inner {
661 ProcMeshKind::V0 { client_proc, .. } => client_proc,
662 ProcMeshKind::V1(_proc_mesh) => unimplemented!("no client proc for v1::ProcMesh"),
663 }
664 }
665
666 pub fn proc_id(&self) -> &ProcId {
667 self.client_proc().proc_id()
668 }
669
670 pub fn world_id(&self) -> &WorldId {
671 match &self.inner {
672 ProcMeshKind::V0 { world_id, .. } => world_id,
673 ProcMeshKind::V1(_proc_mesh) => unimplemented!("no world_id for v1::ProcMesh"),
674 }
675 }
676
677 pub fn events(&mut self) -> Option<ProcEvents> {
680 match &mut self.inner {
681 ProcMeshKind::V0 {
682 event_state,
683 ranks,
684 actor_event_router,
685 ..
686 } => event_state.take().map(|event_state| ProcEvents {
687 event_state,
688 ranks: ranks
689 .iter()
690 .enumerate()
691 .map(|(rank, (create_key, proc_id, _addr, _mesh_agent))| {
692 (proc_id.clone(), (rank, create_key.clone()))
693 })
694 .collect(),
695 actor_event_router: actor_event_router.clone(),
696 }),
697 #[allow(clippy::todo)]
698 ProcMeshKind::V1(_proc_mesh) => todo!(),
699 }
700 }
701
702 pub fn shape(&self) -> &Shape {
703 self.shape.get_or_init(|| match &self.inner {
706 ProcMeshKind::V0 { shape, .. } => shape.clone(),
707 ProcMeshKind::V1(proc_mesh) => proc_mesh.region().into(),
708 })
709 }
710
711 #[hyperactor::observe_result("ProcMesh")]
713 pub async fn stop_actor_by_name(
714 &self,
715 cx: &impl context::Actor,
716 mesh_name: &str,
717 ) -> Result<(), anyhow::Error> {
718 match &self.inner {
719 ProcMeshKind::V0 { client, .. } => {
720 let timeout =
721 hyperactor::config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT);
722 let results = join_all(self.agents().map(|agent| async move {
723 let actor_id =
724 ActorId(agent.actor_id().proc_id().clone(), mesh_name.to_string(), 0);
725 (
726 actor_id.clone(),
727 agent
728 .clone()
729 .stop_actor(client, actor_id, timeout.as_millis() as u64)
730 .await,
731 )
732 }))
733 .await;
734
735 for (actor_id, result) in results {
736 match result {
737 Ok(StopActorResult::Timeout) => {
738 tracing::warn!("timed out while stopping actor {}", actor_id);
739 }
740 Ok(StopActorResult::NotFound) => {
741 tracing::warn!("no actor {} on proc {}", actor_id, actor_id.proc_id());
742 }
743 Ok(StopActorResult::Success) => {
744 tracing::info!("stopped actor {}", actor_id);
745 }
746 Err(e) => {
747 tracing::warn!("error stopping actor {}: {}", actor_id, e);
748 }
749 }
750 }
751 Ok(())
752 }
753 ProcMeshKind::V1(proc_mesh) => {
754 proc_mesh
755 .stop_actor_by_name(cx, Name::new_reserved(mesh_name))
756 .await?;
757 Ok(())
758 }
759 }
760 }
761}
762
763#[derive(Debug, Clone)]
765pub enum ProcEvent {
766 Stopped(usize, ProcStopReason),
768 Crashed(usize, String),
771}
772
773#[derive(Debug, Clone, AsRefStr)]
774pub enum SupervisionEventState {
775 SupervisionEventForward,
776 SupervisionEventForwardFailed,
777 SupervisionEventReceived,
778 SupervisionEventTransmitFailed,
779}
780
781impl fmt::Display for ProcEvent {
782 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
783 match self {
784 ProcEvent::Stopped(rank, reason) => {
785 write!(f, "Proc at rank {} stopped: {}", rank, reason)
786 }
787 ProcEvent::Crashed(rank, reason) => {
788 write!(f, "Proc at rank {} crashed: {}", rank, reason)
789 }
790 }
791 }
792}
793
794type ActorMeshName = String;
795
796pub struct ProcEvents {
799 event_state: EventState,
800 ranks: HashMap<ProcId, (usize, ShortUuid)>,
802 actor_event_router: ActorEventRouter,
803}
804
805impl ProcEvents {
806 pub async fn next(&mut self) -> Option<ProcEvent> {
809 loop {
810 tokio::select! {
811 result = self.event_state.alloc.next() => {
812 tracing::debug!(name = "ProcEventReceived", "received ProcEvent alloc update: {result:?}");
813 let Some(alloc_event) = result else {
815 self.actor_event_router.clear();
816 break None;
817 };
818
819 let ProcState::Stopped { create_key, reason } = alloc_event else {
820 continue;
822 };
823
824 let Some((proc_id, (rank, _create_key))) = self.ranks.iter().find(|(_proc_id, (_rank, key))| key == &create_key) else {
825 tracing::warn!("received stop event for unmapped proc {}", create_key);
826 continue;
827 };
828
829 metrics::PROC_MESH_PROC_STOPPED.add(
830 1,
831 hyperactor_telemetry::kv_pairs!(
832 "create_key" => create_key.to_string(),
833 "rank" => rank.to_string(),
834 "reason" => reason.to_string(),
835 ),
836 );
837
838 for entry in self.actor_event_router.iter() {
841 let event = ActorSupervisionEvent::new(
844 proc_id.actor_id("any", 0),
845 None,
846 ActorStatus::generic_failure(format!("proc {} is stopped", proc_id)),
847 None,
848 );
849 tracing::debug!(name = "SupervisionEvent", %event);
850 if entry.value().send(event.clone()).is_err() {
851 tracing::warn!(
852 name = SupervisionEventState::SupervisionEventTransmitFailed.as_ref(),
853 "unable to transmit supervision event to actor {}", entry.key()
854 );
855 }
856 }
857
858 let event = ProcEvent::Stopped(*rank, reason.clone());
859 tracing::debug!(name = "SupervisionEvent", %event);
860
861 break Some(ProcEvent::Stopped(*rank, reason));
862 }
863
864 Ok(event) = self.event_state.supervision_events.recv() => {
875 let had_headers = event.message_headers.is_some();
876 tracing::info!(
877 name = SupervisionEventState::SupervisionEventReceived.as_ref(),
878 actor_id = %event.actor_id,
879 actor_name = %event.actor_id.name(),
880 status = %event.actor_status,
881 "proc supervision: event received with {had_headers} headers"
882 );
883 tracing::debug!(
884 name = "SupervisionEvent",
885 %event,
886 "proc supervision: full event");
887
888 let event = update_event_actor_id(event);
890
891 let actor_id = event.actor_id.clone();
897 let actor_status = event.actor_status.clone();
898 let reason = event.to_string();
899 if let Some(tx) = self.actor_event_router.get(actor_id.name()) {
900 tracing::info!(
901 name = SupervisionEventState::SupervisionEventForwardFailed.as_ref(),
902 actor_id = %actor_id,
903 actor_name = actor_id.name(),
904 status = %actor_status,
905 "proc supervision: delivering event to registered ActorMesh"
906 );
907 if tx.send(event).is_err() {
908 tracing::warn!(
909 name = SupervisionEventState::SupervisionEventForwardFailed.as_ref(),
910 actor_id = %actor_id,
911 "proc supervision: registered ActorMesh dropped receiver; unable to deliver"
912 );
913 }
914 } else {
915 let registered_meshes: Vec<_> = self.actor_event_router.iter().map(|e| e.key().clone()).collect();
916 tracing::warn!(
917 name = SupervisionEventState::SupervisionEventForwardFailed.as_ref(),
918 actor_id = %actor_id,
919 "proc supervision: no ActorMesh registered for this actor {:?}", registered_meshes,
920 );
921 }
922 let Some((rank, _)) = self.ranks.get(actor_id.proc_id()) else {
926 tracing::warn!(
927 actor_id = %actor_id,
928 "proc supervision: actor belongs to an unmapped proc; dropping event"
929 );
930 continue;
931 };
932
933 metrics::PROC_MESH_ACTOR_FAILURES.add(
934 1,
935 hyperactor_telemetry::kv_pairs!(
936 "actor_id" => actor_id.to_string(),
937 "rank" => rank.to_string(),
938 "status" => actor_status.to_string(),
939 ),
940 );
941
942 break Some(ProcEvent::Crashed(*rank, reason))
945 }
946 }
947 }
948 }
949
950 pub fn into_alloc(self) -> Box<dyn Alloc + Send + Sync> {
951 self.event_state.alloc
952 }
953}
954
955#[async_trait]
958pub trait SharedSpawnable {
959 async fn spawn<A: Actor + Referable>(
962 self,
963 cx: &impl context::Actor,
964 actor_name: &str,
965 params: &A::Params,
966 ) -> Result<RootActorMesh<'static, A>, anyhow::Error>
967 where
968 A::Params: RemoteMessage;
969}
970
971#[async_trait]
972impl<D: Deref<Target = ProcMesh> + Send + Sync + 'static> SharedSpawnable for D {
973 async fn spawn<A: Actor + Referable>(
976 self,
977 cx: &impl context::Actor,
978 actor_name: &str,
979 params: &A::Params,
980 ) -> Result<RootActorMesh<'static, A>, anyhow::Error>
981 where
982 A::Params: RemoteMessage,
983 {
984 match &self.deref().inner {
985 ProcMeshKind::V0 {
986 actor_event_router,
987 client,
988 ..
989 } => {
990 let (tx, rx) = mpsc::unbounded_channel::<ActorSupervisionEvent>();
991 {
992 actor_event_router.insert(actor_name.to_string(), tx);
994 tracing::info!(
995 name = "router_insert",
996 actor_name = %actor_name,
997 "the length of the router is {}", actor_event_router.len(),
998 );
999 }
1000 let ranks =
1001 ProcMesh::spawn_on_procs::<A>(client, self.agents(), actor_name, params)
1002 .await?;
1003 Ok(RootActorMesh::new_shared(
1004 self,
1005 actor_name.to_string(),
1006 rx,
1007 ranks,
1008 ))
1009 }
1010 ProcMeshKind::V1(proc_mesh) => Ok(RootActorMesh::from(
1011 proc_mesh.spawn_service(cx, actor_name, params).await?,
1012 )),
1013 }
1014 }
1015}
1016
1017#[async_trait]
1018impl Mesh for ProcMesh {
1019 type Node = ProcId;
1020 type Id = ProcMeshId;
1021 type Sliced<'a> = SlicedProcMesh<'a>;
1022
1023 fn shape(&self) -> &Shape {
1024 ProcMesh::shape(self)
1025 }
1026
1027 fn select<R: Into<Range>>(
1028 &self,
1029 label: &str,
1030 range: R,
1031 ) -> Result<Self::Sliced<'_>, ShapeError> {
1032 Ok(SlicedProcMesh(self, self.shape().select(label, range)?))
1033 }
1034
1035 fn get(&self, rank: usize) -> Option<ProcId> {
1036 match &self.inner {
1037 ProcMeshKind::V0 { ranks, .. } => Some(ranks[rank].1.clone()),
1038 ProcMeshKind::V1(proc_mesh) => proc_mesh.get(rank).map(|proc| proc.proc_id().clone()),
1039 }
1040 }
1041
1042 fn id(&self) -> Self::Id {
1043 match &self.inner {
1044 ProcMeshKind::V0 { world_id, .. } => ProcMeshId(world_id.name().to_string()),
1045 ProcMeshKind::V1(proc_mesh) => ProcMeshId(proc_mesh.name().to_string()),
1046 }
1047 }
1048}
1049
1050impl fmt::Display for ProcMesh {
1051 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1052 write!(f, "{{ shape: {} }}", self.shape())
1053 }
1054}
1055
1056impl fmt::Debug for ProcMesh {
1057 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1058 match &self.inner {
1059 ProcMeshKind::V0 {
1060 shape,
1061 ranks,
1062 client_proc,
1063 ..
1064 } => f
1065 .debug_struct("ProcMesh::V0")
1066 .field("shape", shape)
1067 .field("ranks", ranks)
1068 .field("client_proc", client_proc)
1069 .field("client", &"<Instance>")
1070 .finish(),
1072 ProcMeshKind::V1(proc_mesh) => fmt::Debug::fmt(proc_mesh, f),
1073 }
1074 }
1075}
1076
1077pub struct SlicedProcMesh<'a>(&'a ProcMesh, Shape);
1078
1079#[async_trait]
1080impl Mesh for SlicedProcMesh<'_> {
1081 type Node = ProcId;
1082 type Id = ProcMeshId;
1083 type Sliced<'b>
1084 = SlicedProcMesh<'b>
1085 where
1086 Self: 'b;
1087
1088 fn shape(&self) -> &Shape {
1089 &self.1
1090 }
1091
1092 fn select<R: Into<Range>>(
1093 &self,
1094 label: &str,
1095 range: R,
1096 ) -> Result<Self::Sliced<'_>, ShapeError> {
1097 Ok(Self(self.0, self.1.select(label, range)?))
1098 }
1099
1100 fn get(&self, _index: usize) -> Option<ProcId> {
1101 unimplemented!()
1102 }
1103
1104 fn id(&self) -> Self::Id {
1105 self.0.id()
1106 }
1107}
1108
1109#[cfg(test)]
1110mod tests {
1111 use std::assert_matches::assert_matches;
1112
1113 use hyperactor::actor::ActorStatus;
1114 use ndslice::extent;
1115
1116 use super::*;
1117 use crate::actor_mesh::ActorMesh;
1118 use crate::actor_mesh::test_util::Error;
1119 use crate::actor_mesh::test_util::TestActor;
1120 use crate::alloc::AllocSpec;
1121 use crate::alloc::Allocator;
1122 use crate::alloc::local::LocalAllocator;
1123 use crate::sel_from_shape;
1124
1125 #[tokio::test]
1126 async fn test_basic() {
1127 let alloc = LocalAllocator
1128 .allocate(AllocSpec {
1129 extent: extent!(replica = 4),
1130 constraints: Default::default(),
1131 proc_name: None,
1132 transport: ChannelTransport::Local,
1133 proc_allocation_mode: Default::default(),
1134 })
1135 .await
1136 .unwrap();
1137
1138 let name = alloc.name().to_string();
1139 let mesh = ProcMesh::allocate(alloc).await.unwrap();
1140
1141 assert_eq!(mesh.get(0).unwrap().world_name(), Some(name.as_str()));
1142 }
1143
1144 #[tokio::test]
1145 async fn test_propagate_lifecycle_events() {
1146 let alloc = LocalAllocator
1147 .allocate(AllocSpec {
1148 extent: extent!(replica = 4),
1149 constraints: Default::default(),
1150 proc_name: None,
1151 transport: ChannelTransport::Local,
1152 proc_allocation_mode: Default::default(),
1153 })
1154 .await
1155 .unwrap();
1156
1157 let stop = alloc.stopper();
1158 let monkey = alloc.chaos_monkey();
1159 let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
1160 let mut events = mesh.events().unwrap();
1161
1162 monkey(1, ProcStopReason::Killed(1, false));
1163 assert_matches!(
1164 events.next().await.unwrap(),
1165 ProcEvent::Stopped(1, ProcStopReason::Killed(1, false))
1166 );
1167
1168 stop();
1169 for _ in 0..3 {
1170 assert_matches!(
1171 events.next().await.unwrap(),
1172 ProcEvent::Stopped(_, ProcStopReason::Stopped)
1173 );
1174 }
1175 assert!(events.next().await.is_none());
1176 }
1177
1178 #[tokio::test]
1179 async fn test_supervision_failure() {
1180 let alloc = LocalAllocator
1183 .allocate(AllocSpec {
1184 extent: extent!(replica = 2),
1185 constraints: Default::default(),
1186 proc_name: None,
1187 transport: ChannelTransport::Local,
1188 proc_allocation_mode: Default::default(),
1189 })
1190 .await
1191 .unwrap();
1192 let stop = alloc.stopper();
1193 let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
1194 let mut events = mesh.events().unwrap();
1195
1196 let instance = crate::v1::testing::instance().await;
1197
1198 let mut actors = mesh
1199 .spawn::<TestActor>(&instance, "failing", &())
1200 .await
1201 .unwrap();
1202 let mut actor_events = actors.events().unwrap();
1203
1204 actors
1205 .cast(
1206 mesh.client(),
1207 sel_from_shape!(actors.shape(), replica = 0),
1208 Error("failmonkey".to_string()),
1209 )
1210 .unwrap();
1211
1212 assert_matches!(
1213 events.next().await.unwrap(),
1214 ProcEvent::Crashed(0, reason) if reason.contains("failmonkey")
1215 );
1216
1217 let mut event = actor_events.next().await.unwrap();
1218 assert_matches!(event.actor_status, ActorStatus::Failed(_));
1219 assert_eq!(event.actor_id.1, "failing".to_string());
1220 assert_eq!(event.actor_id.2, 0);
1221
1222 stop();
1223 assert_matches!(
1224 events.next().await.unwrap(),
1225 ProcEvent::Stopped(0, ProcStopReason::Stopped),
1226 );
1227 assert_matches!(
1228 events.next().await.unwrap(),
1229 ProcEvent::Stopped(1, ProcStopReason::Stopped),
1230 );
1231
1232 assert!(events.next().await.is_none());
1233 event = actor_events.next().await.unwrap();
1234 assert_matches!(event.actor_status, ActorStatus::Failed(_));
1235 assert_eq!(event.actor_id.2, 0);
1236 }
1237
1238 #[timed_test::async_timed_test(timeout_secs = 5)]
1239 async fn test_spawn_twice() {
1240 let alloc = LocalAllocator
1241 .allocate(AllocSpec {
1242 extent: extent!(replica = 1),
1243 constraints: Default::default(),
1244 proc_name: None,
1245 transport: ChannelTransport::Local,
1246 proc_allocation_mode: Default::default(),
1247 })
1248 .await
1249 .unwrap();
1250 let mesh = ProcMesh::allocate(alloc).await.unwrap();
1251
1252 let instance = crate::v1::testing::instance().await;
1253 mesh.spawn::<TestActor>(&instance, "dup", &())
1254 .await
1255 .unwrap();
1256 let result = mesh.spawn::<TestActor>(&instance, "dup", &()).await;
1257 assert!(result.is_err());
1258 }
1259
1260 mod shim {
1261 use std::collections::HashSet;
1262
1263 use hyperactor::context::Mailbox;
1264 use ndslice::Extent;
1265 use ndslice::Selection;
1266
1267 use super::*;
1268 use crate::sel;
1269
1270 #[tokio::test]
1271 #[cfg(fbcode_build)]
1272 async fn test_basic() {
1273 let instance = v1::testing::instance().await;
1274 let ext = extent!(host = 4);
1275 let host_mesh = v1::testing::host_mesh(ext.clone()).await;
1276 let proc_mesh = host_mesh
1277 .spawn(instance, "test", Extent::unity())
1278 .await
1279 .unwrap();
1280 let proc_mesh_v0: ProcMesh = proc_mesh.detach().into();
1281
1282 let actor_mesh = proc_mesh_v0
1283 .spawn::<v1::testactor::TestActor>(instance, "test", &())
1284 .await
1285 .unwrap();
1286
1287 let (cast_info, mut cast_info_rx) = instance.mailbox().open_port();
1288 actor_mesh
1289 .cast(
1290 instance,
1291 sel!(*),
1292 v1::testactor::GetCastInfo {
1293 cast_info: cast_info.bind(),
1294 },
1295 )
1296 .unwrap();
1297
1298 let mut point_to_actor: HashSet<_> = actor_mesh
1299 .iter_actor_refs()
1300 .enumerate()
1301 .map(|(rank, actor_ref)| (ext.point_of_rank(rank).unwrap(), actor_ref))
1302 .collect();
1303 while !point_to_actor.is_empty() {
1304 let (point, origin_actor_ref, sender_actor_id) = cast_info_rx.recv().await.unwrap();
1305 let key = (point, origin_actor_ref);
1306 assert!(
1307 point_to_actor.remove(&key),
1308 "key {:?} not present or removed twice",
1309 key
1310 );
1311 assert_eq!(&sender_actor_id, instance.self_id());
1312 }
1313 }
1314 }
1315}