hyperactor_mesh/v1/
proc_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::any::type_name;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::fmt;
13use std::ops::Deref;
14use std::panic::Location;
15use std::sync::Arc;
16use std::sync::OnceLock;
17use std::sync::atomic::AtomicUsize;
18use std::sync::atomic::Ordering;
19use std::time::Duration;
20
21use hyperactor::Actor;
22use hyperactor::ActorId;
23use hyperactor::ActorRef;
24use hyperactor::Handler;
25use hyperactor::ProcId;
26use hyperactor::RemoteMessage;
27use hyperactor::RemoteSpawn;
28use hyperactor::accum::ReducerOpts;
29use hyperactor::actor::ActorStatus;
30use hyperactor::actor::Referable;
31use hyperactor::actor::remote::Remote;
32use hyperactor::channel;
33use hyperactor::channel::ChannelAddr;
34use hyperactor::clock::Clock;
35use hyperactor::clock::RealClock;
36use hyperactor::context;
37use hyperactor::mailbox::DialMailboxRouter;
38use hyperactor::mailbox::MailboxServer;
39use hyperactor::supervision::ActorSupervisionEvent;
40use hyperactor_config::CONFIG;
41use hyperactor_config::ConfigAttr;
42use hyperactor_config::attrs::declare_attrs;
43use ndslice::Extent;
44use ndslice::ViewExt as _;
45use ndslice::view;
46use ndslice::view::CollectMeshExt;
47use ndslice::view::MapIntoExt;
48use ndslice::view::Ranked;
49use ndslice::view::Region;
50use serde::Deserialize;
51use serde::Serialize;
52use tokio::sync::Notify;
53use tracing::Instrument;
54use typeuri::Named;
55
56use crate::CommActor;
57use crate::alloc::Alloc;
58use crate::alloc::AllocExt;
59use crate::alloc::AllocatedProc;
60use crate::assign::Ranks;
61use crate::comm::CommActorMode;
62use crate::proc_mesh::mesh_agent;
63use crate::proc_mesh::mesh_agent::ActorState;
64use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
65use crate::proc_mesh::mesh_agent::ProcMeshAgent;
66use crate::proc_mesh::mesh_agent::ReconfigurableMailboxSender;
67use crate::resource;
68use crate::resource::GetRankStatus;
69use crate::resource::Status;
70use crate::supervision::MeshFailure;
71use crate::v1;
72use crate::v1::ActorMesh;
73use crate::v1::ActorMeshRef;
74use crate::v1::Error;
75use crate::v1::HostMeshRef;
76use crate::v1::Name;
77use crate::v1::ValueMesh;
78use crate::v1::host_mesh::mesh_agent::ProcState;
79use crate::v1::host_mesh::mesh_to_rankedvalues_with_default;
80use crate::v1::mesh_controller::ActorMeshController;
81
82declare_attrs! {
83    /// The maximum idle time between updates while spawning actor
84    /// meshes.
85    @meta(CONFIG = ConfigAttr {
86        env_name: Some("HYPERACTOR_MESH_ACTOR_SPAWN_MAX_IDLE".to_string()),
87        py_name: Some("actor_spawn_max_idle".to_string()),
88    })
89    pub attr ACTOR_SPAWN_MAX_IDLE: Duration = Duration::from_secs(30);
90
91    @meta(CONFIG = ConfigAttr {
92        env_name: Some("HYPERACTOR_MESH_GET_ACTOR_STATE_MAX_IDLE".to_string()),
93        py_name: Some("get_actor_state_max_idle".to_string()),
94    })
95    pub attr GET_ACTOR_STATE_MAX_IDLE: Duration = Duration::from_mins(1);
96}
97
98/// A reference to a single [`hyperactor::Proc`].
99#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
100pub struct ProcRef {
101    proc_id: ProcId,
102    /// The rank of this proc at creation.
103    create_rank: usize,
104    /// The agent managing this proc.
105    agent: ActorRef<ProcMeshAgent>,
106}
107
108impl ProcRef {
109    /// Create a new proc ref from the provided id, create rank and agent.
110    pub fn new(proc_id: ProcId, create_rank: usize, agent: ActorRef<ProcMeshAgent>) -> Self {
111        Self {
112            proc_id,
113            create_rank,
114            agent,
115        }
116    }
117
118    /// Pings the proc, returning whether it is alive. This will be replaced by a
119    /// finer-grained lifecycle status in the near future.
120    pub(crate) async fn status(&self, cx: &impl context::Actor) -> v1::Result<bool> {
121        let (port, mut rx) = cx.mailbox().open_port();
122        self.agent
123            .status(cx, port.bind())
124            .await
125            .map_err(|e| Error::CallError(self.agent.actor_id().clone(), e))?;
126        loop {
127            let (rank, status) = rx
128                .recv()
129                .await
130                .map_err(|e| Error::CallError(self.agent.actor_id().clone(), e.into()))?;
131            if rank == self.create_rank {
132                break Ok(status);
133            }
134        }
135    }
136
137    /// Get the supervision events for one actor with the given name.
138    #[allow(dead_code)]
139    async fn actor_state(
140        &self,
141        cx: &impl context::Actor,
142        name: Name,
143    ) -> v1::Result<resource::State<ActorState>> {
144        let (port, mut rx) = cx.mailbox().open_port::<resource::State<ActorState>>();
145        self.agent
146            .send(
147                cx,
148                resource::GetState::<ActorState> {
149                    name: name.clone(),
150                    reply: port.bind(),
151                },
152            )
153            .map_err(|e| Error::CallError(self.agent.actor_id().clone(), e.into()))?;
154        let state = rx
155            .recv()
156            .await
157            .map_err(|e| Error::CallError(self.agent.actor_id().clone(), e.into()))?;
158        if let Some(ref inner) = state.state {
159            let rank = inner.create_rank;
160            if rank == self.create_rank {
161                Ok(state)
162            } else {
163                Err(Error::CallError(
164                    self.agent.actor_id().clone(),
165                    anyhow::anyhow!(
166                        "Rank on mesh agent not matching for Actor {}: returned {}, expected {}",
167                        name,
168                        rank,
169                        self.create_rank
170                    ),
171                ))
172            }
173        } else {
174            Err(Error::CallError(
175                self.agent.actor_id().clone(),
176                anyhow::anyhow!("Actor {} does not exist", name),
177            ))
178        }
179    }
180
181    pub fn proc_id(&self) -> &ProcId {
182        &self.proc_id
183    }
184
185    pub(crate) fn actor_id(&self, name: &Name) -> ActorId {
186        self.proc_id.actor_id(name.to_string(), 0)
187    }
188
189    /// Generic bound: `A: Referable` - required because we return
190    /// an `ActorRef<A>`.
191    pub(crate) fn attest<A: Referable>(&self, name: &Name) -> ActorRef<A> {
192        ActorRef::attest(self.actor_id(name))
193    }
194}
195
196/// A mesh of processes.
197#[derive(Debug)]
198pub struct ProcMesh {
199    #[allow(dead_code)]
200    name: Name,
201    allocation: ProcMeshAllocation,
202    #[allow(dead_code)]
203    comm_actor_name: Option<Name>,
204    current_ref: ProcMeshRef,
205}
206
207impl ProcMesh {
208    async fn create<C: context::Actor>(
209        cx: &C,
210        name: Name,
211        allocation: ProcMeshAllocation,
212        spawn_comm_actor: bool,
213    ) -> v1::Result<Self>
214    where
215        C::A: Handler<MeshFailure>,
216    {
217        let comm_actor_name = if spawn_comm_actor {
218            Some(Name::new("comm").unwrap())
219        } else {
220            None
221        };
222
223        let region = allocation.extent().clone().into();
224        let ranks = allocation.ranks();
225        let root_comm_actor = comm_actor_name.as_ref().map(|name| {
226            ActorRef::attest(
227                ranks
228                    .first()
229                    .expect("root mesh cannot be empty")
230                    .actor_id(name),
231            )
232        });
233        let host_mesh = allocation.hosts();
234        let current_ref = ProcMeshRef::new(
235            name.clone(),
236            region,
237            ranks,
238            host_mesh.cloned(),
239            None, // this is the root mesh
240            None, // comm actor is not alive yet
241        )
242        .unwrap();
243
244        let mut proc_mesh = Self {
245            name,
246            allocation,
247            comm_actor_name: comm_actor_name.clone(),
248            current_ref,
249        };
250
251        if let Some(comm_actor_name) = comm_actor_name {
252            // CommActor satisfies `Actor + Referable`, so it can be
253            // spawned and safely referenced via ActorRef<CommActor>.
254            // It is a system actor that should not have a controller managing it.
255            let comm_actor_mesh: ActorMesh<CommActor> = proc_mesh
256                .spawn_with_name(cx, comm_actor_name, &Default::default(), None, true)
257                .await?;
258            let address_book: HashMap<_, _> = comm_actor_mesh
259                .iter()
260                .map(|(point, actor_ref)| (point.rank(), actor_ref))
261                .collect();
262            // Now that we have all of the spawned comm actors, kick them all into
263            // mesh mode.
264            for (rank, comm_actor) in &address_book {
265                comm_actor
266                    .send(cx, CommActorMode::Mesh(*rank, address_book.clone()))
267                    .map_err(|e| Error::SendingError(comm_actor.actor_id().clone(), Box::new(e)))?
268            }
269
270            // The comm actor is now set up and ready to go.
271            proc_mesh.current_ref.root_comm_actor = root_comm_actor;
272        }
273
274        Ok(proc_mesh)
275    }
276
277    pub(crate) async fn create_owned_unchecked<C: context::Actor>(
278        cx: &C,
279        name: Name,
280        extent: Extent,
281        hosts: HostMeshRef,
282        ranks: Vec<ProcRef>,
283    ) -> v1::Result<Self>
284    where
285        C::A: Handler<MeshFailure>,
286    {
287        Self::create(
288            cx,
289            name,
290            ProcMeshAllocation::Owned {
291                hosts,
292                extent,
293                ranks: Arc::new(ranks),
294            },
295            true,
296        )
297        .await
298    }
299
300    fn alloc_counter() -> &'static AtomicUsize {
301        static C: OnceLock<AtomicUsize> = OnceLock::new();
302        C.get_or_init(|| AtomicUsize::new(0))
303    }
304
305    /// Allocate a new ProcMesh from the provided alloc.
306    /// Allocate does not require an owning actor because references are not owned.
307    #[track_caller]
308    pub async fn allocate<C: context::Actor>(
309        cx: &C,
310        alloc: Box<dyn Alloc + Send + Sync + 'static>,
311        name: &str,
312    ) -> v1::Result<Self>
313    where
314        C::A: Handler<MeshFailure>,
315    {
316        let caller = Location::caller();
317        Self::allocate_inner(cx, alloc, Name::new(name)?, caller).await
318    }
319
320    // Use allocate_inner to set field mesh_name in span
321    #[hyperactor::instrument(fields(proc_mesh=name.to_string()))]
322    async fn allocate_inner<C: context::Actor>(
323        cx: &C,
324        mut alloc: Box<dyn Alloc + Send + Sync + 'static>,
325        name: Name,
326        caller: &'static Location<'static>,
327    ) -> v1::Result<Self>
328    where
329        C::A: Handler<MeshFailure>,
330    {
331        let alloc_id = Self::alloc_counter().fetch_add(1, Ordering::Relaxed) + 1;
332        tracing::info!(
333            name = "ProcMeshStatus",
334            status = "Allocate::Attempt",
335            %caller,
336            alloc_id,
337            shape = ?alloc.shape(),
338            "allocating proc mesh"
339        );
340
341        let running = alloc
342            .initialize()
343            .instrument(tracing::info_span!(
344                "ProcMeshStatus::Allocate::Initialize",
345                alloc_id,
346                proc_mesh = %name
347            ))
348            .await?;
349
350        // Wire the newly created mesh into the proc, so that it is routable.
351        // We route all of the relevant prefixes into the proc's forwarder,
352        // and serve it on the alloc's transport.
353        //
354        // This will be removed with direct addressing.
355        let proc = cx.instance().proc();
356
357        // First make sure we can serve the proc:
358        let proc_channel_addr = {
359            let _guard =
360                tracing::info_span!("allocate_serve_proc", proc_id = %proc.proc_id()).entered();
361            let (addr, rx) = channel::serve(ChannelAddr::any(alloc.transport()))?;
362            proc.clone().serve(rx);
363            tracing::info!(
364                name = "ProcMeshStatus",
365                status = "Allocate::ChannelServe",
366                proc_mesh = %name,
367                %addr,
368                "proc started listening on addr: {addr}"
369            );
370            addr
371        };
372
373        let bind_allocated_procs = |router: &DialMailboxRouter| {
374            // Route all of the allocated procs:
375            for AllocatedProc { proc_id, addr, .. } in running.iter() {
376                if proc_id.is_direct() {
377                    continue;
378                }
379                router.bind(proc_id.clone().into(), addr.clone());
380            }
381        };
382
383        // Temporary for backward compatibility with ranked procs and v0 API.
384        // Proc meshes can be allocated either using the root client proc (which
385        // has a DialMailboxRouter forwarder) or a mesh agent proc (which has a
386        // ReconfigurableMailboxSender forwarder with an inner DialMailboxRouter).
387        if let Some(router) = proc.forwarder().downcast_ref() {
388            bind_allocated_procs(router);
389        } else if let Some(router) = proc
390            .forwarder()
391            .downcast_ref::<ReconfigurableMailboxSender>()
392        {
393            bind_allocated_procs(
394                router
395                    .as_inner()
396                    .map_err(|_| Error::UnroutableMesh())?
397                    .as_configured()
398                    .ok_or(Error::UnroutableMesh())?
399                    .downcast_ref()
400                    .ok_or(Error::UnroutableMesh())?,
401            );
402        } else {
403            return Err(Error::UnroutableMesh());
404        }
405
406        // Set up the mesh agents. Since references are not owned, we don't supervise it.
407        // Instead, we just let procs die when they have unhandled supervision events.
408        let address_book: HashMap<_, _> = running
409            .iter()
410            .map(
411                |AllocatedProc {
412                     addr, mesh_agent, ..
413                 }| { (mesh_agent.actor_id().proc_id().clone(), addr.clone()) },
414            )
415            .collect();
416
417        let (config_handle, mut config_receiver) = cx.mailbox().open_port();
418        for (rank, AllocatedProc { mesh_agent, .. }) in running.iter().enumerate() {
419            mesh_agent
420                .configure(
421                    cx,
422                    rank,
423                    proc_channel_addr.clone(),
424                    None, // no supervisor; we just crash
425                    address_book.clone(),
426                    config_handle.bind(),
427                    true,
428                )
429                .await
430                .map_err(Error::ConfigurationError)?;
431        }
432        let mut completed = Ranks::new(running.len());
433        while !completed.is_full() {
434            let rank = config_receiver
435                .recv()
436                .await
437                .map_err(|err| Error::ConfigurationError(err.into()))?;
438            if completed.insert(rank, rank).is_some() {
439                tracing::warn!("multiple completions received for rank {}", rank);
440            }
441        }
442
443        let ranks: Vec<_> = running
444            .into_iter()
445            .enumerate()
446            .map(|(create_rank, allocated)| ProcRef {
447                proc_id: allocated.proc_id,
448                create_rank,
449                agent: allocated.mesh_agent,
450            })
451            .collect();
452
453        let stop = Arc::new(Notify::new());
454        let extent = alloc.extent().clone();
455        let alloc_name = alloc.world_id().to_string();
456
457        let alloc_task = {
458            let stop = Arc::clone(&stop);
459
460            tokio::spawn(
461                async move {
462                    loop {
463                        tokio::select! {
464                            _ = stop.notified() => {
465                                // If we are explicitly stopped, the alloc is torn down.
466                                if let Err(error) = alloc.stop_and_wait().await {
467                                    tracing::error!(
468                                        name = "ProcMeshStatus",
469                                        alloc_name = %alloc.world_id(),
470                                        status = "FailedToStopAlloc",
471                                        %error,
472                                    );
473                                }
474                                break;
475                            }
476                            // We are mostly just using this to drive allocation events.
477                            proc_state = alloc.next() => {
478                                match proc_state {
479                                    // The alloc was stopped.
480                                    None => break,
481                                    Some(proc_state) => {
482                                        tracing::debug!(
483                                            alloc_name = %alloc.world_id(),
484                                            "unmonitored allocation event: {}", proc_state);
485                                    }
486                                }
487
488                            }
489                        }
490                    }
491                }
492                .instrument(tracing::info_span!("alloc_monitor")),
493            )
494        };
495
496        let mesh = Self::create(
497            cx,
498            name,
499            ProcMeshAllocation::Allocated {
500                alloc_name,
501                stop,
502                extent,
503                ranks: Arc::new(ranks),
504                alloc_task: Some(alloc_task),
505            },
506            true, // alloc-based meshes support comm actors
507        )
508        .await;
509        match &mesh {
510            Ok(_) => tracing::info!(name = "ProcMeshStatus", status = "Allocate::Created"),
511            Err(error) => {
512                tracing::info!(name = "ProcMeshStatus", status = "Allocate::Failed", %error)
513            }
514        }
515        mesh
516    }
517
518    /// Detach the proc mesh from the lifetime of `self`, and return its reference.
519    #[cfg(test)]
520    pub(crate) fn detach(self) -> ProcMeshRef {
521        // This also keeps the ProcMeshAllocation::Allocated alloc task alive.
522        self.current_ref.clone()
523    }
524
525    /// Stop this mesh gracefully.
526    pub async fn stop(&mut self, cx: &impl context::Actor) -> anyhow::Result<()> {
527        let region = self.region.clone();
528        match &mut self.allocation {
529            ProcMeshAllocation::Allocated {
530                stop,
531                alloc_task,
532                alloc_name,
533                ..
534            } => {
535                stop.notify_one();
536                // Wait for the alloc monitor task to complete, ensuring the
537                // alloc has fully stopped before we drop it.
538                if let Some(handle) = alloc_task.take() {
539                    if let Err(e) = handle.await {
540                        tracing::warn!(
541                            name = "ProcMeshStatus",
542                            proc_mesh = %self.name,
543                            alloc_name,
544                            %e,
545                            "alloc monitor task failed"
546                        );
547                    }
548                }
549                tracing::info!(
550                    name = "ProcMeshStatus",
551                    proc_mesh = %self.name,
552                    alloc_name,
553                    status = "StoppingAlloc",
554                    "alloc {alloc_name} has stopped",
555                );
556                Ok(())
557            }
558            ProcMeshAllocation::Owned { hosts, .. } => {
559                let procs = self.current_ref.proc_ids().collect::<Vec<ProcId>>();
560                // We use the proc mesh region rather than the host mesh region
561                // because the host agent stores one entry per proc, not per host.
562                hosts.stop_proc_mesh(cx, &self.name, procs, region).await
563            }
564        }
565    }
566}
567
568impl fmt::Display for ProcMesh {
569    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
570        write!(f, "{}", self.current_ref)
571    }
572}
573
574impl Deref for ProcMesh {
575    type Target = ProcMeshRef;
576
577    fn deref(&self) -> &Self::Target {
578        &self.current_ref
579    }
580}
581
582impl Drop for ProcMesh {
583    fn drop(&mut self) {
584        tracing::info!(
585            name = "ProcMeshStatus",
586            proc_mesh = %self.name,
587            status = "Dropped",
588        );
589    }
590}
591
592/// Represents different ways ProcMeshes can be allocated.
593enum ProcMeshAllocation {
594    /// A mesh that has been allocated from an `Alloc`.
595    Allocated {
596        // The name of the alloc from which this mesh was allocated.
597        alloc_name: String,
598
599        // A cancellation token used to stop the task keeping the alloc alive.
600        stop: Arc<Notify>,
601
602        extent: Extent,
603
604        // The allocated ranks.
605        ranks: Arc<Vec<ProcRef>>,
606
607        // The task handle for the alloc monitor. Used to wait for clean shutdown.
608        alloc_task: Option<tokio::task::JoinHandle<()>>,
609    },
610
611    /// An owned allocation: this ProcMesh fully owns the set of ranks.
612    Owned {
613        /// The host mesh from which the proc mesh was spawned.
614        hosts: HostMeshRef,
615        // This is purely for storage: `hosts.extent()` returns a computed (by value)
616        // extent.
617        extent: Extent,
618        /// A proc reference for each rank in the mesh.
619        ranks: Arc<Vec<ProcRef>>,
620    },
621}
622
623impl ProcMeshAllocation {
624    fn extent(&self) -> &Extent {
625        match self {
626            ProcMeshAllocation::Allocated { extent, .. } => extent,
627            ProcMeshAllocation::Owned { extent, .. } => extent,
628        }
629    }
630
631    fn ranks(&self) -> Arc<Vec<ProcRef>> {
632        Arc::clone(match self {
633            ProcMeshAllocation::Allocated { ranks, .. } => ranks,
634            ProcMeshAllocation::Owned { ranks, .. } => ranks,
635        })
636    }
637
638    fn hosts(&self) -> Option<&HostMeshRef> {
639        match self {
640            ProcMeshAllocation::Allocated { .. } => None,
641            ProcMeshAllocation::Owned { hosts, .. } => Some(hosts),
642        }
643    }
644}
645
646impl fmt::Debug for ProcMeshAllocation {
647    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
648        match self {
649            ProcMeshAllocation::Allocated { ranks, .. } => f
650                .debug_struct("ProcMeshAllocation::Allocated")
651                .field("alloc", &"<dyn Alloc>")
652                .field("ranks", ranks)
653                .finish(),
654            ProcMeshAllocation::Owned {
655                hosts,
656                ranks,
657                extent: _,
658            } => f
659                .debug_struct("ProcMeshAllocation::Owned")
660                .field("hosts", hosts)
661                .field("ranks", ranks)
662                .finish(),
663        }
664    }
665}
666
667/// A reference to a ProcMesh, consisting of a set of ranked [`ProcRef`]s,
668/// arranged into a region. ProcMeshes are named, uniquely identifying the
669/// ProcMesh from which the reference was derived.
670///
671/// ProcMeshes can be sliced to create new ProcMeshes with a subset of the
672/// original ranks.
673#[derive(Debug, Clone, PartialEq, Eq, Hash, Named, Serialize, Deserialize)]
674pub struct ProcMeshRef {
675    name: Name,
676    region: Region,
677    ranks: Arc<Vec<ProcRef>>,
678    // Some if this was spawned from a host mesh, else none.
679    host_mesh: Option<HostMeshRef>,
680    // Temporary: used to fit v1 ActorMesh with v0's casting implementation. This
681    // should be removed after we remove the v0 code.
682    // The root region of this mesh. None means this mesh itself is the root.
683    pub(crate) root_region: Option<Region>,
684    // Temporary: used to fit v1 ActorMesh with v0's casting implementation. This
685    // should be removed after we remove the v0 code.
686    // v0 casting requires root mesh rank 0 as the 1st hop, so we need to provide
687    // it here. For v1, this can be removed since v1 can use any rank.
688    pub(crate) root_comm_actor: Option<ActorRef<CommActor>>,
689}
690wirevalue::register_type!(ProcMeshRef);
691
692impl ProcMeshRef {
693    /// Create a new ProcMeshRef from the given name, region, ranks, and so on.
694    #[allow(clippy::result_large_err)]
695    fn new(
696        name: Name,
697        region: Region,
698        ranks: Arc<Vec<ProcRef>>,
699        host_mesh: Option<HostMeshRef>,
700        root_region: Option<Region>,
701        root_comm_actor: Option<ActorRef<CommActor>>,
702    ) -> v1::Result<Self> {
703        if region.num_ranks() != ranks.len() {
704            return Err(v1::Error::InvalidRankCardinality {
705                expected: region.num_ranks(),
706                actual: ranks.len(),
707            });
708        }
709        Ok(Self {
710            name,
711            region,
712            ranks,
713            host_mesh,
714            root_region,
715            root_comm_actor,
716        })
717    }
718
719    /// Create a singleton ProcMeshRef, given the provided ProcRef and name.
720    /// This is used to support creating local singleton proc meshes to support `this_proc()`
721    /// in python client actors.
722    pub fn new_singleton(name: Name, proc_ref: ProcRef) -> Self {
723        Self {
724            name,
725            region: Extent::unity().into(),
726            ranks: Arc::new(vec![proc_ref]),
727            host_mesh: None,
728            root_region: None,
729            root_comm_actor: None,
730        }
731    }
732
733    pub(crate) fn root_comm_actor(&self) -> Option<&ActorRef<CommActor>> {
734        self.root_comm_actor.as_ref()
735    }
736
737    pub fn name(&self) -> &Name {
738        &self.name
739    }
740
741    pub fn host_mesh_name(&self) -> Option<&Name> {
742        self.host_mesh.as_ref().map(|h| h.name())
743    }
744
745    /// Returns the HostMeshRef that this ProcMeshRef might be backed by.
746    /// Returns None if this ProcMeshRef is backed by an Alloc instead of a host mesh.
747    pub fn hosts(&self) -> Option<&HostMeshRef> {
748        self.host_mesh.as_ref()
749    }
750
751    /// The current statuses of procs in this mesh.
752    pub async fn status(&self, cx: &impl context::Actor) -> v1::Result<ValueMesh<bool>> {
753        let vm: ValueMesh<_> = self.map_into(|proc_ref| {
754            let proc_ref = proc_ref.clone();
755            async move { proc_ref.status(cx).await }
756        });
757        vm.join().await.transpose()
758    }
759
760    pub(crate) fn agent_mesh(&self) -> ActorMeshRef<ProcMeshAgent> {
761        let agent_name = self.ranks.first().unwrap().agent.actor_id().name();
762        // This name must match the ProcMeshAgent name, which can change depending on the allocator.
763        // Since we control the agent_name, it is guaranteed to be a valid mesh identifier.
764        // No controller for the ProcMeshAgent mesh.
765        ActorMeshRef::new(Name::new_reserved(agent_name).unwrap(), self.clone(), None)
766    }
767
768    /// The supervision events of procs in this mesh.
769    pub async fn actor_states(
770        &self,
771        cx: &impl context::Actor,
772        name: Name,
773    ) -> v1::Result<ValueMesh<resource::State<ActorState>>> {
774        let agent_mesh = self.agent_mesh();
775        let (port, mut rx) = cx.mailbox().open_port::<resource::State<ActorState>>();
776        // TODO: Use accumulation to get back a single value (representing whether
777        // *any* of the actors failed) instead of a mesh.
778        agent_mesh.cast(
779            cx,
780            resource::GetState::<ActorState> {
781                name: name.clone(),
782                reply: port.bind(),
783            },
784        )?;
785        let expected = self.ranks.len();
786        let mut states = Vec::with_capacity(expected);
787        let timeout = hyperactor_config::global::get(GET_ACTOR_STATE_MAX_IDLE);
788        for _ in 0..expected {
789            // The agent runs on the same process as the running actor, so if some
790            // fatal event caused the process to crash (e.g. OOM, signal, process exit),
791            // the agent will be unresponsive.
792            // We handle this by setting a timeout on the recv, and if we don't get a
793            // message we assume the agent is dead and return a failed state.
794            let state = RealClock.timeout(timeout, rx.recv()).await;
795            if let Ok(state) = state {
796                // Handle non-timeout receiver error.
797                let state = state?;
798                match state.state {
799                    Some(ref inner) => {
800                        states.push((inner.create_rank, state));
801                    }
802                    None => {
803                        return Err(Error::NotExist(state.name));
804                    }
805                }
806            } else {
807                tracing::error!(
808                    "timeout waiting for a message after {:?} from proc mesh agent in mesh {}",
809                    timeout,
810                    agent_mesh
811                );
812                // Timeout error, stop reading from the receiver and send back what we have so far,
813                // padding with failed states.
814                let all_ranks = (0..self.ranks.len()).collect::<HashSet<_>>();
815                let completed_ranks = states.iter().map(|(rank, _)| *rank).collect::<HashSet<_>>();
816                let mut leftover_ranks = all_ranks.difference(&completed_ranks).collect::<Vec<_>>();
817                assert_eq!(leftover_ranks.len(), expected - states.len());
818                while states.len() < expected {
819                    let rank = *leftover_ranks
820                        .pop()
821                        .expect("leftover ranks should not be empty");
822                    let agent = agent_mesh.get(rank).expect("agent should exist");
823                    let agent_id = agent.actor_id().clone();
824                    states.push((
825                        // We populate with any ranks leftover at the time of the timeout.
826                        rank,
827                        resource::State {
828                            name: name.clone(),
829                            status: resource::Status::Timeout(timeout),
830                            // We don't know the ActorId that used to live on this rank.
831                            // But we do know the mesh agent id, so we'll use that.
832                            state: Some(ActorState {
833                                actor_id: agent_id.clone(),
834                                create_rank: rank,
835                                supervision_events: vec![ActorSupervisionEvent::new(
836                                    agent_id,
837                                    None,
838                                    ActorStatus::generic_failure(format!(
839                                        "timeout waiting for message from proc mesh agent while querying for \"{}\". The process likely crashed",
840                                        name,
841                                    )),
842                                    None,
843                                )],
844                            }),
845                        },
846                    ));
847                }
848                break;
849            }
850        }
851        // Ensure that all ranks have replied. Note that if the mesh is sliced,
852        // not all create_ranks may be in the mesh.
853        // Sort by rank, so that the resulting mesh is ordered.
854        states.sort_by_key(|(rank, _)| *rank);
855        let vm = states
856            .into_iter()
857            .map(|(_, state)| state)
858            .collect_mesh::<ValueMesh<_>>(self.region.clone())?;
859        Ok(vm)
860    }
861
862    pub async fn proc_states(
863        &self,
864        cx: &impl context::Actor,
865    ) -> v1::Result<Option<ValueMesh<resource::State<ProcState>>>> {
866        let names = self.proc_ids().collect::<Vec<ProcId>>();
867        if let Some(host_mesh) = &self.host_mesh {
868            Ok(Some(
869                host_mesh
870                    .proc_states(cx, names, self.region.clone())
871                    .await?,
872            ))
873        } else {
874            Ok(None)
875        }
876    }
877
878    /// Returns an iterator over the proc ids in this mesh.
879    pub(crate) fn proc_ids(&self) -> impl Iterator<Item = ProcId> {
880        self.ranks.iter().map(|proc_ref| proc_ref.proc_id.clone())
881    }
882
883    /// Spawn an actor on all of the procs in this mesh, returning a
884    /// new ActorMesh.
885    ///
886    /// Bounds:
887    /// - `A: Actor` - the actor actually runs inside each proc.
888    /// - `A: Referable` - so we can return typed `ActorRef<A>`s
889    ///   inside the `ActorMesh`.
890    /// - `A::Params: RemoteMessage` - spawn parameters must be
891    ///   serializable and routable.
892    pub async fn spawn<A: RemoteSpawn, C: context::Actor>(
893        &self,
894        cx: &C,
895        name: &str,
896        params: &A::Params,
897    ) -> v1::Result<ActorMesh<A>>
898    where
899        A::Params: RemoteMessage,
900        C::A: Handler<MeshFailure>,
901    {
902        // Spawning from a string is never a system actor.
903        self.spawn_with_name(cx, Name::new(name)?, params, None, false)
904            .await
905    }
906
907    /// Spawn a 'service' actor. Service actors are *singletons*, using
908    /// reserved names. The provided name is used verbatim as the actor's
909    /// name, and thus it may be persistently looked up by constructing
910    /// the appropriate name.
911    ///
912    /// Note: avoid using service actors if possible; the mechanism will
913    /// be replaced by an actor registry.
914    pub async fn spawn_service<A: RemoteSpawn, C: context::Actor>(
915        &self,
916        cx: &C,
917        name: &str,
918        params: &A::Params,
919    ) -> v1::Result<ActorMesh<A>>
920    where
921        A::Params: RemoteMessage,
922        C::A: Handler<MeshFailure>,
923    {
924        self.spawn_with_name(cx, Name::new_reserved(name)?, params, None, false)
925            .await
926    }
927
928    /// Spawn an actor on all procs in this mesh under the given
929    /// [`Name`], returning a new `ActorMesh`.
930    ///
931    /// This is the underlying implementation used by [`spawn`]; it
932    /// differs only in that the actor name is passed explicitly
933    /// rather than as a `&str`.
934    ///
935    /// Bounds:
936    /// - `A: Actor` - the actor actually runs inside each proc.
937    /// - `A: Referable` - so we can return typed `ActorRef<A>`s
938    ///   inside the `ActorMesh`.
939    /// - `A::Params: RemoteMessage` - spawn parameters must be
940    ///   serializable and routable.
941    /// - `C::A: Handler<MeshFailure>` - in order to spawn actors,
942    ///   the actor must accept messages of type `MeshFailure`. This
943    ///   is delivered when the actors spawned in the mesh have a failure that
944    ///   isn't handled.
945    #[hyperactor::instrument(fields(
946        host_mesh=self.host_mesh_name().map(|n| n.to_string()),
947        proc_mesh=self.name.to_string(),
948        actor_name=name.to_string(),
949    ))]
950    pub async fn spawn_with_name<A: RemoteSpawn, C: context::Actor>(
951        &self,
952        cx: &C,
953        name: Name,
954        params: &A::Params,
955        supervision_display_name: Option<String>,
956        is_system_actor: bool,
957    ) -> v1::Result<ActorMesh<A>>
958    where
959        A::Params: RemoteMessage,
960        C::A: Handler<MeshFailure>,
961    {
962        tracing::info!(
963            name = "ProcMeshStatus",
964            status = "ActorMesh::Spawn::Attempt",
965        );
966        tracing::info!(name = "ActorMeshStatus", status = "Spawn::Attempt");
967        let result = self
968            .spawn_with_name_inner(cx, name, params, supervision_display_name, is_system_actor)
969            .await;
970        match &result {
971            Ok(_) => {
972                tracing::info!(
973                    name = "ProcMeshStatus",
974                    status = "ActorMesh::Spawn::Success",
975                );
976                tracing::info!(name = "ActorMeshStatus", status = "Spawn::Success");
977            }
978            Err(error) => {
979                tracing::error!(name = "ProcMeshStatus", status = "ActorMesh::Spawn::Failed", %error);
980                tracing::error!(name = "ActorMeshStatus", status = "Spawn::Failed", %error);
981            }
982        }
983        result
984    }
985
986    async fn spawn_with_name_inner<A: RemoteSpawn, C: context::Actor>(
987        &self,
988        cx: &C,
989        name: Name,
990        params: &A::Params,
991        supervision_display_name: Option<String>,
992        is_system_actor: bool,
993    ) -> v1::Result<ActorMesh<A>>
994    where
995        C::A: Handler<MeshFailure>,
996    {
997        let remote = Remote::collect();
998        // `RemoteSpawn` + `remote!(A)` ensure that `A` has a
999        // `SpawnableActor` entry in this registry, so
1000        // `name_of::<A>()` can resolve its global type name.
1001        let actor_type = remote
1002            .name_of::<A>()
1003            .ok_or(Error::ActorTypeNotRegistered(type_name::<A>().to_string()))?
1004            .to_string();
1005
1006        let serialized_params = bincode::serialize(params)?;
1007        let agent_mesh = self.agent_mesh();
1008
1009        agent_mesh.cast(
1010            cx,
1011            resource::CreateOrUpdate::<mesh_agent::ActorSpec> {
1012                name: name.clone(),
1013                rank: Default::default(),
1014                spec: mesh_agent::ActorSpec {
1015                    actor_type: actor_type.clone(),
1016                    params_data: serialized_params.clone(),
1017                },
1018            },
1019        )?;
1020
1021        let region = self.region().clone();
1022        // Open an accum port that *receives overlays* and *emits full
1023        // meshes*.
1024        //
1025        // NOTE: Mailbox initializes the accumulator state via
1026        // `Default`, which is an *empty* ValueMesh (0 ranks). Our
1027        // Accumulator<ValueMesh<T>> implementation detects this on
1028        // the first update and replaces it with the caller-supplied
1029        // template (the `self` passed into open_accum_port), which we
1030        // seed here as "full NotExist over the target region".
1031        let (port, rx) = cx.mailbox().open_accum_port_opts(
1032            // Initial state for the accumulator: full mesh seeded to
1033            // NotExist.
1034            crate::v1::StatusMesh::from_single(region.clone(), Status::NotExist),
1035            Some(ReducerOpts {
1036                max_update_interval: Some(Duration::from_millis(50)),
1037                initial_update_interval: None,
1038            }),
1039        );
1040
1041        let mut reply = port.bind();
1042        // If this proc dies or some other issue renders the reply undeliverable,
1043        // the reply does not need to be returned to the sender.
1044        reply.return_undeliverable(false);
1045        // Send a message to all ranks. They reply with overlays to
1046        // `port`.
1047        agent_mesh.cast(
1048            cx,
1049            resource::GetRankStatus {
1050                name: name.clone(),
1051                reply,
1052            },
1053        )?;
1054
1055        let start_time = RealClock.now();
1056
1057        // Wait for all ranks to report a terminal or running status.
1058        // If any proc reports a failure (via supervision) or the mesh
1059        // times out, `wait()` returns Err with the final snapshot.
1060        //
1061        // `rx` is the accumulator output stream: each time reduced
1062        // overlays are applied, it emits a new StatusMesh snapshot.
1063        // `wait()` loops on it, deciding when the stream is
1064        // "complete" (no more NotExist) or times out.
1065        let (statuses, mut mesh) = match GetRankStatus::wait(
1066            rx,
1067            self.ranks.len(),
1068            hyperactor_config::global::get(ACTOR_SPAWN_MAX_IDLE),
1069            region.clone(), // fallback
1070        )
1071        .await
1072        {
1073            Ok(statuses) => {
1074                // Spawn succeeds only if no rank has reported a
1075                // supervision/terminal state. This preserves the old
1076                // `first_terminating().is_none()` semantics.
1077                let has_terminating = statuses.values().any(|s| s.is_terminating());
1078                if !has_terminating {
1079                    Ok((statuses, ActorMesh::new(self.clone(), name, None)))
1080                } else {
1081                    let legacy = mesh_to_rankedvalues_with_default(
1082                        &statuses,
1083                        Status::NotExist,
1084                        Status::is_not_exist,
1085                        self.ranks.len(),
1086                    );
1087                    Err(Error::ActorSpawnError { statuses: legacy })
1088                }
1089            }
1090            Err(complete) => {
1091                // Fill remaining ranks with a timeout status, now
1092                // handled via the legacy shim.
1093                let elapsed = start_time.elapsed();
1094                let legacy = mesh_to_rankedvalues_with_default(
1095                    &complete,
1096                    Status::Timeout(elapsed),
1097                    Status::is_not_exist,
1098                    self.ranks.len(),
1099                );
1100                Err(Error::ActorSpawnError { statuses: legacy })
1101            }
1102        }?;
1103        // We don't need controllers for a system actor like the CommActor.
1104        if !is_system_actor {
1105            // Spawn a unique mesh manager for each actor mesh, so the type of the
1106            // mesh can be preserved.
1107            let controller: ActorMeshController<A> = ActorMeshController::new(
1108                mesh.deref().clone(),
1109                supervision_display_name,
1110                Some(cx.instance().port().bind()),
1111                statuses,
1112            );
1113            let controller = controller
1114                .spawn(cx)
1115                .map_err(|e| Error::ControllerActorSpawnError(mesh.name().clone(), e))?;
1116            // Controller and ActorMesh both depend on references from each other, break
1117            // the cycle by setting the controller after the fact.
1118            mesh.set_controller(Some(controller.bind()));
1119        }
1120        Ok(mesh)
1121    }
1122
1123    /// Send stop actors message to all mesh agents for a specific mesh name
1124    #[hyperactor::instrument(fields(
1125        host_mesh = self.host_mesh_name().map(|n| n.to_string()),
1126        proc_mesh = self.name.to_string(),
1127        actor_mesh = mesh_name.to_string(),
1128    ))]
1129    pub(crate) async fn stop_actor_by_name(
1130        &self,
1131        cx: &impl context::Actor,
1132        mesh_name: Name,
1133    ) -> v1::Result<ValueMesh<Status>> {
1134        tracing::info!(name = "ProcMeshStatus", status = "ActorMesh::Stop::Attempt");
1135        tracing::info!(name = "ActorMeshStatus", status = "Stop::Attempt");
1136        let result = self.stop_actor_by_name_inner(cx, mesh_name).await;
1137        match &result {
1138            Ok(_) => {
1139                tracing::info!(name = "ProcMeshStatus", status = "ActorMesh::Stop::Success");
1140                tracing::info!(name = "ActorMeshStatus", status = "Stop::Success");
1141            }
1142            Err(error) => {
1143                tracing::error!(name = "ProcMeshStatus", status = "ActorMesh::Stop::Failed", %error);
1144                tracing::error!(name = "ActorMeshStatus", status = "Stop::Failed", %error);
1145            }
1146        }
1147        result
1148    }
1149
1150    async fn stop_actor_by_name_inner(
1151        &self,
1152        cx: &impl context::Actor,
1153        mesh_name: Name,
1154    ) -> v1::Result<ValueMesh<Status>> {
1155        let region = self.region().clone();
1156        let agent_mesh = self.agent_mesh();
1157        agent_mesh.cast(
1158            cx,
1159            resource::Stop {
1160                name: mesh_name.clone(),
1161            },
1162        )?;
1163
1164        // Open an accum port that *receives overlays* and *emits full
1165        // meshes*.
1166        //
1167        // NOTE: Mailbox initializes the accumulator state via
1168        // `Default`, which is an *empty* ValueMesh (0 ranks). Our
1169        // Accumulator<ValueMesh<T>> implementation detects this on
1170        // the first update and replaces it with the caller-supplied
1171        // template (the `self` passed into open_accum_port), which we
1172        // seed here as "full NotExist over the target region".
1173        let (port, rx) = cx.mailbox().open_accum_port_opts(
1174            // Initial state for the accumulator: full mesh seeded to
1175            // NotExist.
1176            crate::v1::StatusMesh::from_single(region.clone(), Status::NotExist),
1177            Some(ReducerOpts {
1178                max_update_interval: Some(Duration::from_millis(50)),
1179                initial_update_interval: None,
1180            }),
1181        );
1182        agent_mesh.cast(
1183            cx,
1184            resource::GetRankStatus {
1185                name: mesh_name,
1186                reply: port.bind(),
1187            },
1188        )?;
1189        let start_time = RealClock.now();
1190
1191        // Reuse actor spawn idle time.
1192        let max_idle_time = hyperactor_config::global::get(ACTOR_SPAWN_MAX_IDLE);
1193        match GetRankStatus::wait(
1194            rx,
1195            self.ranks.len(),
1196            max_idle_time,
1197            region.clone(), // fallback mesh if nothing arrives
1198        )
1199        .await
1200        {
1201            Ok(statuses) => {
1202                // Check that all actors are in some terminal state.
1203                // Failed is ok, because one of these actors may have failed earlier
1204                // and we're trying to stop the others.
1205                let all_stopped = statuses.values().all(|s| s.is_terminating());
1206                if all_stopped {
1207                    Ok(statuses)
1208                } else {
1209                    let legacy = mesh_to_rankedvalues_with_default(
1210                        &statuses,
1211                        Status::NotExist,
1212                        Status::is_not_exist,
1213                        self.ranks.len(),
1214                    );
1215                    Err(Error::ActorStopError { statuses: legacy })
1216                }
1217            }
1218            Err(complete) => {
1219                // Fill remaining ranks with a timeout status via the
1220                // legacy shim.
1221                let legacy = mesh_to_rankedvalues_with_default(
1222                    &complete,
1223                    Status::Timeout(start_time.elapsed()),
1224                    Status::is_not_exist,
1225                    self.ranks.len(),
1226                );
1227                Err(Error::ActorStopError { statuses: legacy })
1228            }
1229        }
1230    }
1231}
1232
1233impl fmt::Display for ProcMeshRef {
1234    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1235        write!(f, "{}{{{}}}", self.name, self.region)
1236    }
1237}
1238
1239impl view::Ranked for ProcMeshRef {
1240    type Item = ProcRef;
1241
1242    fn region(&self) -> &Region {
1243        &self.region
1244    }
1245
1246    fn get(&self, rank: usize) -> Option<&Self::Item> {
1247        self.ranks.get(rank)
1248    }
1249}
1250
1251impl view::RankedSliceable for ProcMeshRef {
1252    fn sliced(&self, region: Region) -> Self {
1253        debug_assert!(region.is_subset(view::Ranked::region(self)));
1254        let ranks = self
1255            .region()
1256            .remap(&region)
1257            .unwrap()
1258            .map(|index| self.get(index).unwrap().clone())
1259            .collect();
1260        Self::new(
1261            self.name.clone(),
1262            region,
1263            Arc::new(ranks),
1264            self.host_mesh.clone(),
1265            Some(self.root_region.as_ref().unwrap_or(&self.region).clone()),
1266            self.root_comm_actor.clone(),
1267        )
1268        .unwrap()
1269    }
1270}
1271
1272#[cfg(test)]
1273mod tests {
1274    use hyperactor::Instance;
1275    use ndslice::ViewExt;
1276    use ndslice::extent;
1277    use timed_test::async_timed_test;
1278
1279    use crate::resource::RankedValues;
1280    use crate::resource::Status;
1281    use crate::v1::testactor;
1282    use crate::v1::testing;
1283
1284    #[tokio::test]
1285    async fn test_proc_mesh_allocate() {
1286        let (mesh, actor, router) = testing::local_proc_mesh(extent!(replica = 4)).await;
1287        assert_eq!(mesh.extent(), extent!(replica = 4));
1288        assert_eq!(mesh.ranks.len(), 4);
1289        assert!(!router.prefixes().is_empty());
1290
1291        // All of the agents are alive, and reachable (both ways).
1292        for proc_ref in mesh.values() {
1293            assert!(proc_ref.status(&actor).await.unwrap());
1294        }
1295
1296        // Same on the proc mesh:
1297        assert!(
1298            mesh.status(&actor)
1299                .await
1300                .unwrap()
1301                .values()
1302                .all(|status| status)
1303        );
1304    }
1305
1306    #[async_timed_test(timeout_secs = 30)]
1307    #[cfg(fbcode_build)]
1308    async fn test_spawn_actor() {
1309        hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default());
1310
1311        let instance = testing::instance();
1312
1313        for proc_mesh in testing::proc_meshes(&instance, extent!(replicas = 4, hosts = 2)).await {
1314            testactor::assert_mesh_shape(proc_mesh.spawn(instance, "test", &()).await.unwrap())
1315                .await;
1316        }
1317    }
1318
1319    #[tokio::test]
1320    #[cfg(fbcode_build)]
1321    async fn test_failing_spawn_actor() {
1322        hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default());
1323
1324        let instance = testing::instance();
1325
1326        for proc_mesh in testing::proc_meshes(&instance, extent!(replicas = 4, hosts = 2)).await {
1327            let err = proc_mesh
1328                .spawn::<testactor::FailingCreateTestActor, Instance<testing::TestRootClient>>(
1329                    instance,
1330                    "testfail",
1331                    &(),
1332                )
1333                .await
1334                .unwrap_err();
1335            let statuses = err.into_actor_spawn_error().unwrap();
1336            assert_eq!(
1337                statuses,
1338                RankedValues::from((0..8, Status::Failed("test failure".to_string()))),
1339            );
1340        }
1341    }
1342}