hyperactor_mesh/v1/
proc_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::any::type_name;
10use std::collections::HashMap;
11use std::fmt;
12use std::ops::Deref;
13use std::sync::Arc;
14
15use hyperactor::Actor;
16use hyperactor::ActorId;
17use hyperactor::ActorRef;
18use hyperactor::Named;
19use hyperactor::ProcId;
20use hyperactor::RemoteMessage;
21use hyperactor::actor::Referable;
22use hyperactor::actor::remote::Remote;
23use hyperactor::channel;
24use hyperactor::channel::ChannelAddr;
25use hyperactor::context;
26use hyperactor::mailbox::DialMailboxRouter;
27use hyperactor::mailbox::MailboxServer;
28use ndslice::Extent;
29use ndslice::ViewExt as _;
30use ndslice::view;
31use ndslice::view::CollectMeshExt;
32use ndslice::view::MapIntoExt;
33use ndslice::view::Ranked;
34use ndslice::view::Region;
35use serde::Deserialize;
36use serde::Serialize;
37
38use crate::CommActor;
39use crate::alloc::Alloc;
40use crate::alloc::AllocExt;
41use crate::alloc::AllocatedProc;
42use crate::assign::Ranks;
43use crate::comm::CommActorMode;
44use crate::proc_mesh::mesh_agent;
45use crate::proc_mesh::mesh_agent::ActorState;
46use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
47use crate::proc_mesh::mesh_agent::ProcMeshAgent;
48use crate::proc_mesh::mesh_agent::ReconfigurableMailboxSender;
49use crate::resource;
50use crate::resource::RankedValues;
51use crate::v1;
52use crate::v1::ActorMesh;
53use crate::v1::ActorMeshRef;
54use crate::v1::Error;
55use crate::v1::HostMeshRef;
56use crate::v1::Name;
57use crate::v1::ValueMesh;
58
59/// A reference to a single [`hyperactor::Proc`].
60#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
61pub struct ProcRef {
62    proc_id: ProcId,
63    /// The rank of this proc at creation.
64    create_rank: usize,
65    /// The agent managing this proc.
66    agent: ActorRef<ProcMeshAgent>,
67}
68
69impl ProcRef {
70    pub(crate) fn new(proc_id: ProcId, create_rank: usize, agent: ActorRef<ProcMeshAgent>) -> Self {
71        Self {
72            proc_id,
73            create_rank,
74            agent,
75        }
76    }
77
78    /// Pings the proc, returning whether it is alive. This will be replaced by a
79    /// finer-grained lifecycle status in the near future.
80    pub(crate) async fn status(&self, cx: &impl context::Actor) -> v1::Result<bool> {
81        let (port, mut rx) = cx.mailbox().open_port();
82        self.agent
83            .status(cx, port.bind())
84            .await
85            .map_err(|e| Error::CallError(self.agent.actor_id().clone(), e))?;
86        loop {
87            let (rank, status) = rx
88                .recv()
89                .await
90                .map_err(|e| Error::CallError(self.agent.actor_id().clone(), e.into()))?;
91            if rank == self.create_rank {
92                break Ok(status);
93            }
94        }
95    }
96
97    /// Get the supervision events for one actor with the given name.
98    #[allow(dead_code)]
99    async fn actor_state(
100        &self,
101        cx: &impl context::Actor,
102        name: Name,
103    ) -> v1::Result<resource::State<ActorState>> {
104        let (port, mut rx) = cx.mailbox().open_port::<resource::State<ActorState>>();
105        self.agent
106            .send(
107                cx,
108                resource::GetState::<ActorState> {
109                    name: name.clone(),
110                    reply: port.bind(),
111                },
112            )
113            .map_err(|e| Error::CallError(self.agent.actor_id().clone(), e.into()))?;
114        let state = rx
115            .recv()
116            .await
117            .map_err(|e| Error::CallError(self.agent.actor_id().clone(), e.into()))?;
118        if let Some(ref inner) = state.state {
119            let rank = inner.create_rank;
120            if rank == self.create_rank {
121                Ok(state)
122            } else {
123                Err(Error::CallError(
124                    self.agent.actor_id().clone(),
125                    anyhow::anyhow!(
126                        "Rank on mesh agent not matching for Actor {}: returned {}, expected {}",
127                        name,
128                        rank,
129                        self.create_rank
130                    ),
131                ))
132            }
133        } else {
134            Err(Error::CallError(
135                self.agent.actor_id().clone(),
136                anyhow::anyhow!("Actor {} does not exist", name),
137            ))
138        }
139    }
140
141    pub(crate) fn actor_id(&self, name: &Name) -> ActorId {
142        self.proc_id.actor_id(name.to_string(), 0)
143    }
144
145    /// Generic bound: `A: Referable` - required because we return
146    /// an `ActorRef<A>`.
147    pub(crate) fn attest<A: Referable>(&self, name: &Name) -> ActorRef<A> {
148        ActorRef::attest(self.actor_id(name))
149    }
150}
151
152/// A mesh of processes.
153#[allow(dead_code)]
154#[derive(Debug)]
155pub struct ProcMesh {
156    name: Name,
157    allocation: ProcMeshAllocation,
158    comm_actor_name: Option<Name>,
159    current_ref: ProcMeshRef,
160}
161
162impl ProcMesh {
163    async fn create(
164        cx: &impl context::Actor,
165        name: Name,
166        allocation: ProcMeshAllocation,
167        spawn_comm_actor: bool,
168    ) -> v1::Result<Self> {
169        let comm_actor_name = if spawn_comm_actor {
170            Some(Name::new("comm"))
171        } else {
172            None
173        };
174
175        let region = allocation.extent().clone().into();
176        let ranks = allocation.ranks();
177        let root_comm_actor = comm_actor_name.as_ref().map(|name| {
178            ActorRef::attest(
179                ranks
180                    .first()
181                    .expect("root mesh cannot be empty")
182                    .actor_id(name),
183            )
184        });
185        let current_ref = ProcMeshRef::new(
186            name.clone(),
187            region,
188            ranks,
189            None, // this is the root mesh
190            None, // comm actor is not alive yet
191        )
192        .unwrap();
193
194        let mut proc_mesh = Self {
195            name,
196            allocation,
197            comm_actor_name: comm_actor_name.clone(),
198            current_ref,
199        };
200
201        if let Some(comm_actor_name) = comm_actor_name {
202            // CommActor satisfies `Actor + Referable`, so it can be
203            // spawned and safely referenced via ActorRef<CommActor>.
204            let comm_actor_mesh = proc_mesh
205                .spawn_with_name::<CommActor>(cx, comm_actor_name, &Default::default())
206                .await?;
207            let address_book: HashMap<_, _> = comm_actor_mesh
208                .iter()
209                .map(|(point, actor_ref)| (point.rank(), actor_ref))
210                .collect();
211            // Now that we have all of the spawned comm actors, kick them all into
212            // mesh mode.
213            for (rank, comm_actor) in &address_book {
214                comm_actor
215                    .send(cx, CommActorMode::Mesh(*rank, address_book.clone()))
216                    .map_err(|e| Error::SendingError(comm_actor.actor_id().clone(), Box::new(e)))?
217            }
218
219            // The comm actor is now set up and ready to go.
220            proc_mesh.current_ref.root_comm_actor = root_comm_actor;
221        }
222
223        Ok(proc_mesh)
224    }
225
226    pub(crate) async fn create_owned_unchecked(
227        cx: &impl context::Actor,
228        name: Name,
229        extent: Extent,
230        hosts: HostMeshRef,
231        ranks: Vec<ProcRef>,
232    ) -> v1::Result<Self> {
233        Self::create(
234            cx,
235            name,
236            ProcMeshAllocation::Owned {
237                hosts,
238                extent,
239                ranks: Arc::new(ranks),
240            },
241            true,
242        )
243        .await
244    }
245
246    /// Allocate a new ProcMesh from the provided alloc.
247    /// Allocate does not require an owning actor because references are not owned.
248    /// Allocate a new ProcMesh from the provided alloc.
249    pub async fn allocate(
250        cx: &impl context::Actor,
251        mut alloc: Box<dyn Alloc + Send + Sync + 'static>,
252        name: &str,
253    ) -> v1::Result<Self> {
254        let running = alloc.initialize().await?;
255
256        // Wire the newly created mesh into the proc, so that it is routable.
257        // We route all of the relevant prefixes into the proc's forwarder,
258        // and serve it on the alloc's transport.
259        //
260        // This will be removed with direct addressing.
261        let proc = cx.instance().proc();
262
263        // First make sure we can serve the proc:
264        let (proc_channel_addr, rx) = channel::serve(ChannelAddr::any(alloc.transport()))?;
265        proc.clone().serve(rx);
266
267        let bind_allocated_procs = |router: &DialMailboxRouter| {
268            // Route all of the allocated procs:
269            for AllocatedProc { proc_id, addr, .. } in running.iter() {
270                if proc_id.is_direct() {
271                    continue;
272                }
273                router.bind(proc_id.clone().into(), addr.clone());
274            }
275        };
276
277        // Temporary for backward compatibility with ranked procs and v0 API.
278        // Proc meshes can be allocated either using the root client proc (which
279        // has a DialMailboxRouter forwarder) or a mesh agent proc (which has a
280        // ReconfigurableMailboxSender forwarder with an inner DialMailboxRouter).
281        if let Some(router) = proc.forwarder().downcast_ref() {
282            bind_allocated_procs(router);
283        } else if let Some(router) = proc
284            .forwarder()
285            .downcast_ref::<ReconfigurableMailboxSender>()
286        {
287            bind_allocated_procs(
288                router
289                    .as_inner()
290                    .map_err(|_| Error::UnroutableMesh())?
291                    .as_configured()
292                    .ok_or(Error::UnroutableMesh())?
293                    .downcast_ref()
294                    .ok_or(Error::UnroutableMesh())?,
295            );
296        } else {
297            return Err(Error::UnroutableMesh());
298        }
299
300        // Set up the mesh agents. Since references are not owned, we don't supervise it.
301        // Instead, we just let procs die when they have unhandled supervision events.
302        let address_book: HashMap<_, _> = running
303            .iter()
304            .map(
305                |AllocatedProc {
306                     addr, mesh_agent, ..
307                 }| { (mesh_agent.actor_id().proc_id().clone(), addr.clone()) },
308            )
309            .collect();
310
311        let (config_handle, mut config_receiver) = cx.mailbox().open_port();
312        for (rank, AllocatedProc { mesh_agent, .. }) in running.iter().enumerate() {
313            mesh_agent
314                .configure(
315                    cx,
316                    rank,
317                    proc_channel_addr.clone(),
318                    None, // no supervisor; we just crash
319                    address_book.clone(),
320                    config_handle.bind(),
321                    true,
322                )
323                .await
324                .map_err(Error::ConfigurationError)?;
325        }
326        let mut completed = Ranks::new(running.len());
327        while !completed.is_full() {
328            let rank = config_receiver
329                .recv()
330                .await
331                .map_err(|err| Error::ConfigurationError(err.into()))?;
332            if completed.insert(rank, rank).is_some() {
333                tracing::warn!("multiple completions received for rank {}", rank);
334            }
335        }
336
337        let ranks: Vec<_> = running
338            .into_iter()
339            .enumerate()
340            .map(|(create_rank, allocated)| ProcRef {
341                proc_id: allocated.proc_id,
342                create_rank,
343                agent: allocated.mesh_agent,
344            })
345            .collect();
346
347        Self::create(
348            cx,
349            Name::new(name),
350            ProcMeshAllocation::Allocated {
351                alloc,
352                ranks: Arc::new(ranks),
353            },
354            true, // alloc-based meshes support comm actors
355        )
356        .await
357    }
358}
359
360impl Deref for ProcMesh {
361    type Target = ProcMeshRef;
362
363    fn deref(&self) -> &Self::Target {
364        &self.current_ref
365    }
366}
367
368/// Represents different ways ProcMeshes can be allocated.
369enum ProcMeshAllocation {
370    /// A mesh that has been allocated from an `Alloc`.
371    Allocated {
372        // We have to hold on to the alloc for the duration of the mesh lifetime.
373        // The procmesh inherits the alloc's extent.
374        alloc: Box<dyn Alloc + Send + Sync + 'static>,
375
376        // The allocated ranks.
377        ranks: Arc<Vec<ProcRef>>,
378    },
379
380    /// An owned allocation: this ProcMesh fully owns the set of ranks.
381    Owned {
382        /// The host mesh from which the proc mesh was spawned.
383        hosts: HostMeshRef,
384        // This is purely for storage: `hosts.extent()` returns a computed (by value)
385        // extent.
386        extent: Extent,
387        /// A proc reference for each rank in the mesh.
388        ranks: Arc<Vec<ProcRef>>,
389    },
390}
391
392impl ProcMeshAllocation {
393    fn extent(&self) -> &Extent {
394        match self {
395            ProcMeshAllocation::Allocated { alloc, .. } => alloc.extent(),
396            ProcMeshAllocation::Owned { extent, .. } => extent,
397        }
398    }
399
400    fn ranks(&self) -> Arc<Vec<ProcRef>> {
401        Arc::clone(match self {
402            ProcMeshAllocation::Allocated { ranks, .. } => ranks,
403            ProcMeshAllocation::Owned { ranks, .. } => ranks,
404        })
405    }
406}
407
408impl fmt::Debug for ProcMeshAllocation {
409    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
410        match self {
411            ProcMeshAllocation::Allocated { ranks, .. } => f
412                .debug_struct("ProcMeshAllocation::Allocated")
413                .field("alloc", &"<dyn Alloc>")
414                .field("ranks", ranks)
415                .finish(),
416            ProcMeshAllocation::Owned {
417                hosts,
418                ranks,
419                extent: _,
420            } => f
421                .debug_struct("ProcMeshAllocation::Owned")
422                .field("hosts", hosts)
423                .field("ranks", ranks)
424                .finish(),
425        }
426    }
427}
428
429/// A reference to a ProcMesh, consisting of a set of ranked [`ProcRef`]s,
430/// arranged into a region. ProcMeshes named, uniquely identifying the
431/// ProcMesh from which the reference was derived.
432///
433/// ProcMeshes can be sliced to create new ProcMeshes with a subset of the
434/// original ranks.
435#[derive(Debug, Clone, PartialEq, Eq, Hash, Named, Serialize, Deserialize)]
436pub struct ProcMeshRef {
437    name: Name,
438    region: Region,
439    ranks: Arc<Vec<ProcRef>>,
440    // Temporary: used to fit v1 ActorMesh with v0's casting implementation. This
441    // should be removed after we remove the v0 code.
442    // The root region of this mesh. None means this mesh itself is the root.
443    pub(crate) root_region: Option<Region>,
444    // Temporary: used to fit v1 ActorMesh with v0's casting implementation. This
445    // should be removed after we remove the v0 code.
446    // v0 casting requires root mesh rank 0 as the 1st hop, so we need to provide
447    // it here. For v1, this can be removed since v1 can use any rank.
448    pub(crate) root_comm_actor: Option<ActorRef<CommActor>>,
449}
450
451impl ProcMeshRef {
452    /// Create a new ProcMeshRef from the given name, region, ranks, and so on.
453    fn new(
454        name: Name,
455        region: Region,
456        ranks: Arc<Vec<ProcRef>>,
457        root_region: Option<Region>,
458        root_comm_actor: Option<ActorRef<CommActor>>,
459    ) -> v1::Result<Self> {
460        if region.num_ranks() != ranks.len() {
461            return Err(v1::Error::InvalidRankCardinality {
462                expected: region.num_ranks(),
463                actual: ranks.len(),
464            });
465        }
466        Ok(Self {
467            name,
468            region,
469            ranks,
470            root_region,
471            root_comm_actor,
472        })
473    }
474
475    pub(crate) fn root_comm_actor(&self) -> Option<&ActorRef<CommActor>> {
476        self.root_comm_actor.as_ref()
477    }
478
479    /// The current statuses of procs in this mesh.
480    pub async fn status(&self, cx: &impl context::Actor) -> v1::Result<ValueMesh<bool>> {
481        let vm: ValueMesh<_> = self.map_into(|proc_ref| {
482            let proc_ref = proc_ref.clone();
483            async move { proc_ref.status(cx).await }
484        });
485        vm.join().await.transpose()
486    }
487
488    fn agent_mesh(&self) -> ActorMeshRef<ProcMeshAgent> {
489        let agent_name = self.ranks.first().unwrap().agent.actor_id().name();
490        // This name must match the ProcMeshAgent name, which can change depending on the allocator.
491        ActorMeshRef::new(Name::new_reserved(agent_name), self.clone())
492    }
493
494    /// The supervision events of procs in this mesh.
495    pub async fn actor_states(
496        &self,
497        cx: &impl context::Actor,
498        name: Name,
499    ) -> v1::Result<ValueMesh<resource::State<ActorState>>> {
500        let agent_mesh = self.agent_mesh();
501        let (port, mut rx) = cx.mailbox().open_port::<resource::State<ActorState>>();
502        // TODO: Use accumulation to get back a single value (representing whether
503        // *any* of the actors failed) instead of a mesh.
504        agent_mesh.cast(
505            cx,
506            resource::GetState::<ActorState> {
507                name: name.clone(),
508                reply: port.bind(),
509            },
510        )?;
511        let expected = self.ranks.len();
512        let mut states = Vec::with_capacity(expected);
513        for _ in 0..expected {
514            let state = rx.recv().await?;
515            match state.state {
516                Some(ref inner) => {
517                    states.push((inner.create_rank, state));
518                }
519                None => {
520                    return Err(Error::NotExist(state.name));
521                }
522            }
523        }
524        // Sort by rank, so that the resulting mesh is ordered.
525        states.sort_by_key(|(rank, _)| *rank);
526        let vm = states
527            .into_iter()
528            .map(|(_, state)| state)
529            .collect_mesh::<ValueMesh<_>>(self.region.clone())?;
530        Ok(vm)
531    }
532
533    /// Spawn an actor on all of the procs in this mesh, returning a
534    /// new ActorMesh.
535    ///
536    /// Bounds:
537    /// - `A: Actor` - the actor actually runs inside each proc.
538    /// - `A: Referable` - so we can return typed `ActorRef<A>`s
539    ///   inside the `ActorMesh`.
540    /// - `A::Params: RemoteMessage` - spawn parameters must be
541    ///   serializable and routable.
542    pub async fn spawn<A: Actor + Referable>(
543        &self,
544        cx: &impl context::Actor,
545        name: &str,
546        params: &A::Params,
547    ) -> v1::Result<ActorMesh<A>>
548    where
549        A::Params: RemoteMessage,
550    {
551        self.spawn_with_name(cx, Name::new(name), params).await
552    }
553
554    /// Spawn an actor on all procs in this mesh under the given
555    /// [`Name`], returning a new `ActorMesh`.
556    ///
557    /// This is the underlying implementation used by [`spawn`]; it
558    /// differs only in that the actor name is passed explicitly
559    /// rather than as a `&str`.
560    ///
561    /// Bounds:
562    /// - `A: Actor` - the actor actually runs inside each proc.
563    /// - `A: Referable` - so we can return typed `ActorRef<A>`s
564    ///   inside the `ActorMesh`.
565    /// - `A::Params: RemoteMessage` — spawn parameters must be
566    ///   serializable and routable.
567    pub(crate) async fn spawn_with_name<A: Actor + Referable>(
568        &self,
569        cx: &impl context::Actor,
570        name: Name,
571        params: &A::Params,
572    ) -> v1::Result<ActorMesh<A>>
573    where
574        A::Params: RemoteMessage,
575    {
576        let remote = Remote::collect();
577        // `Referable` ensures the type `A` is registered with
578        // `Remote`.
579        let actor_type = remote
580            .name_of::<A>()
581            .ok_or(Error::ActorTypeNotRegistered(type_name::<A>().to_string()))?
582            .to_string();
583
584        let serialized_params = bincode::serialize(params)?;
585
586        self.agent_mesh().cast(
587            cx,
588            resource::CreateOrUpdate::<mesh_agent::ActorSpec> {
589                name: name.clone(),
590                rank: Default::default(),
591                spec: mesh_agent::ActorSpec {
592                    actor_type: actor_type.clone(),
593                    params_data: serialized_params.clone(),
594                },
595            },
596        )?;
597
598        let (port, mut rx) = cx.mailbox().open_accum_port(RankedValues::default());
599
600        self.agent_mesh().cast(
601            cx,
602            resource::GetRankStatus {
603                name: name.clone(),
604                reply: port.bind(),
605            },
606        )?;
607
608        // Wait for everyone to report back.
609        // TODO: move out of critical path
610        let statuses = loop {
611            let statuses = rx.recv().await?;
612            if statuses.rank(self.ranks.len()) == self.ranks.len() {
613                break statuses;
614            }
615        };
616
617        let failed: Vec<_> = statuses
618            .iter()
619            .filter_map(|(ranks, status)| {
620                if status.is_terminating() {
621                    Some(ranks.clone())
622                } else {
623                    None
624                }
625            })
626            .flatten()
627            .collect();
628        if !failed.is_empty() {
629            return Err(Error::GspawnError(
630                name,
631                format!("failed ranks: {:?}", failed,),
632            ));
633        }
634
635        Ok(ActorMesh::new(self.clone(), name))
636    }
637}
638
639impl view::Ranked for ProcMeshRef {
640    type Item = ProcRef;
641
642    fn region(&self) -> &Region {
643        &self.region
644    }
645
646    fn get(&self, rank: usize) -> Option<&Self::Item> {
647        self.ranks.get(rank)
648    }
649}
650
651impl view::RankedSliceable for ProcMeshRef {
652    fn sliced(&self, region: Region) -> Self {
653        debug_assert!(region.is_subset(view::Ranked::region(self)));
654        let ranks = self
655            .region()
656            .remap(&region)
657            .unwrap()
658            .map(|index| self.get(index).unwrap().clone())
659            .collect();
660        Self::new(
661            self.name.clone(),
662            region,
663            Arc::new(ranks),
664            Some(self.root_region.as_ref().unwrap_or(&self.region).clone()),
665            self.root_comm_actor.clone(),
666        )
667        .unwrap()
668    }
669}
670
671#[cfg(test)]
672mod tests {
673    use std::assert_matches::assert_matches;
674
675    use ndslice::ViewExt;
676    use ndslice::extent;
677    use timed_test::async_timed_test;
678
679    use crate::v1;
680    use crate::v1::testactor;
681    use crate::v1::testing;
682
683    #[tokio::test]
684    async fn test_proc_mesh_allocate() {
685        let (mesh, actor, router) = testing::local_proc_mesh(extent!(replica = 4)).await;
686        assert_eq!(mesh.extent(), extent!(replica = 4));
687        assert_eq!(mesh.ranks.len(), 4);
688        assert!(!router.prefixes().is_empty());
689
690        // All of the agents are alive, and reachable (both ways).
691        for proc_ref in mesh.values() {
692            assert!(proc_ref.status(&actor).await.unwrap());
693        }
694
695        // Same on the proc mesh:
696        assert!(
697            mesh.status(&actor)
698                .await
699                .unwrap()
700                .values()
701                .all(|status| status)
702        );
703    }
704
705    #[async_timed_test(timeout_secs = 30)]
706    async fn test_spawn_actor() {
707        hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default());
708
709        let instance = testing::instance().await;
710
711        for proc_mesh in testing::proc_meshes(&instance, extent!(replicas = 4, hosts = 2)).await {
712            testactor::assert_mesh_shape(proc_mesh.spawn(instance, "test", &()).await.unwrap())
713                .await;
714        }
715    }
716
717    #[tokio::test]
718    async fn test_failing_spawn_actor() {
719        hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default());
720
721        let instance = testing::instance().await;
722
723        for proc_mesh in testing::proc_meshes(&instance, extent!(replicas = 4, hosts = 2)).await {
724            let err = proc_mesh
725                .spawn::<testactor::FailingCreateTestActor>(instance, "testfail", &())
726                .await
727                .unwrap_err();
728            assert_matches!(err, v1::Error::GspawnError(_, _))
729        }
730    }
731}