1use std::collections::HashMap;
10use std::fmt;
11use std::ops::Deref;
12use std::panic::Location;
13use std::sync::Arc;
14use std::sync::atomic::AtomicUsize;
15use std::sync::atomic::Ordering;
16
17use async_trait::async_trait;
18use dashmap::DashMap;
19use futures::future::join_all;
20use hyperactor::Actor;
21use hyperactor::ActorHandle;
22use hyperactor::ActorId;
23use hyperactor::ActorRef;
24use hyperactor::Context;
25use hyperactor::Handler;
26use hyperactor::Instance;
27use hyperactor::RemoteMessage;
28use hyperactor::RemoteSpawn;
29use hyperactor::WorldId;
30use hyperactor::actor::ActorError;
31use hyperactor::actor::ActorErrorKind;
32use hyperactor::actor::ActorStatus;
33use hyperactor::actor::Signal;
34use hyperactor::actor::remote::Remote;
35use hyperactor::channel;
36use hyperactor::channel::BindSpec;
37use hyperactor::channel::ChannelAddr;
38use hyperactor::channel::ChannelTransport;
39use hyperactor::context;
40use hyperactor::mailbox;
41use hyperactor::mailbox::BoxableMailboxSender;
42use hyperactor::mailbox::BoxedMailboxSender;
43use hyperactor::mailbox::DialMailboxRouter;
44use hyperactor::mailbox::MailboxServer;
45use hyperactor::mailbox::MessageEnvelope;
46use hyperactor::mailbox::PortHandle;
47use hyperactor::mailbox::PortReceiver;
48use hyperactor::mailbox::Undeliverable;
49use hyperactor::metrics;
50use hyperactor::proc::Proc;
51use hyperactor::proc::WorkCell;
52use hyperactor::reference::ProcId;
53use hyperactor::supervision::ActorSupervisionEvent;
54use hyperactor_config::CONFIG;
55use hyperactor_config::ConfigAttr;
56use hyperactor_config::attrs::declare_attrs;
57use hyperactor_config::global;
58use ndslice::Range;
59use ndslice::Shape;
60use ndslice::ShapeError;
61use ndslice::View;
62use ndslice::ViewExt;
63use strum::AsRefStr;
64use tokio::sync::mpsc;
65use tokio::task::JoinHandle;
66use tracing::Instrument;
67use tracing::Level;
68use tracing::span;
69
70use crate::CommActor;
71use crate::Mesh;
72use crate::actor_mesh::RootActorMesh;
73use crate::alloc::Alloc;
74use crate::alloc::AllocExt;
75use crate::alloc::AllocatedProc;
76use crate::alloc::AllocatorError;
77use crate::alloc::ProcState;
78use crate::alloc::ProcStopReason;
79use crate::assign::Ranks;
80use crate::comm::CommActorMode;
81use crate::proc_mesh::mesh_agent::GspawnResult;
82use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
83use crate::proc_mesh::mesh_agent::ProcMeshAgent;
84use crate::proc_mesh::mesh_agent::StopActorResult;
85use crate::proc_mesh::mesh_agent::update_event_actor_id;
86use crate::reference::ProcMeshId;
87use crate::router;
88use crate::shortuuid::ShortUuid;
89use crate::supervision::MeshFailure;
90use crate::v1;
91use crate::v1::Name;
92
93pub mod mesh_agent;
94
95use std::sync::OnceLock;
96use std::sync::RwLock;
97
98declare_attrs! {
99 @meta(CONFIG = ConfigAttr {
101 env_name: Some("HYPERACTOR_MESH_DEFAULT_TRANSPORT".to_string()),
102 py_name: Some("default_transport".to_string()),
103 })
104 pub attr DEFAULT_TRANSPORT: BindSpec = BindSpec::Any(ChannelTransport::Unix);
105}
106
107pub fn default_transport() -> ChannelTransport {
114 match default_bind_spec() {
115 BindSpec::Any(transport) => transport,
116 BindSpec::Addr(addr) => panic!("default_bind_spec() returned BindSpec::Addr({addr})"),
117 }
118}
119
120pub fn default_bind_spec() -> BindSpec {
122 global::get_cloned(DEFAULT_TRANSPORT)
123}
124
125static GLOBAL_SUPERVISION_SINK: OnceLock<RwLock<Option<PortHandle<ActorSupervisionEvent>>>> =
132 OnceLock::new();
133
134fn sink_cell() -> &'static RwLock<Option<PortHandle<ActorSupervisionEvent>>> {
140 GLOBAL_SUPERVISION_SINK.get_or_init(|| RwLock::new(None))
141}
142
143pub(crate) fn set_global_supervision_sink(
157 sink: PortHandle<ActorSupervisionEvent>,
158) -> Option<PortHandle<ActorSupervisionEvent>> {
159 let cell = sink_cell();
160 let mut guard = cell.write().unwrap();
161 let prev = guard.take();
162 *guard = Some(sink);
163 prev
164}
165
166pub(crate) fn get_global_supervision_sink() -> Option<PortHandle<ActorSupervisionEvent>> {
178 sink_cell().read().unwrap().clone()
179}
180
181#[derive(Debug)]
182pub struct GlobalClientActor {
183 signal_rx: PortReceiver<Signal>,
184 supervision_rx: PortReceiver<ActorSupervisionEvent>,
185 work_rx: mpsc::UnboundedReceiver<WorkCell<Self>>,
186}
187
188impl GlobalClientActor {
189 fn run(mut self, instance: &'static Instance<Self>) -> JoinHandle<()> {
190 tokio::spawn(async move {
191 let err = 'messages: loop {
192 tokio::select! {
193 work = self.work_rx.recv() => {
194 let work = work.expect("inconsistent work queue state");
195 if let Err(err) = work.handle(&mut self, instance).await {
196 for supervision_event in self.supervision_rx.drain() {
197 if let Err(err) = instance.handle_supervision_event(&mut self, supervision_event).await {
198 break 'messages err;
199 }
200 }
201 let kind = ActorErrorKind::processing(err);
202 break ActorError {
203 actor_id: Box::new(instance.self_id().clone()),
204 kind: Box::new(kind),
205 };
206 }
207 }
208 _ = self.signal_rx.recv() => {
209 }
211 Ok(supervision_event) = self.supervision_rx.recv() => {
212 if let Err(err) = instance.handle_supervision_event(&mut self, supervision_event).await {
213 break err;
214 }
215 }
216 };
217 };
218 let event = match *err.kind {
219 ActorErrorKind::UnhandledSupervisionEvent(event) => *event,
220 _ => {
221 let status = ActorStatus::generic_failure(err.kind.to_string());
222 ActorSupervisionEvent::new(
223 instance.self_id().clone(),
224 Some("testclient".into()),
225 status,
226 None,
227 )
228 }
229 };
230 instance.proc().handle_supervision_event(event);
231 })
232 }
233}
234
235impl Actor for GlobalClientActor {}
236
237#[async_trait]
238impl Handler<MeshFailure> for GlobalClientActor {
239 async fn handle(&mut self, _cx: &Context<Self>, message: MeshFailure) -> anyhow::Result<()> {
240 tracing::error!("supervision failure reached global client: {}", message);
241 panic!("supervision failure reached global client: {}", message);
242 }
243}
244
245fn fresh_instance() -> (
246 &'static Instance<GlobalClientActor>,
247 &'static ActorHandle<GlobalClientActor>,
248) {
249 static INSTANCE: OnceLock<(Instance<GlobalClientActor>, ActorHandle<GlobalClientActor>)> =
250 OnceLock::new();
251 let client_proc = Proc::direct_with_default(
252 default_bind_spec().binding_addr(),
253 "mesh_root_client_proc".into(),
254 router::global().clone().boxed(),
255 )
256 .unwrap();
257
258 router::global().bind(client_proc.proc_id().clone().into(), client_proc.clone());
261
262 let (client, handle, supervision_rx, signal_rx, work_rx) = client_proc
266 .actor_instance::<GlobalClientActor>("client")
267 .expect("root instance create");
268
269 let (_undeliverable_tx, undeliverable_rx) =
284 client.open_port::<Undeliverable<MessageEnvelope>>();
285 hyperactor::mailbox::supervise_undeliverable_messages_with(
286 undeliverable_rx,
287 crate::proc_mesh::get_global_supervision_sink,
288 |env| {
289 let sink_present = crate::proc_mesh::get_global_supervision_sink().is_some();
290 tracing::info!(
291 actor_id = %env.dest().actor_id(),
292 "global root client undeliverable observed with headers {:?} {}", env.headers(), sink_present
293 );
294 },
295 );
296
297 INSTANCE
299 .set((client, handle))
300 .map_err(|_| "already initialized root client instance")
301 .unwrap();
302 let (instance, handle) = INSTANCE.get().unwrap();
303 let client = GlobalClientActor {
304 signal_rx,
305 supervision_rx,
306 work_rx,
307 };
308 client.run(instance);
309 (instance, handle)
310}
311
312pub fn global_root_client() -> &'static Instance<GlobalClientActor> {
316 static GLOBAL_INSTANCE: OnceLock<(
317 &'static Instance<GlobalClientActor>,
318 &'static ActorHandle<GlobalClientActor>,
319 )> = OnceLock::new();
320 GLOBAL_INSTANCE.get_or_init(fresh_instance).0
321}
322
323type ActorEventRouter = Arc<DashMap<ActorMeshName, mpsc::UnboundedSender<ActorSupervisionEvent>>>;
324
325pub struct ProcMesh {
328 inner: ProcMeshKind,
329 shape: OnceLock<Shape>,
330}
331
332enum ProcMeshKind {
333 V0 {
334 event_state: Option<EventState>,
337 actor_event_router: ActorEventRouter,
338 shape: Shape,
339 ranks: Vec<(ShortUuid, ProcId, ChannelAddr, ActorRef<ProcMeshAgent>)>,
340 #[allow(dead_code)] client_proc: Proc,
342 client: Instance<()>,
343 comm_actors: Vec<ActorRef<CommActor>>,
344 world_id: WorldId,
345 },
346
347 V1(v1::ProcMeshRef),
348}
349
350struct EventState {
351 alloc: Box<dyn Alloc + Send + Sync>,
352 supervision_events: PortReceiver<ActorSupervisionEvent>,
353}
354
355impl From<v1::ProcMeshRef> for ProcMesh {
356 fn from(proc_mesh: v1::ProcMeshRef) -> Self {
357 ProcMesh {
358 inner: ProcMeshKind::V1(proc_mesh),
359 shape: OnceLock::new(),
360 }
361 }
362}
363
364impl ProcMesh {
365 #[hyperactor::instrument(fields(name = "ProcMesh::allocate"))]
366 pub async fn allocate(
367 alloc: impl Alloc + Send + Sync + 'static,
368 ) -> Result<Self, AllocatorError> {
369 ProcMesh::allocate_boxed(Box::new(alloc)).await
370 }
371
372 #[track_caller]
376 pub fn allocate_boxed(
377 alloc: Box<dyn Alloc + Send + Sync>,
378 ) -> impl std::future::Future<Output = Result<Self, AllocatorError>> {
379 Self::allocate_boxed_inner(alloc, Location::caller())
380 }
381
382 fn alloc_counter() -> &'static AtomicUsize {
383 static C: OnceLock<AtomicUsize> = OnceLock::new();
384 C.get_or_init(|| AtomicUsize::new(0))
385 }
386
387 #[hyperactor::instrument]
388 #[hyperactor::observe_result("ProcMesh")]
389 async fn allocate_boxed_inner(
390 mut alloc: Box<dyn Alloc + Send + Sync>,
391 loc: &'static Location<'static>,
392 ) -> Result<Self, AllocatorError> {
393 let alloc_id = Self::alloc_counter().fetch_add(1, Ordering::Relaxed) + 1;
394 let world = alloc.world_id().name().to_string();
395 tracing::info!(
396 name = "ProcMesh::Allocate::Attempt",
397 %world,
398 alloc_id,
399 caller = %format!("{}:{}", loc.file(), loc.line()),
400 shape = ?alloc.shape(),
401 "allocating proc mesh"
402 );
403
404 let running = alloc
406 .initialize()
407 .instrument(span!(
408 Level::INFO,
409 "ProcMesh::Allocate::Initialize",
410 alloc_id
411 ))
412 .await?;
413
414 let router = DialMailboxRouter::new_with_default(router::global().boxed());
417 for AllocatedProc { proc_id, addr, .. } in running.iter() {
418 if proc_id.is_direct() {
419 continue;
420 }
421 router.bind(proc_id.clone().into(), addr.clone());
422 }
423
424 let client_proc_id =
428 ProcId::Ranked(WorldId(format!("{}_client", alloc.world_id().name())), 0);
429 let (client_proc_addr, client_rx) = channel::serve(ChannelAddr::any(alloc.transport()))
430 .map_err(|err| AllocatorError::Other(err.into()))?;
431 tracing::info!(
432 name = "ProcMesh::Allocate::ChannelServe",
433 alloc_id = alloc_id,
434 "client proc started listening on addr: {client_proc_addr}"
435 );
436 let client_proc = Proc::new(
437 client_proc_id.clone(),
438 BoxedMailboxSender::new(router.clone()),
439 );
440 client_proc.clone().serve(client_rx);
441 router.bind(client_proc_id.clone().into(), client_proc_addr.clone());
442
443 router::global().bind_dial_router(&router);
446
447 let (supervisor, _supervisor_handle) = client_proc.instance("supervisor")?;
448 let (supervision_port, supervision_events) =
449 supervisor.open_port::<ActorSupervisionEvent>();
450
451 let _prev = set_global_supervision_sink(supervision_port.clone());
469
470 let (client, _handle) = client_proc.instance("client")?;
481 let (_undeliverable_messages, client_undeliverable_receiver) =
483 client.bind_actor_port::<Undeliverable<MessageEnvelope>>();
484 hyperactor::mailbox::supervise_undeliverable_messages(
485 supervision_port.clone(),
486 client_undeliverable_receiver,
487 |env| {
488 tracing::info!(actor=%env.dest().actor_id(), "per-mesh client undeliverable observed");
489 },
490 );
491
492 let (router_channel_addr, router_rx) = channel::serve(alloc.client_router_addr())
494 .map_err(|e| AllocatorError::Other(e.into()))?;
495 router.serve(router_rx);
496 tracing::info!("router channel started listening on addr: {router_channel_addr}");
497
498 let address_book: HashMap<_, _> = running
501 .iter()
502 .map(
503 |AllocatedProc {
504 addr, mesh_agent, ..
505 }| { (mesh_agent.actor_id().proc_id().clone(), addr.clone()) },
506 )
507 .collect();
508
509 let (config_handle, mut config_receiver) = client.open_port();
510 for (rank, AllocatedProc { mesh_agent, .. }) in running.iter().enumerate() {
511 mesh_agent
512 .configure(
513 &client,
514 rank,
515 router_channel_addr.clone(),
516 Some(supervision_port.bind()),
517 address_book.clone(),
518 config_handle.bind(),
519 false,
520 )
521 .await?;
522 }
523 let mut completed = Ranks::new(running.len());
524 while !completed.is_full() {
525 let rank = config_receiver
526 .recv()
527 .await
528 .map_err(|err| AllocatorError::Other(err.into()))?;
529 if completed.insert(rank, rank).is_some() {
530 tracing::warn!("multiple completions received for rank {}", rank);
531 }
532 }
533
534 fn project_mesh_agent_ref(allocated_proc: &AllocatedProc) -> ActorRef<ProcMeshAgent> {
548 allocated_proc.mesh_agent.clone()
549 }
550
551 let comm_actors = Self::spawn_on_procs::<CommActor>(
556 &client,
557 running.iter().map(project_mesh_agent_ref),
558 "comm",
559 &Default::default(),
560 )
561 .await?;
562 let address_book: HashMap<_, _> = comm_actors.iter().cloned().enumerate().collect();
563 for (rank, comm_actor) in comm_actors.iter().enumerate() {
566 comm_actor
567 .send(&client, CommActorMode::Mesh(rank, address_book.clone()))
568 .map_err(anyhow::Error::from)?;
569 }
570
571 let shape = alloc.shape().clone();
572 let world_id = alloc.world_id().clone();
573 metrics::PROC_MESH_ALLOCATION.add(
574 running.len() as u64,
575 hyperactor_telemetry::kv_pairs!("alloc_id" => alloc_id.to_string()),
576 );
577
578 Ok(Self {
579 inner: ProcMeshKind::V0 {
580 event_state: Some(EventState {
581 alloc,
582 supervision_events,
583 }),
584 actor_event_router: Arc::new(DashMap::new()),
585 shape,
586 ranks: running
587 .into_iter()
588 .map(
589 |AllocatedProc {
590 create_key,
591 proc_id,
592 addr,
593 mesh_agent,
594 }| (create_key, proc_id, addr, mesh_agent),
595 )
596 .collect(),
597 client_proc,
598 client,
599 comm_actors,
600 world_id,
601 },
602 shape: OnceLock::new(),
603 })
604 }
605
606 async fn spawn_on_procs<A: RemoteSpawn>(
615 cx: &impl context::Actor,
616 agents: impl IntoIterator<Item = ActorRef<ProcMeshAgent>> + '_,
617 actor_name: &str,
618 params: &A::Params,
619 ) -> Result<Vec<ActorRef<A>>, anyhow::Error>
620 where
621 A::Params: RemoteMessage,
622 {
623 let remote = Remote::collect();
624 let actor_type = remote
625 .name_of::<A>()
626 .ok_or(anyhow::anyhow!("actor not registered"))?
627 .to_string();
628
629 let (completed_handle, mut completed_receiver) = mailbox::open_port(cx);
630 let mut n = 0;
631 for agent in agents {
632 agent
633 .gspawn(
634 cx,
635 actor_type.clone(),
636 actor_name.to_string(),
637 bincode::serialize(params)?,
638 completed_handle.bind(),
639 )
640 .await?;
641 n += 1;
642 }
643 let mut completed = Ranks::new(n);
644 while !completed.is_full() {
645 let result = completed_receiver.recv().await?;
646 match result {
647 GspawnResult::Success { rank, actor_id } => {
648 if completed.insert(rank, actor_id).is_some() {
649 tracing::warn!("multiple completions received for rank {}", rank);
650 }
651 }
652 GspawnResult::Error(error_msg) => {
653 metrics::PROC_MESH_ACTOR_FAILURES.add(
654 1,
655 hyperactor_telemetry::kv_pairs!(
656 "actor_name" => actor_name.to_string(),
657 "error" => error_msg.clone(),
658 ),
659 );
660
661 anyhow::bail!("gspawn failed: {}", error_msg);
662 }
663 }
664 }
665
666 Ok(completed
669 .into_iter()
670 .map(Option::unwrap)
671 .map(ActorRef::attest)
672 .collect())
673 }
674
675 fn agents(&self) -> Box<dyn Iterator<Item = ActorRef<ProcMeshAgent>> + '_ + Send> {
676 match &self.inner {
677 ProcMeshKind::V0 { ranks, .. } => {
678 Box::new(ranks.iter().map(|(_, _, _, agent)| agent.clone()))
679 }
680 ProcMeshKind::V1(proc_mesh) => Box::new(
681 proc_mesh
682 .agent_mesh()
683 .iter()
684 .map(|(_point, agent)| agent.clone())
685 .collect::<Vec<_>>()
693 .into_iter(),
694 ),
695 }
696 }
697
698 pub(crate) fn comm_actor(&self) -> &ActorRef<CommActor> {
700 match &self.inner {
701 ProcMeshKind::V0 { comm_actors, .. } => &comm_actors[0],
702 ProcMeshKind::V1(proc_mesh) => proc_mesh.root_comm_actor().unwrap(),
703 }
704 }
705
706 pub async fn spawn<A: RemoteSpawn, C: context::Actor>(
722 &self,
723 cx: &C,
724 actor_name: &str,
725 params: &A::Params,
726 ) -> Result<RootActorMesh<'_, A>, anyhow::Error>
727 where
728 A::Params: RemoteMessage,
729 C::A: Handler<MeshFailure>,
730 {
731 match &self.inner {
732 ProcMeshKind::V0 {
733 actor_event_router,
734 client,
735 ..
736 } => {
737 let (tx, rx) = mpsc::unbounded_channel::<ActorSupervisionEvent>();
738 {
739 actor_event_router.insert(actor_name.to_string(), tx);
741 tracing::info!(
742 name = "router_insert",
743 actor_name = %actor_name,
744 "the length of the router is {}", actor_event_router.len(),
745 );
746 }
747 let root_mesh = RootActorMesh::new(
748 self,
749 actor_name.to_string(),
750 rx,
751 Self::spawn_on_procs::<A>(client, self.agents(), actor_name, params).await?,
752 );
753 Ok(root_mesh)
754 }
755 ProcMeshKind::V1(proc_mesh) => {
756 let actor_mesh = proc_mesh.spawn(cx, actor_name, params).await?;
757 Ok(RootActorMesh::new_v1(actor_mesh.detach()))
758 }
759 }
760 }
761
762 pub fn client(&self) -> &Instance<()> {
764 match &self.inner {
765 ProcMeshKind::V0 { client, .. } => client,
766 ProcMeshKind::V1(_proc_mesh) => unimplemented!("no client for v1::ProcMesh"),
767 }
768 }
769
770 pub fn client_proc(&self) -> &Proc {
771 match &self.inner {
772 ProcMeshKind::V0 { client_proc, .. } => client_proc,
773 ProcMeshKind::V1(_proc_mesh) => unimplemented!("no client proc for v1::ProcMesh"),
774 }
775 }
776
777 pub fn proc_id(&self) -> &ProcId {
778 self.client_proc().proc_id()
779 }
780
781 pub fn world_id(&self) -> &WorldId {
782 match &self.inner {
783 ProcMeshKind::V0 { world_id, .. } => world_id,
784 ProcMeshKind::V1(_proc_mesh) => unimplemented!("no world_id for v1::ProcMesh"),
785 }
786 }
787
788 pub fn events(&mut self) -> Option<ProcEvents> {
791 match &mut self.inner {
792 ProcMeshKind::V0 {
793 event_state,
794 ranks,
795 actor_event_router,
796 ..
797 } => event_state.take().map(|event_state| ProcEvents {
798 event_state,
799 ranks: ranks
800 .iter()
801 .enumerate()
802 .map(|(rank, (create_key, proc_id, _addr, _mesh_agent))| {
803 (proc_id.clone(), (rank, create_key.clone()))
804 })
805 .collect(),
806 actor_event_router: actor_event_router.clone(),
807 }),
808 #[allow(clippy::todo)]
809 ProcMeshKind::V1(_proc_mesh) => todo!(),
810 }
811 }
812
813 pub fn shape(&self) -> &Shape {
814 self.shape.get_or_init(|| match &self.inner {
817 ProcMeshKind::V0 { shape, .. } => shape.clone(),
818 ProcMeshKind::V1(proc_mesh) => proc_mesh.region().into(),
819 })
820 }
821
822 #[hyperactor::observe_result("ProcMesh")]
824 pub async fn stop_actor_by_name(
825 &self,
826 cx: &impl context::Actor,
827 mesh_name: &str,
828 ) -> Result<(), anyhow::Error> {
829 match &self.inner {
830 ProcMeshKind::V0 { client, .. } => {
831 let timeout =
832 hyperactor_config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT);
833 let results = join_all(self.agents().map(|agent| async move {
834 let actor_id =
835 ActorId(agent.actor_id().proc_id().clone(), mesh_name.to_string(), 0);
836 (
837 actor_id.clone(),
838 agent
839 .clone()
840 .stop_actor(client, actor_id, timeout.as_millis() as u64)
841 .await,
842 )
843 }))
844 .await;
845
846 for (actor_id, result) in results {
847 match result {
848 Ok(StopActorResult::Timeout) => {
849 tracing::warn!("timed out while stopping actor {}", actor_id);
850 }
851 Ok(StopActorResult::NotFound) => {
852 tracing::warn!("no actor {} on proc {}", actor_id, actor_id.proc_id());
853 }
854 Ok(StopActorResult::Success) => {
855 tracing::info!("stopped actor {}", actor_id);
856 }
857 Err(e) => {
858 tracing::warn!("error stopping actor {}: {}", actor_id, e);
859 }
860 }
861 }
862 Ok(())
863 }
864 ProcMeshKind::V1(proc_mesh) => {
865 proc_mesh
866 .stop_actor_by_name(cx, Name::new_reserved(mesh_name)?)
867 .await?;
868 Ok(())
869 }
870 }
871 }
872}
873
874#[derive(Debug, Clone)]
876pub enum ProcEvent {
877 Stopped(usize, ProcStopReason),
879 Crashed(usize, String),
882}
883
884#[derive(Debug, Clone, AsRefStr)]
885pub enum SupervisionEventState {
886 SupervisionEventForward,
887 SupervisionEventForwardFailed,
888 SupervisionEventReceived,
889 SupervisionEventTransmitFailed,
890}
891
892impl fmt::Display for ProcEvent {
893 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
894 match self {
895 ProcEvent::Stopped(rank, reason) => {
896 write!(f, "Proc at rank {} stopped: {}", rank, reason)
897 }
898 ProcEvent::Crashed(rank, reason) => {
899 write!(f, "Proc at rank {} crashed: {}", rank, reason)
900 }
901 }
902 }
903}
904
905type ActorMeshName = String;
906
907pub struct ProcEvents {
910 event_state: EventState,
911 ranks: HashMap<ProcId, (usize, ShortUuid)>,
913 actor_event_router: ActorEventRouter,
914}
915
916impl ProcEvents {
917 pub async fn next(&mut self) -> Option<ProcEvent> {
920 loop {
921 tokio::select! {
922 result = self.event_state.alloc.next() => {
923 tracing::debug!(name = "ProcEventReceived", "received ProcEvent alloc update: {result:?}");
924 let Some(alloc_event) = result else {
926 self.actor_event_router.clear();
927 break None;
928 };
929
930 let ProcState::Stopped { create_key, reason } = alloc_event else {
931 continue;
933 };
934
935 let Some((proc_id, (rank, _create_key))) = self.ranks.iter().find(|(_proc_id, (_rank, key))| key == &create_key) else {
936 tracing::warn!("received stop event for unmapped proc {}", create_key);
937 continue;
938 };
939
940 metrics::PROC_MESH_PROC_STOPPED.add(
941 1,
942 hyperactor_telemetry::kv_pairs!(
943 "create_key" => create_key.to_string(),
944 "rank" => rank.to_string(),
945 "reason" => reason.to_string(),
946 ),
947 );
948
949 for entry in self.actor_event_router.iter() {
952 let event = ActorSupervisionEvent::new(
955 proc_id.actor_id("any", 0),
956 None,
957 ActorStatus::generic_failure(format!("proc {} is stopped", proc_id)),
958 None,
959 );
960 tracing::debug!(name = "SupervisionEvent", %event);
961 if entry.value().send(event.clone()).is_err() {
962 tracing::warn!(
963 name = SupervisionEventState::SupervisionEventTransmitFailed.as_ref(),
964 "unable to transmit supervision event to actor {}", entry.key()
965 );
966 }
967 }
968
969 let event = ProcEvent::Stopped(*rank, reason.clone());
970 tracing::debug!(name = "SupervisionEvent", %event);
971
972 break Some(ProcEvent::Stopped(*rank, reason));
973 }
974
975 Ok(event) = self.event_state.supervision_events.recv() => {
986 let had_headers = event.message_headers.is_some();
987 tracing::info!(
988 name = SupervisionEventState::SupervisionEventReceived.as_ref(),
989 actor_id = %event.actor_id,
990 actor_name = %event.actor_id.name(),
991 status = %event.actor_status,
992 "proc supervision: event received with {had_headers} headers"
993 );
994 tracing::debug!(
995 name = "SupervisionEvent",
996 %event,
997 "proc supervision: full event");
998
999 let event = update_event_actor_id(event);
1001
1002 let actor_id = event.actor_id.clone();
1008 let actor_status = event.actor_status.clone();
1009 let reason = event.to_string();
1010 if let Some(tx) = self.actor_event_router.get(actor_id.name()) {
1011 tracing::info!(
1012 name = SupervisionEventState::SupervisionEventForwardFailed.as_ref(),
1013 actor_id = %actor_id,
1014 actor_name = actor_id.name(),
1015 status = %actor_status,
1016 "proc supervision: delivering event to registered ActorMesh"
1017 );
1018 if tx.send(event).is_err() {
1019 tracing::warn!(
1020 name = SupervisionEventState::SupervisionEventForwardFailed.as_ref(),
1021 actor_id = %actor_id,
1022 "proc supervision: registered ActorMesh dropped receiver; unable to deliver"
1023 );
1024 }
1025 } else {
1026 let registered_meshes: Vec<_> = self.actor_event_router.iter().map(|e| e.key().clone()).collect();
1027 tracing::warn!(
1028 name = SupervisionEventState::SupervisionEventForwardFailed.as_ref(),
1029 actor_id = %actor_id,
1030 "proc supervision: no ActorMesh registered for this actor {:?}", registered_meshes,
1031 );
1032 }
1033 let Some((rank, _)) = self.ranks.get(actor_id.proc_id()) else {
1037 tracing::warn!(
1038 actor_id = %actor_id,
1039 "proc supervision: actor belongs to an unmapped proc; dropping event"
1040 );
1041 continue;
1042 };
1043
1044 metrics::PROC_MESH_ACTOR_FAILURES.add(
1045 1,
1046 hyperactor_telemetry::kv_pairs!(
1047 "actor_id" => actor_id.to_string(),
1048 "rank" => rank.to_string(),
1049 "status" => actor_status.to_string(),
1050 ),
1051 );
1052
1053 break Some(ProcEvent::Crashed(*rank, reason))
1056 }
1057 }
1058 }
1059 }
1060
1061 pub fn into_alloc(self) -> Box<dyn Alloc + Send + Sync> {
1062 self.event_state.alloc
1063 }
1064}
1065
1066#[async_trait]
1069pub trait SharedSpawnable {
1070 async fn spawn<A: RemoteSpawn, C: context::Actor>(
1073 self,
1074 cx: &C,
1075 actor_name: &str,
1076 params: &A::Params,
1077 ) -> Result<RootActorMesh<'static, A>, anyhow::Error>
1078 where
1079 A::Params: RemoteMessage,
1080 C::A: Handler<MeshFailure>;
1081}
1082
1083#[async_trait]
1084impl<D: Deref<Target = ProcMesh> + Send + Sync + 'static> SharedSpawnable for D {
1085 async fn spawn<A: RemoteSpawn, C: context::Actor>(
1088 self,
1089 cx: &C,
1090 actor_name: &str,
1091 params: &A::Params,
1092 ) -> Result<RootActorMesh<'static, A>, anyhow::Error>
1093 where
1094 A::Params: RemoteMessage,
1095 C::A: Handler<MeshFailure>,
1096 {
1097 match &self.deref().inner {
1098 ProcMeshKind::V0 {
1099 actor_event_router,
1100 client,
1101 ..
1102 } => {
1103 let (tx, rx) = mpsc::unbounded_channel::<ActorSupervisionEvent>();
1104 {
1105 actor_event_router.insert(actor_name.to_string(), tx);
1107 tracing::info!(
1108 name = "router_insert",
1109 actor_name = %actor_name,
1110 "the length of the router is {}", actor_event_router.len(),
1111 );
1112 }
1113 let ranks =
1114 ProcMesh::spawn_on_procs::<A>(client, self.agents(), actor_name, params)
1115 .await?;
1116 Ok(RootActorMesh::new_shared(
1117 self,
1118 actor_name.to_string(),
1119 rx,
1120 ranks,
1121 ))
1122 }
1123 ProcMeshKind::V1(proc_mesh) => Ok(RootActorMesh::from(
1124 proc_mesh.spawn_service(cx, actor_name, params).await?,
1125 )),
1126 }
1127 }
1128}
1129
1130#[async_trait]
1131impl Mesh for ProcMesh {
1132 type Node = ProcId;
1133 type Id = ProcMeshId;
1134 type Sliced<'a> = SlicedProcMesh<'a>;
1135
1136 fn shape(&self) -> &Shape {
1137 ProcMesh::shape(self)
1138 }
1139
1140 fn select<R: Into<Range>>(
1141 &self,
1142 label: &str,
1143 range: R,
1144 ) -> Result<Self::Sliced<'_>, ShapeError> {
1145 Ok(SlicedProcMesh(self, self.shape().select(label, range)?))
1146 }
1147
1148 fn get(&self, rank: usize) -> Option<ProcId> {
1149 match &self.inner {
1150 ProcMeshKind::V0 { ranks, .. } => Some(ranks[rank].1.clone()),
1151 ProcMeshKind::V1(proc_mesh) => proc_mesh.get(rank).map(|proc| proc.proc_id().clone()),
1152 }
1153 }
1154
1155 fn id(&self) -> Self::Id {
1156 match &self.inner {
1157 ProcMeshKind::V0 { world_id, .. } => ProcMeshId(world_id.name().to_string()),
1158 ProcMeshKind::V1(proc_mesh) => ProcMeshId(proc_mesh.name().to_string()),
1159 }
1160 }
1161}
1162
1163impl fmt::Display for ProcMesh {
1164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1165 write!(f, "{{ shape: {} }}", self.shape())
1166 }
1167}
1168
1169impl fmt::Debug for ProcMesh {
1170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1171 match &self.inner {
1172 ProcMeshKind::V0 { shape, ranks, .. } => f
1173 .debug_struct("ProcMesh::V0")
1174 .field("shape", shape)
1175 .field("ranks", ranks)
1176 .field("client_proc", &"<Proc>")
1177 .field("client", &"<Instance>")
1178 .finish(),
1180 ProcMeshKind::V1(proc_mesh) => fmt::Debug::fmt(proc_mesh, f),
1181 }
1182 }
1183}
1184
1185pub struct SlicedProcMesh<'a>(&'a ProcMesh, Shape);
1186
1187#[async_trait]
1188impl Mesh for SlicedProcMesh<'_> {
1189 type Node = ProcId;
1190 type Id = ProcMeshId;
1191 type Sliced<'b>
1192 = SlicedProcMesh<'b>
1193 where
1194 Self: 'b;
1195
1196 fn shape(&self) -> &Shape {
1197 &self.1
1198 }
1199
1200 fn select<R: Into<Range>>(
1201 &self,
1202 label: &str,
1203 range: R,
1204 ) -> Result<Self::Sliced<'_>, ShapeError> {
1205 Ok(Self(self.0, self.1.select(label, range)?))
1206 }
1207
1208 fn get(&self, _index: usize) -> Option<ProcId> {
1209 unimplemented!()
1210 }
1211
1212 fn id(&self) -> Self::Id {
1213 self.0.id()
1214 }
1215}
1216
1217#[cfg(test)]
1218mod tests {
1219 use std::assert_matches::assert_matches;
1220
1221 use hyperactor::actor::ActorStatus;
1222 use ndslice::extent;
1223
1224 use super::*;
1225 use crate::actor_mesh::ActorMesh;
1226 use crate::actor_mesh::test_util::Error;
1227 use crate::actor_mesh::test_util::TestActor;
1228 use crate::alloc::AllocSpec;
1229 use crate::alloc::Allocator;
1230 use crate::alloc::local::LocalAllocator;
1231 use crate::sel_from_shape;
1232
1233 #[tokio::test]
1234 async fn test_basic() {
1235 let alloc = LocalAllocator
1236 .allocate(AllocSpec {
1237 extent: extent!(replica = 4),
1238 constraints: Default::default(),
1239 proc_name: None,
1240 transport: ChannelTransport::Local,
1241 proc_allocation_mode: Default::default(),
1242 })
1243 .await
1244 .unwrap();
1245
1246 let name = alloc.name().to_string();
1247 let mesh = ProcMesh::allocate(alloc).await.unwrap();
1248
1249 assert_eq!(mesh.get(0).unwrap().world_name(), Some(name.as_str()));
1250 }
1251
1252 #[tokio::test]
1253 async fn test_propagate_lifecycle_events() {
1254 let alloc = LocalAllocator
1255 .allocate(AllocSpec {
1256 extent: extent!(replica = 4),
1257 constraints: Default::default(),
1258 proc_name: None,
1259 transport: ChannelTransport::Local,
1260 proc_allocation_mode: Default::default(),
1261 })
1262 .await
1263 .unwrap();
1264
1265 let stop = alloc.stopper();
1266 let monkey = alloc.chaos_monkey();
1267 let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
1268 let mut events = mesh.events().unwrap();
1269
1270 monkey(1, ProcStopReason::Killed(1, false));
1271 assert_matches!(
1272 events.next().await.unwrap(),
1273 ProcEvent::Stopped(1, ProcStopReason::Killed(1, false))
1274 );
1275
1276 stop();
1277 for _ in 0..3 {
1278 assert_matches!(
1279 events.next().await.unwrap(),
1280 ProcEvent::Stopped(_, ProcStopReason::Stopped)
1281 );
1282 }
1283 assert!(events.next().await.is_none());
1284 }
1285
1286 #[tokio::test]
1287 async fn test_supervision_failure() {
1288 let alloc = LocalAllocator
1291 .allocate(AllocSpec {
1292 extent: extent!(replica = 2),
1293 constraints: Default::default(),
1294 proc_name: None,
1295 transport: ChannelTransport::Local,
1296 proc_allocation_mode: Default::default(),
1297 })
1298 .await
1299 .unwrap();
1300 let stop = alloc.stopper();
1301 let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
1302 let mut events = mesh.events().unwrap();
1303
1304 let instance = crate::v1::testing::instance();
1305
1306 let mut actors: RootActorMesh<TestActor> =
1307 mesh.spawn(&instance, "failing", &()).await.unwrap();
1308 let mut actor_events = actors.events().unwrap();
1309
1310 actors
1311 .cast(
1312 mesh.client(),
1313 sel_from_shape!(actors.shape(), replica = 0),
1314 Error("failmonkey".to_string()),
1315 )
1316 .unwrap();
1317
1318 assert_matches!(
1319 events.next().await.unwrap(),
1320 ProcEvent::Crashed(0, reason) if reason.contains("failmonkey")
1321 );
1322
1323 let mut event = actor_events.next().await.unwrap();
1324 assert_matches!(event.actor_status, ActorStatus::Failed(_));
1325 assert_eq!(event.actor_id.1, "failing".to_string());
1326 assert_eq!(event.actor_id.2, 0);
1327
1328 stop();
1329 assert_matches!(
1330 events.next().await.unwrap(),
1331 ProcEvent::Stopped(0, ProcStopReason::Stopped),
1332 );
1333 assert_matches!(
1334 events.next().await.unwrap(),
1335 ProcEvent::Stopped(1, ProcStopReason::Stopped),
1336 );
1337
1338 assert!(events.next().await.is_none());
1339 event = actor_events.next().await.unwrap();
1340 assert_matches!(event.actor_status, ActorStatus::Failed(_));
1341 assert_eq!(event.actor_id.2, 0);
1342 }
1343
1344 #[timed_test::async_timed_test(timeout_secs = 5)]
1345 async fn test_spawn_twice() {
1346 let alloc = LocalAllocator
1347 .allocate(AllocSpec {
1348 extent: extent!(replica = 1),
1349 constraints: Default::default(),
1350 proc_name: None,
1351 transport: ChannelTransport::Local,
1352 proc_allocation_mode: Default::default(),
1353 })
1354 .await
1355 .unwrap();
1356 let mesh = ProcMesh::allocate(alloc).await.unwrap();
1357
1358 let instance = crate::v1::testing::instance();
1359 let _: RootActorMesh<TestActor> = mesh.spawn(&instance, "dup", &()).await.unwrap();
1360 let result: Result<RootActorMesh<TestActor>, _> = mesh.spawn(&instance, "dup", &()).await;
1361 assert!(result.is_err());
1362 }
1363
1364 mod shim {
1365 use std::collections::HashSet;
1366
1367 use hyperactor::context::Mailbox;
1368 use ndslice::Extent;
1369 use ndslice::Selection;
1370
1371 use super::*;
1372 use crate::sel;
1373
1374 #[tokio::test]
1375 #[cfg(fbcode_build)]
1376 async fn test_basic() {
1377 let instance = v1::testing::instance();
1378 let ext = extent!(host = 4);
1379 let host_mesh = v1::testing::host_mesh(ext.clone()).await;
1380 let proc_mesh = host_mesh
1381 .spawn(instance, "test", Extent::unity())
1382 .await
1383 .unwrap();
1384 let proc_mesh_v0: ProcMesh = proc_mesh.detach().into();
1385
1386 let actor_mesh: RootActorMesh<v1::testactor::TestActor> =
1387 proc_mesh_v0.spawn(instance, "test", &()).await.unwrap();
1388
1389 let (cast_info, mut cast_info_rx) = instance.mailbox().open_port();
1390 actor_mesh
1391 .cast(
1392 instance,
1393 sel!(*),
1394 v1::testactor::GetCastInfo {
1395 cast_info: cast_info.bind(),
1396 },
1397 )
1398 .unwrap();
1399
1400 let mut point_to_actor: HashSet<_> = actor_mesh
1401 .iter_actor_refs()
1402 .enumerate()
1403 .map(|(rank, actor_ref)| (ext.point_of_rank(rank).unwrap(), actor_ref))
1404 .collect();
1405 while !point_to_actor.is_empty() {
1406 let (point, origin_actor_ref, sender_actor_id) = cast_info_rx.recv().await.unwrap();
1407 let key = (point, origin_actor_ref);
1408 assert!(
1409 point_to_actor.remove(&key),
1410 "key {:?} not present or removed twice",
1411 key
1412 );
1413 assert_eq!(&sender_actor_id, instance.self_id());
1414 }
1415 }
1416 }
1417}