hyperactor_mesh/
actor_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![allow(dead_code)] // until used publically
10
11use std::collections::BTreeSet;
12use std::ops::Deref;
13
14use async_trait::async_trait;
15use hyperactor::Actor;
16use hyperactor::ActorRef;
17use hyperactor::Bind;
18use hyperactor::GangId;
19use hyperactor::GangRef;
20use hyperactor::Message;
21use hyperactor::Named;
22use hyperactor::PortHandle;
23use hyperactor::RemoteHandles;
24use hyperactor::RemoteMessage;
25use hyperactor::Unbind;
26use hyperactor::WorldId;
27use hyperactor::actor::Referable;
28use hyperactor::attrs::Attrs;
29use hyperactor::attrs::declare_attrs;
30use hyperactor::config;
31use hyperactor::context;
32use hyperactor::mailbox::MailboxSenderError;
33use hyperactor::mailbox::PortReceiver;
34use hyperactor::message::Castable;
35use hyperactor::message::IndexedErasedUnbound;
36use hyperactor::supervision::ActorSupervisionEvent;
37use ndslice::Range;
38use ndslice::Selection;
39use ndslice::Shape;
40use ndslice::ShapeError;
41use ndslice::SliceError;
42use ndslice::reshape::Limit;
43use ndslice::reshape::ReshapeError;
44use ndslice::reshape::ReshapeSliceExt;
45use ndslice::reshape::reshape_selection;
46use ndslice::selection;
47use ndslice::selection::EvalOpts;
48use ndslice::selection::ReifySlice;
49use ndslice::selection::normal;
50use serde::Deserialize;
51use serde::Serialize;
52use serde_multipart::Part;
53use tokio::sync::mpsc;
54
55use crate::CommActor;
56use crate::Mesh;
57use crate::comm::multicast::CastMessage;
58use crate::comm::multicast::CastMessageEnvelope;
59use crate::comm::multicast::Uslice;
60use crate::config::MAX_CAST_DIMENSION_SIZE;
61use crate::metrics;
62use crate::proc_mesh::ProcMesh;
63use crate::reference::ActorMeshId;
64use crate::reference::ActorMeshRef;
65use crate::reference::ProcMeshId;
66
67declare_attrs! {
68    /// Which mesh this message was cast to. Used for undeliverable message
69    /// handling, where the CastMessageEnvelope is serialized, and its content
70    /// cannot be inspected.
71    pub attr CAST_ACTOR_MESH_ID: ActorMeshId;
72}
73
74/// Common implementation for `ActorMesh`s and `ActorMeshRef`s to cast
75/// an `M`-typed message
76#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
77pub(crate) fn actor_mesh_cast<A, M>(
78    cx: &impl context::Actor,
79    actor_mesh_id: ActorMeshId,
80    comm_actor_ref: &ActorRef<CommActor>,
81    selection_of_root: Selection,
82    root_mesh_shape: &Shape,
83    cast_mesh_shape: &Shape,
84    message: M,
85) -> Result<(), CastError>
86where
87    A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
88    M: Castable + RemoteMessage,
89{
90    let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
91        "message_type" => M::typename(),
92        "message_variant" => message.arm().unwrap_or_default(),
93    ));
94
95    let message = CastMessageEnvelope::new::<A, M>(
96        actor_mesh_id.clone(),
97        cx.mailbox().actor_id().clone(),
98        cast_mesh_shape.clone(),
99        message,
100    )?;
101
102    // Mesh's shape might have large extents on some dimensions. Those
103    // dimensions would cause large fanout in our comm actor
104    // implementation. To avoid that, we reshape it by increasing
105    // dimensionality and limiting the extent of each dimension. Note
106    // the reshape is only visible to the internal algorithm. The
107    // shape that user sees maintains intact.
108    //
109    // For example, a typical shape is [hosts=1024, gpus=8]. By using
110    // limit 8, it becomes [8, 8, 8, 2, 8] during casting. In other
111    // words, it adds 3 extra layers to the comm actor tree, while
112    // keeping the fanout in each layer per dimension at 8 or smaller.
113    //
114    // An important note here is that max dimension size != max fanout.
115    // Rank 0 must send a message to all ranks at index 0 for every dimension.
116    // If our reshaped shape is [8, 8, 8, 2, 8], rank 0 must send
117    // 7 + 7 + 7 + 1 + 7 = 21 messages.
118
119    let slice_of_root = root_mesh_shape.slice();
120
121    let max_cast_dimension_size = config::global::get(MAX_CAST_DIMENSION_SIZE);
122
123    let slice_of_cast = slice_of_root.reshape_with_limit(Limit::from(max_cast_dimension_size));
124
125    let selection_of_cast =
126        reshape_selection(selection_of_root, root_mesh_shape.slice(), &slice_of_cast)?;
127
128    let cast_message = CastMessage {
129        dest: Uslice {
130            slice: slice_of_cast,
131            selection: selection_of_cast,
132        },
133        message,
134    };
135
136    let mut headers = Attrs::new();
137    headers.set(CAST_ACTOR_MESH_ID, actor_mesh_id);
138
139    comm_actor_ref
140        .port()
141        .send_with_headers(cx, headers, cast_message)?;
142
143    Ok(())
144}
145
146#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
147pub(crate) fn cast_to_sliced_mesh<A, M>(
148    cx: &impl context::Actor,
149    actor_mesh_id: ActorMeshId,
150    comm_actor_ref: &ActorRef<CommActor>,
151    sel_of_sliced: &Selection,
152    message: M,
153    sliced_shape: &Shape,
154    root_mesh_shape: &Shape,
155) -> Result<(), CastError>
156where
157    A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
158    M: Castable + RemoteMessage,
159{
160    let root_slice = root_mesh_shape.slice();
161
162    // Casting to `*`?
163    let sel_of_root = if selection::normalize(sel_of_sliced) == normal::NormalizedSelection::True {
164        // Reify this view into base.
165        root_slice.reify_slice(sliced_shape.slice())?
166    } else {
167        // No, fall back on `of_ranks`.
168        let ranks = sel_of_sliced
169            .eval(&EvalOpts::strict(), sliced_shape.slice())?
170            .collect::<BTreeSet<_>>();
171        Selection::of_ranks(root_slice, &ranks)?
172    };
173
174    // Cast.
175    actor_mesh_cast::<A, M>(
176        cx,
177        actor_mesh_id,
178        comm_actor_ref,
179        sel_of_root,
180        root_mesh_shape,
181        sliced_shape,
182        message,
183    )
184}
185
186/// A mesh of actors, all of which reside on the same [`ProcMesh`].
187#[async_trait]
188pub trait ActorMesh: Mesh<Id = ActorMeshId> {
189    /// The type of actor in the mesh.
190    type Actor: Referable;
191
192    /// Cast an `M`-typed message to the ranks selected by `sel` in
193    /// this ActorMesh.
194    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
195    fn cast<M>(
196        &self,
197        cx: &impl context::Actor,
198        selection: Selection,
199        message: M,
200    ) -> Result<(), CastError>
201    where
202        Self::Actor: RemoteHandles<IndexedErasedUnbound<M>>,
203        M: Castable + RemoteMessage,
204    {
205        actor_mesh_cast::<Self::Actor, M>(
206            cx,                            // actor context
207            self.id(),                     // actor mesh id (destination mesh)
208            self.proc_mesh().comm_actor(), // comm actor
209            selection,                     // the selected actors
210            self.shape(),                  // root mesh shape
211            self.shape(),                  // cast mesh shape
212            message,                       // the message
213        )
214    }
215
216    /// The ProcMesh on top of which this actor mesh is spawned.
217    fn proc_mesh(&self) -> &ProcMesh;
218
219    /// The name given to the actors in this mesh.
220    fn name(&self) -> &str;
221
222    fn world_id(&self) -> &WorldId {
223        self.proc_mesh().world_id()
224    }
225
226    /// Iterate over all `ActorRef<Self::Actor>` in this mesh.
227    fn iter_actor_refs(&self) -> impl Iterator<Item = ActorRef<Self::Actor>> {
228        let gang: GangRef<Self::Actor> =
229            GangId(self.proc_mesh().world_id().clone(), self.name().to_string()).into();
230        self.shape().slice().iter().map(move |rank| gang.rank(rank))
231    }
232
233    async fn stop(&self) -> Result<(), anyhow::Error> {
234        self.proc_mesh().stop_actor_by_name(self.name()).await
235    }
236
237    /// Get a serializeable reference to this mesh similar to ActorHandle::bind
238    fn bind(&self) -> ActorMeshRef<Self::Actor> {
239        ActorMeshRef::attest(
240            ActorMeshId::V0(
241                ProcMeshId(self.world_id().to_string()),
242                self.name().to_string(),
243            ),
244            self.shape().clone(),
245            self.proc_mesh().comm_actor().clone(),
246        )
247    }
248}
249
250/// Abstracts over shared and borrowed references to a [`ProcMesh`].
251/// Given a shared ProcMesh, we can obtain a [`ActorMesh<'static, _>`]
252/// for it, useful when lifetime must be managed dynamically.
253enum ProcMeshRef<'a> {
254    /// The reference is shared without requiring a reference.
255    Shared(Box<dyn Deref<Target = ProcMesh> + Sync + Send>),
256    /// The reference is borrowed with a parameterized
257    /// lifetime.
258    Borrowed(&'a ProcMesh),
259}
260
261impl Deref for ProcMeshRef<'_> {
262    type Target = ProcMesh;
263
264    fn deref(&self) -> &Self::Target {
265        match self {
266            Self::Shared(p) => p,
267            Self::Borrowed(p) => p, // p: &ProcMesh
268        }
269    }
270}
271
272/// A mesh of actor instances. ActorMeshes are obtained by spawning an
273/// actor on a [`ProcMesh`].
274///
275/// Generic bound: `A: Referable` — this type hands out typed
276/// `ActorRef<A>` handles (see `ranks`), and `ActorRef` is only
277/// defined for `A: Referable`.
278pub struct RootActorMesh<'a, A: Referable> {
279    proc_mesh: ProcMeshRef<'a>,
280    name: String,
281    pub(crate) ranks: Vec<ActorRef<A>>, // temporary until we remove `ArcActorMesh`.
282    // The receiver of supervision events. It is None if it has been transferred to
283    // an actor event observer.
284    actor_supervision_rx: Option<mpsc::UnboundedReceiver<ActorSupervisionEvent>>,
285}
286
287impl<'a, A: Referable> RootActorMesh<'a, A> {
288    pub(crate) fn new(
289        proc_mesh: &'a ProcMesh,
290        name: String,
291        actor_supervision_rx: mpsc::UnboundedReceiver<ActorSupervisionEvent>,
292        ranks: Vec<ActorRef<A>>,
293    ) -> Self {
294        Self {
295            proc_mesh: ProcMeshRef::Borrowed(proc_mesh),
296            name,
297            ranks,
298            actor_supervision_rx: Some(actor_supervision_rx),
299        }
300    }
301
302    pub(crate) fn new_shared<D: Deref<Target = ProcMesh> + Send + Sync + 'static>(
303        proc_mesh: D,
304        name: String,
305        actor_supervision_rx: mpsc::UnboundedReceiver<ActorSupervisionEvent>,
306        ranks: Vec<ActorRef<A>>,
307    ) -> Self {
308        Self {
309            proc_mesh: ProcMeshRef::Shared(Box::new(proc_mesh)),
310            name,
311            ranks,
312            actor_supervision_rx: Some(actor_supervision_rx),
313        }
314    }
315
316    /// Open a port on this ActorMesh.
317    pub fn open_port<M: Message>(&self) -> (PortHandle<M>, PortReceiver<M>) {
318        self.proc_mesh.client().open_port()
319    }
320
321    /// An event stream of actor events. Each RootActorMesh can produce only one such
322    /// stream, returning None after the first call.
323    pub fn events(&mut self) -> Option<ActorSupervisionEvents> {
324        self.actor_supervision_rx
325            .take()
326            .map(|actor_supervision_rx| ActorSupervisionEvents {
327                actor_supervision_rx,
328                mesh_id: self.id(),
329            })
330    }
331}
332
333/// Supervision event stream for actor mesh. It emits actor supervision events.
334pub struct ActorSupervisionEvents {
335    // The receiver of supervision events from proc mesh.
336    actor_supervision_rx: mpsc::UnboundedReceiver<ActorSupervisionEvent>,
337    // The name of the actor mesh.
338    mesh_id: ActorMeshId,
339}
340
341impl ActorSupervisionEvents {
342    pub async fn next(&mut self) -> Option<ActorSupervisionEvent> {
343        let result = self.actor_supervision_rx.recv().await;
344        if result.is_none() {
345            tracing::info!(
346                "supervision stream for actor mesh {:?} was closed!",
347                self.mesh_id
348            );
349        }
350        result
351    }
352}
353
354#[async_trait]
355impl<'a, A: Referable> Mesh for RootActorMesh<'a, A> {
356    type Node = ActorRef<A>;
357    type Id = ActorMeshId;
358    type Sliced<'b>
359        = SlicedActorMesh<'b, A>
360    where
361        'a: 'b;
362
363    fn shape(&self) -> &Shape {
364        self.proc_mesh.shape()
365    }
366
367    fn select<R: Into<Range>>(
368        &self,
369        label: &str,
370        range: R,
371    ) -> Result<Self::Sliced<'_>, ShapeError> {
372        Ok(SlicedActorMesh(self, self.shape().select(label, range)?))
373    }
374
375    fn get(&self, rank: usize) -> Option<ActorRef<A>> {
376        self.ranks.get(rank).cloned()
377    }
378
379    fn id(&self) -> Self::Id {
380        ActorMeshId::V0(self.proc_mesh.id(), self.name.clone())
381    }
382}
383
384impl<A: Referable> ActorMesh for RootActorMesh<'_, A> {
385    type Actor = A;
386
387    fn proc_mesh(&self) -> &ProcMesh {
388        &self.proc_mesh
389    }
390
391    fn name(&self) -> &str {
392        &self.name
393    }
394}
395
396pub struct SlicedActorMesh<'a, A: Referable>(&'a RootActorMesh<'a, A>, Shape);
397
398impl<'a, A: Referable> SlicedActorMesh<'a, A> {
399    pub fn new(actor_mesh: &'a RootActorMesh<'a, A>, shape: Shape) -> Self {
400        Self(actor_mesh, shape)
401    }
402
403    pub fn shape(&self) -> &Shape {
404        &self.1
405    }
406}
407
408#[async_trait]
409impl<A: Referable> Mesh for SlicedActorMesh<'_, A> {
410    type Node = ActorRef<A>;
411    type Id = ActorMeshId;
412    type Sliced<'b>
413        = SlicedActorMesh<'b, A>
414    where
415        Self: 'b;
416
417    fn shape(&self) -> &Shape {
418        &self.1
419    }
420
421    fn select<R: Into<Range>>(
422        &self,
423        label: &str,
424        range: R,
425    ) -> Result<Self::Sliced<'_>, ShapeError> {
426        Ok(Self(self.0, self.1.select(label, range)?))
427    }
428
429    fn get(&self, _index: usize) -> Option<ActorRef<A>> {
430        unimplemented!()
431    }
432
433    fn id(&self) -> Self::Id {
434        self.0.id()
435    }
436}
437
438impl<A: Referable> ActorMesh for SlicedActorMesh<'_, A> {
439    type Actor = A;
440
441    fn proc_mesh(&self) -> &ProcMesh {
442        &self.0.proc_mesh
443    }
444
445    fn name(&self) -> &str {
446        &self.0.name
447    }
448
449    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
450    fn cast<M>(&self, cx: &impl context::Actor, sel: Selection, message: M) -> Result<(), CastError>
451    where
452        Self::Actor: RemoteHandles<IndexedErasedUnbound<M>>,
453        M: Castable + RemoteMessage,
454    {
455        cast_to_sliced_mesh::<A, M>(
456            /*cx=*/ cx,
457            /*actor_mesh_id=*/ self.id(),
458            /*comm_actor_ref*/ self.proc_mesh().comm_actor(),
459            /*sel_of_sliced=*/ &sel,
460            /*message=*/ message,
461            /*sliced_shape=*/ self.shape(),
462            /*root_mesh_shape=*/ self.0.shape(),
463        )
464    }
465}
466
467/// The type of error of casting operations.
468#[derive(Debug, thiserror::Error)]
469pub enum CastError {
470    #[error("invalid selection {0}: {1}")]
471    InvalidSelection(Selection, ShapeError),
472
473    #[error("send on rank {0}: {1}")]
474    MailboxSenderError(usize, MailboxSenderError),
475
476    #[error(transparent)]
477    RootMailboxSenderError(#[from] MailboxSenderError),
478
479    #[error(transparent)]
480    ShapeError(#[from] ShapeError),
481
482    #[error(transparent)]
483    SliceError(#[from] SliceError),
484
485    #[error(transparent)]
486    SerializationError(#[from] bincode::Error),
487
488    #[error(transparent)]
489    Other(#[from] anyhow::Error),
490
491    #[error(transparent)]
492    ReshapeError(#[from] ReshapeError),
493}
494
495// This has to be compiled outside of test mode because the bootstrap binary
496// is not built in test mode, and requires access to TestActor.
497pub(crate) mod test_util {
498    use std::collections::VecDeque;
499    use std::fmt;
500    use std::fmt::Debug;
501    use std::sync::Arc;
502
503    use anyhow::ensure;
504    use hyperactor::Context;
505    use hyperactor::Handler;
506    use hyperactor::PortRef;
507    use ndslice::extent;
508
509    use super::*;
510    use crate::comm::multicast::CastInfo;
511
512    // This can't be defined under a `#[cfg(test)]` because there needs to
513    // be an entry in the spawnable actor registry in the executable
514    // 'hyperactor_mesh_test_bootstrap' for the `tests::process` actor
515    // mesh test suite.
516    #[derive(Debug, Default, Actor)]
517    #[hyperactor::export(
518        spawn = true,
519        handlers = [
520            Echo { cast = true },
521            Payload { cast = true },
522            GetRank { cast = true },
523            Error { cast = true },
524            Relay,
525        ],
526    )]
527    pub struct TestActor;
528
529    /// Request message to retrieve the actor's rank.
530    ///
531    /// The `bool` in the tuple controls the outcome of the handler:
532    /// - If `true`, the handler will send the rank and return
533    ///   `Ok(())`.
534    /// - If `false`, the handler will still send the rank, but return
535    ///   an error (`Err(...)`).
536    ///
537    /// This is useful for testing both successful and failing
538    /// responses from a single message type.
539    #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
540    pub struct GetRank(pub bool, #[binding(include)] pub PortRef<usize>);
541
542    #[async_trait]
543    impl Handler<GetRank> for TestActor {
544        async fn handle(
545            &mut self,
546            cx: &Context<Self>,
547            GetRank(ok, reply): GetRank,
548        ) -> Result<(), anyhow::Error> {
549            let point = cx.cast_point();
550            reply.send(cx, point.rank())?;
551            anyhow::ensure!(ok, "intentional error!"); // If `!ok` exit with `Err()`.
552            Ok(())
553        }
554    }
555
556    #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
557    pub struct Echo(pub String, #[binding(include)] pub PortRef<String>);
558
559    #[async_trait]
560    impl Handler<Echo> for TestActor {
561        async fn handle(&mut self, cx: &Context<Self>, message: Echo) -> Result<(), anyhow::Error> {
562            let Echo(message, reply_port) = message;
563            reply_port.send(cx, message)?;
564            Ok(())
565        }
566    }
567
568    #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
569    pub struct Payload {
570        pub part: Part,
571        #[binding(include)]
572        pub reply_port: PortRef<()>,
573    }
574
575    #[async_trait]
576    impl Handler<Payload> for TestActor {
577        async fn handle(
578            &mut self,
579            cx: &Context<Self>,
580            message: Payload,
581        ) -> Result<(), anyhow::Error> {
582            let Payload { reply_port, .. } = message;
583            reply_port.send(cx, ())?;
584            Ok(())
585        }
586    }
587
588    #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
589    pub struct Error(pub String);
590
591    #[async_trait]
592    impl Handler<Error> for TestActor {
593        async fn handle(
594            &mut self,
595            _cx: &Context<Self>,
596            Error(error): Error,
597        ) -> Result<(), anyhow::Error> {
598            Err(anyhow::anyhow!("{}", error))
599        }
600    }
601
602    #[derive(Debug, Serialize, Deserialize, Named, Clone)]
603    pub struct Relay(pub usize, pub VecDeque<PortRef<Relay>>);
604
605    #[async_trait]
606    impl Handler<Relay> for TestActor {
607        async fn handle(
608            &mut self,
609            cx: &Context<Self>,
610            Relay(count, mut hops): Relay,
611        ) -> Result<(), anyhow::Error> {
612            ensure!(!hops.is_empty(), "relay must have at least one hop");
613            let next = hops.pop_front().unwrap();
614            next.send(cx, Relay(count + 1, hops))?;
615            Ok(())
616        }
617    }
618
619    // -- ProxyActor
620
621    #[hyperactor::export(
622        spawn = true,
623        handlers = [
624            Echo,
625        ],
626    )]
627    pub struct ProxyActor {
628        proc_mesh: Arc<ProcMesh>,
629        actor_mesh: RootActorMesh<'static, TestActor>,
630    }
631
632    impl fmt::Debug for ProxyActor {
633        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
634            f.debug_struct("ProxyActor")
635                .field("proc_mesh", &"...")
636                .field("actor_mesh", &"...")
637                .finish()
638        }
639    }
640
641    #[async_trait]
642    impl Actor for ProxyActor {
643        type Params = ();
644
645        async fn new(_params: Self::Params) -> Result<Self, anyhow::Error> {
646            // The actor creates a mesh.
647            use std::sync::Arc;
648
649            use hyperactor::channel::ChannelTransport;
650
651            use crate::alloc::AllocSpec;
652            use crate::alloc::Allocator;
653            use crate::alloc::LocalAllocator;
654
655            let mut allocator = LocalAllocator;
656            let alloc = allocator
657                .allocate(AllocSpec {
658                    extent: extent! { replica = 1 },
659                    constraints: Default::default(),
660                    proc_name: None,
661                    transport: ChannelTransport::Local,
662                })
663                .await
664                .unwrap();
665            let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap());
666            let leaked: &'static Arc<ProcMesh> = Box::leak(Box::new(proc_mesh));
667            let actor_mesh: RootActorMesh<'static, TestActor> =
668                leaked.spawn("echo", &()).await.unwrap();
669            Ok(Self {
670                proc_mesh: Arc::clone(leaked),
671                actor_mesh,
672            })
673        }
674    }
675
676    #[async_trait]
677    impl Handler<Echo> for ProxyActor {
678        async fn handle(&mut self, cx: &Context<Self>, message: Echo) -> Result<(), anyhow::Error> {
679            if std::env::var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK").is_err() {
680                // test_proxy_mesh
681
682                let actor = self.actor_mesh.get(0).unwrap();
683
684                // For now, we reply directly to the client.
685                // We will support directly wiring up the meshes later.
686                let (tx, mut rx) = cx.open_port();
687
688                actor.send(cx, Echo(message.0, tx.bind()))?;
689                message.1.send(cx, rx.recv().await.unwrap())?;
690
691                Ok(())
692            } else {
693                // test_router_undeliverable_return
694
695                let actor: ActorRef<_> = self.actor_mesh.get(0).unwrap();
696                let (tx, mut rx) = cx.open_port::<String>();
697                actor.send(cx, Echo(message.0, tx.bind()))?;
698
699                use tokio::time::Duration;
700                use tokio::time::timeout;
701                #[allow(clippy::disallowed_methods)]
702                if let Ok(_) = timeout(Duration::from_secs(1), rx.recv()).await {
703                    message
704                        .1
705                        .send(cx, "the impossible happened".to_owned())
706                        .unwrap()
707                }
708
709                Ok(())
710            }
711        }
712    }
713}
714
715#[cfg(test)]
716mod tests {
717    use std::sync::Arc;
718
719    use hyperactor::ActorId;
720    use hyperactor::PortRef;
721    use hyperactor::ProcId;
722    use hyperactor::WorldId;
723    use hyperactor::attrs::Attrs;
724    use hyperactor::data::Encoding;
725    use timed_test::async_timed_test;
726
727    use super::*;
728    use crate::proc_mesh::ProcEvent;
729
730    // These tests are parametric over allocators.
731    #[macro_export]
732    macro_rules! actor_mesh_test_suite {
733        ($allocator:expr) => {
734            use std::assert_matches::assert_matches;
735
736            use ndslice::extent;
737            use $crate::alloc::AllocSpec;
738            use $crate::alloc::Allocator;
739            use $crate::assign::Ranks;
740            use $crate::sel_from_shape;
741            use $crate::sel;
742            use $crate::comm::multicast::set_cast_info_on_headers;
743            use $crate::proc_mesh::SharedSpawnable;
744            use std::collections::VecDeque;
745            use hyperactor::data::Serialized;
746            use $crate::proc_mesh::default_transport;
747
748            use super::*;
749            use super::test_util::*;
750
751            #[tokio::test]
752            async fn test_proxy_mesh() {
753                use super::test_util::*;
754                use $crate::alloc::AllocSpec;
755                use $crate::alloc::Allocator;
756
757                use ndslice::extent;
758
759                let alloc = $allocator
760                    .allocate(AllocSpec {
761                        extent: extent! { replica = 1 },
762                        constraints: Default::default(),
763                        proc_name: None,
764                        transport: default_transport()
765                    })
766                    .await
767                    .unwrap();
768                let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
769                let actor_mesh: RootActorMesh<'_, ProxyActor> = proc_mesh.spawn("proxy", &()).await.unwrap();
770                let proxy_actor = actor_mesh.get(0).unwrap();
771                let (tx, mut rx) = actor_mesh.open_port::<String>();
772                proxy_actor.send(proc_mesh.client(), Echo("hello!".to_owned(), tx.bind())).unwrap();
773
774                #[allow(clippy::disallowed_methods)]
775                match tokio::time::timeout(tokio::time::Duration::from_secs(3), rx.recv()).await {
776                    Ok(msg) => assert_eq!(&msg.unwrap(), "hello!"),
777                    Err(_) =>  assert!(false),
778                }
779            }
780
781            #[tokio::test]
782            async fn test_basic() {
783                let alloc = $allocator
784                    .allocate(AllocSpec {
785                        extent: extent!(replica = 4),
786                        constraints: Default::default(),
787                        proc_name: None,
788                        transport: default_transport()
789                    })
790                    .await
791                    .unwrap();
792
793                let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
794                let actor_mesh: RootActorMesh<TestActor> = proc_mesh.spawn("echo", &()).await.unwrap();
795                let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
796                actor_mesh
797                    .cast(proc_mesh.client(), sel!(*), Echo("Hello".to_string(), reply_handle.bind()))
798                    .unwrap();
799                for _ in 0..4 {
800                    assert_eq!(&reply_receiver.recv().await.unwrap(), "Hello");
801                }
802            }
803
804            #[tokio::test]
805            async fn test_ping_pong() {
806                use hyperactor::test_utils::pingpong::PingPongActor;
807                use hyperactor::test_utils::pingpong::PingPongMessage;
808                use hyperactor::test_utils::pingpong::PingPongActorParams;
809
810                let alloc = $allocator
811                    .allocate(AllocSpec {
812                        extent: extent!(replica = 2),
813                        constraints: Default::default(),
814                        proc_name: None,
815                        transport: default_transport(),
816                    })
817                    .await
818                    .unwrap();
819                let mesh = ProcMesh::allocate(alloc).await.unwrap();
820
821                let (undeliverable_msg_tx, _) = mesh.client().open_port();
822                let ping_pong_actor_params = PingPongActorParams::new(Some(undeliverable_msg_tx.bind()), None);
823                let actor_mesh: RootActorMesh<PingPongActor> = mesh
824                    .spawn::<PingPongActor>("ping-pong", &ping_pong_actor_params)
825                    .await
826                    .unwrap();
827
828                let ping: ActorRef<PingPongActor> = actor_mesh.get(0).unwrap();
829                let pong: ActorRef<PingPongActor> = actor_mesh.get(1).unwrap();
830                let (done_tx, done_rx) = mesh.client().open_once_port();
831                ping.send(mesh.client(), PingPongMessage(4, pong.clone(), done_tx.bind())).unwrap();
832
833                assert!(done_rx.recv().await.unwrap());
834            }
835
836            #[tokio::test]
837            async fn test_pingpong_full_mesh() {
838                use hyperactor::test_utils::pingpong::PingPongActor;
839                use hyperactor::test_utils::pingpong::PingPongActorParams;
840                use hyperactor::test_utils::pingpong::PingPongMessage;
841
842                use futures::future::join_all;
843
844                const X: usize = 3;
845                const Y: usize = 3;
846                const Z: usize = 3;
847                let alloc = $allocator
848                    .allocate(AllocSpec {
849                        extent: extent!(x = X, y = Y, z = Z),
850                        constraints: Default::default(),
851                        proc_name: None,
852                        transport: default_transport(),
853                    })
854                    .await
855                    .unwrap();
856
857                let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
858                let (undeliverable_tx, _undeliverable_rx) = proc_mesh.client().open_port();
859                let params = PingPongActorParams::new(Some(undeliverable_tx.bind()), None);
860                let actor_mesh = proc_mesh.spawn::<PingPongActor>("pingpong", &params).await.unwrap();
861                let slice = actor_mesh.shape().slice();
862
863                let mut futures = Vec::new();
864                for rank in slice.iter() {
865                    let actor = actor_mesh.get(rank).unwrap();
866                    let coords = (&slice.coordinates(rank).unwrap()[..]).try_into().unwrap();
867                    let sizes = (&slice.sizes())[..].try_into().unwrap();
868                    let neighbors = ndslice::utils::stencil::moore_neighbors::<3>();
869                    for neighbor_coords in ndslice::utils::apply_stencil(&coords, sizes, &neighbors) {
870                        if let Ok(neighbor_rank) = slice.location(&neighbor_coords) {
871                            let neighbor = actor_mesh.get(neighbor_rank).unwrap();
872                            let (done_tx, done_rx) = proc_mesh.client().open_once_port();
873                            actor
874                                .send(
875                                    proc_mesh.client(),
876                                    PingPongMessage(4, neighbor.clone(), done_tx.bind()),
877                                )
878                                .unwrap();
879                            futures.push(done_rx.recv());
880                        }
881                    }
882                }
883                let results = join_all(futures).await;
884                assert_eq!(results.len(), 316); // 5180 messages
885                for result in results {
886                    assert_eq!(result.unwrap(), true);
887                }
888            }
889
890            #[tokio::test]
891            async fn test_cast() {
892                let alloc = $allocator
893                    .allocate(AllocSpec {
894                        extent: extent!(replica = 2, host = 2, gpu = 8),
895                        constraints: Default::default(),
896                        proc_name: None,
897                        transport: default_transport(),
898                    })
899                    .await
900                    .unwrap();
901
902                let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
903                let actor_mesh: RootActorMesh<TestActor> = proc_mesh.spawn("echo", &()).await.unwrap();
904                let dont_simulate_error = true;
905                let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
906                actor_mesh
907                    .cast(proc_mesh.client(), sel!(*), GetRank(dont_simulate_error, reply_handle.bind()))
908                    .unwrap();
909                let mut ranks = Ranks::new(actor_mesh.shape().slice().len());
910                while !ranks.is_full() {
911                    let rank = reply_receiver.recv().await.unwrap();
912                    assert!(ranks.insert(rank, rank).is_none(), "duplicate rank {rank}");
913                }
914                // Retrieve all GPUs on replica=0, host=0
915                let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
916                actor_mesh
917                    .cast(
918                        proc_mesh.client(),
919                        sel_from_shape!(actor_mesh.shape(), replica = 0, host = 0),
920                        GetRank(dont_simulate_error, reply_handle.bind()),
921                    )
922                    .unwrap();
923                let mut ranks = Ranks::new(8);
924                while !ranks.is_full() {
925                    let rank = reply_receiver.recv().await.unwrap();
926                    assert!(ranks.insert(rank, rank).is_none(), "duplicate rank {rank}");
927                }
928            }
929
930            #[tokio::test]
931            async fn test_inter_actor_comms() {
932                let alloc = $allocator
933                    .allocate(AllocSpec {
934                        // Sizes intentionally small to keep the time
935                        // required for this test in the process case
936                        // reasonable (< 60s).
937                        extent: extent!(replica = 2, host = 2, gpu = 8),
938                        constraints: Default::default(),
939                        proc_name: None,
940                        transport: default_transport(),
941                    })
942                    .await
943                    .unwrap();
944
945                let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
946                let actor_mesh: RootActorMesh<TestActor> = proc_mesh.spawn("echo", &()).await.unwrap();
947
948                // Bounce the message through all actors and return it to the sender (us).
949                let mut hops: VecDeque<_> = actor_mesh.iter().map(|actor| actor.port()).collect();
950                let (handle, mut rx) = proc_mesh.client().open_port();
951                hops.push_back(handle.bind());
952                hops.pop_front()
953                    .unwrap()
954                    .send(proc_mesh.client(), Relay(0, hops))
955                    .unwrap();
956                assert_matches!(
957                    rx.recv().await.unwrap(),
958                    Relay(count, hops)
959                        if count == actor_mesh.shape().slice().len()
960                        && hops.is_empty());
961            }
962
963            #[tokio::test]
964            async fn test_inter_proc_mesh_comms() {
965                let mut meshes = Vec::new();
966                for _ in 0..2 {
967                    let alloc = $allocator
968                        .allocate(AllocSpec {
969                            extent: extent!(replica = 1),
970                            constraints: Default::default(),
971                            proc_name: None,
972                            transport: default_transport(),
973                        })
974                        .await
975                        .unwrap();
976
977                    let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap());
978                    let proc_mesh_clone = Arc::clone(&proc_mesh);
979                    let actor_mesh : RootActorMesh<TestActor> = proc_mesh_clone.spawn("echo", &()).await.unwrap();
980                    meshes.push((proc_mesh, actor_mesh));
981                }
982
983                let mut hops: VecDeque<_> = meshes
984                    .iter()
985                    .flat_map(|(_proc_mesh, actor_mesh)| actor_mesh.iter())
986                    .map(|actor| actor.port())
987                    .collect();
988                let num_hops = hops.len();
989
990                let client = meshes[0].0.client();
991                let (handle, mut rx) = client.open_port();
992                hops.push_back(handle.bind());
993                hops.pop_front()
994                    .unwrap()
995                    .send(client, Relay(0, hops))
996                    .unwrap();
997                assert_matches!(
998                    rx.recv().await.unwrap(),
999                    Relay(count, hops)
1000                        if count == num_hops
1001                        && hops.is_empty());
1002            }
1003
1004            #[async_timed_test(timeout_secs = 60)]
1005            async fn test_actor_mesh_cast() {
1006                // Verify a full broadcast in the mesh. Send a message
1007                // to every actor and check each actor receives it.
1008
1009                use $crate::sel;
1010                use $crate::comm::test_utils::TestActor as CastTestActor;
1011                use $crate::comm::test_utils::TestActorParams as CastTestActorParams;
1012                use $crate::comm::test_utils::TestMessage as CastTestMessage;
1013
1014                let extent = extent!(replica = 4, host = 4, gpu = 4);
1015                let num_actors = extent.len();
1016                let alloc = $allocator
1017                    .allocate(AllocSpec {
1018                        extent,
1019                        constraints: Default::default(),
1020                        proc_name: None,
1021                        transport: default_transport(),
1022                    })
1023                    .await
1024                    .unwrap();
1025
1026                let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1027
1028                let (tx, mut rx) = hyperactor::mailbox::open_port(proc_mesh.client());
1029                let params = CastTestActorParams{ forward_port: tx.bind() };
1030                let actor_mesh: RootActorMesh<CastTestActor> = proc_mesh.spawn("actor", &params).await.unwrap();
1031
1032                actor_mesh.cast(proc_mesh.client(), sel!(*), CastTestMessage::Forward("abc".to_string())).unwrap();
1033
1034                for _ in 0..num_actors {
1035                    assert_eq!(rx.recv().await.unwrap(), CastTestMessage::Forward("abc".to_string()));
1036                }
1037
1038                // Attempt to avoid this intermittent fatal error.
1039                // ⚠ Fatal: monarch/hyperactor_mesh:hyperactor_mesh-unittest - \
1040                //            actor_mesh::tests::sim::test_actor_mesh_cast (2.5s)
1041                // Test appears to have passed but the binary exited with a non-zero exit code.
1042                proc_mesh.events().unwrap().into_alloc().stop_and_wait().await.unwrap();
1043            }
1044
1045            #[tokio::test]
1046            async fn test_delivery_failure() {
1047                let alloc = $allocator
1048                    .allocate(AllocSpec {
1049                        extent: extent!(replica = 1 ),
1050                        constraints: Default::default(),
1051                        proc_name: None,
1052                        transport: default_transport(),
1053                    })
1054                    .await
1055                    .unwrap();
1056
1057                let name = alloc.name().to_string();
1058                let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
1059                let mut events = mesh.events().unwrap();
1060
1061                // Send a message to a non-existent actor (the proc however exists).
1062                let unmonitored_reply_to = mesh.client().open_port::<usize>().0.bind();
1063                let bad_actor = ActorRef::<TestActor>::attest(ActorId(ProcId::Ranked(WorldId(name.clone()), 0), "foo".into(), 0));
1064                bad_actor.send(mesh.client(), GetRank(true, unmonitored_reply_to)).unwrap();
1065
1066                // The message will be returned!
1067                assert_matches!(
1068                    events.next().await.unwrap(),
1069                    ProcEvent::Crashed(0, reason) if reason.contains("failed: message not delivered")
1070                );
1071
1072                // TODO: Stop the proc.
1073            }
1074
1075            #[tokio::test]
1076            async fn test_send_with_headers() {
1077                let extent = extent!(replica = 3);
1078                let alloc = $allocator
1079                    .allocate(AllocSpec {
1080                        extent: extent.clone(),
1081                        constraints: Default::default(),
1082                        proc_name: None,
1083                        transport: default_transport(),
1084                    })
1085                    .await
1086                    .unwrap();
1087
1088                let mesh = ProcMesh::allocate(alloc).await.unwrap();
1089                let (reply_port_handle, mut reply_port_receiver) = mesh.client().open_port::<usize>();
1090                let reply_port = reply_port_handle.bind();
1091
1092                let actor_mesh: RootActorMesh<TestActor> = mesh.spawn("test", &()).await.unwrap();
1093                let actor_ref = actor_mesh.get(0).unwrap();
1094                let mut headers = Attrs::new();
1095                set_cast_info_on_headers(&mut headers, extent.point_of_rank(0).unwrap(), mesh.client().self_id().clone());
1096                actor_ref.send_with_headers(mesh.client(), headers.clone(), GetRank(true, reply_port.clone())).unwrap();
1097                assert_eq!(0, reply_port_receiver.recv().await.unwrap());
1098
1099                set_cast_info_on_headers(&mut headers, extent.point_of_rank(1).unwrap(), mesh.client().self_id().clone());
1100                actor_ref.port()
1101                    .send_with_headers(mesh.client(), headers.clone(), GetRank(true, reply_port.clone()))
1102                    .unwrap();
1103                assert_eq!(1, reply_port_receiver.recv().await.unwrap());
1104
1105                set_cast_info_on_headers(&mut headers, extent.point_of_rank(2).unwrap(), mesh.client().self_id().clone());
1106                actor_ref.actor_id()
1107                    .port_id(GetRank::port())
1108                    .send_with_headers(
1109                        mesh.client(),
1110                        Serialized::serialize(&GetRank(true, reply_port)).unwrap(),
1111                        headers
1112                    );
1113                assert_eq!(2, reply_port_receiver.recv().await.unwrap());
1114                // TODO: Stop the proc.
1115            }
1116        }
1117    }
1118
1119    mod local {
1120        use hyperactor::channel::ChannelTransport;
1121
1122        use crate::alloc::local::LocalAllocator;
1123
1124        actor_mesh_test_suite!(LocalAllocator);
1125
1126        #[tokio::test]
1127        async fn test_send_failure() {
1128            hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default());
1129
1130            use hyperactor::test_utils::pingpong::PingPongActor;
1131            use hyperactor::test_utils::pingpong::PingPongActorParams;
1132            use hyperactor::test_utils::pingpong::PingPongMessage;
1133
1134            use crate::alloc::ProcStopReason;
1135            use crate::proc_mesh::ProcEvent;
1136
1137            let config = hyperactor::config::global::lock();
1138            let _guard = config.override_key(
1139                hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1140                tokio::time::Duration::from_secs(1),
1141            );
1142
1143            let alloc = LocalAllocator
1144                .allocate(AllocSpec {
1145                    extent: extent!(replica = 2),
1146                    constraints: Default::default(),
1147                    proc_name: None,
1148                    transport: ChannelTransport::Local,
1149                })
1150                .await
1151                .unwrap();
1152            let monkey = alloc.chaos_monkey();
1153            let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
1154            let mut events = mesh.events().unwrap();
1155
1156            let ping_pong_actor_params = PingPongActorParams::new(
1157                Some(PortRef::attest_message_port(mesh.client().self_id())),
1158                None,
1159            );
1160            let actor_mesh: RootActorMesh<PingPongActor> = mesh
1161                .spawn::<PingPongActor>("ping-pong", &ping_pong_actor_params)
1162                .await
1163                .unwrap();
1164
1165            let ping: ActorRef<PingPongActor> = actor_mesh.get(0).unwrap();
1166            let pong: ActorRef<PingPongActor> = actor_mesh.get(1).unwrap();
1167
1168            // Kill ping.
1169            monkey(0, ProcStopReason::Killed(0, false));
1170            assert_matches!(
1171                events.next().await.unwrap(),
1172                ProcEvent::Stopped(0, ProcStopReason::Killed(0, false))
1173            );
1174
1175            // Try to send a message to 'ping'. Since 'ping's mailbox
1176            // is stopped, the send will timeout and fail.
1177            let (unmonitored_done_tx, _) = mesh.client().open_once_port();
1178            ping.send(
1179                mesh.client(),
1180                PingPongMessage(1, pong.clone(), unmonitored_done_tx.bind()),
1181            )
1182            .unwrap();
1183
1184            // The message will be returned!
1185            assert_matches!(
1186                events.next().await.unwrap(),
1187                ProcEvent::Crashed(0, reason) if reason.contains("failed: message not delivered")
1188            );
1189
1190            // Get 'pong' to send 'ping' a message. Since 'ping's
1191            // mailbox is stopped, the send will timeout and fail.
1192            let (unmonitored_done_tx, _) = mesh.client().open_once_port();
1193            pong.send(
1194                mesh.client(),
1195                PingPongMessage(1, ping.clone(), unmonitored_done_tx.bind()),
1196            )
1197            .unwrap();
1198
1199            // The message will be returned!
1200            assert_matches!(
1201                events.next().await.unwrap(),
1202                ProcEvent::Crashed(0, reason) if reason.contains("failed: message not delivered")
1203            );
1204        }
1205
1206        #[tokio::test]
1207        async fn test_cast_failure() {
1208            use crate::alloc::ProcStopReason;
1209            use crate::proc_mesh::ProcEvent;
1210            use crate::sel;
1211
1212            let alloc = LocalAllocator
1213                .allocate(AllocSpec {
1214                    extent: extent!(replica = 1),
1215                    constraints: Default::default(),
1216                    proc_name: None,
1217                    transport: ChannelTransport::Local,
1218                })
1219                .await
1220                .unwrap();
1221
1222            let stop = alloc.stopper();
1223            let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
1224            let mut events = mesh.events().unwrap();
1225
1226            let actor_mesh = mesh
1227                .spawn::<TestActor>("reply-then-fail", &())
1228                .await
1229                .unwrap();
1230
1231            // `GetRank` with `false` means exit with error after
1232            // replying with rank.
1233            let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
1234            actor_mesh
1235                .cast(mesh.client(), sel!(*), GetRank(false, reply_handle.bind()))
1236                .unwrap();
1237            let rank = reply_receiver.recv().await.unwrap();
1238            assert_eq!(rank, 0);
1239
1240            // The above is expected to trigger a proc crash.
1241            assert_matches!(
1242                events.next().await.unwrap(),
1243                ProcEvent::Crashed(0, reason) if reason.contains("intentional error!")
1244            );
1245
1246            // Cast the message.
1247            let (reply_handle, _) = actor_mesh.open_port();
1248            actor_mesh
1249                .cast(mesh.client(), sel!(*), GetRank(false, reply_handle.bind()))
1250                .unwrap();
1251
1252            // The message will be returned!
1253            assert_matches!(
1254                events.next().await.unwrap(),
1255                ProcEvent::Crashed(0, reason) if reason.contains("failed: message not delivered")
1256            );
1257
1258            // Stop the mesh.
1259            stop();
1260            assert_matches!(
1261                events.next().await.unwrap(),
1262                ProcEvent::Stopped(0, ProcStopReason::Stopped),
1263            );
1264            assert!(events.next().await.is_none());
1265        }
1266
1267        #[tracing_test::traced_test]
1268        #[tokio::test]
1269        async fn test_stop_actor_mesh() {
1270            use hyperactor::test_utils::pingpong::PingPongActor;
1271            use hyperactor::test_utils::pingpong::PingPongActorParams;
1272            use hyperactor::test_utils::pingpong::PingPongMessage;
1273
1274            let config = hyperactor::config::global::lock();
1275            let _guard = config.override_key(
1276                hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1277                tokio::time::Duration::from_secs(1),
1278            );
1279
1280            let alloc = LocalAllocator
1281                .allocate(AllocSpec {
1282                    extent: extent!(replica = 2),
1283                    constraints: Default::default(),
1284                    proc_name: None,
1285                    transport: ChannelTransport::Local,
1286                })
1287                .await
1288                .unwrap();
1289            let mesh = ProcMesh::allocate(alloc).await.unwrap();
1290
1291            let ping_pong_actor_params = PingPongActorParams::new(
1292                Some(PortRef::attest_message_port(mesh.client().self_id())),
1293                None,
1294            );
1295            let mesh_one: RootActorMesh<PingPongActor> = mesh
1296                .spawn::<PingPongActor>("mesh_one", &ping_pong_actor_params)
1297                .await
1298                .unwrap();
1299
1300            let mesh_two: RootActorMesh<PingPongActor> = mesh
1301                .spawn::<PingPongActor>("mesh_two", &ping_pong_actor_params)
1302                .await
1303                .unwrap();
1304
1305            mesh_two.stop().await.unwrap();
1306
1307            let ping_two: ActorRef<PingPongActor> = mesh_two.get(0).unwrap();
1308            let pong_two: ActorRef<PingPongActor> = mesh_two.get(1).unwrap();
1309
1310            assert!(logs_contain(&format!(
1311                "stopped actor {}",
1312                ping_two.actor_id()
1313            )));
1314            assert!(logs_contain(&format!(
1315                "stopped actor {}",
1316                pong_two.actor_id()
1317            )));
1318
1319            // Other actor meshes on this proc mesh should still be up and running
1320            let ping_one: ActorRef<PingPongActor> = mesh_one.get(0).unwrap();
1321            let pong_one: ActorRef<PingPongActor> = mesh_one.get(1).unwrap();
1322            let (done_tx, done_rx) = mesh.client().open_once_port();
1323            pong_one
1324                .send(
1325                    mesh.client(),
1326                    PingPongMessage(1, ping_one.clone(), done_tx.bind()),
1327                )
1328                .unwrap();
1329            assert!(done_rx.recv().await.is_ok());
1330        }
1331    } // mod local
1332
1333    mod process {
1334
1335        use bytes::Bytes;
1336        use hyperactor::PortId;
1337        use hyperactor::channel::ChannelTransport;
1338        use hyperactor::clock::Clock;
1339        use hyperactor::clock::RealClock;
1340        use hyperactor::mailbox::MessageEnvelope;
1341        use rand::Rng;
1342        use tokio::process::Command;
1343
1344        use crate::alloc::process::ProcessAllocator;
1345
1346        fn process_allocator() -> ProcessAllocator {
1347            ProcessAllocator::new(Command::new(crate::testresource::get(
1348                "monarch/hyperactor_mesh/bootstrap",
1349            )))
1350        }
1351
1352        #[cfg(fbcode_build)] // we use an external binary, produced by buck
1353        actor_mesh_test_suite!(process_allocator());
1354
1355        // This test is concerned with correctly reporting failures
1356        // when message sizes exceed configured limits.
1357        #[cfg(fbcode_build)]
1358        //#[tracing_test::traced_test]
1359        #[async_timed_test(timeout_secs = 30)]
1360        async fn test_oversized_frames() {
1361            // Reproduced from 'net.rs'.
1362            #[derive(Debug, Serialize, Deserialize, PartialEq)]
1363            enum Frame<M> {
1364                Init(u64),
1365                Message(u64, M),
1366            }
1367            // Calculate the frame length for the given message.
1368            fn frame_length(src: &ActorId, dst: &PortId, pay: &Payload) -> usize {
1369                let serialized = Serialized::serialize(pay).unwrap();
1370                let mut headers = Attrs::new();
1371                hyperactor::mailbox::headers::set_send_timestamp(&mut headers);
1372                let envelope = MessageEnvelope::new(src.clone(), dst.clone(), serialized, headers);
1373                let frame = Frame::Message(0u64, envelope);
1374                let message = serde_multipart::serialize_illegal_bincode(&frame).unwrap();
1375                message.frame_len()
1376            }
1377
1378            // This process: short delivery timeout.
1379            let config = hyperactor::config::global::lock();
1380            // This process (write): max frame len for frame writes.
1381            let _guard2 =
1382                config.override_key(hyperactor::config::CODEC_MAX_FRAME_LENGTH, 1024usize);
1383            // Remote process (read): max frame len for frame reads.
1384            // SAFETY: Ok here but not safe for concurrent access.
1385            unsafe {
1386                std::env::set_var("HYPERACTOR_CODEC_MAX_FRAME_LENGTH", "1024");
1387            };
1388            let _guard3 =
1389                config.override_key(hyperactor::config::DEFAULT_ENCODING, Encoding::Bincode);
1390            let _guard4 = config.override_key(hyperactor::config::CHANNEL_MULTIPART, false);
1391
1392            let alloc = process_allocator()
1393                .allocate(AllocSpec {
1394                    extent: extent!(replica = 1),
1395                    constraints: Default::default(),
1396                    proc_name: None,
1397                    transport: ChannelTransport::Unix,
1398                })
1399                .await
1400                .unwrap();
1401            let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1402            let mut proc_events = proc_mesh.events().unwrap();
1403            let actor_mesh: RootActorMesh<TestActor> =
1404                proc_mesh.spawn("ingest", &()).await.unwrap();
1405            let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
1406            let dest = actor_mesh.get(0).unwrap();
1407
1408            // Message sized to exactly max frame length.
1409            let payload = Payload {
1410                part: Part::from(Bytes::from(vec![0u8; 698])),
1411                reply_port: reply_handle.bind(),
1412            };
1413            let frame_len = frame_length(
1414                proc_mesh.client().self_id(),
1415                dest.port::<Payload>().port_id(),
1416                &payload,
1417            );
1418            assert_eq!(frame_len, 1024);
1419
1420            // Send direct. A cast message is > 1024 bytes.
1421            dest.send(proc_mesh.client(), payload).unwrap();
1422            #[allow(clippy::disallowed_methods)]
1423            let result = RealClock
1424                .timeout(Duration::from_secs(2), reply_receiver.recv())
1425                .await;
1426            assert!(result.is_ok(), "Operation should not time out");
1427
1428            // Message sized to max frame length + 1.
1429            let payload = Payload {
1430                part: Part::from(Bytes::from(vec![0u8; 699])),
1431                reply_port: reply_handle.bind(),
1432            };
1433            let frame_len = frame_length(
1434                proc_mesh.client().self_id(),
1435                dest.port::<Payload>().port_id(),
1436                &payload,
1437            );
1438            assert_eq!(frame_len, 1025); // over the max frame len
1439
1440            // Send direct or cast. Either are guaranteed over the
1441            // limit and will fail.
1442            if rand::thread_rng().gen_bool(0.5) {
1443                dest.send(proc_mesh.client(), payload).unwrap();
1444            } else {
1445                actor_mesh
1446                    .cast(proc_mesh.client(), sel!(*), payload)
1447                    .unwrap();
1448            }
1449
1450            // The undeliverable supervision event that happens next
1451            // does not depend on a timeout.
1452            {
1453                let event = proc_events.next().await.unwrap();
1454                assert_matches!(
1455                    event,
1456                    ProcEvent::Crashed(_, _),
1457                    "Should have received crash event"
1458                );
1459            }
1460        }
1461
1462        // Set this test only for `mod process` because it relies on a
1463        // trick to emulate router failure that only works when using
1464        // non-local allocators.
1465        #[cfg(fbcode_build)]
1466        #[tokio::test]
1467        async fn test_router_undeliverable_return() {
1468            // Test that an undeliverable message received by a
1469            // router results in actor mesh supervision events.
1470            use ndslice::extent;
1471
1472            use super::test_util::*;
1473            use crate::alloc::AllocSpec;
1474            use crate::alloc::Allocator;
1475
1476            let alloc = process_allocator()
1477                .allocate(AllocSpec {
1478                    extent: extent! { replica = 1 },
1479                    constraints: Default::default(),
1480                    proc_name: None,
1481                    transport: ChannelTransport::Unix,
1482                })
1483                .await
1484                .unwrap();
1485
1486            // SAFETY: Not multithread safe.
1487            unsafe { std::env::set_var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK", "1") };
1488
1489            let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1490            let mut proc_events = proc_mesh.events().unwrap();
1491            let mut actor_mesh: RootActorMesh<'_, ProxyActor> =
1492                { proc_mesh.spawn("proxy", &()).await.unwrap() };
1493            let mut actor_events = actor_mesh.events().unwrap();
1494
1495            let proxy_actor = actor_mesh.get(0).unwrap();
1496            let (tx, mut rx) = actor_mesh.open_port::<String>();
1497            proxy_actor
1498                .send(proc_mesh.client(), Echo("hello!".to_owned(), tx.bind()))
1499                .unwrap();
1500
1501            #[allow(clippy::disallowed_methods)]
1502            match tokio::time::timeout(tokio::time::Duration::from_secs(3), rx.recv()).await {
1503                Ok(_) => panic!("the impossible happened"),
1504                Err(_) => {
1505                    assert_matches!(
1506                        proc_events.next().await.unwrap(),
1507                        ProcEvent::Crashed(0, reason) if reason.contains("undeliverable")
1508                    );
1509                    assert_eq!(
1510                        actor_events.next().await.unwrap().actor_id.name(),
1511                        &actor_mesh.name
1512                    );
1513                }
1514            }
1515
1516            // SAFETY: Not multithread safe.
1517            unsafe { std::env::remove_var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK") };
1518        }
1519    }
1520
1521    mod sim {
1522        use crate::alloc::sim::SimAllocator;
1523
1524        actor_mesh_test_suite!(SimAllocator::new_and_start_simnet());
1525    }
1526
1527    mod reshape_cast {
1528        use async_trait::async_trait;
1529        use hyperactor::Actor;
1530        use hyperactor::Context;
1531        use hyperactor::Handler;
1532        use hyperactor::channel::ChannelAddr;
1533        use hyperactor::channel::ChannelTransport;
1534        use hyperactor::channel::ChannelTx;
1535        use hyperactor::channel::Rx;
1536        use hyperactor::channel::Tx;
1537        use hyperactor::channel::dial;
1538        use hyperactor::channel::serve;
1539        use hyperactor::clock::Clock;
1540        use hyperactor::clock::RealClock;
1541        use ndslice::Selection;
1542
1543        use crate::Mesh;
1544        use crate::ProcMesh;
1545        use crate::RootActorMesh;
1546        use crate::actor_mesh::ActorMesh;
1547        use crate::alloc::AllocSpec;
1548        use crate::alloc::Allocator;
1549        use crate::alloc::LocalAllocator;
1550        use crate::config::MAX_CAST_DIMENSION_SIZE;
1551
1552        #[derive(Debug)]
1553        #[hyperactor::export(
1554            spawn = true,
1555            handlers = [() { cast = true }],
1556        )]
1557        struct EchoActor(ChannelTx<usize>);
1558
1559        #[async_trait]
1560        impl Actor for EchoActor {
1561            type Params = ChannelAddr;
1562
1563            async fn new(params: ChannelAddr) -> Result<Self, anyhow::Error> {
1564                Ok(Self(dial::<usize>(params)?))
1565            }
1566        }
1567
1568        #[async_trait]
1569        impl Handler<()> for EchoActor {
1570            async fn handle(
1571                &mut self,
1572                cx: &Context<Self>,
1573                _message: (),
1574            ) -> Result<(), anyhow::Error> {
1575                let Self(port) = self;
1576                port.post(cx.self_id().rank());
1577                Ok(())
1578            }
1579        }
1580
1581        async fn validate_cast<A>(
1582            actor_mesh: &A,
1583            caps: &impl hyperactor::context::Actor,
1584            addr: ChannelAddr,
1585            selection: Selection,
1586        ) where
1587            A: ActorMesh<Actor = EchoActor>,
1588        {
1589            let config = hyperactor::config::global::lock();
1590            let _guard = config.override_key(MAX_CAST_DIMENSION_SIZE, 2);
1591
1592            let (_, mut rx) = serve::<usize>(addr).unwrap();
1593
1594            let expected_ranks = selection
1595                .eval(
1596                    &ndslice::selection::EvalOpts::strict(),
1597                    actor_mesh.shape().slice(),
1598                )
1599                .unwrap()
1600                .collect::<std::collections::BTreeSet<_>>();
1601
1602            actor_mesh.cast(caps, selection, ()).unwrap();
1603
1604            let mut received = std::collections::BTreeSet::new();
1605
1606            for _ in 0..(expected_ranks.len()) {
1607                received.insert(
1608                    RealClock
1609                        .timeout(tokio::time::Duration::from_secs(1), rx.recv())
1610                        .await
1611                        .unwrap()
1612                        .unwrap(),
1613                );
1614            }
1615
1616            assert_eq!(received, expected_ranks);
1617        }
1618
1619        use ndslice::strategy::gen_extent;
1620        use ndslice::strategy::gen_selection;
1621        use proptest::prelude::*;
1622        use proptest::test_runner::TestRunner;
1623
1624        fn make_tokio_runtime() -> tokio::runtime::Runtime {
1625            tokio::runtime::Builder::new_multi_thread()
1626                .enable_all()
1627                .worker_threads(2)
1628                .build()
1629                .unwrap()
1630        }
1631
1632        proptest! {
1633            #![proptest_config(ProptestConfig {
1634                cases: 8, ..ProptestConfig::default()
1635            })]
1636            #[test]
1637            fn test_reshaped_actor_mesh_cast(extent in gen_extent(1..=4, 8)) {
1638                let runtime = make_tokio_runtime();
1639                let alloc = runtime.block_on(LocalAllocator
1640                    .allocate(AllocSpec {
1641                        extent,
1642                        constraints: Default::default(),
1643                        proc_name: None,
1644                        transport: ChannelTransport::Local
1645                    }))
1646                    .unwrap();
1647                let proc_mesh = runtime.block_on(ProcMesh::allocate(alloc)).unwrap();
1648
1649                let addr = ChannelAddr::any(ChannelTransport::Unix);
1650
1651                let actor_mesh: RootActorMesh<EchoActor> =
1652                    runtime.block_on(proc_mesh.spawn("echo", &addr)).unwrap();
1653
1654                let mut runner = TestRunner::default();
1655                let selection = gen_selection(4, actor_mesh.shape().slice().sizes().to_vec(), 0)
1656                    .new_tree(&mut runner)
1657                    .unwrap()
1658                    .current();
1659
1660                runtime.block_on(validate_cast(&actor_mesh, actor_mesh.proc_mesh().client(), addr, selection));
1661            }
1662        }
1663
1664        proptest! {
1665            #![proptest_config(ProptestConfig {
1666                cases: 8, ..ProptestConfig::default()
1667            })]
1668            #[test]
1669            fn test_reshaped_actor_mesh_slice_cast(extent in gen_extent(1..=4, 8)) {
1670                let runtime = make_tokio_runtime();
1671                let alloc = runtime.block_on(LocalAllocator
1672                    .allocate(AllocSpec {
1673                        extent: extent.clone(),
1674                        constraints: Default::default(),
1675                        proc_name: None,
1676                        transport: ChannelTransport::Local
1677                    }))
1678                    .unwrap();
1679                let proc_mesh = runtime.block_on(ProcMesh::allocate(alloc)).unwrap();
1680
1681                let addr = ChannelAddr::any(ChannelTransport::Unix);
1682
1683                let actor_mesh: RootActorMesh<EchoActor> =
1684                    runtime.block_on(proc_mesh.spawn("echo", &addr)).unwrap();
1685
1686
1687                let first_label = extent.labels().first().unwrap();
1688                let slice = actor_mesh.select(first_label, 0..extent.size(first_label).unwrap()).unwrap();
1689
1690                // Unfortunately we must do things this way due to borrow checker reasons
1691                let slice = if extent.len() >= 2 {
1692                    let label = &extent.labels()[1];
1693                    let size = extent.size(label).unwrap();
1694                    let start = if size > 1 { 1 } else { 0 };
1695                    let end = (if size > 1 { size - 1 } else { 1 }).max(start + 1);
1696                    slice.select(label, start..end).unwrap()
1697                } else {
1698                    slice
1699                };
1700
1701                let slice = if extent.len() >= 3 {
1702                    let label = &extent.labels()[2];
1703                    let size = extent.size(label).unwrap();
1704                    let start = if size > 1 { 1 } else { 0 };
1705                    let end = (if size > 1 { size - 1 } else { 1 }).max(start + 1);
1706                    slice.select(label, start..end).unwrap()
1707                } else {
1708                    slice
1709                };
1710
1711                let slice = if extent.len() >= 4 {
1712                    let label = &extent.labels()[3];
1713                    let size = extent.size(label).unwrap();
1714                    let start = if size > 1 { 1 } else { 0 };
1715                    let end = (if size > 1 { size - 1 } else { 1 }).max(start + 1);
1716                    slice.select(label, start..end).unwrap()
1717                } else {
1718                    slice
1719                };
1720
1721
1722                let mut runner = TestRunner::default();
1723                let selection = gen_selection(4, slice.shape().slice().sizes().to_vec(), 0)
1724                    .new_tree(&mut runner)
1725                    .unwrap()
1726                    .current();
1727
1728                runtime.block_on(validate_cast(
1729                    &slice,
1730                    actor_mesh.proc_mesh().client(),
1731                    addr,
1732                    selection
1733                ));
1734            }
1735        }
1736
1737        proptest! {
1738            #![proptest_config(ProptestConfig {
1739                cases: 8, ..ProptestConfig::default()
1740            })]
1741             #[test]
1742             fn test_reshaped_actor_mesh_cast_with_selection(extent in gen_extent(1..=4, 8)) {
1743                let runtime = make_tokio_runtime();
1744                let alloc = runtime.block_on(LocalAllocator
1745                    .allocate(AllocSpec {
1746                        extent,
1747                        constraints: Default::default(),
1748                        proc_name: None,
1749                        transport: ChannelTransport::Local
1750                    }))
1751                    .unwrap();
1752                let proc_mesh = runtime.block_on(ProcMesh::allocate(alloc)).unwrap();
1753
1754                let addr = ChannelAddr::any(ChannelTransport::Unix);
1755
1756                let actor_mesh: RootActorMesh<EchoActor> =
1757                    runtime.block_on(proc_mesh.spawn("echo", &addr)).unwrap();
1758
1759                let mut runner = TestRunner::default();
1760                let selection = gen_selection(4, actor_mesh.shape().slice().sizes().to_vec(), 0)
1761                    .new_tree(&mut runner)
1762                    .unwrap()
1763                    .current();
1764
1765                runtime.block_on(validate_cast(
1766                    &actor_mesh,
1767                    actor_mesh.proc_mesh().client(),
1768                    addr,
1769                    selection
1770                ));
1771            }
1772        }
1773    }
1774}