hyperactor_mesh/
proc_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::fmt;
11use std::ops::Deref;
12use std::panic::Location;
13use std::sync::Arc;
14use std::sync::atomic::AtomicUsize;
15use std::sync::atomic::Ordering;
16
17use async_trait::async_trait;
18use dashmap::DashMap;
19use futures::future::join_all;
20use hyperactor::Actor;
21use hyperactor::ActorHandle;
22use hyperactor::ActorId;
23use hyperactor::ActorRef;
24use hyperactor::Instance;
25use hyperactor::RemoteMessage;
26use hyperactor::WorldId;
27use hyperactor::actor::ActorStatus;
28use hyperactor::actor::Referable;
29use hyperactor::actor::remote::Remote;
30use hyperactor::channel;
31use hyperactor::channel::ChannelAddr;
32use hyperactor::channel::ChannelTransport;
33use hyperactor::config;
34use hyperactor::config::CONFIG;
35use hyperactor::config::ConfigAttr;
36use hyperactor::context;
37use hyperactor::declare_attrs;
38use hyperactor::mailbox;
39use hyperactor::mailbox::BoxableMailboxSender;
40use hyperactor::mailbox::BoxedMailboxSender;
41use hyperactor::mailbox::DialMailboxRouter;
42use hyperactor::mailbox::MailboxServer;
43use hyperactor::mailbox::MessageEnvelope;
44use hyperactor::mailbox::PortHandle;
45use hyperactor::mailbox::PortReceiver;
46use hyperactor::mailbox::Undeliverable;
47use hyperactor::metrics;
48use hyperactor::proc::Proc;
49use hyperactor::reference::ProcId;
50use hyperactor::supervision::ActorSupervisionEvent;
51use ndslice::Range;
52use ndslice::Shape;
53use ndslice::ShapeError;
54use ndslice::View;
55use ndslice::ViewExt;
56use strum::AsRefStr;
57use tokio::sync::mpsc;
58use tracing::Instrument;
59use tracing::Level;
60use tracing::span;
61
62use crate::CommActor;
63use crate::Mesh;
64use crate::actor_mesh::RootActorMesh;
65use crate::alloc::Alloc;
66use crate::alloc::AllocExt;
67use crate::alloc::AllocatedProc;
68use crate::alloc::AllocatorError;
69use crate::alloc::ProcState;
70use crate::alloc::ProcStopReason;
71use crate::alloc::serve_with_config;
72use crate::assign::Ranks;
73use crate::comm::CommActorMode;
74use crate::proc_mesh::mesh_agent::GspawnResult;
75use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
76use crate::proc_mesh::mesh_agent::ProcMeshAgent;
77use crate::proc_mesh::mesh_agent::StopActorResult;
78use crate::proc_mesh::mesh_agent::update_event_actor_id;
79use crate::reference::ProcMeshId;
80use crate::router;
81use crate::shortuuid::ShortUuid;
82use crate::v1;
83use crate::v1::Name;
84
85pub mod mesh_agent;
86
87use std::sync::OnceLock;
88use std::sync::RwLock;
89
90declare_attrs! {
91    /// Default transport type to use across the application.
92    @meta(CONFIG = ConfigAttr {
93        env_name: Some("HYPERACTOR_MESH_DEFAULT_TRANSPORT".to_string()),
94        py_name: Some("default_transport".to_string()),
95    })
96    pub attr DEFAULT_TRANSPORT: ChannelTransport = ChannelTransport::Unix;
97}
98
99/// Get the default transport type to use across the application.
100pub fn default_transport() -> ChannelTransport {
101    config::global::get_cloned(DEFAULT_TRANSPORT)
102}
103
104/// Single, process-wide supervision sink storage.
105///
106/// This is a pragmatic "good enough for now" global used to route
107/// undeliverables observed by the process-global root client (c.f.
108/// [`global_root_client`])to the *currently active* `ProcMesh`. Newer
109/// meshes override older ones ("last sink wins").
110static GLOBAL_SUPERVISION_SINK: OnceLock<RwLock<Option<PortHandle<ActorSupervisionEvent>>>> =
111    OnceLock::new();
112
113/// Returns the lazily-initialized container that holds the current
114/// process-global supervision sink.
115///
116/// Internal helper: callers should use `set_global_supervision_sink`
117/// and `get_global_supervision_sink` instead.
118fn sink_cell() -> &'static RwLock<Option<PortHandle<ActorSupervisionEvent>>> {
119    GLOBAL_SUPERVISION_SINK.get_or_init(|| RwLock::new(None))
120}
121
122/// Install (or replace) the process-global supervision sink.
123///
124/// This function enforces "last sink wins" semantics: if a sink was
125/// already installed, it is replaced and the previous sink is
126/// returned. Called from `ProcMesh::allocate_boxed`, after creating
127/// the mesh's supervision port.
128///
129/// Returns:
130/// - `Some(prev)` if a prior sink was installed, allowing the caller
131///   to log/inspect it if desired;
132/// - `None` if this is the first sink.
133///
134/// Thread-safety: takes a write lock briefly to swap the handle.
135pub(crate) fn set_global_supervision_sink(
136    sink: PortHandle<ActorSupervisionEvent>,
137) -> Option<PortHandle<ActorSupervisionEvent>> {
138    let cell = sink_cell();
139    let mut guard = cell.write().unwrap();
140    let prev = guard.take();
141    *guard = Some(sink);
142    prev
143}
144
145/// Get a clone of the current process-global supervision sink, if
146/// any.
147///
148/// This is used by the process-global root client [c.f.
149/// `global_root_client`] to forward undeliverables once a mesh has
150/// installed its sink. If no sink has been installed yet, returns
151/// `None` and callers should defer/ignore forwarding until one
152/// appears.
153///
154/// Thread-safety: takes a read lock briefly; cloning the `PortHandle`
155/// is cheap.
156pub(crate) fn get_global_supervision_sink() -> Option<PortHandle<ActorSupervisionEvent>> {
157    sink_cell().read().unwrap().clone()
158}
159
160/// Context use by root client to send messages.
161/// This mailbox allows us to open ports before we know which proc the
162/// messages will be sent to.
163pub fn global_root_client() -> &'static Instance<()> {
164    static GLOBAL_INSTANCE: OnceLock<(Instance<()>, ActorHandle<()>)> = OnceLock::new();
165    &GLOBAL_INSTANCE.get_or_init(|| {
166        let client_proc = Proc::direct_with_default(
167            ChannelAddr::any(default_transport()),
168            "mesh_root_client_proc".into(),
169            router::global().clone().boxed(),
170        )
171        .unwrap();
172
173        // Make this proc reachable through the global router, so that we can use the
174        // same client in both direct-addressed and ranked-addressed modes.
175        router::global().bind(client_proc.proc_id().clone().into(), client_proc.clone());
176
177        let (client, handle) = client_proc
178            .instance("client")
179            .expect("root instance create");
180
181        // Bind the global root client's undeliverable port and
182        // forward any undeliverable messages to the currently active
183        // supervision sink.
184        //
185        // The resolver (`get_global_supervision_sink`) is passed as a
186        // function pointer, so each time an undeliverable is
187        // processed, we look up the *latest* sink. This allows the
188        // root client to seamlessly track whichever ProcMesh most
189        // recently installed a supervision sink (e.g., the
190        // application mesh instead of an internal controller mesh).
191        //
192        // The hook logs each undeliverable, along with whether a sink
193        // was present at the time of receipt, which helps diagnose
194        // lost or misrouted events.
195        let (_undeliverable_tx, undeliverable_rx) =
196            client.bind_actor_port::<Undeliverable<MessageEnvelope>>();
197        hyperactor::mailbox::supervise_undeliverable_messages_with(
198            undeliverable_rx,
199            crate::proc_mesh::get_global_supervision_sink,
200            |env| {
201                let sink_present = crate::proc_mesh::get_global_supervision_sink().is_some();
202                tracing::info!(
203                    actor_id = %env.dest().actor_id(),
204                    "global root client undeliverable observed with headers {:?} {}", env.headers(), sink_present
205                );
206            },
207        );
208
209        (client, handle)
210    }).0
211}
212
213type ActorEventRouter = Arc<DashMap<ActorMeshName, mpsc::UnboundedSender<ActorSupervisionEvent>>>;
214
215/// A ProcMesh maintains a mesh of procs whose lifecycles are managed by
216/// an allocator.
217pub struct ProcMesh {
218    inner: ProcMeshKind,
219    shape: OnceLock<Shape>,
220}
221
222enum ProcMeshKind {
223    V0 {
224        // The underlying set of events. It is None if it has been transferred to
225        // a proc event observer.
226        event_state: Option<EventState>,
227        actor_event_router: ActorEventRouter,
228        shape: Shape,
229        ranks: Vec<(ShortUuid, ProcId, ChannelAddr, ActorRef<ProcMeshAgent>)>,
230        #[allow(dead_code)] // will be used in subsequent diff
231        client_proc: Proc,
232        client: Instance<()>,
233        comm_actors: Vec<ActorRef<CommActor>>,
234        world_id: WorldId,
235    },
236
237    V1(v1::ProcMeshRef),
238}
239
240struct EventState {
241    alloc: Box<dyn Alloc + Send + Sync>,
242    supervision_events: PortReceiver<ActorSupervisionEvent>,
243}
244
245impl From<v1::ProcMeshRef> for ProcMesh {
246    fn from(proc_mesh: v1::ProcMeshRef) -> Self {
247        ProcMesh {
248            inner: ProcMeshKind::V1(proc_mesh),
249            shape: OnceLock::new(),
250        }
251    }
252}
253
254impl ProcMesh {
255    #[hyperactor::instrument(fields(name = "ProcMesh::allocate"))]
256    pub async fn allocate(
257        alloc: impl Alloc + Send + Sync + 'static,
258    ) -> Result<Self, AllocatorError> {
259        ProcMesh::allocate_boxed(Box::new(alloc)).await
260    }
261
262    /// Allocate a new ProcMesh from the provided allocator. Allocate returns
263    /// after the mesh has been successfully (and fully) allocated, returning
264    /// early on any allocation failure.
265    #[track_caller]
266    pub fn allocate_boxed(
267        alloc: Box<dyn Alloc + Send + Sync>,
268    ) -> impl std::future::Future<Output = Result<Self, AllocatorError>> {
269        Self::allocate_boxed_inner(alloc, Location::caller())
270    }
271
272    fn alloc_counter() -> &'static AtomicUsize {
273        static C: OnceLock<AtomicUsize> = OnceLock::new();
274        C.get_or_init(|| AtomicUsize::new(0))
275    }
276
277    #[tracing::instrument(skip_all)]
278    #[hyperactor::observe_result("ProcMesh")]
279    async fn allocate_boxed_inner(
280        mut alloc: Box<dyn Alloc + Send + Sync>,
281        loc: &'static Location<'static>,
282    ) -> Result<Self, AllocatorError> {
283        let alloc_id = Self::alloc_counter().fetch_add(1, Ordering::Relaxed) + 1;
284        let world = alloc.world_id().name().to_string();
285        tracing::info!(
286            name = "ProcMesh::Allocate::Attempt",
287            %world,
288            alloc_id,
289            caller = %format!("{}:{}", loc.file(), loc.line()),
290            shape = ?alloc.shape(),
291            "allocating proc mesh"
292        );
293
294        // 1. Initialize the alloc, producing the initial set of ranked procs:
295        let running = alloc
296            .initialize()
297            .instrument(span!(
298                Level::INFO,
299                "ProcMesh::Allocate::Initialize",
300                alloc_id
301            ))
302            .await?;
303
304        // 2. Set up routing to the initialized procs; these require dialing.
305        // let router = DialMailboxRouter::new();
306        let router = DialMailboxRouter::new_with_default(router::global().boxed());
307        for AllocatedProc { proc_id, addr, .. } in running.iter() {
308            if proc_id.is_direct() {
309                continue;
310            }
311            router.bind(proc_id.clone().into(), addr.clone());
312        }
313
314        // 3. Set up a client proc for the mesh itself, so that we can attach ourselves
315        //    to it, and communicate with the agents. We wire it into the same router as
316        //    everything else, so now the whole mesh should be able to communicate.
317        let client_proc_id =
318            ProcId::Ranked(WorldId(format!("{}_client", alloc.world_id().name())), 0);
319        let (client_proc_addr, client_rx) = channel::serve(ChannelAddr::any(alloc.transport()))
320            .map_err(|err| AllocatorError::Other(err.into()))?;
321        tracing::info!(
322            name = "ProcMesh::Allocate::ChannelServe",
323            alloc_id = alloc_id,
324            "client proc started listening on addr: {client_proc_addr}"
325        );
326        let client_proc = Proc::new(
327            client_proc_id.clone(),
328            BoxedMailboxSender::new(router.clone()),
329        );
330        client_proc.clone().serve(client_rx);
331        router.bind(client_proc_id.clone().into(), client_proc_addr.clone());
332
333        // 4. Bind the dial router to the global router, so that everything is
334        //    connected to a single root.
335        router::global().bind_dial_router(&router);
336
337        let (supervisor, _supervisor_handle) = client_proc.instance("supervisor")?;
338        let (supervision_port, supervision_events) =
339            supervisor.open_port::<ActorSupervisionEvent>();
340
341        // 5. Install this mesh’s supervision sink.
342        //
343        // We intentionally use "last sink wins": if multiple
344        // ProcMeshes exist in the process (e.g., a hidden
345        // controller_controller mesh and the app/test mesh), the most
346        // recently allocated mesh’s sink replaces the prior global
347        // sink.
348        //
349        // Scope: this only affects undeliverables that arrive on the
350        // `global_root_client()` undeliverable port. Per-mesh client
351        // bindings (set up below) are unaffected and continue to
352        // forward their own undeliverables to this mesh’s
353        // `supervision_port`.
354        //
355        // NOTE: This is a pragmatic stopgap to restore correct
356        // routing with multiple meshes in-process. If/when we move to
357        // per-world root clients, this override can be removed.
358        let _prev = set_global_supervision_sink(supervision_port.clone());
359
360        // Wire this mesh’s *own* client mailbox to supervision.
361        //
362        // Attach a client mailbox for this `ProcMesh`, bind its
363        // undeliverable port, and forward those undeliverables as
364        // `ActorSupervisionEvent` records into this mesh's
365        // supervision_port.
366        //
367        // Scope: covers undeliverables observed on this mesh's client
368        // mailbox only. It does not affect other meshes or the
369        // `global_root_client()`.
370        let (client, _handle) = client_proc.instance("client")?;
371        // Bind an undeliverable message port in the client.
372        let (_undeliverable_messages, client_undeliverable_receiver) =
373            client.bind_actor_port::<Undeliverable<MessageEnvelope>>();
374        hyperactor::mailbox::supervise_undeliverable_messages(
375            supervision_port.clone(),
376            client_undeliverable_receiver,
377            |env| {
378                tracing::info!(actor=%env.dest().actor_id(), "per-mesh client undeliverable observed");
379            },
380        );
381
382        // Ensure that the router is served so that agents may reach us.
383        let (router_channel_addr, router_rx) =
384            serve_with_config(alloc.client_router_addr()).map_err(AllocatorError::Other)?;
385        router.serve(router_rx);
386        tracing::info!("router channel started listening on addr: {router_channel_addr}");
387
388        // 6. Configure the mesh agents. This transmits the address book to all agents,
389        //    so that they can resolve and route traffic to all nodes in the mesh.
390        let address_book: HashMap<_, _> = running
391            .iter()
392            .map(
393                |AllocatedProc {
394                     addr, mesh_agent, ..
395                 }| { (mesh_agent.actor_id().proc_id().clone(), addr.clone()) },
396            )
397            .collect();
398
399        let (config_handle, mut config_receiver) = client.open_port();
400        for (rank, AllocatedProc { mesh_agent, .. }) in running.iter().enumerate() {
401            mesh_agent
402                .configure(
403                    &client,
404                    rank,
405                    router_channel_addr.clone(),
406                    Some(supervision_port.bind()),
407                    address_book.clone(),
408                    config_handle.bind(),
409                    false,
410                )
411                .await?;
412        }
413        let mut completed = Ranks::new(running.len());
414        while !completed.is_full() {
415            let rank = config_receiver
416                .recv()
417                .await
418                .map_err(|err| AllocatorError::Other(err.into()))?;
419            if completed.insert(rank, rank).is_some() {
420                tracing::warn!("multiple completions received for rank {}", rank);
421            }
422        }
423
424        // For reasons I fail to fully understand, the below call fails
425        // when invoked from `pyo3_async_runtimes::tokio::future_into_py`
426        // when using a closure. It appears to be some subtle failure of
427        // the compiler to unify lifetimes. If we use a function instead,
428        // it does better.
429        //
430        // Interestingly, this only appears to fail in *specific* caller
431        // contexts (e.g., https://fburl.com/code/evfgtfx1), and the error
432        // is reported there as "implementation of `std::ops::FnOnce` is not general enough",
433        // suggesting some failure of modularity in the compiler's lifetime
434        // unification!
435        //
436        // Baffling and unsettling.
437        fn project_mesh_agent_ref(allocated_proc: &AllocatedProc) -> ActorRef<ProcMeshAgent> {
438            allocated_proc.mesh_agent.clone()
439        }
440
441        // 7. Start comm actors and set them up to communicate via the same address book.
442
443        // Spawn a comm actor on each proc, so that they can be used
444        // to perform tree distribution and accumulation.
445        let comm_actors = Self::spawn_on_procs::<CommActor>(
446            &client,
447            running.iter().map(project_mesh_agent_ref),
448            "comm",
449            &Default::default(),
450        )
451        .await?;
452        let address_book: HashMap<_, _> = comm_actors.iter().cloned().enumerate().collect();
453        // Now that we have all of the spawned comm actors, kick them all into
454        // mesh mode.
455        for (rank, comm_actor) in comm_actors.iter().enumerate() {
456            comm_actor
457                .send(&client, CommActorMode::Mesh(rank, address_book.clone()))
458                .map_err(anyhow::Error::from)?;
459        }
460
461        let shape = alloc.shape().clone();
462        let world_id = alloc.world_id().clone();
463        metrics::PROC_MESH_ALLOCATION.add(
464            running.len() as u64,
465            hyperactor_telemetry::kv_pairs!("alloc_id" => alloc_id.to_string()),
466        );
467
468        Ok(Self {
469            inner: ProcMeshKind::V0 {
470                event_state: Some(EventState {
471                    alloc,
472                    supervision_events,
473                }),
474                actor_event_router: Arc::new(DashMap::new()),
475                shape,
476                ranks: running
477                    .into_iter()
478                    .map(
479                        |AllocatedProc {
480                             create_key,
481                             proc_id,
482                             addr,
483                             mesh_agent,
484                         }| (create_key, proc_id, addr, mesh_agent),
485                    )
486                    .collect(),
487                client_proc,
488                client,
489                comm_actors,
490                world_id,
491            },
492            shape: OnceLock::new(),
493        })
494    }
495
496    /// Bounds:
497    /// - `A: Actor` - we actually spawn this concrete actor on each
498    ///   proc.
499    /// - `A: Referable` - required because we return
500    ///   `Vec<ActorRef<A>>`, and `ActorRef` is only defined for `A:
501    ///   Referable`.
502    /// - `A::Params: RemoteMessage` - params must serialize for
503    ///   cross-proc spawn.
504    async fn spawn_on_procs<A: Actor + Referable>(
505        cx: &impl context::Actor,
506        agents: impl IntoIterator<Item = ActorRef<ProcMeshAgent>> + '_,
507        actor_name: &str,
508        params: &A::Params,
509    ) -> Result<Vec<ActorRef<A>>, anyhow::Error>
510    where
511        A::Params: RemoteMessage,
512    {
513        let remote = Remote::collect();
514        let actor_type = remote
515            .name_of::<A>()
516            .ok_or(anyhow::anyhow!("actor not registered"))?
517            .to_string();
518
519        let (completed_handle, mut completed_receiver) = mailbox::open_port(cx);
520        let mut n = 0;
521        for agent in agents {
522            agent
523                .gspawn(
524                    cx,
525                    actor_type.clone(),
526                    actor_name.to_string(),
527                    bincode::serialize(params)?,
528                    completed_handle.bind(),
529                )
530                .await?;
531            n += 1;
532        }
533        let mut completed = Ranks::new(n);
534        while !completed.is_full() {
535            let result = completed_receiver.recv().await?;
536            match result {
537                GspawnResult::Success { rank, actor_id } => {
538                    if completed.insert(rank, actor_id).is_some() {
539                        tracing::warn!("multiple completions received for rank {}", rank);
540                    }
541                }
542                GspawnResult::Error(error_msg) => {
543                    metrics::PROC_MESH_ACTOR_FAILURES.add(
544                        1,
545                        hyperactor_telemetry::kv_pairs!(
546                            "actor_name" => actor_name.to_string(),
547                            "error" => error_msg.clone(),
548                        ),
549                    );
550
551                    anyhow::bail!("gspawn failed: {}", error_msg);
552                }
553            }
554        }
555
556        // `Ranks` really should have some way to convert into a "completed" rank
557        // in a one-shot way; the API here is too awkward otherwise.
558        Ok(completed
559            .into_iter()
560            .map(Option::unwrap)
561            .map(ActorRef::attest)
562            .collect())
563    }
564
565    fn agents(&self) -> Box<dyn Iterator<Item = ActorRef<ProcMeshAgent>> + '_ + Send> {
566        match &self.inner {
567            ProcMeshKind::V0 { ranks, .. } => {
568                Box::new(ranks.iter().map(|(_, _, _, agent)| agent.clone()))
569            }
570            ProcMeshKind::V1(proc_mesh) => Box::new(
571                proc_mesh
572                    .agent_mesh()
573                    .iter()
574                    .map(|(_point, agent)| agent.clone())
575                    // We need to collect here so that we can return an iterator
576                    // that fully owns the data and does not reference temporary
577                    // values.
578                    //
579                    // Because this is a shim that we expect to be short-lived,
580                    // we'll leave this hack as is; a proper solution here would
581                    // be to have implement an owning iterator (into_iter) for views.
582                    .collect::<Vec<_>>()
583                    .into_iter(),
584            ),
585        }
586    }
587
588    /// Return the comm actor to which casts should be forwarded.
589    pub(crate) fn comm_actor(&self) -> &ActorRef<CommActor> {
590        match &self.inner {
591            ProcMeshKind::V0 { comm_actors, .. } => &comm_actors[0],
592            ProcMeshKind::V1(proc_mesh) => proc_mesh.root_comm_actor().unwrap(),
593        }
594    }
595
596    /// Spawn an `ActorMesh` by launching the same actor type on all
597    /// agents, using the **same** parameters instance for every
598    /// actor.
599    ///
600    /// - `actor_name`: Name for all spawned actors.
601    /// - `params`: Reference to the parameter struct, reused for all
602    ///   actors.
603    ///
604    /// Bounds:
605    /// - `A: Actor` — we actually spawn this type on each agent.
606    /// - `A: Referable` — we return a `RootActorMesh<'_, A>` that
607    ///   contains `ActorRef<A>`s; those exist only for `A:
608    ///   Referable`.
609    /// - `A::Params: RemoteMessage` — params must be serializable to
610    ///   cross proc boundaries when launching each actor.
611    pub async fn spawn<A: Actor + Referable>(
612        &self,
613        cx: &impl context::Actor,
614        actor_name: &str,
615        params: &A::Params,
616    ) -> Result<RootActorMesh<'_, A>, anyhow::Error>
617    where
618        A::Params: RemoteMessage,
619    {
620        match &self.inner {
621            ProcMeshKind::V0 {
622                actor_event_router,
623                client,
624                ..
625            } => {
626                let (tx, rx) = mpsc::unbounded_channel::<ActorSupervisionEvent>();
627                {
628                    // Instantiate supervision routing BEFORE spawning the actor mesh.
629                    actor_event_router.insert(actor_name.to_string(), tx);
630                    tracing::info!(
631                        name = "router_insert",
632                        actor_name = %actor_name,
633                        "the length of the router is {}", actor_event_router.len(),
634                    );
635                }
636                let root_mesh = RootActorMesh::new(
637                    self,
638                    actor_name.to_string(),
639                    rx,
640                    Self::spawn_on_procs::<A>(client, self.agents(), actor_name, params).await?,
641                );
642                Ok(root_mesh)
643            }
644            ProcMeshKind::V1(proc_mesh) => {
645                let actor_mesh = proc_mesh.spawn(cx, actor_name, params).await?;
646                Ok(RootActorMesh::new_v1(actor_mesh.detach()))
647            }
648        }
649    }
650
651    /// A client actor used to communicate with any member of this mesh.
652    pub fn client(&self) -> &Instance<()> {
653        match &self.inner {
654            ProcMeshKind::V0 { client, .. } => client,
655            ProcMeshKind::V1(_proc_mesh) => unimplemented!("no client for v1::ProcMesh"),
656        }
657    }
658
659    pub fn client_proc(&self) -> &Proc {
660        match &self.inner {
661            ProcMeshKind::V0 { client_proc, .. } => client_proc,
662            ProcMeshKind::V1(_proc_mesh) => unimplemented!("no client proc for v1::ProcMesh"),
663        }
664    }
665
666    pub fn proc_id(&self) -> &ProcId {
667        self.client_proc().proc_id()
668    }
669
670    pub fn world_id(&self) -> &WorldId {
671        match &self.inner {
672            ProcMeshKind::V0 { world_id, .. } => world_id,
673            ProcMeshKind::V1(_proc_mesh) => unimplemented!("no world_id for v1::ProcMesh"),
674        }
675    }
676
677    /// An event stream of proc events. Each ProcMesh can produce only one such
678    /// stream, returning None after the first call.
679    pub fn events(&mut self) -> Option<ProcEvents> {
680        match &mut self.inner {
681            ProcMeshKind::V0 {
682                event_state,
683                ranks,
684                actor_event_router,
685                ..
686            } => event_state.take().map(|event_state| ProcEvents {
687                event_state,
688                ranks: ranks
689                    .iter()
690                    .enumerate()
691                    .map(|(rank, (create_key, proc_id, _addr, _mesh_agent))| {
692                        (proc_id.clone(), (rank, create_key.clone()))
693                    })
694                    .collect(),
695                actor_event_router: actor_event_router.clone(),
696            }),
697            #[allow(clippy::todo)]
698            ProcMeshKind::V1(_proc_mesh) => todo!(),
699        }
700    }
701
702    pub fn shape(&self) -> &Shape {
703        // We store the shape here, only because it isn't materialized in
704        // V1 meshes.
705        self.shape.get_or_init(|| match &self.inner {
706            ProcMeshKind::V0 { shape, .. } => shape.clone(),
707            ProcMeshKind::V1(proc_mesh) => proc_mesh.region().into(),
708        })
709    }
710
711    /// Send stop actors message to all mesh agents for a specific mesh name
712    #[hyperactor::observe_result("ProcMesh")]
713    pub async fn stop_actor_by_name(
714        &self,
715        cx: &impl context::Actor,
716        mesh_name: &str,
717    ) -> Result<(), anyhow::Error> {
718        match &self.inner {
719            ProcMeshKind::V0 { client, .. } => {
720                let timeout =
721                    hyperactor::config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT);
722                let results = join_all(self.agents().map(|agent| async move {
723                    let actor_id =
724                        ActorId(agent.actor_id().proc_id().clone(), mesh_name.to_string(), 0);
725                    (
726                        actor_id.clone(),
727                        agent
728                            .clone()
729                            .stop_actor(client, actor_id, timeout.as_millis() as u64)
730                            .await,
731                    )
732                }))
733                .await;
734
735                for (actor_id, result) in results {
736                    match result {
737                        Ok(StopActorResult::Timeout) => {
738                            tracing::warn!("timed out while stopping actor {}", actor_id);
739                        }
740                        Ok(StopActorResult::NotFound) => {
741                            tracing::warn!("no actor {} on proc {}", actor_id, actor_id.proc_id());
742                        }
743                        Ok(StopActorResult::Success) => {
744                            tracing::info!("stopped actor {}", actor_id);
745                        }
746                        Err(e) => {
747                            tracing::warn!("error stopping actor {}: {}", actor_id, e);
748                        }
749                    }
750                }
751                Ok(())
752            }
753            ProcMeshKind::V1(proc_mesh) => {
754                proc_mesh
755                    .stop_actor_by_name(cx, Name::new_reserved(mesh_name))
756                    .await?;
757                Ok(())
758            }
759        }
760    }
761}
762
763/// Proc lifecycle events.
764#[derive(Debug, Clone)]
765pub enum ProcEvent {
766    /// The proc of the given rank was stopped with the provided reason.
767    Stopped(usize, ProcStopReason),
768    /// The proc crashed, with the provided "reason". This is reserved for
769    /// unhandled supervision events.
770    Crashed(usize, String),
771}
772
773#[derive(Debug, Clone, AsRefStr)]
774pub enum SupervisionEventState {
775    SupervisionEventForward,
776    SupervisionEventForwardFailed,
777    SupervisionEventReceived,
778    SupervisionEventTransmitFailed,
779}
780
781impl fmt::Display for ProcEvent {
782    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
783        match self {
784            ProcEvent::Stopped(rank, reason) => {
785                write!(f, "Proc at rank {} stopped: {}", rank, reason)
786            }
787            ProcEvent::Crashed(rank, reason) => {
788                write!(f, "Proc at rank {} crashed: {}", rank, reason)
789            }
790        }
791    }
792}
793
794type ActorMeshName = String;
795
796/// An event stream of [`ProcEvent`]
797// TODO: consider using streams for this.
798pub struct ProcEvents {
799    event_state: EventState,
800    // Proc id to its rank and create key.
801    ranks: HashMap<ProcId, (usize, ShortUuid)>,
802    actor_event_router: ActorEventRouter,
803}
804
805impl ProcEvents {
806    /// Get the next lifecycle event. The stream is closed when this method
807    /// returns `None`.
808    pub async fn next(&mut self) -> Option<ProcEvent> {
809        loop {
810            tokio::select! {
811                result = self.event_state.alloc.next() => {
812                    tracing::debug!(name = "ProcEventReceived", "received ProcEvent alloc update: {result:?}");
813                    // Don't disable the outer branch on None: this is always terminal.
814                    let Some(alloc_event) = result else {
815                        self.actor_event_router.clear();
816                        break None;
817                    };
818
819                    let ProcState::Stopped { create_key, reason } = alloc_event else {
820                        // Ignore non-stopped events for now.
821                        continue;
822                    };
823
824                    let Some((proc_id, (rank, _create_key))) = self.ranks.iter().find(|(_proc_id, (_rank, key))| key == &create_key) else {
825                        tracing::warn!("received stop event for unmapped proc {}", create_key);
826                        continue;
827                    };
828
829                    metrics::PROC_MESH_PROC_STOPPED.add(
830                        1,
831                        hyperactor_telemetry::kv_pairs!(
832                            "create_key" => create_key.to_string(),
833                            "rank" => rank.to_string(),
834                            "reason" => reason.to_string(),
835                        ),
836                    );
837
838                    // Need to send this event to actor meshes to notify them of the proc's death.
839                    // TODO(albertli): only send this event to all root actor meshes if any of them use this proc.
840                    for entry in self.actor_event_router.iter() {
841                        // Make a dummy actor supervision event, all actors on the proc are affected if a proc stops.
842                        // TODO(T231868026): find a better way to represent all actors in a proc for supervision event
843                        let event = ActorSupervisionEvent::new(
844                            proc_id.actor_id("any", 0),
845                            None,
846                            ActorStatus::generic_failure(format!("proc {} is stopped", proc_id)),
847                            None,
848                        );
849                        tracing::debug!(name = "SupervisionEvent", %event);
850                        if entry.value().send(event.clone()).is_err() {
851                            tracing::warn!(
852                                name = SupervisionEventState::SupervisionEventTransmitFailed.as_ref(),
853                                "unable to transmit supervision event to actor {}", entry.key()
854                            );
855                        }
856                    }
857
858                    let event = ProcEvent::Stopped(*rank, reason.clone());
859                    tracing::debug!(name = "SupervisionEvent", %event);
860
861                    break Some(ProcEvent::Stopped(*rank, reason));
862                }
863
864                // Supervision events for this ProcMesh, delivered on
865                // the client's "supervisor" port. Some failures are
866                // observed while messages are routed through the
867                // comm-actor tree; in those cases the event's
868                // `actor_id` points at a comm actor rather than the
869                // logical actor-mesh. When the `CAST_ACTOR_MESH_ID`
870                // header is present, we normalize the event by
871                // rewriting `actor_id` to a synthetic mesh-level id
872                // so that routing reaches the correct `ActorMesh`
873                // subscribers.
874                Ok(event) = self.event_state.supervision_events.recv() => {
875                    let had_headers = event.message_headers.is_some();
876                    tracing::info!(
877                        name = SupervisionEventState::SupervisionEventReceived.as_ref(),
878                        actor_id = %event.actor_id,
879                        actor_name = %event.actor_id.name(),
880                        status = %event.actor_status,
881                        "proc supervision: event received with {had_headers} headers"
882                    );
883                    tracing::debug!(
884                        name = "SupervisionEvent",
885                        %event,
886                        "proc supervision: full event");
887
888                    // Normalize events that came via the comm tree.
889                    let event = update_event_actor_id(event);
890
891                    // Forward the supervision event to the ActorMesh (keyed by its mesh name)
892                    // that registered for events in this ProcMesh. The routing table
893                    // (actor_event_router) is keyed by ActorMeshName, which we obtain from
894                    // actor_id.name(). If no matching mesh is found, log the current table
895                    // to aid diagnosis.
896                    let actor_id = event.actor_id.clone();
897                    let actor_status = event.actor_status.clone();
898                    let reason = event.to_string();
899                    if let Some(tx) = self.actor_event_router.get(actor_id.name()) {
900                        tracing::info!(
901                            name = SupervisionEventState::SupervisionEventForwardFailed.as_ref(),
902                            actor_id = %actor_id,
903                            actor_name = actor_id.name(),
904                            status = %actor_status,
905                            "proc supervision: delivering event to registered ActorMesh"
906                        );
907                        if tx.send(event).is_err() {
908                            tracing::warn!(
909                                name = SupervisionEventState::SupervisionEventForwardFailed.as_ref(),
910                                actor_id = %actor_id,
911                                "proc supervision: registered ActorMesh dropped receiver; unable to deliver"
912                            );
913                        }
914                    } else {
915                        let registered_meshes: Vec<_> = self.actor_event_router.iter().map(|e| e.key().clone()).collect();
916                        tracing::warn!(
917                            name = SupervisionEventState::SupervisionEventForwardFailed.as_ref(),
918                            actor_id = %actor_id,
919                            "proc supervision: no ActorMesh registered for this actor {:?}", registered_meshes,
920                        );
921                    }
922                    // Ensure we have a known rank for the proc
923                    // containing this actor. If we don't, we can't
924                    // attribute the failure to a known process.
925                    let Some((rank, _)) = self.ranks.get(actor_id.proc_id()) else {
926                        tracing::warn!(
927                            actor_id = %actor_id,
928                            "proc supervision: actor belongs to an unmapped proc; dropping event"
929                        );
930                        continue;
931                    };
932
933                    metrics::PROC_MESH_ACTOR_FAILURES.add(
934                        1,
935                        hyperactor_telemetry::kv_pairs!(
936                            "actor_id" => actor_id.to_string(),
937                            "rank" => rank.to_string(),
938                            "status" => actor_status.to_string(),
939                        ),
940                    );
941
942                    // Send this event to Python proc mesh to keep its
943                    // health status up to date.
944                    break Some(ProcEvent::Crashed(*rank, reason))
945                }
946            }
947        }
948    }
949
950    pub fn into_alloc(self) -> Box<dyn Alloc + Send + Sync> {
951        self.event_state.alloc
952    }
953}
954
955/// Spawns from shared ([`Arc`]) proc meshes, providing [`ActorMesh`]es with
956/// static lifetimes.
957#[async_trait]
958pub trait SharedSpawnable {
959    // `Actor`: the type actually runs in the mesh;
960    // `Referable`: so we can hand back ActorRef<A> in RootActorMesh
961    async fn spawn<A: Actor + Referable>(
962        self,
963        cx: &impl context::Actor,
964        actor_name: &str,
965        params: &A::Params,
966    ) -> Result<RootActorMesh<'static, A>, anyhow::Error>
967    where
968        A::Params: RemoteMessage;
969}
970
971#[async_trait]
972impl<D: Deref<Target = ProcMesh> + Send + Sync + 'static> SharedSpawnable for D {
973    // `Actor`: the type actually runs in the mesh;
974    // `Referable`: so we can hand back ActorRef<A> in RootActorMesh
975    async fn spawn<A: Actor + Referable>(
976        self,
977        cx: &impl context::Actor,
978        actor_name: &str,
979        params: &A::Params,
980    ) -> Result<RootActorMesh<'static, A>, anyhow::Error>
981    where
982        A::Params: RemoteMessage,
983    {
984        match &self.deref().inner {
985            ProcMeshKind::V0 {
986                actor_event_router,
987                client,
988                ..
989            } => {
990                let (tx, rx) = mpsc::unbounded_channel::<ActorSupervisionEvent>();
991                {
992                    // Instantiate supervision routing BEFORE spawning the actor mesh.
993                    actor_event_router.insert(actor_name.to_string(), tx);
994                    tracing::info!(
995                        name = "router_insert",
996                        actor_name = %actor_name,
997                        "the length of the router is {}", actor_event_router.len(),
998                    );
999                }
1000                let ranks =
1001                    ProcMesh::spawn_on_procs::<A>(client, self.agents(), actor_name, params)
1002                        .await?;
1003                Ok(RootActorMesh::new_shared(
1004                    self,
1005                    actor_name.to_string(),
1006                    rx,
1007                    ranks,
1008                ))
1009            }
1010            ProcMeshKind::V1(proc_mesh) => Ok(RootActorMesh::from(
1011                proc_mesh.spawn_service(cx, actor_name, params).await?,
1012            )),
1013        }
1014    }
1015}
1016
1017#[async_trait]
1018impl Mesh for ProcMesh {
1019    type Node = ProcId;
1020    type Id = ProcMeshId;
1021    type Sliced<'a> = SlicedProcMesh<'a>;
1022
1023    fn shape(&self) -> &Shape {
1024        ProcMesh::shape(self)
1025    }
1026
1027    fn select<R: Into<Range>>(
1028        &self,
1029        label: &str,
1030        range: R,
1031    ) -> Result<Self::Sliced<'_>, ShapeError> {
1032        Ok(SlicedProcMesh(self, self.shape().select(label, range)?))
1033    }
1034
1035    fn get(&self, rank: usize) -> Option<ProcId> {
1036        match &self.inner {
1037            ProcMeshKind::V0 { ranks, .. } => Some(ranks[rank].1.clone()),
1038            ProcMeshKind::V1(proc_mesh) => proc_mesh.get(rank).map(|proc| proc.proc_id().clone()),
1039        }
1040    }
1041
1042    fn id(&self) -> Self::Id {
1043        match &self.inner {
1044            ProcMeshKind::V0 { world_id, .. } => ProcMeshId(world_id.name().to_string()),
1045            ProcMeshKind::V1(proc_mesh) => ProcMeshId(proc_mesh.name().to_string()),
1046        }
1047    }
1048}
1049
1050impl fmt::Display for ProcMesh {
1051    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1052        write!(f, "{{ shape: {} }}", self.shape())
1053    }
1054}
1055
1056impl fmt::Debug for ProcMesh {
1057    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1058        match &self.inner {
1059            ProcMeshKind::V0 {
1060                shape,
1061                ranks,
1062                client_proc,
1063                ..
1064            } => f
1065                .debug_struct("ProcMesh::V0")
1066                .field("shape", shape)
1067                .field("ranks", ranks)
1068                .field("client_proc", client_proc)
1069                .field("client", &"<Instance>")
1070                // Skip the alloc field since it doesn't implement Debug
1071                .finish(),
1072            ProcMeshKind::V1(proc_mesh) => fmt::Debug::fmt(proc_mesh, f),
1073        }
1074    }
1075}
1076
1077pub struct SlicedProcMesh<'a>(&'a ProcMesh, Shape);
1078
1079#[async_trait]
1080impl Mesh for SlicedProcMesh<'_> {
1081    type Node = ProcId;
1082    type Id = ProcMeshId;
1083    type Sliced<'b>
1084        = SlicedProcMesh<'b>
1085    where
1086        Self: 'b;
1087
1088    fn shape(&self) -> &Shape {
1089        &self.1
1090    }
1091
1092    fn select<R: Into<Range>>(
1093        &self,
1094        label: &str,
1095        range: R,
1096    ) -> Result<Self::Sliced<'_>, ShapeError> {
1097        Ok(Self(self.0, self.1.select(label, range)?))
1098    }
1099
1100    fn get(&self, _index: usize) -> Option<ProcId> {
1101        unimplemented!()
1102    }
1103
1104    fn id(&self) -> Self::Id {
1105        self.0.id()
1106    }
1107}
1108
1109#[cfg(test)]
1110mod tests {
1111    use std::assert_matches::assert_matches;
1112
1113    use hyperactor::actor::ActorStatus;
1114    use ndslice::extent;
1115
1116    use super::*;
1117    use crate::actor_mesh::ActorMesh;
1118    use crate::actor_mesh::test_util::Error;
1119    use crate::actor_mesh::test_util::TestActor;
1120    use crate::alloc::AllocSpec;
1121    use crate::alloc::Allocator;
1122    use crate::alloc::local::LocalAllocator;
1123    use crate::sel_from_shape;
1124
1125    #[tokio::test]
1126    async fn test_basic() {
1127        let alloc = LocalAllocator
1128            .allocate(AllocSpec {
1129                extent: extent!(replica = 4),
1130                constraints: Default::default(),
1131                proc_name: None,
1132                transport: ChannelTransport::Local,
1133                proc_allocation_mode: Default::default(),
1134            })
1135            .await
1136            .unwrap();
1137
1138        let name = alloc.name().to_string();
1139        let mesh = ProcMesh::allocate(alloc).await.unwrap();
1140
1141        assert_eq!(mesh.get(0).unwrap().world_name(), Some(name.as_str()));
1142    }
1143
1144    #[tokio::test]
1145    async fn test_propagate_lifecycle_events() {
1146        let alloc = LocalAllocator
1147            .allocate(AllocSpec {
1148                extent: extent!(replica = 4),
1149                constraints: Default::default(),
1150                proc_name: None,
1151                transport: ChannelTransport::Local,
1152                proc_allocation_mode: Default::default(),
1153            })
1154            .await
1155            .unwrap();
1156
1157        let stop = alloc.stopper();
1158        let monkey = alloc.chaos_monkey();
1159        let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
1160        let mut events = mesh.events().unwrap();
1161
1162        monkey(1, ProcStopReason::Killed(1, false));
1163        assert_matches!(
1164            events.next().await.unwrap(),
1165            ProcEvent::Stopped(1, ProcStopReason::Killed(1, false))
1166        );
1167
1168        stop();
1169        for _ in 0..3 {
1170            assert_matches!(
1171                events.next().await.unwrap(),
1172                ProcEvent::Stopped(_, ProcStopReason::Stopped)
1173            );
1174        }
1175        assert!(events.next().await.is_none());
1176    }
1177
1178    #[tokio::test]
1179    async fn test_supervision_failure() {
1180        // For now, we propagate all actor failures to the proc.
1181
1182        let alloc = LocalAllocator
1183            .allocate(AllocSpec {
1184                extent: extent!(replica = 2),
1185                constraints: Default::default(),
1186                proc_name: None,
1187                transport: ChannelTransport::Local,
1188                proc_allocation_mode: Default::default(),
1189            })
1190            .await
1191            .unwrap();
1192        let stop = alloc.stopper();
1193        let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
1194        let mut events = mesh.events().unwrap();
1195
1196        let instance = crate::v1::testing::instance().await;
1197
1198        let mut actors = mesh
1199            .spawn::<TestActor>(&instance, "failing", &())
1200            .await
1201            .unwrap();
1202        let mut actor_events = actors.events().unwrap();
1203
1204        actors
1205            .cast(
1206                mesh.client(),
1207                sel_from_shape!(actors.shape(), replica = 0),
1208                Error("failmonkey".to_string()),
1209            )
1210            .unwrap();
1211
1212        assert_matches!(
1213            events.next().await.unwrap(),
1214            ProcEvent::Crashed(0, reason) if reason.contains("failmonkey")
1215        );
1216
1217        let mut event = actor_events.next().await.unwrap();
1218        assert_matches!(event.actor_status, ActorStatus::Failed(_));
1219        assert_eq!(event.actor_id.1, "failing".to_string());
1220        assert_eq!(event.actor_id.2, 0);
1221
1222        stop();
1223        assert_matches!(
1224            events.next().await.unwrap(),
1225            ProcEvent::Stopped(0, ProcStopReason::Stopped),
1226        );
1227        assert_matches!(
1228            events.next().await.unwrap(),
1229            ProcEvent::Stopped(1, ProcStopReason::Stopped),
1230        );
1231
1232        assert!(events.next().await.is_none());
1233        event = actor_events.next().await.unwrap();
1234        assert_matches!(event.actor_status, ActorStatus::Failed(_));
1235        assert_eq!(event.actor_id.2, 0);
1236    }
1237
1238    #[timed_test::async_timed_test(timeout_secs = 5)]
1239    async fn test_spawn_twice() {
1240        let alloc = LocalAllocator
1241            .allocate(AllocSpec {
1242                extent: extent!(replica = 1),
1243                constraints: Default::default(),
1244                proc_name: None,
1245                transport: ChannelTransport::Local,
1246                proc_allocation_mode: Default::default(),
1247            })
1248            .await
1249            .unwrap();
1250        let mesh = ProcMesh::allocate(alloc).await.unwrap();
1251
1252        let instance = crate::v1::testing::instance().await;
1253        mesh.spawn::<TestActor>(&instance, "dup", &())
1254            .await
1255            .unwrap();
1256        let result = mesh.spawn::<TestActor>(&instance, "dup", &()).await;
1257        assert!(result.is_err());
1258    }
1259
1260    mod shim {
1261        use std::collections::HashSet;
1262
1263        use hyperactor::context::Mailbox;
1264        use ndslice::Extent;
1265        use ndslice::Selection;
1266
1267        use super::*;
1268        use crate::sel;
1269
1270        #[tokio::test]
1271        #[cfg(fbcode_build)]
1272        async fn test_basic() {
1273            let instance = v1::testing::instance().await;
1274            let ext = extent!(host = 4);
1275            let host_mesh = v1::testing::host_mesh(ext.clone()).await;
1276            let proc_mesh = host_mesh
1277                .spawn(instance, "test", Extent::unity())
1278                .await
1279                .unwrap();
1280            let proc_mesh_v0: ProcMesh = proc_mesh.detach().into();
1281
1282            let actor_mesh = proc_mesh_v0
1283                .spawn::<v1::testactor::TestActor>(instance, "test", &())
1284                .await
1285                .unwrap();
1286
1287            let (cast_info, mut cast_info_rx) = instance.mailbox().open_port();
1288            actor_mesh
1289                .cast(
1290                    instance,
1291                    sel!(*),
1292                    v1::testactor::GetCastInfo {
1293                        cast_info: cast_info.bind(),
1294                    },
1295                )
1296                .unwrap();
1297
1298            let mut point_to_actor: HashSet<_> = actor_mesh
1299                .iter_actor_refs()
1300                .enumerate()
1301                .map(|(rank, actor_ref)| (ext.point_of_rank(rank).unwrap(), actor_ref))
1302                .collect();
1303            while !point_to_actor.is_empty() {
1304                let (point, origin_actor_ref, sender_actor_id) = cast_info_rx.recv().await.unwrap();
1305                let key = (point, origin_actor_ref);
1306                assert!(
1307                    point_to_actor.remove(&key),
1308                    "key {:?} not present or removed twice",
1309                    key
1310                );
1311                assert_eq!(&sender_actor_id, instance.self_id());
1312            }
1313        }
1314    }
1315}