hyperactor_mesh/
reference.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::cmp::Ord;
10use std::cmp::PartialOrd;
11use std::hash::Hash;
12use std::marker::PhantomData;
13
14use hyperactor::ActorRef;
15use hyperactor::Named;
16use hyperactor::RemoteHandles;
17use hyperactor::RemoteMessage;
18use hyperactor::actor::RemoteActor;
19use hyperactor::cap;
20use hyperactor::message::Castable;
21use hyperactor::message::IndexedErasedUnbound;
22use ndslice::Range;
23use ndslice::Selection;
24use ndslice::Shape;
25use ndslice::ShapeError;
26use ndslice::selection::ReifySlice;
27use serde::Deserialize;
28use serde::Serialize;
29
30use crate::CommActor;
31use crate::actor_mesh::CastError;
32use crate::actor_mesh::actor_mesh_cast;
33use crate::actor_mesh::cast_to_sliced_mesh;
34
35#[macro_export]
36macro_rules! mesh_id {
37    ($proc_mesh:ident) => {
38        $crate::reference::ProcMeshId(stringify!($proc_mesh).to_string(), "0".into())
39    };
40    ($proc_mesh:ident . $actor_mesh:ident) => {
41        $crate::reference::ActorMeshId(
42            $crate::reference::ProcMeshId(stringify!($proc_mesh).to_string()),
43            stringify!($proc_mesh).to_string(),
44        )
45    };
46}
47
48#[derive(
49    Debug,
50    Serialize,
51    Deserialize,
52    Clone,
53    PartialEq,
54    Eq,
55    PartialOrd,
56    Hash,
57    Ord,
58    Named
59)]
60pub struct ProcMeshId(pub String);
61
62/// Actor Mesh ID.  Tuple of the ProcMesh ID and actor name.
63#[derive(
64    Debug,
65    Serialize,
66    Deserialize,
67    Clone,
68    PartialEq,
69    Eq,
70    PartialOrd,
71    Hash,
72    Ord,
73    Named
74)]
75pub struct ActorMeshId(pub ProcMeshId, pub String);
76
77/// Types references to Actor Meshes.
78#[derive(Debug, Serialize, Deserialize, PartialEq)]
79pub struct ActorMeshRef<A: RemoteActor> {
80    pub(crate) mesh_id: ActorMeshId,
81    /// The shape of the root mesh.
82    root: Shape,
83    /// If some, it mean this mesh ref points to a sliced mesh, and this field
84    /// is this sliced mesh's shape. If None, it means this mesh ref points to
85    /// the root mesh.
86    sliced: Option<Shape>,
87    /// The reference to the comm actor of the underlying Proc Mesh.
88    comm_actor_ref: ActorRef<CommActor>,
89    phantom: PhantomData<A>,
90}
91
92impl<A: RemoteActor> ActorMeshRef<A> {
93    /// The caller guarantees that the provided mesh ID is also a valid,
94    /// typed reference.  This is usually invoked to provide a guarantee
95    /// that an externally-provided mesh ID (e.g., through a command
96    /// line argument) is a valid reference.
97    pub fn attest(mesh_id: ActorMeshId, root: Shape, comm_actor_ref: ActorRef<CommActor>) -> Self {
98        Self {
99            mesh_id,
100            root,
101            sliced: None,
102            comm_actor_ref,
103            phantom: PhantomData,
104        }
105    }
106
107    /// The Actor Mesh ID corresponding with this reference.
108    pub fn mesh_id(&self) -> &ActorMeshId {
109        &self.mesh_id
110    }
111
112    /// Shape of the Actor Mesh.
113    pub fn shape(&self) -> &Shape {
114        match &self.sliced {
115            Some(s) => s,
116            None => &self.root,
117        }
118    }
119
120    /// Cast an [`M`]-typed message to the ranks selected by `sel`
121    /// in this ActorMesh.
122    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
123    pub fn cast<M>(
124        &self,
125        caps: &(impl cap::CanSend + cap::CanOpenPort),
126        selection: Selection,
127        message: M,
128    ) -> Result<(), CastError>
129    where
130        A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
131        M: Castable + RemoteMessage,
132    {
133        match &self.sliced {
134            Some(sliced_shape) => cast_to_sliced_mesh::<A, M>(
135                caps,
136                self.mesh_id.clone(),
137                &self.comm_actor_ref,
138                &selection,
139                message,
140                sliced_shape,
141                &self.root,
142            ),
143            None => actor_mesh_cast::<A, M>(
144                caps,
145                self.mesh_id.clone(),
146                &self.comm_actor_ref,
147                selection,
148                &self.root,
149                &self.root,
150                message,
151            ),
152        }
153    }
154
155    pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
156        let sliced = self.shape().select(label, range)?;
157        Ok(Self {
158            mesh_id: self.mesh_id.clone(),
159            root: self.root.clone(),
160            sliced: Some(sliced),
161            comm_actor_ref: self.comm_actor_ref.clone(),
162            phantom: PhantomData,
163        })
164    }
165
166    pub fn new_with_shape(&self, new_shape: Shape) -> anyhow::Result<Self> {
167        let base_slice = self.shape().slice();
168        base_slice.reify_slice(new_shape.slice()).map_err(|e| {
169            anyhow::anyhow!(
170                "failed to reify the new shape into the base shape; this \
171                normally means the new shape is not a valid slice of the base \
172                error is: {e:?}"
173            )
174        })?;
175
176        Ok(Self {
177            mesh_id: self.mesh_id.clone(),
178            root: self.root.clone(),
179            sliced: Some(new_shape),
180            comm_actor_ref: self.comm_actor_ref.clone(),
181            phantom: PhantomData,
182        })
183    }
184}
185
186impl<A: RemoteActor> Clone for ActorMeshRef<A> {
187    fn clone(&self) -> Self {
188        Self {
189            mesh_id: self.mesh_id.clone(),
190            root: self.root.clone(),
191            sliced: self.sliced.clone(),
192            comm_actor_ref: self.comm_actor_ref.clone(),
193            phantom: PhantomData,
194        }
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use async_trait::async_trait;
201    use hyperactor::Actor;
202    use hyperactor::Bind;
203    use hyperactor::Context;
204    use hyperactor::Handler;
205    use hyperactor::PortRef;
206    use hyperactor::Unbind;
207    use hyperactor_mesh_macros::sel;
208    use ndslice::Extent;
209    use ndslice::extent;
210
211    use super::*;
212    use crate::Mesh;
213    use crate::ProcMesh;
214    use crate::RootActorMesh;
215    use crate::actor_mesh::ActorMesh;
216    use crate::alloc::AllocSpec;
217    use crate::alloc::Allocator;
218    use crate::alloc::LocalAllocator;
219
220    fn extent() -> Extent {
221        extent!(replica = 4)
222    }
223
224    #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
225    struct MeshPingPongMessage(
226        /*ttl:*/ u64,
227        ActorMeshRef<MeshPingPongActor>,
228        /*completed port:*/ #[binding(include)] PortRef<bool>,
229    );
230
231    #[derive(Debug, Clone)]
232    #[hyperactor::export(
233        spawn = true,
234        handlers = [MeshPingPongMessage { cast = true }],
235    )]
236    struct MeshPingPongActor {
237        mesh_ref: ActorMeshRef<MeshPingPongActor>,
238    }
239
240    #[derive(Debug, Serialize, Deserialize, Named, Clone)]
241    struct MeshPingPongActorParams {
242        mesh_id: ActorMeshId,
243        shape: Shape,
244        comm_actor_ref: ActorRef<CommActor>,
245    }
246
247    #[async_trait]
248    impl Actor for MeshPingPongActor {
249        type Params = MeshPingPongActorParams;
250
251        async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
252            Ok(Self {
253                mesh_ref: ActorMeshRef::attest(params.mesh_id, params.shape, params.comm_actor_ref),
254            })
255        }
256    }
257
258    #[async_trait]
259    impl Handler<MeshPingPongMessage> for MeshPingPongActor {
260        async fn handle(
261            &mut self,
262            cx: &Context<Self>,
263            MeshPingPongMessage(ttl, sender_mesh, done_tx): MeshPingPongMessage,
264        ) -> Result<(), anyhow::Error> {
265            if ttl == 0 {
266                done_tx.send(cx, true)?;
267                return Ok(());
268            }
269            let msg = MeshPingPongMessage(ttl - 1, self.mesh_ref.clone(), done_tx);
270            sender_mesh.cast(cx, sel!(?), msg)?;
271            Ok(())
272        }
273    }
274
275    #[tokio::test]
276    async fn test_inter_mesh_ping_pong() {
277        let alloc_ping = LocalAllocator
278            .allocate(AllocSpec {
279                extent: extent(),
280                constraints: Default::default(),
281            })
282            .await
283            .unwrap();
284        let alloc_pong = LocalAllocator
285            .allocate(AllocSpec {
286                extent: extent(),
287                constraints: Default::default(),
288            })
289            .await
290            .unwrap();
291        let ping_proc_mesh = ProcMesh::allocate(alloc_ping).await.unwrap();
292        let ping_mesh: RootActorMesh<MeshPingPongActor> = ping_proc_mesh
293            .spawn(
294                "ping",
295                &MeshPingPongActorParams {
296                    mesh_id: ActorMeshId(
297                        ProcMeshId(ping_proc_mesh.world_id().to_string()),
298                        "ping".to_string(),
299                    ),
300                    shape: ping_proc_mesh.shape().clone(),
301                    comm_actor_ref: ping_proc_mesh.comm_actor().clone(),
302                },
303            )
304            .await
305            .unwrap();
306        assert_eq!(ping_proc_mesh.shape(), ping_mesh.shape());
307
308        let pong_proc_mesh = ProcMesh::allocate(alloc_pong).await.unwrap();
309        let pong_mesh: RootActorMesh<MeshPingPongActor> = pong_proc_mesh
310            .spawn(
311                "pong",
312                &MeshPingPongActorParams {
313                    mesh_id: ActorMeshId(
314                        ProcMeshId(pong_proc_mesh.world_id().to_string()),
315                        "pong".to_string(),
316                    ),
317                    shape: pong_proc_mesh.shape().clone(),
318                    comm_actor_ref: pong_proc_mesh.comm_actor().clone(),
319                },
320            )
321            .await
322            .unwrap();
323
324        let ping_mesh_ref: ActorMeshRef<MeshPingPongActor> = ping_mesh.bind();
325        let pong_mesh_ref: ActorMeshRef<MeshPingPongActor> = pong_mesh.bind();
326
327        let (done_tx, mut done_rx) = ping_proc_mesh.client().open_port::<bool>();
328        ping_mesh_ref
329            .cast(
330                ping_proc_mesh.client(),
331                sel!(?),
332                MeshPingPongMessage(10, pong_mesh_ref, done_tx.bind()),
333            )
334            .unwrap();
335
336        assert!(done_rx.recv().await.unwrap());
337    }
338}