hyperactor_mesh/
reference.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::cmp::Ord;
10use std::cmp::PartialOrd;
11use std::fmt;
12use std::hash::Hash;
13use std::marker::PhantomData;
14use std::str::FromStr;
15
16use hyperactor::ActorRef;
17use hyperactor::RemoteHandles;
18use hyperactor::RemoteMessage;
19use hyperactor::actor::Referable;
20use hyperactor::context;
21use hyperactor::message::Castable;
22use hyperactor::message::IndexedErasedUnbound;
23use hyperactor_config::AttrValue;
24use ndslice::Range;
25use ndslice::Selection;
26use ndslice::Shape;
27use ndslice::ShapeError;
28use ndslice::selection::ReifySlice;
29use serde::Deserialize;
30use serde::Serialize;
31use typeuri::Named;
32
33use crate::CommActor;
34use crate::actor_mesh::CastError;
35use crate::actor_mesh::actor_mesh_cast;
36use crate::actor_mesh::cast_to_sliced_mesh;
37use crate::v1::Name;
38
39#[macro_export]
40macro_rules! mesh_id {
41    ($proc_mesh:ident) => {
42        $crate::reference::ProcMeshId(stringify!($proc_mesh).to_string(), "0".into())
43    };
44    ($proc_mesh:ident . $actor_mesh:ident) => {
45        $crate::reference::ActorMeshId::V0(
46            $crate::reference::ProcMeshId(stringify!($proc_mesh).to_string()),
47            stringify!($proc_mesh).to_string(),
48        )
49    };
50}
51
52#[derive(
53    Debug,
54    Serialize,
55    Deserialize,
56    Clone,
57    PartialEq,
58    Eq,
59    PartialOrd,
60    Hash,
61    Ord,
62    Named
63)]
64pub struct ProcMeshId(pub String);
65
66/// Actor Mesh ID.  Enum with different versions.
67#[derive(
68    Debug,
69    Serialize,
70    Deserialize,
71    Clone,
72    PartialEq,
73    Eq,
74    PartialOrd,
75    Hash,
76    Ord,
77    Named,
78    AttrValue
79)]
80pub enum ActorMeshId {
81    /// V0: Tuple of the ProcMesh ID and actor name.
82    V0(ProcMeshId, String),
83    /// V1: Name-based actor mesh ID.
84    V1(Name),
85}
86
87impl fmt::Display for ActorMeshId {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        match self {
90            ActorMeshId::V0(proc_mesh_id, actor_name) => {
91                write!(f, "v0:{},{}", proc_mesh_id.0, actor_name)
92            }
93            ActorMeshId::V1(name) => write!(f, "{}", name),
94        }
95    }
96}
97
98impl FromStr for ActorMeshId {
99    type Err = anyhow::Error;
100
101    #[allow(clippy::manual_strip)]
102    fn from_str(s: &str) -> Result<Self, Self::Err> {
103        if s.starts_with("v0:") {
104            let parts: Vec<_> = s[3..].split(',').collect();
105            if parts.len() != 2 {
106                return Err(anyhow::anyhow!("invalid v0 actor mesh id: {}", s));
107            }
108            let proc_mesh_id = parts[0];
109            let actor_name = parts[1];
110            Ok(ActorMeshId::V0(
111                ProcMeshId(proc_mesh_id.to_string()),
112                actor_name.to_string(),
113            ))
114        } else {
115            Ok(ActorMeshId::V1(Name::from_str(s)?))
116        }
117    }
118}
119
120/// Types references to Actor Meshes.
121#[derive(Debug, Serialize, Deserialize, PartialEq)]
122pub struct ActorMeshRef<A: Referable> {
123    pub(crate) mesh_id: ActorMeshId,
124    /// The shape of the root mesh.
125    root: Shape,
126    /// If some, it mean this mesh ref points to a sliced mesh, and this field
127    /// is this sliced mesh's shape. If None, it means this mesh ref points to
128    /// the root mesh.
129    sliced: Option<Shape>,
130    /// The reference to the comm actor of the underlying Proc Mesh.
131    comm_actor_ref: ActorRef<CommActor>,
132    phantom: PhantomData<A>,
133}
134
135impl<A: Referable> ActorMeshRef<A> {
136    /// The caller guarantees that the provided mesh ID is also a valid,
137    /// typed reference.  This is usually invoked to provide a guarantee
138    /// that an externally-provided mesh ID (e.g., through a command
139    /// line argument) is a valid reference.
140    pub fn attest(mesh_id: ActorMeshId, root: Shape, comm_actor_ref: ActorRef<CommActor>) -> Self {
141        Self {
142            mesh_id,
143            root,
144            sliced: None,
145            comm_actor_ref,
146            phantom: PhantomData,
147        }
148    }
149
150    /// The Actor Mesh ID corresponding with this reference.
151    pub fn mesh_id(&self) -> &ActorMeshId {
152        &self.mesh_id
153    }
154
155    /// Shape of the Actor Mesh.
156    pub fn shape(&self) -> &Shape {
157        match &self.sliced {
158            Some(s) => s,
159            None => &self.root,
160        }
161    }
162
163    /// Cast an [`M`]-typed message to the ranks selected by `sel`
164    /// in this ActorMesh.
165    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
166    pub fn cast<M>(
167        &self,
168        cx: &impl context::Actor,
169        selection: Selection,
170        message: M,
171    ) -> Result<(), CastError>
172    where
173        A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
174        M: Castable + RemoteMessage,
175    {
176        match &self.sliced {
177            Some(sliced_shape) => cast_to_sliced_mesh::<A, M>(
178                cx,
179                self.mesh_id.clone(),
180                &self.comm_actor_ref,
181                &selection,
182                message,
183                sliced_shape,
184                &self.root,
185            ),
186            None => actor_mesh_cast::<A, M>(
187                cx,
188                self.mesh_id.clone(),
189                &self.comm_actor_ref,
190                selection,
191                &self.root,
192                &self.root,
193                message,
194            ),
195        }
196    }
197
198    pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
199        let sliced = self.shape().select(label, range)?;
200        Ok(Self {
201            mesh_id: self.mesh_id.clone(),
202            root: self.root.clone(),
203            sliced: Some(sliced),
204            comm_actor_ref: self.comm_actor_ref.clone(),
205            phantom: PhantomData,
206        })
207    }
208
209    pub fn new_with_shape(&self, new_shape: Shape) -> anyhow::Result<Self> {
210        let base_slice = self.shape().slice();
211        base_slice.reify_slice(new_shape.slice()).map_err(|e| {
212            anyhow::anyhow!(
213                "failed to reify the new shape into the base shape; this \
214                normally means the new shape is not a valid slice of the base \
215                error is: {e:?}"
216            )
217        })?;
218
219        Ok(Self {
220            mesh_id: self.mesh_id.clone(),
221            root: self.root.clone(),
222            sliced: Some(new_shape),
223            comm_actor_ref: self.comm_actor_ref.clone(),
224            phantom: PhantomData,
225        })
226    }
227}
228
229impl<A: Referable> Clone for ActorMeshRef<A> {
230    fn clone(&self) -> Self {
231        Self {
232            mesh_id: self.mesh_id.clone(),
233            root: self.root.clone(),
234            sliced: self.sliced.clone(),
235            comm_actor_ref: self.comm_actor_ref.clone(),
236            phantom: PhantomData,
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use async_trait::async_trait;
244    use hyperactor::Actor;
245    use hyperactor::Bind;
246    use hyperactor::Context;
247    use hyperactor::Handler;
248    use hyperactor::PortRef;
249    use hyperactor::Unbind;
250    use hyperactor::channel::ChannelTransport;
251    use hyperactor_mesh_macros::sel;
252    use ndslice::Extent;
253    use ndslice::extent;
254
255    use super::*;
256    use crate::Mesh;
257    use crate::ProcMesh;
258    use crate::RootActorMesh;
259    use crate::actor_mesh::ActorMesh;
260    use crate::alloc::AllocSpec;
261    use crate::alloc::Allocator;
262    use crate::alloc::LocalAllocator;
263
264    fn extent() -> Extent {
265        extent!(replica = 4)
266    }
267
268    #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
269    struct MeshPingPongMessage(
270        /*ttl:*/ u64,
271        ActorMeshRef<MeshPingPongActor>,
272        /*completed port:*/ #[binding(include)] PortRef<bool>,
273    );
274
275    #[derive(Debug, Clone)]
276    #[hyperactor::export(
277        spawn = true,
278        handlers = [MeshPingPongMessage { cast = true }],
279    )]
280    struct MeshPingPongActor {
281        mesh_ref: ActorMeshRef<MeshPingPongActor>,
282    }
283
284    #[derive(Debug, Serialize, Deserialize, Named, Clone)]
285    struct MeshPingPongActorParams {
286        mesh_id: ActorMeshId,
287        shape: Shape,
288        comm_actor_ref: ActorRef<CommActor>,
289    }
290
291    #[async_trait]
292    impl Actor for MeshPingPongActor {}
293
294    #[async_trait]
295    impl hyperactor::RemoteSpawn for MeshPingPongActor {
296        type Params = MeshPingPongActorParams;
297
298        async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
299            Ok(Self {
300                mesh_ref: ActorMeshRef::attest(params.mesh_id, params.shape, params.comm_actor_ref),
301            })
302        }
303    }
304
305    #[async_trait]
306    impl Handler<MeshPingPongMessage> for MeshPingPongActor {
307        async fn handle(
308            &mut self,
309            cx: &Context<Self>,
310            MeshPingPongMessage(ttl, sender_mesh, done_tx): MeshPingPongMessage,
311        ) -> Result<(), anyhow::Error> {
312            if ttl == 0 {
313                done_tx.send(cx, true)?;
314                return Ok(());
315            }
316            let msg = MeshPingPongMessage(ttl - 1, self.mesh_ref.clone(), done_tx);
317            sender_mesh.cast(cx, sel!(?), msg)?;
318            Ok(())
319        }
320    }
321
322    #[tokio::test]
323    async fn test_inter_mesh_ping_pong() {
324        let alloc_ping = LocalAllocator
325            .allocate(AllocSpec {
326                extent: extent(),
327                constraints: Default::default(),
328                proc_name: None,
329                transport: ChannelTransport::Local,
330                proc_allocation_mode: Default::default(),
331            })
332            .await
333            .unwrap();
334        let alloc_pong = LocalAllocator
335            .allocate(AllocSpec {
336                extent: extent(),
337                constraints: Default::default(),
338                proc_name: None,
339                transport: ChannelTransport::Local,
340                proc_allocation_mode: Default::default(),
341            })
342            .await
343            .unwrap();
344        let instance = crate::v1::testing::instance();
345        let ping_proc_mesh = ProcMesh::allocate(alloc_ping).await.unwrap();
346        let ping_mesh: RootActorMesh<MeshPingPongActor> = ping_proc_mesh
347            .spawn(
348                &instance,
349                "ping",
350                &MeshPingPongActorParams {
351                    mesh_id: ActorMeshId::V0(
352                        ProcMeshId(ping_proc_mesh.world_id().to_string()),
353                        "ping".to_string(),
354                    ),
355                    shape: ping_proc_mesh.shape().clone(),
356                    comm_actor_ref: ping_proc_mesh.comm_actor().clone(),
357                },
358            )
359            .await
360            .unwrap();
361        assert_eq!(ping_proc_mesh.shape(), ping_mesh.shape());
362
363        let pong_proc_mesh = ProcMesh::allocate(alloc_pong).await.unwrap();
364        let pong_mesh: RootActorMesh<MeshPingPongActor> = pong_proc_mesh
365            .spawn(
366                &instance,
367                "pong",
368                &MeshPingPongActorParams {
369                    mesh_id: ActorMeshId::V0(
370                        ProcMeshId(pong_proc_mesh.world_id().to_string()),
371                        "pong".to_string(),
372                    ),
373                    shape: pong_proc_mesh.shape().clone(),
374                    comm_actor_ref: pong_proc_mesh.comm_actor().clone(),
375                },
376            )
377            .await
378            .unwrap();
379
380        let ping_mesh_ref: ActorMeshRef<MeshPingPongActor> = ping_mesh.bind();
381        let pong_mesh_ref: ActorMeshRef<MeshPingPongActor> = pong_mesh.bind();
382
383        let (done_tx, mut done_rx) = ping_proc_mesh.client().open_port::<bool>();
384        ping_mesh_ref
385            .cast(
386                ping_proc_mesh.client(),
387                sel!(?),
388                MeshPingPongMessage(10, pong_mesh_ref, done_tx.bind()),
389            )
390            .unwrap();
391
392        assert!(done_rx.recv().await.unwrap());
393    }
394
395    #[test]
396    fn test_actor_mesh_id_roundtrip() {
397        let mesh_ids = &[
398            ActorMeshId::V0(
399                ProcMeshId("proc_mesh".to_string()),
400                "actor_mesh".to_string(),
401            ),
402            ActorMeshId::V1(Name::new("testing").unwrap()),
403        ];
404
405        for mesh_id in mesh_ids {
406            assert_eq!(
407                mesh_id,
408                &mesh_id.to_string().parse::<ActorMeshId>().unwrap()
409            );
410        }
411    }
412}