hyperactor_mesh/v1/
actor_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::fmt;
10use std::hash::Hash;
11use std::hash::Hasher;
12use std::marker::PhantomData;
13use std::ops::Deref;
14use std::sync::OnceLock as OnceCell;
15
16use hyperactor::Actor;
17use hyperactor::ActorRef;
18use hyperactor::RemoteHandles;
19use hyperactor::RemoteMessage;
20use hyperactor::actor::Referable;
21use hyperactor::attrs::Attrs;
22use hyperactor::context;
23use hyperactor::message::Castable;
24use hyperactor::message::IndexedErasedUnbound;
25use hyperactor::message::Unbound;
26use hyperactor_mesh_macros::sel;
27use ndslice::Selection;
28use ndslice::ViewExt as _;
29use ndslice::view;
30use ndslice::view::Region;
31use ndslice::view::View;
32use serde::Deserialize;
33use serde::Deserializer;
34use serde::Serialize;
35use serde::Serializer;
36
37use crate::CommActor;
38use crate::actor_mesh as v0_actor_mesh;
39use crate::comm::multicast;
40use crate::proc_mesh::mesh_agent::ActorState;
41use crate::reference::ActorMeshId;
42use crate::resource;
43use crate::v1;
44use crate::v1::Error;
45use crate::v1::Name;
46use crate::v1::ProcMeshRef;
47use crate::v1::ValueMesh;
48
49/// An ActorMesh is a collection of ranked A-typed actors.
50///
51/// Bound note: `A: Referable` because the mesh stores/returns
52/// `ActorRef<A>`, which is only defined for `A: Referable`.
53#[derive(Debug)]
54pub struct ActorMesh<A: Referable> {
55    proc_mesh: ProcMeshRef,
56    name: Name,
57    current_ref: ActorMeshRef<A>,
58}
59
60// `A: Referable` for the same reason as the struct: the mesh holds
61// `ActorRef<A>`.
62impl<A: Referable> ActorMesh<A> {
63    pub(crate) fn new(proc_mesh: ProcMeshRef, name: Name) -> Self {
64        let current_ref =
65            ActorMeshRef::with_page_size(name.clone(), proc_mesh.clone(), DEFAULT_PAGE);
66
67        Self {
68            proc_mesh,
69            name,
70            current_ref,
71        }
72    }
73}
74
75impl<A: Referable> Deref for ActorMesh<A> {
76    type Target = ActorMeshRef<A>;
77
78    fn deref(&self) -> &Self::Target {
79        &self.current_ref
80    }
81}
82
83/// Manual implementation of Clone because `A` doesn't need to implement Clone
84/// but we still want to be able to clone the ActorMesh.
85impl<A: Referable> Clone for ActorMesh<A> {
86    fn clone(&self) -> Self {
87        Self {
88            proc_mesh: self.proc_mesh.clone(),
89            name: self.name.clone(),
90            current_ref: self.current_ref.clone(),
91        }
92    }
93}
94
95/// Influences paging behavior for the lazy cache. Smaller pages
96/// reduce over-allocation for sparse access; larger pages reduce the
97/// number of heap allocations for contiguous scans.
98const DEFAULT_PAGE: usize = 1024;
99
100/// A lazily materialized page of ActorRefs.
101struct Page<A: Referable> {
102    slots: Box<[OnceCell<ActorRef<A>>]>,
103}
104
105impl<A: Referable> Page<A> {
106    fn new(len: usize) -> Self {
107        let mut v = Vec::with_capacity(len);
108        for _ in 0..len {
109            v.push(OnceCell::new());
110        }
111        Self {
112            slots: v.into_boxed_slice(),
113        }
114    }
115}
116
117/// A reference to a stable snapshot of an [`ActorMesh`].
118pub struct ActorMeshRef<A: Referable> {
119    proc_mesh: ProcMeshRef,
120    name: Name,
121
122    /// Lazily allocated collection of pages:
123    /// - The outer `OnceCell` defers creating the vector until first
124    ///   use.
125    /// - The `Vec` holds slots for multiple pages.
126    /// - Each slot is itself a `OnceCell<Box<Page<A>>>`, so that each
127    ///   page can be initialized on demand.
128    /// - A `Page<A>` is a boxed slice of `OnceCell<ActorRef<A>>`,
129    ///   i.e. the actual storage for actor references within that
130    ///   page.
131    pages: OnceCell<Vec<OnceCell<Box<Page<A>>>>>,
132    // Page size knob (not serialize; defaults after deserialize).
133    page_size: usize,
134
135    _phantom: PhantomData<A>,
136}
137
138impl<A: Actor + Referable> ActorMeshRef<A> {
139    /// Cast a message to all actors in this mesh.
140    pub fn cast<M>(&self, cx: &impl context::Actor, message: M) -> v1::Result<()>
141    where
142        A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
143        M: Castable + RemoteMessage + Clone, // Clone is required until we are fully onto comm actor
144    {
145        if let Some(root_comm_actor) = self.proc_mesh.root_comm_actor() {
146            self.cast_v0(cx, message, root_comm_actor)
147        } else {
148            for (point, actor) in self.iter() {
149                let create_rank = point.rank();
150                let mut headers = Attrs::new();
151                headers.set(
152                    multicast::CAST_ORIGINATING_SENDER,
153                    cx.instance().self_id().clone(),
154                );
155                headers.set(multicast::CAST_POINT, point);
156
157                // Make sure that we re-bind ranks, as these may be used for
158                // bootstrapping comm actors.
159                let mut unbound = Unbound::try_from_message(message.clone())
160                    .map_err(|e| Error::CastingError(self.name.clone(), e))?;
161                unbound
162                    .visit_mut::<resource::Rank>(|resource::Rank(rank)| {
163                        *rank = Some(create_rank);
164                        Ok(())
165                    })
166                    .map_err(|e| Error::CastingError(self.name.clone(), e))?;
167                let rebound_message = unbound
168                    .bind()
169                    .map_err(|e| Error::CastingError(self.name.clone(), e))?;
170                actor
171                    .send_with_headers(cx, headers, rebound_message)
172                    .map_err(|e| Error::SendingError(actor.actor_id().clone(), Box::new(e)))?;
173            }
174            Ok(())
175        }
176    }
177
178    fn cast_v0<M>(
179        &self,
180        cx: &impl context::Actor,
181        message: M,
182        root_comm_actor: &ActorRef<CommActor>,
183    ) -> v1::Result<()>
184    where
185        A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
186        M: Castable + RemoteMessage + Clone, // Clone is required until we are fully onto comm actor
187    {
188        let cast_mesh_shape = view::Ranked::region(self).into();
189        let actor_mesh_id = ActorMeshId::V1(self.name.clone());
190        match &self.proc_mesh.root_region {
191            Some(root_region) => {
192                let root_mesh_shape = root_region.into();
193                v0_actor_mesh::cast_to_sliced_mesh::<A, M>(
194                    cx,
195                    actor_mesh_id,
196                    root_comm_actor,
197                    &sel!(*),
198                    message,
199                    &cast_mesh_shape,
200                    &root_mesh_shape,
201                )
202                .map_err(|e| Error::CastingError(self.name.clone(), e.into()))
203            }
204            None => v0_actor_mesh::actor_mesh_cast::<A, M>(
205                cx,
206                actor_mesh_id,
207                root_comm_actor,
208                sel!(*),
209                &cast_mesh_shape,
210                &cast_mesh_shape,
211                message,
212            )
213            .map_err(|e| Error::CastingError(self.name.clone(), e.into())),
214        }
215    }
216
217    pub async fn actor_states(
218        &self,
219        cx: &impl context::Actor,
220    ) -> v1::Result<ValueMesh<resource::State<ActorState>>> {
221        self.proc_mesh.actor_states(cx, self.name.clone()).await
222    }
223}
224
225impl<A: Referable> ActorMeshRef<A> {
226    pub(crate) fn new(name: Name, proc_mesh: ProcMeshRef) -> Self {
227        Self::with_page_size(name, proc_mesh, DEFAULT_PAGE)
228    }
229
230    pub(crate) fn with_page_size(name: Name, proc_mesh: ProcMeshRef, page_size: usize) -> Self {
231        Self {
232            proc_mesh,
233            name,
234            pages: OnceCell::new(),
235            page_size: page_size.max(1),
236            _phantom: PhantomData,
237        }
238    }
239
240    #[inline]
241    fn len(&self) -> usize {
242        view::Ranked::region(&self.proc_mesh).num_ranks()
243    }
244
245    fn ensure_pages(&self) -> &Vec<OnceCell<Box<Page<A>>>> {
246        let n = self.len().div_ceil(self.page_size); // ⌈len / page_size⌉
247        self.pages
248            .get_or_init(|| (0..n).map(|_| OnceCell::new()).collect())
249    }
250
251    fn materialize(&self, rank: usize) -> Option<&ActorRef<A>> {
252        let len = self.len();
253        if rank >= len {
254            return None;
255        }
256        let p = self.page_size;
257        let page_ix = rank / p;
258        let local_ix = rank % p;
259
260        let pages = self.ensure_pages();
261        let page = pages[page_ix].get_or_init(|| {
262            // Last page may be partial.
263            let base = page_ix * p;
264            let remaining = len - base;
265            let page_len = remaining.min(p);
266            Box::new(Page::<A>::new(page_len))
267        });
268
269        Some(page.slots[local_ix].get_or_init(|| {
270            // Invariant: `proc_mesh` and this view share the same
271            // dense rank space:
272            //   - ranks are contiguous [0, self.len()) with no gaps
273            //     or reordering
274            //   - for every rank r, `proc_mesh.get(r)` is Some(..)
275            // Therefore we can index `proc_mesh` with `rank`
276            // directly.
277            debug_assert!(rank < self.len(), "rank must be within [0, len)");
278            debug_assert!(
279                self.proc_mesh.get(rank).is_some(),
280                "proc_mesh must be dense/aligned with this view"
281            );
282            let proc_ref = self.proc_mesh.get(rank).expect("rank in-bounds");
283            proc_ref.attest(&self.name)
284        }))
285    }
286}
287
288impl<A: Referable> Clone for ActorMeshRef<A> {
289    fn clone(&self) -> Self {
290        Self {
291            proc_mesh: self.proc_mesh.clone(),
292            name: self.name.clone(),
293            pages: OnceCell::new(), // No clone cache.
294            page_size: self.page_size,
295            _phantom: PhantomData,
296        }
297    }
298}
299
300impl<A: Referable> PartialEq for ActorMeshRef<A> {
301    fn eq(&self, other: &Self) -> bool {
302        self.proc_mesh == other.proc_mesh && self.name == other.name
303    }
304}
305impl<A: Referable> Eq for ActorMeshRef<A> {}
306
307impl<A: Referable> Hash for ActorMeshRef<A> {
308    fn hash<H: Hasher>(&self, state: &mut H) {
309        self.proc_mesh.hash(state);
310        self.name.hash(state);
311    }
312}
313
314impl<A: Referable> fmt::Debug for ActorMeshRef<A> {
315    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316        f.debug_struct("ActorMeshRef")
317            .field("proc_mesh", &self.proc_mesh)
318            .field("name", &self.name)
319            .field("page_size", &self.page_size)
320            .finish_non_exhaustive() // No print cache.
321    }
322}
323
324// Implement Serialize manually, without requiring A: Serialize
325impl<A: Referable> Serialize for ActorMeshRef<A> {
326    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
327    where
328        S: Serializer,
329    {
330        // Serialize only the fields that don't depend on A
331        (&self.proc_mesh, &self.name).serialize(serializer)
332    }
333}
334
335// Implement Deserialize manually, without requiring A: Deserialize
336impl<'de, A: Referable> Deserialize<'de> for ActorMeshRef<A> {
337    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
338    where
339        D: Deserializer<'de>,
340    {
341        let (proc_mesh, name) = <(ProcMeshRef, Name)>::deserialize(deserializer)?;
342        Ok(ActorMeshRef::with_page_size(name, proc_mesh, DEFAULT_PAGE))
343    }
344}
345
346impl<A: Referable> view::Ranked for ActorMeshRef<A> {
347    type Item = ActorRef<A>;
348
349    #[inline]
350    fn region(&self) -> &Region {
351        view::Ranked::region(&self.proc_mesh)
352    }
353
354    #[inline]
355    fn get(&self, rank: usize) -> Option<&Self::Item> {
356        self.materialize(rank)
357    }
358}
359
360impl<A: Referable> view::RankedSliceable for ActorMeshRef<A> {
361    fn sliced(&self, region: Region) -> Self {
362        debug_assert!(region.is_subset(view::Ranked::region(self)));
363        let proc_mesh = self.proc_mesh.subset(region).unwrap();
364        Self::with_page_size(self.name.clone(), proc_mesh, self.page_size)
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use std::assert_matches::assert_matches;
371    use std::collections::HashSet;
372
373    use hyperactor::actor::ActorStatus;
374    use hyperactor::clock::Clock;
375    use hyperactor::clock::RealClock;
376    use hyperactor::context::Mailbox as _;
377    use hyperactor::mailbox;
378    use ndslice::Extent;
379    use ndslice::ViewExt;
380    use ndslice::extent;
381    use ndslice::view::Ranked;
382    use timed_test::async_timed_test;
383    use tokio::time::Duration;
384
385    use super::ActorMesh;
386    use crate::proc_mesh::mesh_agent::ActorState;
387    use crate::resource;
388    use crate::v1::ActorMeshRef;
389    use crate::v1::Name;
390    use crate::v1::ProcMesh;
391    use crate::v1::testactor;
392    use crate::v1::testing;
393
394    #[tokio::test]
395    async fn test_actor_mesh_ref_lazy_materialization() {
396        // 1) Bring up procs and spawn actors.
397        let instance = testing::instance().await;
398        // Small mesh so the test runs fast, but > page_size so we
399        // cross a boundary
400        let extent = extent!(replicas = 3, hosts = 2); // 6 ranks
401        let pm: ProcMesh = testing::proc_meshes(instance, extent.clone())
402            .await
403            .into_iter()
404            .next()
405            .expect("at least one proc mesh");
406        let am: ActorMesh<testactor::TestActor> = pm.spawn(instance, "test", &()).await.unwrap();
407
408        // 2) Build our ActorMeshRef with a tiny page size (2) to
409        // force multiple pages:
410        // page 0: ranks [0,1], page 1: [2,3], page 2: [4,5]
411        let page_size = 2;
412        let amr: ActorMeshRef<testactor::TestActor> =
413            ActorMeshRef::with_page_size(am.name.clone(), pm.clone(), page_size);
414        assert_eq!(amr.extent(), extent);
415        assert_eq!(amr.region().num_ranks(), 6);
416
417        // 3) Within-rank pointer stability (OnceLock caches &ActorRef)
418        let p0_a = amr.get(0).expect("rank 0 exists") as *const _;
419        let p0_b = amr.get(0).expect("rank 0 exists") as *const _;
420        assert_eq!(p0_a, p0_b, "same rank should return same cached pointer");
421
422        // 4) Same page, different rank (both materialize fine)
423        let p1_a = amr.get(1).expect("rank 1 exists") as *const _;
424        let p1_b = amr.get(1).expect("rank 1 exists") as *const _;
425        assert_eq!(p1_a, p1_b, "same rank should return same cached pointer");
426        // They're different ranks, so the pointers are different
427        // (distinct OnceLocks in the page)
428        assert_ne!(p0_a, p1_a, "different ranks have different cache slots");
429
430        // 5) Cross a page boundary (rank 2 is in a different page than rank 0/1)
431        let p2_a = amr.get(2).expect("rank 2 exists") as *const _;
432        let p2_b = amr.get(2).expect("rank 2 exists") as *const _;
433        assert_eq!(p2_a, p2_b, "same rank should return same cached pointer");
434        assert_ne!(p0_a, p2_a, "different pages have different cache slots");
435
436        // 6) Clone should drop the cache but keep identity (actor_id)
437        let amr_clone = amr.clone();
438        let orig_id_0 = amr.get(0).unwrap().actor_id().clone();
439        let clone_id_0 = amr_clone.get(0).unwrap().actor_id().clone();
440        assert_eq!(orig_id_0, clone_id_0, "clone preserves identity");
441        let p0_clone = amr_clone.get(0).unwrap() as *const _;
442        assert_ne!(
443            p0_a, p0_clone,
444            "cloned ActorMeshRef has a fresh cache (different pointer)"
445        );
446
447        // 7) Slicing preserves page_size and clears cache
448        // (RankedSliceable::sliced)
449        let sliced = amr.range("replicas", 1..).expect("slice should be valid"); // leaves 4 ranks
450        assert_eq!(sliced.region().num_ranks(), 4);
451        // First access materializes a new cache for the sliced view.
452        let sp0_a = sliced.get(0).unwrap() as *const _;
453        let sp0_b = sliced.get(0).unwrap() as *const _;
454        assert_eq!(sp0_a, sp0_b, "sliced view has its own cache slot per rank");
455        // Cross-page inside the slice too (page_size = 2 => pages are
456        // [0..2), [2..4)).
457        let sp2 = sliced.get(2).unwrap() as *const _;
458        assert_ne!(sp0_a, sp2, "sliced view crosses its own page boundary");
459
460        // 8) Hash/Eq ignore cache state; identical identity collapses
461        // to one set entry.
462        let mut set = HashSet::new();
463        set.insert(amr.clone());
464        set.insert(amr.clone());
465        assert_eq!(set.len(), 1, "cache state must not affect Hash/Eq");
466
467        // 9) As a sanity check, cast to ensure the refs are indeed
468        // usable/live.
469        let (port, mut rx) = mailbox::open_port(instance);
470        // Send to rank 0 and rank 3 (extent 3x2 => at least 4 ranks
471        // exist).
472        amr.get(0)
473            .expect("rank 0 exists")
474            .send(instance, testactor::GetActorId(port.bind()))
475            .expect("send to rank 0 should succeed");
476        amr.get(3)
477            .expect("rank 3 exists")
478            .send(instance, testactor::GetActorId(port.bind()))
479            .expect("send to rank 3 should succeed");
480        let id_a = RealClock
481            .timeout(Duration::from_secs(3), rx.recv())
482            .await
483            .expect("timed out waiting for first reply")
484            .expect("channel closed before first reply");
485        let id_b = RealClock
486            .timeout(Duration::from_secs(3), rx.recv())
487            .await
488            .expect("timed out waiting for second reply")
489            .expect("channel closed before second reply");
490        assert_ne!(id_a, id_b, "two different ranks responded");
491    }
492
493    #[async_timed_test(timeout_secs = 30)]
494    async fn test_actor_states() {
495        hyperactor_telemetry::initialize_logging_for_test();
496
497        let instance = testing::instance().await;
498        // Listen for supervision events sent to the parent instance.
499        let (supervision_port, mut supervision_receiver) =
500            instance.open_port::<resource::State<ActorState>>();
501        let supervisor = supervision_port.bind();
502        let num_replicas = 4;
503        let meshes = testing::proc_meshes(instance, extent!(replicas = num_replicas)).await;
504        let proc_mesh = &meshes[1];
505        let child_name = Name::new("child");
506
507        let actor_mesh = proc_mesh
508            .spawn_with_name::<testactor::TestActor>(instance, child_name.clone(), &())
509            .await
510            .unwrap();
511
512        actor_mesh
513            .cast(
514                instance,
515                testactor::CauseSupervisionEvent(testactor::SupervisionEventType::Panic),
516            )
517            .unwrap();
518
519        // Wait for the casted message to cause a panic on all actors.
520        // We can't use a reply port because the handler for the message will
521        // by definition not complete and send a reply.
522        #[allow(clippy::disallowed_methods)]
523        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
524
525        // Now that all ranks have completed, set up a continuous poll of the
526        // status such that when a process switches to unhealthy it sets a
527        // supervision event.
528        let supervision_task = tokio::spawn(async move {
529            match actor_mesh.actor_states(&instance).await {
530                Ok(events) => {
531                    for state in events.values() {
532                        supervisor.send(instance, state.clone()).unwrap();
533                    }
534                }
535                Err(e) => {
536                    println!("error: {:?}", e);
537                }
538            };
539        });
540        // Make sure the task completes first without a panic.
541        supervision_task.await.unwrap();
542
543        for _ in 0..num_replicas {
544            let state = supervision_receiver.recv().await.unwrap();
545            if let resource::Status::Failed(s) = state.status {
546                assert!(s.contains("supervision events"));
547            } else {
548                panic!("Not failed: {:?}", state.status);
549            }
550            if let Some(ref inner) = state.state {
551                assert!(!inner.supervision_events.is_empty());
552                for event in &inner.supervision_events {
553                    println!("receiving event: {:?}", event);
554                    assert_eq!(event.actor_id.name(), format!("{}", child_name.clone()));
555                    assert_matches!(event.actor_status, ActorStatus::Failed(_));
556                }
557            }
558        }
559    }
560
561    #[async_timed_test(timeout_secs = 30)]
562    async fn test_cast() {
563        let config = hyperactor::config::global::lock();
564        let _guard = config.override_key(crate::bootstrap::MESH_BOOTSTRAP_ENABLE_PDEATHSIG, false);
565
566        let instance = testing::instance().await;
567        let host_mesh = testing::host_mesh(extent!(host = 4)).await;
568        let proc_mesh = host_mesh
569            .spawn(instance, "test", Extent::unity())
570            .await
571            .unwrap();
572        let actor_mesh = proc_mesh
573            .spawn::<testactor::TestActor>(instance, "test", &())
574            .await
575            .unwrap();
576
577        let (cast_info, mut cast_info_rx) = instance.mailbox().open_port();
578        actor_mesh
579            .cast(
580                instance,
581                testactor::GetCastInfo {
582                    cast_info: cast_info.bind(),
583                },
584            )
585            .unwrap();
586
587        let mut point_to_actor: HashSet<_> = actor_mesh.iter().collect();
588        while !point_to_actor.is_empty() {
589            let (point, origin_actor_ref, sender_actor_id) = cast_info_rx.recv().await.unwrap();
590            let key = (point, origin_actor_ref);
591            assert!(
592                point_to_actor.remove(&key),
593                "key {:?} not present or removed twice",
594                key
595            );
596            assert_eq!(&sender_actor_id, instance.self_id());
597        }
598
599        let _ = host_mesh.shutdown(&instance).await;
600    }
601}