hyperactor_mesh/
proc_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::collections::HashSet;
11use std::fmt;
12use std::ops::Deref;
13use std::sync::Arc;
14
15use async_trait::async_trait;
16use dashmap::DashMap;
17use futures::future::join_all;
18use hyperactor::Actor;
19use hyperactor::ActorHandle;
20use hyperactor::ActorId;
21use hyperactor::ActorRef;
22use hyperactor::Instance;
23use hyperactor::Mailbox;
24use hyperactor::Named;
25use hyperactor::RemoteMessage;
26use hyperactor::WorldId;
27use hyperactor::actor::ActorStatus;
28use hyperactor::actor::RemoteActor;
29use hyperactor::actor::remote::Remote;
30use hyperactor::cap;
31use hyperactor::channel;
32use hyperactor::channel::ChannelAddr;
33use hyperactor::mailbox;
34use hyperactor::mailbox::BoxableMailboxSender;
35use hyperactor::mailbox::BoxedMailboxSender;
36use hyperactor::mailbox::DialMailboxRouter;
37use hyperactor::mailbox::MailboxRouter;
38use hyperactor::mailbox::MailboxServer;
39use hyperactor::mailbox::MessageEnvelope;
40use hyperactor::mailbox::PortReceiver;
41use hyperactor::mailbox::Undeliverable;
42use hyperactor::metrics;
43use hyperactor::proc::Proc;
44use hyperactor::reference::ProcId;
45use hyperactor::reference::Reference;
46use hyperactor::supervision::ActorSupervisionEvent;
47use ndslice::Range;
48use ndslice::Shape;
49use ndslice::ShapeError;
50use tokio::sync::mpsc;
51
52use crate::CommActor;
53use crate::Mesh;
54use crate::actor_mesh::CAST_ACTOR_MESH_ID;
55use crate::actor_mesh::RootActorMesh;
56use crate::alloc::Alloc;
57use crate::alloc::AllocatorError;
58use crate::alloc::ProcState;
59use crate::alloc::ProcStopReason;
60use crate::assign::Ranks;
61use crate::comm::CommActorMode;
62use crate::proc_mesh::mesh_agent::GspawnResult;
63use crate::proc_mesh::mesh_agent::MeshAgent;
64use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
65use crate::proc_mesh::mesh_agent::StopActorResult;
66use crate::reference::ProcMeshId;
67use crate::shortuuid::ShortUuid;
68
69pub mod mesh_agent;
70
71use std::sync::OnceLock;
72
73/// A global router shared by all meshes managed in this process;
74/// this allows different meshes to communicate with each other.
75///
76/// This is definitely a "good enough for now" solution; in the future,
77/// we'll likely have some form of truly global registration for meshes,
78/// also benefitting tooling, etc.
79pub(crate) fn global_router() -> &'static MailboxRouter {
80    static GLOBAL_ROUTER: OnceLock<MailboxRouter> = OnceLock::new();
81    GLOBAL_ROUTER.get_or_init(MailboxRouter::new)
82}
83
84/// Context use by root client to send messages.
85/// This mailbox allows us to open ports before we know which proc the
86/// messages will be sent to.
87pub fn global_root_client() -> &'static Instance<()> {
88    static GLOBAL_INSTANCE: OnceLock<(Instance<()>, ActorHandle<()>)> = OnceLock::new();
89    let (instance, _) = GLOBAL_INSTANCE.get_or_init(|| {
90        let world_id = WorldId(ShortUuid::generate().to_string());
91        let client_proc_id = ProcId::Ranked(world_id.clone(), 0);
92        let client_proc = Proc::new(
93            client_proc_id.clone(),
94            BoxedMailboxSender::new(global_router().clone()),
95        );
96        global_router().bind(world_id.clone().into(), client_proc.clone());
97        client_proc
98            .instance("client")
99            .expect("root instance create")
100    });
101    instance
102}
103
104type ActorEventRouter = Arc<DashMap<ActorMeshName, mpsc::UnboundedSender<ActorSupervisionEvent>>>;
105/// A ProcMesh maintains a mesh of procs whose lifecycles are managed by
106/// an allocator.
107pub struct ProcMesh {
108    // The underlying set of events. It is None if it has been transferred to
109    // a proc event observer.
110    event_state: Option<EventState>,
111    actor_event_router: ActorEventRouter,
112    shape: Shape,
113    ranks: Vec<(ProcId, (ChannelAddr, ActorRef<MeshAgent>))>,
114    #[allow(dead_code)] // will be used in subsequent diff
115    client_proc: Proc,
116    client: Mailbox,
117    comm_actors: Vec<ActorRef<CommActor>>,
118    world_id: WorldId,
119}
120
121struct EventState {
122    alloc: Box<dyn Alloc + Send + Sync>,
123    supervision_events: PortReceiver<ActorSupervisionEvent>,
124}
125
126impl ProcMesh {
127    pub async fn allocate(
128        alloc: impl Alloc + Send + Sync + 'static,
129    ) -> Result<Self, AllocatorError> {
130        ProcMesh::allocate_boxed(Box::new(alloc)).await
131    }
132    /// Allocate a new ProcMesh from the provided allocator. Allocate returns
133    /// after the mesh has been successfully (and fully) allocated, returning
134    /// early on any allocation failure.
135    pub async fn allocate_boxed(
136        mut alloc: Box<dyn Alloc + Send + Sync>,
137    ) -> Result<Self, AllocatorError> {
138        // We wait for the full allocation to be running before returning the mesh.
139        let shape = alloc.shape().clone();
140
141        let mut proc_ids = Ranks::new(shape.slice().len());
142        let mut running = Ranks::new(shape.slice().len());
143
144        while !running.is_full() {
145            let Some(state) = alloc.next().await else {
146                // Alloc finished before it was fully allocated.
147                return Err(AllocatorError::Incomplete(alloc.extent().clone()));
148            };
149
150            match state {
151                ProcState::Created { proc_id, point, .. } => {
152                    let rank = shape
153                        .slice()
154                        .location(point.coords())
155                        .map_err(|err| AllocatorError::Other(err.into()))?;
156                    if let Some(old_proc_id) = proc_ids.insert(rank, proc_id.clone()) {
157                        tracing::warn!("rank {rank} reassigned from {old_proc_id} to {proc_id}");
158                    }
159                    tracing::info!("proc {} rank {}: created", proc_id, rank);
160                }
161                ProcState::Running {
162                    proc_id,
163                    mesh_agent,
164                    addr,
165                } => {
166                    let Some(rank) = proc_ids.rank(&proc_id) else {
167                        tracing::warn!("proc id {proc_id} running, but not created");
168                        continue;
169                    };
170
171                    if let Some((old_addr, old_mesh_agent)) =
172                        running.insert(*rank, (addr.clone(), mesh_agent.clone()))
173                    {
174                        tracing::warn!(
175                            "duplicate running notifications for {proc_id}, addr:{addr}, mesh_agent:{mesh_agent}, old addr:{old_addr}, old mesh_agent:{old_mesh_agent}"
176                        )
177                    }
178                    tracing::info!(
179                        "proc {} rank {}: running at addr:{addr} mesh_agent:{mesh_agent}",
180                        proc_id,
181                        rank
182                    );
183                }
184                // TODO: We should push responsibility to the allocator, which
185                // can choose to either provide a new proc or emit a
186                // ProcState::Failed to fail the whole allocation.
187                ProcState::Stopped { proc_id, reason } => {
188                    tracing::error!("allocation failed for proc_id {}: {}", proc_id, reason);
189                    return Err(AllocatorError::Other(anyhow::Error::msg(reason)));
190                }
191                ProcState::Failed {
192                    world_id,
193                    description,
194                } => {
195                    tracing::error!("allocation failed for world {}: {}", world_id, description);
196                    return Err(AllocatorError::Other(anyhow::Error::msg(description)));
197                }
198            }
199        }
200
201        // We collect all the ranks at this point of completion, so that we can
202        // avoid holding Rcs across awaits.
203        let running: Vec<_> = running.into_iter().map(Option::unwrap).collect();
204
205        // All procs are running, so we now configure them.
206        let mut world_ids = HashSet::new();
207
208        let (router_channel_addr, router_rx) = channel::serve(ChannelAddr::any(alloc.transport()))
209            .await
210            .map_err(|err| AllocatorError::Other(err.into()))?;
211        tracing::info!("router channel started listening on addr: {router_channel_addr}");
212        let router = DialMailboxRouter::new_with_default(global_router().boxed());
213        for (rank, (addr, _agent)) in running.iter().enumerate() {
214            let proc_id = proc_ids.get(rank).unwrap().clone();
215            router.bind(Reference::Proc(proc_id.clone()), addr.clone());
216            // Work around for Allocs that have more than one world.
217            world_ids.insert(
218                proc_id
219                    .world_id()
220                    .expect("proc in running state must be ranked")
221                    .clone(),
222            );
223        }
224        router.clone().serve(router_rx);
225
226        // Set up a client proc for the mesh itself, so that we can attach ourselves
227        // to it, and communicate with the agents. We wire it into the same router as
228        // everything else, so now the whole mesh should be able to communicate.
229        let client_proc_id =
230            ProcId::Ranked(WorldId(format!("{}_manager", alloc.world_id().name())), 0);
231        let (client_proc_addr, client_rx) = channel::serve(ChannelAddr::any(alloc.transport()))
232            .await
233            .map_err(|err| AllocatorError::Other(err.into()))?;
234        tracing::info!("client proc started listening on addr: {client_proc_addr}");
235        let client_proc = Proc::new(
236            client_proc_id.clone(),
237            BoxedMailboxSender::new(router.clone()),
238        );
239        client_proc.clone().serve(client_rx);
240        router.bind(client_proc_id.clone().into(), client_proc_addr.clone());
241
242        // Bind this router to the global router, to enable cross-mesh routing.
243        // TODO: unbind this when we incorporate mesh destruction too.
244        for world_id in world_ids {
245            global_router().bind(world_id.into(), router.clone());
246        }
247        global_router().bind(alloc.world_id().clone().into(), router.clone());
248        global_router().bind(client_proc_id.into(), router.clone());
249
250        // TODO: No actor bound to "supervisor" yet.
251        let supervisor = client_proc.attach("supervisor")?;
252        let (supervision_port, supervision_events) =
253            supervisor.open_port::<ActorSupervisionEvent>();
254
255        // Now, configure the full mesh, so that the local agents are
256        // wired up to our router.
257        // TODO: No actor bound to "client" yet.
258        let client = client_proc.attach("client")?;
259        // Bind an undeliverable message port in the client.
260        let (undeliverable_messages, client_undeliverable_receiver) =
261            client.open_port::<Undeliverable<MessageEnvelope>>();
262        undeliverable_messages.bind_to(Undeliverable::<MessageEnvelope>::port());
263        // Monitor undeliverable messages from the client and emit
264        // corresponding actor supervision events via the supervision
265        // port.
266        hyperactor::mailbox::supervise_undeliverable_messages(
267            supervision_port.clone(),
268            client_undeliverable_receiver,
269        );
270
271        // Map of procs -> channel addresses
272        let address_book: HashMap<_, _> = running
273            .iter()
274            .map(|(addr, agent)| (agent.actor_id().proc_id().clone(), addr.clone()))
275            .collect();
276
277        let (config_handle, mut config_receiver) = client.open_port();
278        for (rank, (_, agent)) in running.iter().enumerate() {
279            agent
280                .configure(
281                    &client,
282                    rank,
283                    router_channel_addr.clone(),
284                    supervision_port.bind(),
285                    address_book.clone(),
286                    config_handle.bind(),
287                )
288                .await?;
289        }
290        let mut completed = Ranks::new(shape.slice().len());
291        while !completed.is_full() {
292            let rank = config_receiver
293                .recv()
294                .await
295                .map_err(|err| AllocatorError::Other(err.into()))?;
296            if completed.insert(rank, rank).is_some() {
297                tracing::warn!("multiple completions received for rank {}", rank);
298            }
299        }
300
301        // For reasons I fail to fully understand, the below call fails
302        // when invoked from `pyo3_async_runtimes::tokio::future_into_py`
303        // when using a closure. It appears to be some subtle failure of
304        // the compiler to unify lifetimes. If we use a function instead,
305        // it does better.
306        //
307        // Interestingly, this only appears to fail in *specific* caller
308        // contexts (e.g., https://fburl.com/code/evfgtfx1), and the error
309        // is reported there as "implementation of `std::ops::FnOnce` is not general enough",
310        // suggesting some failure of modularity in the compiler's lifetime
311        // unification!
312        //
313        // Baffling and unsettling.
314        fn project_actor_ref(pair: &(ChannelAddr, ActorRef<MeshAgent>)) -> ActorRef<MeshAgent> {
315            pair.1.clone()
316        }
317
318        // Spawn a comm actor on each proc, so that they can be used
319        // to perform tree distribution and accumulation.
320        let comm_actors = Self::spawn_on_procs::<CommActor>(
321            &client,
322            running.iter().map(project_actor_ref),
323            "comm",
324            &Default::default(),
325        )
326        .await?;
327        let address_book: HashMap<_, _> = comm_actors.iter().cloned().enumerate().collect();
328        // Now that we have all of the spawned comm actors, kick them all into
329        // mesh mode.
330        for (rank, comm_actor) in comm_actors.iter().enumerate() {
331            comm_actor
332                .send(&client, CommActorMode::Mesh(rank, address_book.clone()))
333                .map_err(anyhow::Error::from)?;
334        }
335
336        let shape = alloc.shape().clone();
337        let world_id = alloc.world_id().clone();
338        metrics::PROC_MESH_ALLOCATION.add(1, hyperactor_telemetry::kv_pairs!());
339
340        Ok(Self {
341            event_state: Some(EventState {
342                alloc,
343                supervision_events,
344            }),
345            actor_event_router: Arc::new(DashMap::new()),
346            shape,
347            ranks: proc_ids
348                .into_iter()
349                .map(Option::unwrap)
350                .zip(running.into_iter())
351                .collect(),
352            client_proc,
353            client,
354            comm_actors,
355            world_id,
356        })
357    }
358
359    async fn spawn_on_procs<A: Actor + RemoteActor>(
360        cx: &(impl cap::CanSend + cap::CanOpenPort),
361        agents: impl IntoIterator<Item = ActorRef<MeshAgent>> + '_,
362        actor_name: &str,
363        params: &A::Params,
364    ) -> Result<Vec<ActorRef<A>>, anyhow::Error>
365    where
366        A::Params: RemoteMessage,
367    {
368        let remote = Remote::collect();
369        let actor_type = remote
370            .name_of::<A>()
371            .ok_or(anyhow::anyhow!("actor not registered"))?
372            .to_string();
373
374        let (completed_handle, mut completed_receiver) = mailbox::open_port(cx);
375        let mut n = 0;
376        for agent in agents {
377            agent
378                .gspawn(
379                    cx,
380                    actor_type.clone(),
381                    actor_name.to_string(),
382                    bincode::serialize(params)?,
383                    completed_handle.bind(),
384                )
385                .await?;
386            n += 1;
387        }
388        let mut completed = Ranks::new(n);
389        while !completed.is_full() {
390            let result = completed_receiver.recv().await?;
391            match result {
392                GspawnResult::Success { rank, actor_id } => {
393                    if completed.insert(rank, actor_id).is_some() {
394                        tracing::warn!("multiple completions received for rank {}", rank);
395                    }
396                }
397                GspawnResult::Error(error_msg) => {
398                    metrics::PROC_MESH_ACTOR_FAILURES.add(
399                        1,
400                        hyperactor_telemetry::kv_pairs!(
401                            "actor_name" => actor_name.to_string(),
402                            "error" => error_msg.clone(),
403                        ),
404                    );
405
406                    anyhow::bail!("gspawn failed: {}", error_msg);
407                }
408            }
409        }
410
411        // `Ranks` really should have some way to convert into a "completed" rank
412        // in a one-shot way; the API here is too awkward otherwise.
413        Ok(completed
414            .into_iter()
415            .map(Option::unwrap)
416            .map(ActorRef::attest)
417            .collect())
418    }
419
420    fn agents(&self) -> impl Iterator<Item = ActorRef<MeshAgent>> + '_ {
421        self.ranks.iter().map(|(_, (_, agent))| agent.clone())
422    }
423
424    /// Return the comm actor to which casts should be forwarded.
425    pub(crate) fn comm_actor(&self) -> &ActorRef<CommActor> {
426        &self.comm_actors[0]
427    }
428
429    /// Spawn an `ActorMesh` by launching the same actor type on all
430    /// agents, using the **same** parameters instance for every
431    /// actor.
432    ///
433    /// - `actor_name`: Name for all spawned actors.
434    /// - `params`: Reference to the parameter struct, reused for all
435    ///   actors.
436    pub async fn spawn<A: Actor + RemoteActor>(
437        &self,
438        actor_name: &str,
439        params: &A::Params,
440    ) -> Result<RootActorMesh<'_, A>, anyhow::Error>
441    where
442        A::Params: RemoteMessage,
443    {
444        let (tx, rx) = mpsc::unbounded_channel::<ActorSupervisionEvent>();
445        {
446            // Instantiate supervision routing BEFORE spawning the actor mesh.
447            self.actor_event_router.insert(actor_name.to_string(), tx);
448        }
449        let root_mesh = RootActorMesh::new(
450            self,
451            actor_name.to_string(),
452            rx,
453            Self::spawn_on_procs::<A>(&self.client, self.agents(), actor_name, params).await?,
454        );
455        Ok(root_mesh)
456    }
457
458    /// A client used to communicate with any member of this mesh.
459    pub fn client(&self) -> &Mailbox {
460        &self.client
461    }
462
463    pub fn client_proc(&self) -> &Proc {
464        &self.client_proc
465    }
466
467    pub fn proc_id(&self) -> &ProcId {
468        self.client_proc.proc_id()
469    }
470
471    pub fn world_id(&self) -> &WorldId {
472        &self.world_id
473    }
474
475    /// An event stream of proc events. Each ProcMesh can produce only one such
476    /// stream, returning None after the first call.
477    pub fn events(&mut self) -> Option<ProcEvents> {
478        self.event_state.take().map(|event_state| ProcEvents {
479            event_state,
480            ranks: self
481                .ranks
482                .iter()
483                .enumerate()
484                .map(|(rank, (proc_id, _))| (proc_id.clone(), rank))
485                .collect(),
486            actor_event_router: self.actor_event_router.clone(),
487        })
488    }
489
490    pub fn shape(&self) -> &Shape {
491        &self.shape
492    }
493
494    /// Send stop actors message to all mesh agents for a specific mesh name
495    pub async fn stop_actor_by_name(&self, mesh_name: &str) -> Result<(), anyhow::Error> {
496        let timeout = hyperactor::config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT);
497        let results = join_all(self.agents().map(|agent| async move {
498            let actor_id = ActorId(agent.actor_id().proc_id().clone(), mesh_name.to_string(), 0);
499            (
500                actor_id.clone(),
501                agent
502                    .clone()
503                    .stop_actor(&self.client, actor_id, timeout.as_millis() as u64)
504                    .await,
505            )
506        }))
507        .await;
508
509        for (actor_id, result) in results {
510            match result {
511                Ok(StopActorResult::Timeout) => {
512                    tracing::warn!("timed out while stopping actor {}", actor_id);
513                }
514                Ok(StopActorResult::NotFound) => {
515                    tracing::warn!("no actor {} on proc {}", actor_id, actor_id.proc_id());
516                }
517                Ok(StopActorResult::Success) => {
518                    tracing::info!("stopped actor {}", actor_id);
519                }
520                Err(e) => {
521                    tracing::warn!("error stopping actor {}: {}", actor_id, e);
522                }
523            }
524        }
525        Ok(())
526    }
527}
528
529/// Proc lifecycle events.
530#[derive(Debug, Clone)]
531pub enum ProcEvent {
532    /// The proc of the given rank was stopped with the provided reason.
533    Stopped(usize, ProcStopReason),
534    /// The proc crashed, with the provided "reason". This is reserved for
535    /// unhandled supervision events.
536    Crashed(usize, String),
537}
538
539impl fmt::Display for ProcEvent {
540    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
541        match self {
542            ProcEvent::Stopped(rank, reason) => {
543                write!(f, "Proc at rank {} stopped: {}", rank, reason)
544            }
545            ProcEvent::Crashed(rank, reason) => {
546                write!(f, "Proc at rank {} crashed: {}", rank, reason)
547            }
548        }
549    }
550}
551
552type ActorMeshName = String;
553
554/// An event stream of [`ProcEvent`]
555// TODO: consider using streams for this.
556pub struct ProcEvents {
557    event_state: EventState,
558    ranks: HashMap<ProcId, usize>,
559    actor_event_router: ActorEventRouter,
560}
561
562impl ProcEvents {
563    /// Get the next lifecycle event. The stream is closed when this method
564    /// returns `None`.
565    pub async fn next(&mut self) -> Option<ProcEvent> {
566        loop {
567            tokio::select! {
568                result = self.event_state.alloc.next() => {
569                    tracing::debug!("received ProcEvent alloc update: {result:?}");
570                    // Don't disable the outer branch on None: this is always terminal.
571                    let Some(alloc_event) = result else {
572                        self.actor_event_router.clear();
573                        break None;
574                    };
575
576                    let ProcState::Stopped { proc_id, reason } = alloc_event else {
577                        // Ignore non-stopped events for now.
578                        continue;
579                    };
580
581                    let Some(rank) = self.ranks.get(&proc_id) else {
582                        tracing::warn!("received stop event for unmapped proc {}", proc_id);
583                        continue;
584                    };
585
586                    metrics::PROC_MESH_PROC_STOPPED.add(
587                        1,
588                        hyperactor_telemetry::kv_pairs!(
589                            "proc_id" => proc_id.to_string(),
590                            "rank" => rank.to_string(),
591                            "reason" => reason.to_string(),
592                        ),
593                    );
594
595                    // Need to send this event to actor meshes to notify them of the proc's death.
596                    // TODO(albertli): only send this event to all root actor meshes if any of them use this proc.
597                    for entry in self.actor_event_router.iter() {
598                        // Make a dummy actor supervision event, all actors on the proc are affected if a proc stops.
599                        // TODO(T231868026): find a better way to represent all actors in a proc for supervision event
600                        let event = ActorSupervisionEvent {
601                            actor_id: proc_id.actor_id("any", 0),
602                            actor_status: ActorStatus::Failed(format!("proc {} is stopped", proc_id)),
603                            message_headers: None,
604                            caused_by: None,
605                        };
606                        if entry.value().send(event).is_err() {
607                            tracing::warn!("unable to transmit supervision event to actor {}", entry.key());
608                        }
609                    }
610
611                    break Some(ProcEvent::Stopped(*rank, reason));
612                }
613                Ok(mut event) = self.event_state.supervision_events.recv() => {
614                    tracing::debug!("received ProcEvent supervision event: {event:?}");
615                    // Cast message might fail to deliver when it is propagated
616                    // through the comm actor tree. In this case, the event is
617                    // for the actor mesh, not the comm actor. In that case,
618                    // we update the event with the actor mesh id, so it can be
619                    // forwarded to the mesh.
620                    if let Some(headers) = &event.message_headers
621                        && let Some(actor_mesh_id) = headers.get(CAST_ACTOR_MESH_ID)
622                    {
623                        // Make a dummy actor id to represent the mesh in ActorSupervisionEvent.
624                        // TODO(T231868026): find a better way to represent all actors in an actor
625                        // mesh for supervision event
626                        event.actor_id = ActorId(
627                            ProcId::Ranked(WorldId(actor_mesh_id.0.0.clone()), 0),
628                            actor_mesh_id.1.clone(),
629                            0,
630                        );
631                    };
632                    let actor_id = event.actor_id.clone();
633                    let actor_status = event.actor_status.clone();
634                    let reason = event.to_string();
635                    let Some(rank) = self.ranks.get(actor_id.proc_id()) else {
636                        tracing::warn!("received supervision event for unmapped actor {}", actor_id);
637                        continue;
638                    };
639                    // transmit to the correct root actor mesh.
640                    {
641                        if let Some(tx) = self.actor_event_router.get(actor_id.name()) {
642                            if tx.send(event).is_err() {
643                                tracing::warn!("unable to transmit supervision event to actor {}", actor_id);
644                            }
645                        } else {
646                            tracing::warn!("received supervision event for unregistered actor {}", actor_id);
647                        }
648                    }
649                    metrics::PROC_MESH_ACTOR_FAILURES.add(
650                        1,
651                        hyperactor_telemetry::kv_pairs!(
652                            "actor_id" => actor_id.to_string(),
653                            "rank" => rank.to_string(),
654                            "status" => actor_status.to_string(),
655                        ),
656                    );
657
658                    // Send this event to Python proc mesh to keep its health status up to date.
659                    break Some(ProcEvent::Crashed(*rank, reason))
660                }
661            }
662        }
663    }
664
665    pub fn into_alloc(self) -> Box<dyn Alloc + Send + Sync> {
666        self.event_state.alloc
667    }
668}
669
670/// Spawns from shared ([`Arc`]) proc meshes, providing [`ActorMesh`]es with
671/// static lifetimes.
672#[async_trait]
673pub trait SharedSpawnable {
674    async fn spawn<A: Actor + RemoteActor>(
675        self,
676        actor_name: &str,
677        params: &A::Params,
678    ) -> Result<RootActorMesh<'static, A>, anyhow::Error>
679    where
680        A::Params: RemoteMessage;
681}
682
683#[async_trait]
684impl<D: Deref<Target = ProcMesh> + Send + Sync + 'static> SharedSpawnable for D {
685    async fn spawn<A: Actor + RemoteActor>(
686        self,
687        actor_name: &str,
688        params: &A::Params,
689    ) -> Result<RootActorMesh<'static, A>, anyhow::Error>
690    where
691        A::Params: RemoteMessage,
692    {
693        let (tx, rx) = mpsc::unbounded_channel::<ActorSupervisionEvent>();
694        {
695            // Instantiate supervision routing BEFORE spawning the actor mesh.
696            self.actor_event_router.insert(actor_name.to_string(), tx);
697        }
698        let ranks =
699            ProcMesh::spawn_on_procs::<A>(&self.client, self.agents(), actor_name, params).await?;
700        Ok(RootActorMesh::new_shared(
701            self,
702            actor_name.to_string(),
703            rx,
704            ranks,
705        ))
706    }
707}
708
709#[async_trait]
710impl Mesh for ProcMesh {
711    type Node = ProcId;
712    type Id = ProcMeshId;
713    type Sliced<'a> = SlicedProcMesh<'a>;
714
715    fn shape(&self) -> &Shape {
716        &self.shape
717    }
718
719    fn select<R: Into<Range>>(
720        &self,
721        label: &str,
722        range: R,
723    ) -> Result<Self::Sliced<'_>, ShapeError> {
724        Ok(SlicedProcMesh(self, self.shape().select(label, range)?))
725    }
726
727    fn get(&self, rank: usize) -> Option<ProcId> {
728        Some(self.ranks[rank].0.clone())
729    }
730
731    fn id(&self) -> Self::Id {
732        ProcMeshId(self.world_id().name().to_string())
733    }
734}
735
736impl fmt::Display for ProcMesh {
737    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
738        write!(f, "{{ shape: {} }}", self.shape())
739    }
740}
741
742impl fmt::Debug for ProcMesh {
743    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
744        f.debug_struct("ProcMesh")
745            .field("shape", &self.shape())
746            .field("ranks", &self.ranks)
747            .field("client_proc", &self.client_proc)
748            .field("client", &self.client)
749            // Skip the alloc field since it doesn't implement Debug
750            .finish()
751    }
752}
753
754pub struct SlicedProcMesh<'a>(&'a ProcMesh, Shape);
755
756#[async_trait]
757impl Mesh for SlicedProcMesh<'_> {
758    type Node = ProcId;
759    type Id = ProcMeshId;
760    type Sliced<'b>
761        = SlicedProcMesh<'b>
762    where
763        Self: 'b;
764
765    fn shape(&self) -> &Shape {
766        &self.1
767    }
768
769    fn select<R: Into<Range>>(
770        &self,
771        label: &str,
772        range: R,
773    ) -> Result<Self::Sliced<'_>, ShapeError> {
774        Ok(Self(self.0, self.1.select(label, range)?))
775    }
776
777    fn get(&self, _index: usize) -> Option<ProcId> {
778        unimplemented!()
779    }
780
781    fn id(&self) -> Self::Id {
782        self.0.id()
783    }
784}
785
786#[cfg(test)]
787mod tests {
788    use std::assert_matches::assert_matches;
789
790    use hyperactor::actor::ActorStatus;
791    use ndslice::extent;
792
793    use super::*;
794    use crate::actor_mesh::ActorMesh;
795    use crate::actor_mesh::test_util::Error;
796    use crate::actor_mesh::test_util::TestActor;
797    use crate::alloc::AllocSpec;
798    use crate::alloc::Allocator;
799    use crate::alloc::local::LocalAllocator;
800    use crate::sel_from_shape;
801
802    #[tokio::test]
803    async fn test_basic() {
804        let alloc = LocalAllocator
805            .allocate(AllocSpec {
806                extent: extent!(replica = 4),
807                constraints: Default::default(),
808            })
809            .await
810            .unwrap();
811
812        let name = alloc.name().to_string();
813        let mesh = ProcMesh::allocate(alloc).await.unwrap();
814
815        assert_eq!(mesh.get(0).unwrap().world_name(), Some(name.as_str()));
816    }
817
818    #[tokio::test]
819    async fn test_propagate_lifecycle_events() {
820        let alloc = LocalAllocator
821            .allocate(AllocSpec {
822                extent: extent!(replica = 4),
823                constraints: Default::default(),
824            })
825            .await
826            .unwrap();
827
828        let stop = alloc.stopper();
829        let monkey = alloc.chaos_monkey();
830        let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
831        let mut events = mesh.events().unwrap();
832
833        monkey(1, ProcStopReason::Killed(1, false));
834        assert_matches!(
835            events.next().await.unwrap(),
836            ProcEvent::Stopped(1, ProcStopReason::Killed(1, false))
837        );
838
839        stop();
840        for _ in 0..3 {
841            assert_matches!(
842                events.next().await.unwrap(),
843                ProcEvent::Stopped(_, ProcStopReason::Stopped)
844            );
845        }
846        assert!(events.next().await.is_none());
847    }
848
849    #[tokio::test]
850    async fn test_supervision_failure() {
851        // For now, we propagate all actor failures to the proc.
852
853        let alloc = LocalAllocator
854            .allocate(AllocSpec {
855                extent: extent!(replica = 2),
856                constraints: Default::default(),
857            })
858            .await
859            .unwrap();
860        let stop = alloc.stopper();
861        let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
862        let mut events = mesh.events().unwrap();
863
864        let mut actors = mesh.spawn::<TestActor>("failing", &()).await.unwrap();
865        let mut actor_events = actors.events().unwrap();
866
867        actors
868            .cast(
869                mesh.client(),
870                sel_from_shape!(actors.shape(), replica = 0),
871                Error("failmonkey".to_string()),
872            )
873            .unwrap();
874
875        assert_matches!(
876            events.next().await.unwrap(),
877            ProcEvent::Crashed(0, reason) if reason.contains("failmonkey")
878        );
879
880        let mut event = actor_events.next().await.unwrap();
881        assert_matches!(event.actor_status, ActorStatus::Failed(_));
882        assert_eq!(event.actor_id.1, "failing".to_string());
883        assert_eq!(event.actor_id.2, 0);
884
885        stop();
886        assert_matches!(
887            events.next().await.unwrap(),
888            ProcEvent::Stopped(0, ProcStopReason::Stopped),
889        );
890        assert_matches!(
891            events.next().await.unwrap(),
892            ProcEvent::Stopped(1, ProcStopReason::Stopped),
893        );
894
895        assert!(events.next().await.is_none());
896        event = actor_events.next().await.unwrap();
897        assert_matches!(event.actor_status, ActorStatus::Failed(_));
898        assert_eq!(event.actor_id.2, 0);
899    }
900
901    #[timed_test::async_timed_test(timeout_secs = 5)]
902    async fn test_spawn_twice() {
903        let alloc = LocalAllocator
904            .allocate(AllocSpec {
905                extent: extent!(replica = 1),
906                constraints: Default::default(),
907            })
908            .await
909            .unwrap();
910        let mesh = ProcMesh::allocate(alloc).await.unwrap();
911
912        mesh.spawn::<TestActor>("dup", &()).await.unwrap();
913        let result = mesh.spawn::<TestActor>("dup", &()).await;
914        assert!(result.is_err());
915    }
916}