hyperactor_mesh/
actor_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![allow(dead_code)] // until used publically
10
11use std::collections::BTreeSet;
12use std::ops::Deref;
13use std::sync::OnceLock;
14
15use async_trait::async_trait;
16use hyperactor::Actor;
17use hyperactor::ActorId;
18use hyperactor::ActorRef;
19use hyperactor::Bind;
20use hyperactor::GangId;
21use hyperactor::GangRef;
22use hyperactor::Message;
23use hyperactor::PortHandle;
24use hyperactor::ProcId;
25use hyperactor::RemoteHandles;
26use hyperactor::RemoteMessage;
27use hyperactor::Unbind;
28use hyperactor::WorldId;
29use hyperactor::actor::Referable;
30use hyperactor::context;
31use hyperactor::mailbox::MailboxSenderError;
32use hyperactor::mailbox::MessageEnvelope;
33use hyperactor::mailbox::PortReceiver;
34use hyperactor::mailbox::Undeliverable;
35use hyperactor::message::Castable;
36use hyperactor::message::IndexedErasedUnbound;
37use hyperactor::supervision::ActorSupervisionEvent;
38use hyperactor_config::attrs::Attrs;
39use hyperactor_config::attrs::declare_attrs;
40use ndslice::Range;
41use ndslice::Selection;
42use ndslice::Shape;
43use ndslice::ShapeError;
44use ndslice::SliceError;
45use ndslice::View;
46use ndslice::reshape::Limit;
47use ndslice::reshape::ReshapeError;
48use ndslice::reshape::ReshapeSliceExt;
49use ndslice::reshape::reshape_selection;
50use ndslice::selection;
51use ndslice::selection::EvalOpts;
52use ndslice::selection::ReifySlice;
53use ndslice::selection::normal;
54use ndslice::view::ViewExt;
55use serde::Deserialize;
56use serde::Serialize;
57use serde_multipart::Part;
58use tokio::sync::mpsc;
59use typeuri::Named;
60
61use crate::CommActor;
62use crate::Mesh;
63use crate::comm::multicast::CAST_ORIGINATING_SENDER;
64use crate::comm::multicast::CastMessage;
65use crate::comm::multicast::CastMessageEnvelope;
66use crate::comm::multicast::Uslice;
67use crate::config::MAX_CAST_DIMENSION_SIZE;
68use crate::metrics;
69use crate::proc_mesh::ProcMesh;
70use crate::reference::ActorMeshId;
71use crate::reference::ActorMeshRef;
72use crate::v1;
73
74declare_attrs! {
75    /// Which mesh this message was cast to. Used for undeliverable message
76    /// handling, where the CastMessageEnvelope is serialized, and its content
77    /// cannot be inspected.
78    pub attr CAST_ACTOR_MESH_ID: ActorMeshId;
79}
80
81/// An undeliverable might have its sender address set as the comm actor instead
82/// of the original sender. Update it based on the headers present in the message
83/// so it matches the sender.
84pub fn update_undeliverable_envelope_for_casting(
85    mut envelope: Undeliverable<MessageEnvelope>,
86) -> Undeliverable<MessageEnvelope> {
87    let old_actor = envelope.0.sender().clone();
88    // v1 casting
89    if let Some(actor_id) = envelope.0.headers().get(CAST_ORIGINATING_SENDER).cloned() {
90        tracing::debug!(
91            actor_id = %old_actor,
92            "remapped comm-actor id to id from CAST_ORIGINATING_SENDER {}", actor_id
93        );
94        envelope.0.update_sender(actor_id);
95    // v0 casting
96    } else if let Some(actor_mesh_id) = envelope.0.headers().get(CAST_ACTOR_MESH_ID) {
97        match actor_mesh_id {
98            ActorMeshId::V0(proc_mesh_id, actor_name) => {
99                let actor_id = ActorId(
100                    ProcId::Ranked(WorldId(proc_mesh_id.0.clone()), 0),
101                    actor_name.clone(),
102                    0,
103                );
104                tracing::debug!(
105                    actor_id = %old_actor,
106                    "remapped comm-actor id to mesh id from CAST_ACTOR_MESH_ID {}", actor_id
107                );
108                envelope.0.update_sender(actor_id);
109            }
110            ActorMeshId::V1(_) => {
111                tracing::debug!("headers present but V1 ActorMeshId; leaving actor_id unchanged");
112            }
113        }
114    } else {
115        // Do nothing, it wasn't from a comm actor.
116    }
117    envelope
118}
119
120/// Common implementation for `ActorMesh`s and `ActorMeshRef`s to cast
121/// an `M`-typed message
122#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
123#[hyperactor::instrument]
124pub(crate) fn actor_mesh_cast<A, M>(
125    cx: &impl context::Actor,
126    actor_mesh_id: ActorMeshId,
127    comm_actor_ref: &ActorRef<CommActor>,
128    selection_of_root: Selection,
129    root_mesh_shape: &Shape,
130    cast_mesh_shape: &Shape,
131    message: M,
132) -> Result<(), CastError>
133where
134    A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
135    M: Castable + RemoteMessage,
136{
137    let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
138        "message_type" => M::typename(),
139        "message_variant" => message.arm().unwrap_or_default(),
140    ));
141
142    let message = CastMessageEnvelope::new::<A, M>(
143        actor_mesh_id.clone(),
144        cx.mailbox().actor_id().clone(),
145        cast_mesh_shape.clone(),
146        message,
147    )?;
148
149    // Mesh's shape might have large extents on some dimensions. Those
150    // dimensions would cause large fanout in our comm actor
151    // implementation. To avoid that, we reshape it by increasing
152    // dimensionality and limiting the extent of each dimension. Note
153    // the reshape is only visible to the internal algorithm. The
154    // shape that user sees maintains intact.
155    //
156    // For example, a typical shape is [hosts=1024, gpus=8]. By using
157    // limit 8, it becomes [8, 8, 8, 2, 8] during casting. In other
158    // words, it adds 3 extra layers to the comm actor tree, while
159    // keeping the fanout in each layer per dimension at 8 or smaller.
160    //
161    // An important note here is that max dimension size != max fanout.
162    // Rank 0 must send a message to all ranks at index 0 for every dimension.
163    // If our reshaped shape is [8, 8, 8, 2, 8], rank 0 must send
164    // 7 + 7 + 7 + 1 + 7 = 21 messages.
165
166    let slice_of_root = root_mesh_shape.slice();
167
168    let max_cast_dimension_size = hyperactor_config::global::get(MAX_CAST_DIMENSION_SIZE);
169
170    let slice_of_cast = slice_of_root.reshape_with_limit(Limit::from(max_cast_dimension_size));
171
172    let selection_of_cast =
173        reshape_selection(selection_of_root, root_mesh_shape.slice(), &slice_of_cast)?;
174
175    let cast_message = CastMessage {
176        dest: Uslice {
177            slice: slice_of_cast,
178            selection: selection_of_cast,
179        },
180        message,
181    };
182
183    let mut headers = Attrs::new();
184    headers.set(CAST_ACTOR_MESH_ID, actor_mesh_id);
185
186    comm_actor_ref
187        .port()
188        .send_with_headers(cx, headers, cast_message)?;
189
190    Ok(())
191}
192
193#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
194pub(crate) fn cast_to_sliced_mesh<A, M>(
195    cx: &impl context::Actor,
196    actor_mesh_id: ActorMeshId,
197    comm_actor_ref: &ActorRef<CommActor>,
198    sel_of_sliced: &Selection,
199    message: M,
200    sliced_shape: &Shape,
201    root_mesh_shape: &Shape,
202) -> Result<(), CastError>
203where
204    A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
205    M: Castable + RemoteMessage,
206{
207    let root_slice = root_mesh_shape.slice();
208
209    // Casting to `*`?
210    let sel_of_root = if selection::normalize(sel_of_sliced) == normal::NormalizedSelection::True {
211        // Reify this view into base.
212        root_slice.reify_slice(sliced_shape.slice())?
213    } else {
214        // No, fall back on `of_ranks`.
215        let ranks = sel_of_sliced
216            .eval(&EvalOpts::strict(), sliced_shape.slice())?
217            .collect::<BTreeSet<_>>();
218        Selection::of_ranks(root_slice, &ranks)?
219    };
220
221    // Cast.
222    actor_mesh_cast::<A, M>(
223        cx,
224        actor_mesh_id,
225        comm_actor_ref,
226        sel_of_root,
227        root_mesh_shape,
228        sliced_shape,
229        message,
230    )
231}
232
233/// A mesh of actors, all of which reside on the same [`ProcMesh`].
234#[async_trait]
235pub trait ActorMesh: Mesh<Id = ActorMeshId> {
236    /// The type of actor in the mesh.
237    type Actor: Referable;
238
239    /// Cast an `M`-typed message to the ranks selected by `sel` in
240    /// this ActorMesh.
241    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
242    fn cast<M>(
243        &self,
244        cx: &impl context::Actor,
245        selection: Selection,
246        message: M,
247    ) -> Result<(), CastError>
248    where
249        Self::Actor: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
250        M: Castable + RemoteMessage + Clone,
251    {
252        if let Some(v1) = self.v1() {
253            return v1
254                .cast_for_tensor_engine_only_do_not_use(cx, selection, message)
255                .map_err(anyhow::Error::from)
256                .map_err(CastError::from);
257        }
258        actor_mesh_cast::<Self::Actor, M>(
259            cx,                            // actor context
260            self.id(),                     // actor mesh id (destination mesh)
261            self.proc_mesh().comm_actor(), // comm actor
262            selection,                     // the selected actors
263            self.shape(),                  // root mesh shape
264            self.shape(),                  // cast mesh shape
265            message,                       // the message
266        )
267    }
268
269    /// The ProcMesh on top of which this actor mesh is spawned.
270    fn proc_mesh(&self) -> &ProcMesh;
271
272    /// The name given to the actors in this mesh.
273    fn name(&self) -> &str;
274
275    fn world_id(&self) -> &WorldId {
276        self.proc_mesh().world_id()
277    }
278
279    /// Iterate over all `ActorRef<Self::Actor>` in this mesh.
280    fn iter_actor_refs(&self) -> Box<dyn Iterator<Item = ActorRef<Self::Actor>>> {
281        if let Some(v1) = self.v1() {
282            // We collect() here to ensure that the data are owned. Since this is a short-lived
283            // shim, we'll live with it.
284            return Box::new(
285                v1.iter()
286                    .map(|(_point, actor_ref)| actor_ref.clone())
287                    .collect::<Vec<_>>()
288                    .into_iter(),
289            );
290        }
291        let gang: GangRef<Self::Actor> = GangRef::attest(GangId(
292            self.proc_mesh().world_id().clone(),
293            self.name().to_string(),
294        ));
295        Box::new(self.shape().slice().iter().map(move |rank| gang.rank(rank)))
296    }
297
298    async fn stop(&self, cx: &impl context::Actor) -> Result<(), anyhow::Error> {
299        self.proc_mesh().stop_actor_by_name(cx, self.name()).await
300    }
301
302    /// Get a serializeable reference to this mesh similar to ActorHandle::bind
303    fn bind(&self) -> ActorMeshRef<Self::Actor> {
304        ActorMeshRef::attest(
305            self.id(),
306            self.shape().clone(),
307            self.proc_mesh().comm_actor().clone(),
308        )
309    }
310
311    /// Retrieves the v1 mesh for this v0 ActorMesh, if it is available.
312    fn v1(&self) -> Option<v1::ActorMeshRef<Self::Actor>>;
313}
314
315/// Abstracts over shared and borrowed references to a [`ProcMesh`].
316/// Given a shared ProcMesh, we can obtain a [`ActorMesh<'static, _>`]
317/// for it, useful when lifetime must be managed dynamically.
318enum ProcMeshRef<'a> {
319    /// The reference is shared without requiring a reference.
320    Shared(Box<dyn Deref<Target = ProcMesh> + Sync + Send>),
321    /// The reference is borrowed with a parameterized
322    /// lifetime.
323    Borrowed(&'a ProcMesh),
324}
325
326impl Deref for ProcMeshRef<'_> {
327    type Target = ProcMesh;
328
329    fn deref(&self) -> &Self::Target {
330        match self {
331            Self::Shared(p) => p,
332            Self::Borrowed(p) => p, // p: &ProcMesh
333        }
334    }
335}
336
337/// A mesh of actor instances. ActorMeshes are obtained by spawning an
338/// actor on a [`ProcMesh`].
339///
340/// Generic bound: `A: Referable` — this type hands out typed
341/// `ActorRef<A>` handles (see `ranks`), and `ActorRef` is only
342/// defined for `A: Referable`.
343pub struct RootActorMesh<'a, A: Referable> {
344    inner: ActorMeshKind<'a, A>,
345    shape: OnceLock<Shape>,
346    proc_mesh: OnceLock<ProcMesh>,
347    name: OnceLock<String>,
348}
349
350enum ActorMeshKind<'a, A: Referable> {
351    V0 {
352        proc_mesh: ProcMeshRef<'a>,
353        name: String,
354        ranks: Vec<ActorRef<A>>, // temporary until we remove `ArcActorMesh`.
355        // The receiver of supervision events. It is None if it has been transferred to
356        // an actor event observer.
357        actor_supervision_rx: Option<mpsc::UnboundedReceiver<ActorSupervisionEvent>>,
358    },
359
360    V1(v1::ActorMeshRef<A>),
361}
362
363impl<'a, A: Referable> From<v1::ActorMeshRef<A>> for RootActorMesh<'a, A> {
364    fn from(actor_mesh: v1::ActorMeshRef<A>) -> Self {
365        Self {
366            inner: ActorMeshKind::V1(actor_mesh),
367            shape: OnceLock::new(),
368            proc_mesh: OnceLock::new(),
369            name: OnceLock::new(),
370        }
371    }
372}
373
374impl<'a, A: Referable> From<v1::ActorMesh<A>> for RootActorMesh<'a, A> {
375    fn from(actor_mesh: v1::ActorMesh<A>) -> Self {
376        actor_mesh.detach().into()
377    }
378}
379
380impl<'a, A: Referable> RootActorMesh<'a, A> {
381    pub(crate) fn new(
382        proc_mesh: &'a ProcMesh,
383        name: String,
384        actor_supervision_rx: mpsc::UnboundedReceiver<ActorSupervisionEvent>,
385        ranks: Vec<ActorRef<A>>,
386    ) -> Self {
387        Self {
388            inner: ActorMeshKind::V0 {
389                proc_mesh: ProcMeshRef::Borrowed(proc_mesh),
390                name,
391                ranks,
392                actor_supervision_rx: Some(actor_supervision_rx),
393            },
394            shape: OnceLock::new(),
395            proc_mesh: OnceLock::new(),
396            name: OnceLock::new(),
397        }
398    }
399
400    pub fn new_v1(actor_mesh: v1::ActorMeshRef<A>) -> Self {
401        Self {
402            inner: ActorMeshKind::V1(actor_mesh),
403            shape: OnceLock::new(),
404            proc_mesh: OnceLock::new(),
405            name: OnceLock::new(),
406        }
407    }
408
409    pub(crate) fn new_shared<D: Deref<Target = ProcMesh> + Send + Sync + 'static>(
410        proc_mesh: D,
411        name: String,
412        actor_supervision_rx: mpsc::UnboundedReceiver<ActorSupervisionEvent>,
413        ranks: Vec<ActorRef<A>>,
414    ) -> Self {
415        Self {
416            inner: ActorMeshKind::V0 {
417                proc_mesh: ProcMeshRef::Shared(Box::new(proc_mesh)),
418                name,
419                ranks,
420                actor_supervision_rx: Some(actor_supervision_rx),
421            },
422            shape: OnceLock::new(),
423            proc_mesh: OnceLock::new(),
424            name: OnceLock::new(),
425        }
426    }
427
428    /// Open a port on this ActorMesh.
429    pub fn open_port<M: Message>(&self) -> (PortHandle<M>, PortReceiver<M>) {
430        match &self.inner {
431            ActorMeshKind::V0 { proc_mesh, .. } => proc_mesh.client().open_port(),
432            ActorMeshKind::V1(_actor_mesh) => unimplemented!("unsupported operation"),
433        }
434    }
435
436    /// An event stream of actor events. Each RootActorMesh can produce only one such
437    /// stream, returning None after the first call.
438    pub fn events(&mut self) -> Option<ActorSupervisionEvents> {
439        match &mut self.inner {
440            ActorMeshKind::V0 {
441                actor_supervision_rx,
442                ..
443            } => actor_supervision_rx
444                .take()
445                .map(|actor_supervision_rx| ActorSupervisionEvents {
446                    actor_supervision_rx,
447                    mesh_id: self.id(),
448                }),
449            ActorMeshKind::V1(_actor_mesh) => unimplemented!("unsupported operation"),
450        }
451    }
452
453    /// Access the ranks field (temporary until we remove `ArcActorMesh`).
454    #[cfg(test)]
455    pub(crate) fn ranks(&self) -> &Vec<ActorRef<A>> {
456        match &self.inner {
457            ActorMeshKind::V0 { ranks, .. } => ranks,
458            ActorMeshKind::V1(_actor_mesh) => unimplemented!("unsupported operation"),
459        }
460    }
461}
462
463/// Supervision event stream for actor mesh. It emits actor supervision events.
464pub struct ActorSupervisionEvents {
465    // The receiver of supervision events from proc mesh.
466    actor_supervision_rx: mpsc::UnboundedReceiver<ActorSupervisionEvent>,
467    // The name of the actor mesh.
468    mesh_id: ActorMeshId,
469}
470
471impl ActorSupervisionEvents {
472    pub async fn next(&mut self) -> Option<ActorSupervisionEvent> {
473        let result = self.actor_supervision_rx.recv().await;
474        if result.is_none() {
475            tracing::info!(
476                "supervision stream for actor mesh {:?} was closed!",
477                self.mesh_id
478            );
479        }
480        result
481    }
482}
483
484#[async_trait]
485impl<'a, A: Referable> Mesh for RootActorMesh<'a, A> {
486    type Node = ActorRef<A>;
487    type Id = ActorMeshId;
488    type Sliced<'b>
489        = SlicedActorMesh<'b, A>
490    where
491        'a: 'b;
492
493    fn shape(&self) -> &Shape {
494        self.shape.get_or_init(|| match &self.inner {
495            ActorMeshKind::V0 { proc_mesh, .. } => proc_mesh.shape().clone(),
496            ActorMeshKind::V1(actor_mesh) => actor_mesh.region().into(),
497        })
498    }
499
500    fn select<R: Into<Range>>(
501        &self,
502        label: &str,
503        range: R,
504    ) -> Result<Self::Sliced<'_>, ShapeError> {
505        Ok(SlicedActorMesh(self, self.shape().select(label, range)?))
506    }
507
508    fn get(&self, rank: usize) -> Option<ActorRef<A>> {
509        match &self.inner {
510            ActorMeshKind::V0 { ranks, .. } => ranks.get(rank).cloned(),
511            ActorMeshKind::V1(actor_mesh) => actor_mesh.get(rank),
512        }
513    }
514
515    fn id(&self) -> Self::Id {
516        match &self.inner {
517            ActorMeshKind::V0 {
518                proc_mesh, name, ..
519            } => ActorMeshId::V0(proc_mesh.id(), name.clone()),
520            ActorMeshKind::V1(actor_mesh) => ActorMeshId::V1(actor_mesh.name().clone()),
521        }
522    }
523}
524
525impl<A: Referable> ActorMesh for RootActorMesh<'_, A> {
526    type Actor = A;
527
528    fn proc_mesh(&self) -> &ProcMesh {
529        match &self.inner {
530            ActorMeshKind::V0 { proc_mesh, .. } => proc_mesh,
531            ActorMeshKind::V1(actor_mesh) => self
532                .proc_mesh
533                .get_or_init(|| actor_mesh.proc_mesh().clone().into()),
534        }
535    }
536
537    fn name(&self) -> &str {
538        match &self.inner {
539            ActorMeshKind::V0 { name, .. } => name,
540            ActorMeshKind::V1(actor_mesh) => {
541                self.name.get_or_init(|| actor_mesh.name().to_string())
542            }
543        }
544    }
545
546    fn v1(&self) -> Option<v1::ActorMeshRef<Self::Actor>> {
547        match &self.inner {
548            ActorMeshKind::V0 { .. } => None,
549            ActorMeshKind::V1(actor_mesh) => Some(actor_mesh.clone()),
550        }
551    }
552}
553
554pub struct SlicedActorMesh<'a, A: Referable>(&'a RootActorMesh<'a, A>, Shape);
555
556impl<'a, A: Referable> SlicedActorMesh<'a, A> {
557    pub fn new(actor_mesh: &'a RootActorMesh<'a, A>, shape: Shape) -> Self {
558        Self(actor_mesh, shape)
559    }
560
561    pub fn shape(&self) -> &Shape {
562        &self.1
563    }
564}
565
566#[async_trait]
567impl<A: Referable> Mesh for SlicedActorMesh<'_, A> {
568    type Node = ActorRef<A>;
569    type Id = ActorMeshId;
570    type Sliced<'b>
571        = SlicedActorMesh<'b, A>
572    where
573        Self: 'b;
574
575    fn shape(&self) -> &Shape {
576        &self.1
577    }
578
579    fn select<R: Into<Range>>(
580        &self,
581        label: &str,
582        range: R,
583    ) -> Result<Self::Sliced<'_>, ShapeError> {
584        Ok(Self(self.0, self.1.select(label, range)?))
585    }
586
587    fn get(&self, _index: usize) -> Option<ActorRef<A>> {
588        unimplemented!()
589    }
590
591    fn id(&self) -> Self::Id {
592        self.0.id()
593    }
594}
595
596impl<A: Referable> ActorMesh for SlicedActorMesh<'_, A> {
597    type Actor = A;
598
599    fn proc_mesh(&self) -> &ProcMesh {
600        self.0.proc_mesh()
601    }
602
603    fn name(&self) -> &str {
604        self.0.name()
605    }
606
607    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
608    fn cast<M>(&self, cx: &impl context::Actor, sel: Selection, message: M) -> Result<(), CastError>
609    where
610        Self::Actor: RemoteHandles<IndexedErasedUnbound<M>>,
611        M: Castable + RemoteMessage,
612    {
613        cast_to_sliced_mesh::<A, M>(
614            /*cx=*/ cx,
615            /*actor_mesh_id=*/ self.id(),
616            /*comm_actor_ref*/ self.proc_mesh().comm_actor(),
617            /*sel_of_sliced=*/ &sel,
618            /*message=*/ message,
619            /*sliced_shape=*/ self.shape(),
620            /*root_mesh_shape=*/ self.0.shape(),
621        )
622    }
623
624    fn v1(&self) -> Option<v1::ActorMeshRef<Self::Actor>> {
625        self.0
626            .v1()
627            .map(|actor_mesh| actor_mesh.subset(self.shape().into()).unwrap())
628    }
629}
630
631/// The type of error of casting operations.
632#[derive(Debug, thiserror::Error)]
633pub enum CastError {
634    #[error("invalid selection {0}: {1}")]
635    InvalidSelection(Selection, ShapeError),
636
637    #[error("send on rank {0}: {1}")]
638    MailboxSenderError(usize, MailboxSenderError),
639
640    #[error("unsupported selection: {0}")]
641    SelectionNotSupported(String),
642
643    #[error(transparent)]
644    RootMailboxSenderError(#[from] MailboxSenderError),
645
646    #[error(transparent)]
647    ShapeError(#[from] ShapeError),
648
649    #[error(transparent)]
650    SliceError(#[from] SliceError),
651
652    #[error(transparent)]
653    SerializationError(#[from] bincode::Error),
654
655    #[error(transparent)]
656    Other(#[from] anyhow::Error),
657
658    #[error(transparent)]
659    ReshapeError(#[from] ReshapeError),
660}
661
662// This has to be compiled outside of test mode because the bootstrap binary
663// is not built in test mode, and requires access to TestActor.
664pub(crate) mod test_util {
665    use std::collections::VecDeque;
666    use std::fmt;
667    use std::fmt::Debug;
668    use std::sync::Arc;
669
670    use anyhow::ensure;
671    use hyperactor::Context;
672    use hyperactor::Handler;
673    use hyperactor::Instance;
674    use hyperactor::PortRef;
675    use hyperactor::RemoteSpawn;
676    use ndslice::extent;
677
678    use super::*;
679    use crate::comm::multicast::CastInfo;
680    use crate::supervision::MeshFailure;
681
682    // This can't be defined under a `#[cfg(test)]` because there needs to
683    // be an entry in the spawnable actor registry in the executable
684    // 'hyperactor_mesh_test_bootstrap' for the `tests::process` actor
685    // mesh test suite.
686    #[derive(Debug, Default)]
687    #[hyperactor::export(
688        spawn = true,
689        handlers = [
690            Echo { cast = true },
691            Payload { cast = true },
692            GetRank { cast = true },
693            Error { cast = true },
694            Relay,
695        ],
696    )]
697    pub struct TestActor;
698
699    impl Actor for TestActor {}
700
701    /// Request message to retrieve the actor's rank.
702    ///
703    /// The `bool` in the tuple controls the outcome of the handler:
704    /// - If `true`, the handler will send the rank and return
705    ///   `Ok(())`.
706    /// - If `false`, the handler will still send the rank, but return
707    ///   an error (`Err(...)`).
708    ///
709    /// This is useful for testing both successful and failing
710    /// responses from a single message type.
711    #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
712    pub struct GetRank(pub bool, #[binding(include)] pub PortRef<usize>);
713
714    #[async_trait]
715    impl Handler<GetRank> for TestActor {
716        async fn handle(
717            &mut self,
718            cx: &Context<Self>,
719            GetRank(ok, reply): GetRank,
720        ) -> Result<(), anyhow::Error> {
721            let point = cx.cast_point();
722            reply.send(cx, point.rank())?;
723            anyhow::ensure!(ok, "intentional error!"); // If `!ok` exit with `Err()`.
724            Ok(())
725        }
726    }
727
728    #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
729    pub struct Echo(pub String, #[binding(include)] pub PortRef<String>);
730
731    #[async_trait]
732    impl Handler<Echo> for TestActor {
733        async fn handle(&mut self, cx: &Context<Self>, message: Echo) -> Result<(), anyhow::Error> {
734            let Echo(message, reply_port) = message;
735            reply_port.send(cx, message)?;
736            Ok(())
737        }
738    }
739
740    #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
741    pub struct Payload {
742        pub part: Part,
743        #[binding(include)]
744        pub reply_port: PortRef<()>,
745    }
746
747    #[async_trait]
748    impl Handler<Payload> for TestActor {
749        async fn handle(
750            &mut self,
751            cx: &Context<Self>,
752            message: Payload,
753        ) -> Result<(), anyhow::Error> {
754            let Payload { reply_port, .. } = message;
755            reply_port.send(cx, ())?;
756            Ok(())
757        }
758    }
759
760    #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
761    pub struct Error(pub String);
762
763    #[async_trait]
764    impl Handler<Error> for TestActor {
765        async fn handle(
766            &mut self,
767            _cx: &Context<Self>,
768            Error(error): Error,
769        ) -> Result<(), anyhow::Error> {
770            Err(anyhow::anyhow!("{}", error))
771        }
772    }
773
774    #[derive(Debug, Serialize, Deserialize, Named, Clone)]
775    pub struct Relay(pub usize, pub VecDeque<PortRef<Relay>>);
776
777    #[async_trait]
778    impl Handler<Relay> for TestActor {
779        async fn handle(
780            &mut self,
781            cx: &Context<Self>,
782            Relay(count, mut hops): Relay,
783        ) -> Result<(), anyhow::Error> {
784            ensure!(!hops.is_empty(), "relay must have at least one hop");
785            let next = hops.pop_front().unwrap();
786            next.send(cx, Relay(count + 1, hops))?;
787            Ok(())
788        }
789    }
790
791    // -- ProxyActor
792
793    #[hyperactor::export(
794        spawn = true,
795        handlers = [
796            Echo,
797        ],
798    )]
799    pub struct ProxyActor {
800        proc_mesh: &'static Arc<ProcMesh>,
801        actor_mesh: Option<RootActorMesh<'static, TestActor>>,
802    }
803
804    impl fmt::Debug for ProxyActor {
805        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
806            f.debug_struct("ProxyActor")
807                .field("proc_mesh", &"...")
808                .field("actor_mesh", &"...")
809                .finish()
810        }
811    }
812
813    #[async_trait]
814    impl Actor for ProxyActor {
815        async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
816            self.actor_mesh = Some(self.proc_mesh.spawn(this, "echo", &()).await?);
817            Ok(())
818        }
819    }
820
821    #[async_trait]
822    impl RemoteSpawn for ProxyActor {
823        type Params = ();
824
825        async fn new(_params: Self::Params) -> Result<Self, anyhow::Error> {
826            // The actor creates a mesh.
827            use std::sync::Arc;
828
829            use hyperactor::channel::ChannelTransport;
830
831            use crate::alloc::AllocSpec;
832            use crate::alloc::Allocator;
833            use crate::alloc::LocalAllocator;
834
835            let alloc = LocalAllocator
836                .allocate(AllocSpec {
837                    extent: extent! { replica = 1 },
838                    constraints: Default::default(),
839                    proc_name: None,
840                    transport: ChannelTransport::Local,
841                    proc_allocation_mode: Default::default(),
842                })
843                .await
844                .unwrap();
845
846            let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap());
847            let leaked: &'static Arc<ProcMesh> = Box::leak(Box::new(proc_mesh));
848            Ok(Self {
849                proc_mesh: leaked,
850                actor_mesh: None,
851            })
852        }
853    }
854
855    #[async_trait]
856    impl Handler<Echo> for ProxyActor {
857        async fn handle(&mut self, cx: &Context<Self>, message: Echo) -> Result<(), anyhow::Error> {
858            if std::env::var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK").is_err() {
859                // test_proxy_mesh
860
861                let actor = self.actor_mesh.as_ref().unwrap().get(0).unwrap();
862
863                // For now, we reply directly to the client.
864                // We will support directly wiring up the meshes later.
865                let (tx, mut rx) = cx.open_port();
866
867                actor.send(cx, Echo(message.0, tx.bind()))?;
868                message.1.send(cx, rx.recv().await.unwrap())?;
869
870                Ok(())
871            } else {
872                // test_router_undeliverable_return
873
874                let actor: ActorRef<_> = self.actor_mesh.as_ref().unwrap().get(0).unwrap();
875                let (tx, mut rx) = cx.open_port::<String>();
876                actor.send(cx, Echo(message.0, tx.bind()))?;
877
878                use tokio::time::Duration;
879                use tokio::time::timeout;
880                #[allow(clippy::disallowed_methods)]
881                if timeout(Duration::from_secs(1), rx.recv()).await.is_ok() {
882                    message
883                        .1
884                        .send(cx, "the impossible happened".to_owned())
885                        .unwrap()
886                }
887
888                Ok(())
889            }
890        }
891    }
892    #[async_trait]
893    impl Handler<MeshFailure> for ProxyActor {
894        async fn handle(
895            &mut self,
896            _cx: &Context<Self>,
897            message: MeshFailure,
898        ) -> Result<(), anyhow::Error> {
899            panic!("unhandled supervision failure: {}", message);
900        }
901    }
902}
903
904#[cfg(test)]
905mod tests {
906    use std::sync::Arc;
907
908    use hyperactor::ActorId;
909    use hyperactor::PortRef;
910    use hyperactor::ProcId;
911    use hyperactor::WorldId;
912    use hyperactor_config::attrs::Attrs;
913    use timed_test::async_timed_test;
914    use wirevalue::Encoding;
915
916    use super::*;
917    use crate::proc_mesh::ProcEvent;
918
919    // These tests are parametric over allocators.
920    #[macro_export]
921    macro_rules! actor_mesh_test_suite {
922        ($allocator:expr) => {
923            use std::assert_matches::assert_matches;
924
925            use ndslice::extent;
926            use $crate::alloc::AllocSpec;
927            use $crate::alloc::Allocator;
928            use $crate::assign::Ranks;
929            use $crate::sel_from_shape;
930            use $crate::sel;
931            use $crate::comm::multicast::set_cast_info_on_headers;
932            use $crate::proc_mesh::SharedSpawnable;
933            use std::collections::VecDeque;
934            use $crate::proc_mesh::default_transport;
935
936            use super::*;
937            use super::test_util::*;
938
939            #[tokio::test]
940            async fn test_proxy_mesh() {
941                use super::test_util::*;
942                use $crate::alloc::AllocSpec;
943                use $crate::alloc::Allocator;
944
945                use ndslice::extent;
946
947                let alloc = $allocator
948                    .allocate(AllocSpec {
949                        extent: extent! { replica = 1 },
950                        constraints: Default::default(),
951                        proc_name: None,
952                        transport: default_transport(),
953                        proc_allocation_mode: Default::default(),
954                    })
955                    .await
956                    .unwrap();
957                let instance = $crate::v1::testing::instance();
958                let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
959                let actor_mesh: RootActorMesh<'_, ProxyActor> = proc_mesh.spawn(&instance, "proxy", &()).await.unwrap();
960                let proxy_actor = actor_mesh.get(0).unwrap();
961                let (tx, mut rx) = actor_mesh.open_port::<String>();
962                proxy_actor.send(proc_mesh.client(), Echo("hello!".to_owned(), tx.bind())).unwrap();
963
964                #[allow(clippy::disallowed_methods)]
965                match tokio::time::timeout(tokio::time::Duration::from_secs(3), rx.recv()).await {
966                    Ok(msg) => assert_eq!(&msg.unwrap(), "hello!"),
967                    Err(_) =>  assert!(false),
968                }
969            }
970
971            #[tokio::test]
972            async fn test_basic() {
973                let alloc = $allocator
974                    .allocate(AllocSpec {
975                        extent: extent!(replica = 4),
976                        constraints: Default::default(),
977                        proc_name: None,
978                        transport: default_transport(),
979                        proc_allocation_mode: Default::default(),
980                    })
981                    .await
982                    .unwrap();
983
984                let instance = $crate::v1::testing::instance();
985                let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
986                let actor_mesh: RootActorMesh<TestActor> = proc_mesh.spawn(&instance, "echo", &()).await.unwrap();
987                let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
988                actor_mesh
989                    .cast(proc_mesh.client(), sel!(*), Echo("Hello".to_string(), reply_handle.bind()))
990                    .unwrap();
991                for _ in 0..4 {
992                    assert_eq!(&reply_receiver.recv().await.unwrap(), "Hello");
993                }
994            }
995
996            #[tokio::test]
997            async fn test_ping_pong() {
998                use hyperactor::test_utils::pingpong::PingPongActor;
999                use hyperactor::test_utils::pingpong::PingPongMessage;
1000
1001                let alloc = $allocator
1002                    .allocate(AllocSpec {
1003                        extent: extent!(replica = 2),
1004                        constraints: Default::default(),
1005                        proc_name: None,
1006                        transport: default_transport(),
1007                        proc_allocation_mode: Default::default(),
1008                    })
1009                    .await
1010                    .unwrap();
1011                let instance = $crate::v1::testing::instance();
1012                let mesh = ProcMesh::allocate(alloc).await.unwrap();
1013
1014                let (undeliverable_msg_tx, _) = mesh.client().open_port();
1015                let actor_mesh: RootActorMesh<PingPongActor> = mesh
1016                    .spawn(&instance, "ping-pong", &(Some(undeliverable_msg_tx.bind()), None, None))
1017                    .await
1018                    .unwrap();
1019
1020                let ping: ActorRef<PingPongActor> = actor_mesh.get(0).unwrap();
1021                let pong: ActorRef<PingPongActor> = actor_mesh.get(1).unwrap();
1022                let (done_tx, done_rx) = mesh.client().open_once_port();
1023                ping.send(mesh.client(), PingPongMessage(4, pong.clone(), done_tx.bind())).unwrap();
1024
1025                assert!(done_rx.recv().await.unwrap());
1026            }
1027
1028            #[tokio::test]
1029            async fn test_pingpong_full_mesh() {
1030                use hyperactor::test_utils::pingpong::PingPongActor;
1031                use hyperactor::test_utils::pingpong::PingPongMessage;
1032
1033                use futures::future::join_all;
1034
1035                const X: usize = 3;
1036                const Y: usize = 3;
1037                const Z: usize = 3;
1038                let alloc = $allocator
1039                    .allocate(AllocSpec {
1040                        extent: extent!(x = X, y = Y, z = Z),
1041                        constraints: Default::default(),
1042                        proc_name: None,
1043                        transport: default_transport(),
1044                        proc_allocation_mode: Default::default(),
1045                    })
1046                    .await
1047                    .unwrap();
1048
1049                let instance = $crate::v1::testing::instance();
1050                let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1051                let (undeliverable_tx, _undeliverable_rx) = proc_mesh.client().open_port();
1052                let actor_mesh: RootActorMesh<PingPongActor> = proc_mesh.spawn(&instance, "pingpong", &(Some(undeliverable_tx.bind()), None, None)).await.unwrap();
1053                let slice = actor_mesh.shape().slice();
1054
1055                let mut futures = Vec::new();
1056                for rank in slice.iter() {
1057                    let actor = actor_mesh.get(rank).unwrap();
1058                    let coords = (&slice.coordinates(rank).unwrap()[..]).try_into().unwrap();
1059                    let sizes = (&slice.sizes())[..].try_into().unwrap();
1060                    let neighbors = ndslice::utils::stencil::moore_neighbors::<3>();
1061                    for neighbor_coords in ndslice::utils::apply_stencil(&coords, sizes, &neighbors) {
1062                        if let Ok(neighbor_rank) = slice.location(&neighbor_coords) {
1063                            let neighbor = actor_mesh.get(neighbor_rank).unwrap();
1064                            let (done_tx, done_rx) = proc_mesh.client().open_once_port();
1065                            actor
1066                                .send(
1067                                    proc_mesh.client(),
1068                                    PingPongMessage(4, neighbor.clone(), done_tx.bind()),
1069                                )
1070                                .unwrap();
1071                            futures.push(done_rx.recv());
1072                        }
1073                    }
1074                }
1075                let results = join_all(futures).await;
1076                assert_eq!(results.len(), 316); // 5180 messages
1077                for result in results {
1078                    assert_eq!(result.unwrap(), true);
1079                }
1080            }
1081
1082            #[tokio::test]
1083            async fn test_cast() {
1084                let alloc = $allocator
1085                    .allocate(AllocSpec {
1086                        extent: extent!(replica = 2, host = 2, gpu = 8),
1087                        constraints: Default::default(),
1088                        proc_name: None,
1089                        transport: default_transport(),
1090                        proc_allocation_mode: Default::default(),
1091                    })
1092                    .await
1093                    .unwrap();
1094
1095                let instance = $crate::v1::testing::instance();
1096                let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1097                let actor_mesh: RootActorMesh<TestActor> = proc_mesh.spawn(&instance, "echo", &()).await.unwrap();
1098                let dont_simulate_error = true;
1099                let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
1100                actor_mesh
1101                    .cast(proc_mesh.client(), sel!(*), GetRank(dont_simulate_error, reply_handle.bind()))
1102                    .unwrap();
1103                let mut ranks = Ranks::new(actor_mesh.shape().slice().len());
1104                while !ranks.is_full() {
1105                    let rank = reply_receiver.recv().await.unwrap();
1106                    assert!(ranks.insert(rank, rank).is_none(), "duplicate rank {rank}");
1107                }
1108                // Retrieve all GPUs on replica=0, host=0
1109                let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
1110                actor_mesh
1111                    .cast(
1112                        proc_mesh.client(),
1113                        sel_from_shape!(actor_mesh.shape(), replica = 0, host = 0),
1114                        GetRank(dont_simulate_error, reply_handle.bind()),
1115                    )
1116                    .unwrap();
1117                let mut ranks = Ranks::new(8);
1118                while !ranks.is_full() {
1119                    let rank = reply_receiver.recv().await.unwrap();
1120                    assert!(ranks.insert(rank, rank).is_none(), "duplicate rank {rank}");
1121                }
1122            }
1123
1124            #[tokio::test]
1125            async fn test_inter_actor_comms() {
1126                let alloc = $allocator
1127                    .allocate(AllocSpec {
1128                        // Sizes intentionally small to keep the time
1129                        // required for this test in the process case
1130                        // reasonable (< 60s).
1131                        extent: extent!(replica = 2, host = 2, gpu = 8),
1132                        constraints: Default::default(),
1133                        proc_name: None,
1134                        transport: default_transport(),
1135                        proc_allocation_mode: Default::default(),
1136                    })
1137                    .await
1138                    .unwrap();
1139
1140                let instance = $crate::v1::testing::instance();
1141                let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1142                let actor_mesh: RootActorMesh<TestActor> = proc_mesh.spawn(&instance, "echo", &()).await.unwrap();
1143
1144                // Bounce the message through all actors and return it to the sender (us).
1145                let mut hops: VecDeque<_> = actor_mesh.iter().map(|actor| actor.port()).collect();
1146                let (handle, mut rx) = proc_mesh.client().open_port();
1147                hops.push_back(handle.bind());
1148                hops.pop_front()
1149                    .unwrap()
1150                    .send(proc_mesh.client(), Relay(0, hops))
1151                    .unwrap();
1152                assert_matches!(
1153                    rx.recv().await.unwrap(),
1154                    Relay(count, hops)
1155                        if count == actor_mesh.shape().slice().len()
1156                        && hops.is_empty());
1157            }
1158
1159            #[tokio::test]
1160            async fn test_inter_proc_mesh_comms() {
1161                let mut meshes = Vec::new();
1162                let instance = $crate::v1::testing::instance();
1163                for _ in 0..2 {
1164                    let alloc = $allocator
1165                        .allocate(AllocSpec {
1166                            extent: extent!(replica = 1),
1167                            constraints: Default::default(),
1168                            proc_name: None,
1169                            transport: default_transport(),
1170                            proc_allocation_mode: Default::default(),
1171                        })
1172                        .await
1173                        .unwrap();
1174
1175                    let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap());
1176                    let proc_mesh_clone = Arc::clone(&proc_mesh);
1177                    let actor_mesh : RootActorMesh<TestActor> = proc_mesh_clone.spawn(&instance, "echo", &()).await.unwrap();
1178                    meshes.push((proc_mesh, actor_mesh));
1179                }
1180
1181                let mut hops: VecDeque<_> = meshes
1182                    .iter()
1183                    .flat_map(|(_proc_mesh, actor_mesh)| actor_mesh.iter())
1184                    .map(|actor| actor.port())
1185                    .collect();
1186                let num_hops = hops.len();
1187
1188                let client = meshes[0].0.client();
1189                let (handle, mut rx) = client.open_port();
1190                hops.push_back(handle.bind());
1191                hops.pop_front()
1192                    .unwrap()
1193                    .send(client, Relay(0, hops))
1194                    .unwrap();
1195                assert_matches!(
1196                    rx.recv().await.unwrap(),
1197                    Relay(count, hops)
1198                        if count == num_hops
1199                        && hops.is_empty());
1200            }
1201
1202            #[async_timed_test(timeout_secs = 60)]
1203            async fn test_actor_mesh_cast() {
1204                // Verify a full broadcast in the mesh. Send a message
1205                // to every actor and check each actor receives it.
1206
1207                use $crate::sel;
1208                use $crate::comm::test_utils::TestActor as CastTestActor;
1209                use $crate::comm::test_utils::TestActorParams as CastTestActorParams;
1210                use $crate::comm::test_utils::TestMessage as CastTestMessage;
1211
1212                let extent = extent!(replica = 4, host = 4, gpu = 4);
1213                let num_actors = extent.len();
1214                let alloc = $allocator
1215                    .allocate(AllocSpec {
1216                        extent,
1217                        constraints: Default::default(),
1218                        proc_name: None,
1219                        transport: default_transport(),
1220                        proc_allocation_mode: Default::default(),
1221                    })
1222                    .await
1223                    .unwrap();
1224
1225                let instance = $crate::v1::testing::instance();
1226                let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1227
1228                let (tx, mut rx) = hyperactor::mailbox::open_port(proc_mesh.client());
1229                let params = CastTestActorParams{ forward_port: tx.bind() };
1230                let actor_mesh: RootActorMesh<CastTestActor> = proc_mesh.spawn(&instance, "actor", &params).await.unwrap();
1231
1232                actor_mesh.cast(proc_mesh.client(), sel!(*), CastTestMessage::Forward("abc".to_string())).unwrap();
1233
1234                for _ in 0..num_actors {
1235                    assert_eq!(rx.recv().await.unwrap(), CastTestMessage::Forward("abc".to_string()));
1236                }
1237
1238                // Attempt to avoid this intermittent fatal error.
1239                // ⚠ Fatal: monarch/hyperactor_mesh:hyperactor_mesh-unittest - \
1240                //            actor_mesh::tests::sim::test_actor_mesh_cast (2.5s)
1241                // Test appears to have passed but the binary exited with a non-zero exit code.
1242                proc_mesh.events().unwrap().into_alloc().stop_and_wait().await.unwrap();
1243            }
1244
1245            #[tokio::test]
1246            async fn test_delivery_failure() {
1247                let alloc = $allocator
1248                    .allocate(AllocSpec {
1249                        extent: extent!(replica = 1 ),
1250                        constraints: Default::default(),
1251                        proc_name: None,
1252                        transport: default_transport(),
1253                        proc_allocation_mode: Default::default(),
1254                    })
1255                    .await
1256                    .unwrap();
1257
1258                let name = alloc.name().to_string();
1259                let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
1260                let mut events = mesh.events().unwrap();
1261
1262                // Send a message to a non-existent actor (the proc however exists).
1263                let unmonitored_reply_to = mesh.client().open_port::<usize>().0.bind();
1264                let bad_actor = ActorRef::<TestActor>::attest(ActorId(ProcId::Ranked(WorldId(name.clone()), 0), "foo".into(), 0));
1265                bad_actor.send(mesh.client(), GetRank(true, unmonitored_reply_to)).unwrap();
1266
1267                // The message will be returned!
1268                assert_matches!(
1269                    events.next().await.unwrap(),
1270                    ProcEvent::Crashed(0, reason) if reason.contains("message not delivered")
1271                );
1272
1273                // TODO: Stop the proc.
1274            }
1275
1276            #[tokio::test]
1277            async fn test_send_with_headers() {
1278                let extent = extent!(replica = 3);
1279                let alloc = $allocator
1280                    .allocate(AllocSpec {
1281                        extent: extent.clone(),
1282                        constraints: Default::default(),
1283                        proc_name: None,
1284                        transport: default_transport(),
1285                        proc_allocation_mode: Default::default(),
1286                    })
1287                    .await
1288                    .unwrap();
1289
1290                let instance = $crate::v1::testing::instance();
1291                let mesh = ProcMesh::allocate(alloc).await.unwrap();
1292                let (reply_port_handle, mut reply_port_receiver) = mesh.client().open_port::<usize>();
1293                let reply_port = reply_port_handle.bind();
1294
1295                let actor_mesh: RootActorMesh<TestActor> = mesh.spawn(&instance, "test", &()).await.unwrap();
1296                let actor_ref = actor_mesh.get(0).unwrap();
1297                let mut headers = Attrs::new();
1298                set_cast_info_on_headers(&mut headers, extent.point_of_rank(0).unwrap(), mesh.client().self_id().clone());
1299                actor_ref.send_with_headers(mesh.client(), headers.clone(), GetRank(true, reply_port.clone())).unwrap();
1300                assert_eq!(0, reply_port_receiver.recv().await.unwrap());
1301
1302                set_cast_info_on_headers(&mut headers, extent.point_of_rank(1).unwrap(), mesh.client().self_id().clone());
1303                actor_ref.port()
1304                    .send_with_headers(mesh.client(), headers.clone(), GetRank(true, reply_port.clone()))
1305                    .unwrap();
1306                assert_eq!(1, reply_port_receiver.recv().await.unwrap());
1307
1308                set_cast_info_on_headers(&mut headers, extent.point_of_rank(2).unwrap(), mesh.client().self_id().clone());
1309                actor_ref.actor_id()
1310                    .port_id(GetRank::port())
1311                    .send_with_headers(
1312                        mesh.client(),
1313                        wirevalue::Any::serialize(&GetRank(true, reply_port)).unwrap(),
1314                        headers
1315                    );
1316                assert_eq!(2, reply_port_receiver.recv().await.unwrap());
1317                // TODO: Stop the proc.
1318            }
1319        }
1320    }
1321
1322    mod local {
1323        use hyperactor::channel::ChannelTransport;
1324
1325        use crate::alloc::local::LocalAllocator;
1326
1327        actor_mesh_test_suite!(LocalAllocator);
1328
1329        #[tokio::test]
1330        async fn test_send_failure() {
1331            hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default());
1332
1333            use hyperactor::test_utils::pingpong::PingPongActor;
1334            use hyperactor::test_utils::pingpong::PingPongMessage;
1335
1336            use crate::alloc::ProcStopReason;
1337            use crate::proc_mesh::ProcEvent;
1338
1339            let config = hyperactor_config::global::lock();
1340            let _guard = config.override_key(
1341                hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1342                tokio::time::Duration::from_secs(1),
1343            );
1344
1345            let alloc = LocalAllocator
1346                .allocate(AllocSpec {
1347                    extent: extent!(replica = 2),
1348                    constraints: Default::default(),
1349                    proc_name: None,
1350                    transport: ChannelTransport::Local,
1351                    proc_allocation_mode: Default::default(),
1352                })
1353                .await
1354                .unwrap();
1355            let instance = crate::v1::testing::instance();
1356            let monkey = alloc.chaos_monkey();
1357            let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
1358            let mut events = mesh.events().unwrap();
1359
1360            let actor_mesh: RootActorMesh<PingPongActor> = mesh
1361                .spawn(
1362                    &instance,
1363                    "ping-pong",
1364                    &(
1365                        Some(PortRef::attest_message_port(mesh.client().self_id())),
1366                        None,
1367                        None,
1368                    ),
1369                )
1370                .await
1371                .unwrap();
1372
1373            let ping: ActorRef<PingPongActor> = actor_mesh.get(0).unwrap();
1374            let pong: ActorRef<PingPongActor> = actor_mesh.get(1).unwrap();
1375
1376            // Kill ping.
1377            monkey(0, ProcStopReason::Killed(0, false));
1378            assert_matches!(
1379                events.next().await.unwrap(),
1380                ProcEvent::Stopped(0, ProcStopReason::Killed(0, false))
1381            );
1382
1383            // Try to send a message to 'ping'. Since 'ping's mailbox
1384            // is stopped, the send will timeout and fail.
1385            let (unmonitored_done_tx, _) = mesh.client().open_once_port();
1386            ping.send(
1387                mesh.client(),
1388                PingPongMessage(1, pong.clone(), unmonitored_done_tx.bind()),
1389            )
1390            .unwrap();
1391
1392            // The message will be returned!
1393            assert_matches!(
1394                events.next().await.unwrap(),
1395                ProcEvent::Crashed(0, reason) if reason.contains("message not delivered")
1396            );
1397
1398            // Get 'pong' to send 'ping' a message. Since 'ping's
1399            // mailbox is stopped, the send will timeout and fail.
1400            let (unmonitored_done_tx, _) = mesh.client().open_once_port();
1401            pong.send(
1402                mesh.client(),
1403                PingPongMessage(1, ping.clone(), unmonitored_done_tx.bind()),
1404            )
1405            .unwrap();
1406
1407            // The message will be returned!
1408            assert_matches!(
1409                events.next().await.unwrap(),
1410                ProcEvent::Crashed(0, reason) if reason.contains("message not delivered")
1411            );
1412        }
1413
1414        #[tokio::test]
1415        async fn test_cast_failure() {
1416            use crate::alloc::ProcStopReason;
1417            use crate::proc_mesh::ProcEvent;
1418            use crate::sel;
1419
1420            let alloc = LocalAllocator
1421                .allocate(AllocSpec {
1422                    extent: extent!(replica = 1),
1423                    constraints: Default::default(),
1424                    proc_name: None,
1425                    transport: ChannelTransport::Local,
1426                    proc_allocation_mode: Default::default(),
1427                })
1428                .await
1429                .unwrap();
1430            let instance = crate::v1::testing::instance();
1431
1432            let stop = alloc.stopper();
1433            let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
1434            let mut events = mesh.events().unwrap();
1435
1436            let actor_mesh: RootActorMesh<TestActor> =
1437                mesh.spawn(&instance, "reply-then-fail", &()).await.unwrap();
1438
1439            // `GetRank` with `false` means exit with error after
1440            // replying with rank.
1441            let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
1442            actor_mesh
1443                .cast(mesh.client(), sel!(*), GetRank(false, reply_handle.bind()))
1444                .unwrap();
1445            let rank = reply_receiver.recv().await.unwrap();
1446            assert_eq!(rank, 0);
1447
1448            // The above is expected to trigger a proc crash.
1449            assert_matches!(
1450                events.next().await.unwrap(),
1451                ProcEvent::Crashed(0, reason) if reason.contains("intentional error!")
1452            );
1453
1454            // Cast the message.
1455            let (reply_handle, _) = actor_mesh.open_port();
1456            actor_mesh
1457                .cast(mesh.client(), sel!(*), GetRank(false, reply_handle.bind()))
1458                .unwrap();
1459
1460            // The message will be returned!
1461            assert_matches!(
1462                events.next().await.unwrap(),
1463                ProcEvent::Crashed(0, reason) if reason.contains("message not delivered")
1464            );
1465
1466            // Stop the mesh.
1467            stop();
1468            assert_matches!(
1469                events.next().await.unwrap(),
1470                ProcEvent::Stopped(0, ProcStopReason::Stopped),
1471            );
1472            assert!(events.next().await.is_none());
1473        }
1474
1475        #[tracing_test::traced_test]
1476        #[tokio::test]
1477        async fn test_stop_actor_mesh() {
1478            use hyperactor::test_utils::pingpong::PingPongActor;
1479            use hyperactor::test_utils::pingpong::PingPongMessage;
1480
1481            let config = hyperactor_config::global::lock();
1482            let _guard = config.override_key(
1483                hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1484                tokio::time::Duration::from_secs(1),
1485            );
1486
1487            let alloc = LocalAllocator
1488                .allocate(AllocSpec {
1489                    extent: extent!(replica = 2),
1490                    constraints: Default::default(),
1491                    proc_name: None,
1492                    transport: ChannelTransport::Local,
1493                    proc_allocation_mode: Default::default(),
1494                })
1495                .await
1496                .unwrap();
1497            let instance = crate::v1::testing::instance();
1498            let mesh = ProcMesh::allocate(alloc).await.unwrap();
1499
1500            let mesh_one: RootActorMesh<PingPongActor> = mesh
1501                .spawn(
1502                    &instance,
1503                    "mesh_one",
1504                    &(
1505                        Some(PortRef::attest_message_port(mesh.client().self_id())),
1506                        None,
1507                        None,
1508                    ),
1509                )
1510                .await
1511                .unwrap();
1512
1513            let mesh_two: RootActorMesh<PingPongActor> = mesh
1514                .spawn(
1515                    &instance,
1516                    "mesh_two",
1517                    &(
1518                        Some(PortRef::attest_message_port(mesh.client().self_id())),
1519                        None,
1520                        None,
1521                    ),
1522                )
1523                .await
1524                .unwrap();
1525
1526            mesh_two.stop(&instance).await.unwrap();
1527
1528            let ping_two: ActorRef<PingPongActor> = mesh_two.get(0).unwrap();
1529            let pong_two: ActorRef<PingPongActor> = mesh_two.get(1).unwrap();
1530
1531            assert!(logs_contain(&format!(
1532                "stopped actor {}",
1533                ping_two.actor_id()
1534            )));
1535            assert!(logs_contain(&format!(
1536                "stopped actor {}",
1537                pong_two.actor_id()
1538            )));
1539
1540            // Other actor meshes on this proc mesh should still be up and running
1541            let ping_one: ActorRef<PingPongActor> = mesh_one.get(0).unwrap();
1542            let pong_one: ActorRef<PingPongActor> = mesh_one.get(1).unwrap();
1543            let (done_tx, done_rx) = mesh.client().open_once_port();
1544            pong_one
1545                .send(
1546                    mesh.client(),
1547                    PingPongMessage(1, ping_one.clone(), done_tx.bind()),
1548                )
1549                .unwrap();
1550            assert!(done_rx.recv().await.is_ok());
1551        }
1552    } // mod local
1553
1554    mod process {
1555
1556        use bytes::Bytes;
1557        use hyperactor::PortId;
1558        use hyperactor::channel::ChannelTransport;
1559        use hyperactor::clock::Clock;
1560        use hyperactor::clock::RealClock;
1561        use hyperactor::mailbox::MessageEnvelope;
1562        use rand::Rng;
1563        use tokio::process::Command;
1564
1565        use crate::alloc::process::ProcessAllocator;
1566
1567        #[cfg(fbcode_build)]
1568        fn process_allocator() -> ProcessAllocator {
1569            ProcessAllocator::new(Command::new(crate::testresource::get(
1570                "monarch/hyperactor_mesh/bootstrap",
1571            )))
1572        }
1573
1574        #[cfg(fbcode_build)] // we use an external binary, produced by buck
1575        actor_mesh_test_suite!(process_allocator());
1576
1577        // This test is concerned with correctly reporting failures
1578        // when message sizes exceed configured limits.
1579        #[cfg(fbcode_build)]
1580        //#[tracing_test::traced_test]
1581        #[async_timed_test(timeout_secs = 30)]
1582        async fn test_oversized_frames() {
1583            // Reproduced from 'net.rs'.
1584            #[derive(Debug, Serialize, Deserialize, PartialEq)]
1585            enum Frame<M> {
1586                Init(u64),
1587                Message(u64, M),
1588            }
1589            // Calculate the frame length for the given message.
1590            fn frame_length(src: &ActorId, dst: &PortId, pay: &Payload) -> usize {
1591                let serialized = wirevalue::Any::serialize(pay).unwrap();
1592                let mut headers = Attrs::new();
1593                hyperactor::mailbox::headers::set_send_timestamp(&mut headers);
1594                hyperactor::mailbox::headers::set_rust_message_type::<Payload>(&mut headers);
1595                let envelope = MessageEnvelope::new(src.clone(), dst.clone(), serialized, headers);
1596                let frame = Frame::Message(0u64, envelope);
1597                let message = serde_multipart::serialize_bincode(&frame).unwrap();
1598                message.frame_len()
1599            }
1600
1601            // This process: short delivery timeout.
1602            let config = hyperactor_config::global::lock();
1603            // This process (write): max frame len for frame writes.
1604            let _guard2 =
1605                config.override_key(hyperactor::config::CODEC_MAX_FRAME_LENGTH, 1024usize);
1606            // Remote process (read): max frame len for frame reads.
1607            // SAFETY: Ok here but not safe for concurrent access.
1608            unsafe {
1609                std::env::set_var("HYPERACTOR_CODEC_MAX_FRAME_LENGTH", "1024");
1610            };
1611            let _guard3 =
1612                config.override_key(wirevalue::config::DEFAULT_ENCODING, Encoding::Bincode);
1613
1614            let alloc = process_allocator()
1615                .allocate(AllocSpec {
1616                    extent: extent!(replica = 1),
1617                    constraints: Default::default(),
1618                    proc_name: None,
1619                    transport: ChannelTransport::Unix,
1620                    proc_allocation_mode: Default::default(),
1621                })
1622                .await
1623                .unwrap();
1624            let instance = crate::v1::testing::instance();
1625            let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1626            let mut proc_events = proc_mesh.events().unwrap();
1627            let actor_mesh: RootActorMesh<TestActor> =
1628                proc_mesh.spawn(&instance, "ingest", &()).await.unwrap();
1629            let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
1630            let dest = actor_mesh.get(0).unwrap();
1631
1632            // Message sized to exactly max frame length.
1633            let payload = Payload {
1634                part: Part::from(Bytes::from(vec![0u8; 586])),
1635                reply_port: reply_handle.bind(),
1636            };
1637            let frame_len = frame_length(
1638                proc_mesh.client().self_id(),
1639                dest.port::<Payload>().port_id(),
1640                &payload,
1641            );
1642            assert_eq!(frame_len, 1024);
1643
1644            // Send direct. A cast message is > 1024 bytes.
1645            dest.send(proc_mesh.client(), payload).unwrap();
1646            #[allow(clippy::disallowed_methods)]
1647            let result = RealClock
1648                .timeout(Duration::from_secs(2), reply_receiver.recv())
1649                .await;
1650            assert!(result.is_ok(), "Operation should not time out");
1651
1652            // Message sized to max frame length + 1.
1653            let payload = Payload {
1654                part: Part::from(Bytes::from(vec![0u8; 587])),
1655                reply_port: reply_handle.bind(),
1656            };
1657            let frame_len = frame_length(
1658                proc_mesh.client().self_id(),
1659                dest.port::<Payload>().port_id(),
1660                &payload,
1661            );
1662            assert_eq!(frame_len, 1025); // over the max frame len
1663
1664            // Send direct or cast. Either are guaranteed over the
1665            // limit and will fail.
1666            if rand::thread_rng().gen_bool(0.5) {
1667                dest.send(proc_mesh.client(), payload).unwrap();
1668            } else {
1669                actor_mesh
1670                    .cast(proc_mesh.client(), sel!(*), payload)
1671                    .unwrap();
1672            }
1673
1674            // The undeliverable supervision event that happens next
1675            // does not depend on a timeout.
1676            {
1677                let event = proc_events.next().await.unwrap();
1678                assert_matches!(
1679                    event,
1680                    ProcEvent::Crashed(_, _),
1681                    "Should have received crash event"
1682                );
1683            }
1684        }
1685
1686        // Set this test only for `mod process` because it relies on a
1687        // trick to emulate router failure that only works when using
1688        // non-local allocators.
1689        #[cfg(fbcode_build)]
1690        #[tokio::test]
1691        async fn test_router_undeliverable_return() {
1692            // Test that an undeliverable message received by a
1693            // router results in actor mesh supervision events.
1694            use ndslice::extent;
1695
1696            use super::test_util::*;
1697            use crate::alloc::AllocSpec;
1698            use crate::alloc::Allocator;
1699
1700            let alloc = process_allocator()
1701                .allocate(AllocSpec {
1702                    extent: extent! { replica = 1 },
1703                    constraints: Default::default(),
1704                    proc_name: None,
1705                    transport: ChannelTransport::Unix,
1706                    proc_allocation_mode: Default::default(),
1707                })
1708                .await
1709                .unwrap();
1710
1711            // SAFETY: Not multithread safe.
1712            unsafe { std::env::set_var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK", "1") };
1713
1714            let instance = crate::v1::testing::instance();
1715            let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1716            let mut proc_events = proc_mesh.events().unwrap();
1717            let mut actor_mesh: RootActorMesh<'_, ProxyActor> =
1718                { proc_mesh.spawn(&instance, "proxy", &()).await.unwrap() };
1719            let mut actor_events = actor_mesh.events().unwrap();
1720
1721            let proxy_actor = actor_mesh.get(0).unwrap();
1722            let (tx, mut rx) = actor_mesh.open_port::<String>();
1723            proxy_actor
1724                .send(proc_mesh.client(), Echo("hello!".to_owned(), tx.bind()))
1725                .unwrap();
1726
1727            #[allow(clippy::disallowed_methods)]
1728            match tokio::time::timeout(tokio::time::Duration::from_secs(3), rx.recv()).await {
1729                Ok(_) => panic!("the impossible happened"),
1730                Err(_) => {
1731                    assert_matches!(
1732                        proc_events.next().await.unwrap(),
1733                        ProcEvent::Crashed(0, reason) if reason.contains("undeliverable")
1734                    );
1735                    assert_eq!(
1736                        actor_events.next().await.unwrap().actor_id.name(),
1737                        actor_mesh.name(),
1738                    );
1739                }
1740            }
1741
1742            // SAFETY: Not multithread safe.
1743            unsafe { std::env::remove_var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK") };
1744        }
1745    }
1746
1747    mod sim {
1748        use crate::alloc::sim::SimAllocator;
1749
1750        actor_mesh_test_suite!(SimAllocator::new_and_start_simnet());
1751    }
1752
1753    mod reshape_cast {
1754        use async_trait::async_trait;
1755        use hyperactor::Actor;
1756        use hyperactor::Context;
1757        use hyperactor::Handler;
1758        use hyperactor::RemoteSpawn;
1759        use hyperactor::channel::ChannelAddr;
1760        use hyperactor::channel::ChannelTransport;
1761        use hyperactor::channel::ChannelTx;
1762        use hyperactor::channel::Rx;
1763        use hyperactor::channel::Tx;
1764        use hyperactor::channel::dial;
1765        use hyperactor::channel::serve;
1766        use hyperactor::clock::Clock;
1767        use hyperactor::clock::RealClock;
1768        use ndslice::Extent;
1769        use ndslice::Selection;
1770
1771        use crate::Mesh;
1772        use crate::ProcMesh;
1773        use crate::RootActorMesh;
1774        use crate::actor_mesh::ActorMesh;
1775        use crate::alloc::AllocSpec;
1776        use crate::alloc::Allocator;
1777        use crate::alloc::LocalAllocator;
1778        use crate::config::MAX_CAST_DIMENSION_SIZE;
1779
1780        #[derive(Debug)]
1781        #[hyperactor::export(
1782            spawn = true,
1783            handlers = [() { cast = true }],
1784        )]
1785        struct EchoActor(ChannelTx<usize>);
1786
1787        #[async_trait]
1788        impl Actor for EchoActor {}
1789
1790        #[async_trait]
1791        impl RemoteSpawn for EchoActor {
1792            type Params = ChannelAddr;
1793
1794            async fn new(params: ChannelAddr) -> Result<Self, anyhow::Error> {
1795                Ok(Self(dial::<usize>(params)?))
1796            }
1797        }
1798
1799        #[async_trait]
1800        impl Handler<()> for EchoActor {
1801            async fn handle(
1802                &mut self,
1803                cx: &Context<Self>,
1804                _message: (),
1805            ) -> Result<(), anyhow::Error> {
1806                let Self(port) = self;
1807                port.post(cx.self_id().rank());
1808                Ok(())
1809            }
1810        }
1811
1812        async fn validate_cast<A>(
1813            actor_mesh: &A,
1814            caps: &impl hyperactor::context::Actor,
1815            addr: ChannelAddr,
1816            selection: Selection,
1817        ) where
1818            A: ActorMesh<Actor = EchoActor>,
1819        {
1820            let config = hyperactor_config::global::lock();
1821            let _guard = config.override_key(MAX_CAST_DIMENSION_SIZE, 2);
1822
1823            let (_, mut rx) = serve::<usize>(addr).unwrap();
1824
1825            let expected_ranks = selection
1826                .eval(
1827                    &ndslice::selection::EvalOpts::strict(),
1828                    actor_mesh.shape().slice(),
1829                )
1830                .unwrap()
1831                .collect::<std::collections::BTreeSet<_>>();
1832
1833            actor_mesh.cast(caps, selection, ()).unwrap();
1834
1835            let mut received = std::collections::BTreeSet::new();
1836
1837            for _ in 0..(expected_ranks.len()) {
1838                received.insert(
1839                    RealClock
1840                        .timeout(tokio::time::Duration::from_secs(1), rx.recv())
1841                        .await
1842                        .unwrap()
1843                        .unwrap(),
1844                );
1845            }
1846
1847            assert_eq!(received, expected_ranks);
1848        }
1849
1850        use ndslice::strategy::gen_extent;
1851        use ndslice::strategy::gen_selection;
1852        use proptest::prelude::*;
1853        use proptest::test_runner::TestRunner;
1854
1855        fn make_tokio_runtime() -> tokio::runtime::Runtime {
1856            tokio::runtime::Builder::new_multi_thread()
1857                .enable_all()
1858                .worker_threads(2)
1859                .build()
1860                .unwrap()
1861        }
1862
1863        proptest! {
1864            #![proptest_config(ProptestConfig {
1865                cases: 8, ..ProptestConfig::default()
1866            })]
1867            #[test]
1868            fn test_reshaped_actor_mesh_cast(extent in gen_extent(1..=4, 8)) {
1869                let runtime = make_tokio_runtime();
1870                async fn inner(extent: Extent) {
1871                    let alloc = LocalAllocator
1872                        .allocate(AllocSpec {
1873                            extent,
1874                            constraints: Default::default(),
1875                            proc_name: None,
1876                            transport: ChannelTransport::Local,
1877                            proc_allocation_mode: Default::default(),
1878                        }).await
1879                        .unwrap();
1880                    let instance = crate::v1::testing::instance();
1881                    let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1882                    let addr = ChannelAddr::any(ChannelTransport::Unix);
1883                    let actor_mesh: RootActorMesh<EchoActor> =
1884                        proc_mesh.spawn(&instance, "echo", &addr).await.unwrap();
1885                    let mut runner = TestRunner::default();
1886                    let selection = gen_selection(4, actor_mesh.shape().slice().sizes().to_vec(), 0)
1887                        .new_tree(&mut runner)
1888                        .unwrap()
1889                        .current();
1890                    validate_cast(&actor_mesh, actor_mesh.proc_mesh().client(), addr, selection).await;
1891                }
1892                runtime.block_on(inner(extent));
1893            }
1894        }
1895
1896        proptest! {
1897            #![proptest_config(ProptestConfig {
1898                cases: 8, ..ProptestConfig::default()
1899            })]
1900            #[test]
1901            fn test_reshaped_actor_mesh_slice_cast(extent in gen_extent(1..=4, 8)) {
1902                let runtime = make_tokio_runtime();
1903                async fn inner(extent: Extent) {
1904                    let alloc = LocalAllocator
1905                        .allocate(AllocSpec {
1906                            extent: extent.clone(),
1907                            constraints: Default::default(),
1908                            proc_name: None,
1909                            transport: ChannelTransport::Local,
1910                            proc_allocation_mode: Default::default(),
1911                        }).await
1912                        .unwrap();
1913                    let instance = crate::v1::testing::instance();
1914                    let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1915
1916                    let addr = ChannelAddr::any(ChannelTransport::Unix);
1917
1918                    let actor_mesh: RootActorMesh<EchoActor> =
1919                        proc_mesh.spawn(&instance, "echo", &addr).await.unwrap();
1920
1921
1922                    let first_label = extent.labels().first().unwrap();
1923                    let slice = actor_mesh.select(first_label, 0..extent.size(first_label).unwrap()).unwrap();
1924
1925                    // Unfortunately we must do things this way due to borrow checker reasons
1926                    let slice = if extent.len() >= 2 {
1927                        let label = &extent.labels()[1];
1928                        let size = extent.size(label).unwrap();
1929                        let start = if size > 1 { 1 } else { 0 };
1930                        let end = (if size > 1 { size - 1 } else { 1 }).max(start + 1);
1931                        slice.select(label, start..end).unwrap()
1932                    } else {
1933                        slice
1934                    };
1935
1936                    let slice = if extent.len() >= 3 {
1937                        let label = &extent.labels()[2];
1938                        let size = extent.size(label).unwrap();
1939                        let start = if size > 1 { 1 } else { 0 };
1940                        let end = (if size > 1 { size - 1 } else { 1 }).max(start + 1);
1941                        slice.select(label, start..end).unwrap()
1942                    } else {
1943                        slice
1944                    };
1945
1946                    let slice = if extent.len() >= 4 {
1947                        let label = &extent.labels()[3];
1948                        let size = extent.size(label).unwrap();
1949                        let start = if size > 1 { 1 } else { 0 };
1950                        let end = (if size > 1 { size - 1 } else { 1 }).max(start + 1);
1951                        slice.select(label, start..end).unwrap()
1952                    } else {
1953                        slice
1954                    };
1955
1956
1957                    let mut runner = TestRunner::default();
1958                    let selection = gen_selection(4, slice.shape().slice().sizes().to_vec(), 0)
1959                        .new_tree(&mut runner)
1960                        .unwrap()
1961                        .current();
1962
1963                    validate_cast(
1964                        &slice,
1965                        actor_mesh.proc_mesh().client(),
1966                        addr,
1967                        selection
1968                    ).await;
1969                }
1970                runtime.block_on(inner(extent));
1971            }
1972        }
1973
1974        proptest! {
1975            #![proptest_config(ProptestConfig {
1976                cases: 8, ..ProptestConfig::default()
1977            })]
1978             #[test]
1979             fn test_reshaped_actor_mesh_cast_with_selection(extent in gen_extent(1..=4, 8)) {
1980                let runtime = make_tokio_runtime();
1981                async fn inner(extent: Extent) {
1982                    let alloc = LocalAllocator
1983                        .allocate(AllocSpec {
1984                            extent,
1985                            constraints: Default::default(),
1986                            proc_name: None,
1987                            transport: ChannelTransport::Local,
1988                            proc_allocation_mode: Default::default(),
1989                        }).await
1990                        .unwrap();
1991                    let instance = crate::v1::testing::instance();
1992                    let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1993
1994                    let addr = ChannelAddr::any(ChannelTransport::Unix);
1995
1996                    let actor_mesh: RootActorMesh<EchoActor> =
1997                        proc_mesh.spawn(&instance, "echo", &addr).await.unwrap();
1998
1999                    let mut runner = TestRunner::default();
2000                    let selection = gen_selection(4, actor_mesh.shape().slice().sizes().to_vec(), 0)
2001                        .new_tree(&mut runner)
2002                        .unwrap()
2003                        .current();
2004
2005                    validate_cast(
2006                        &actor_mesh,
2007                        actor_mesh.proc_mesh().client(),
2008                        addr,
2009                        selection
2010                    ).await;
2011                }
2012                runtime.block_on(inner(extent));
2013            }
2014        }
2015    }
2016
2017    mod shim {
2018        use std::collections::HashSet;
2019
2020        use hyperactor::context::Mailbox;
2021        use ndslice::Extent;
2022        use ndslice::extent;
2023
2024        use super::*;
2025        use crate::sel;
2026
2027        #[tokio::test]
2028        #[cfg(fbcode_build)]
2029        async fn test_basic() {
2030            let instance = v1::testing::instance();
2031            let host_mesh = v1::testing::host_mesh(extent!(host = 4)).await;
2032            let proc_mesh = host_mesh
2033                .spawn(instance, "test", Extent::unity())
2034                .await
2035                .unwrap();
2036            let actor_mesh: v1::ActorMesh<v1::testactor::TestActor> =
2037                proc_mesh.spawn(instance, "test", &()).await.unwrap();
2038
2039            let actor_mesh_v0: RootActorMesh<'_, _> = actor_mesh.clone().into();
2040
2041            let (cast_info, mut cast_info_rx) = instance.mailbox().open_port();
2042            actor_mesh_v0
2043                .cast(
2044                    instance,
2045                    sel!(*),
2046                    v1::testactor::GetCastInfo {
2047                        cast_info: cast_info.bind(),
2048                    },
2049                )
2050                .unwrap();
2051
2052            let mut point_to_actor: HashSet<_> = actor_mesh.iter().collect();
2053            while !point_to_actor.is_empty() {
2054                let (point, origin_actor_ref, sender_actor_id) = cast_info_rx.recv().await.unwrap();
2055                let key = (point, origin_actor_ref);
2056                assert!(
2057                    point_to_actor.remove(&key),
2058                    "key {:?} not present or removed twice",
2059                    key
2060                );
2061                assert_eq!(&sender_actor_id, instance.self_id());
2062            }
2063        }
2064    }
2065}