hyperactor_mesh/
reference.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::cmp::Ord;
10use std::cmp::PartialOrd;
11use std::fmt;
12use std::hash::Hash;
13use std::marker::PhantomData;
14use std::str::FromStr;
15
16use hyperactor::ActorRef;
17use hyperactor::AttrValue;
18use hyperactor::Named;
19use hyperactor::RemoteHandles;
20use hyperactor::RemoteMessage;
21use hyperactor::actor::Referable;
22use hyperactor::context;
23use hyperactor::message::Castable;
24use hyperactor::message::IndexedErasedUnbound;
25use ndslice::Range;
26use ndslice::Selection;
27use ndslice::Shape;
28use ndslice::ShapeError;
29use ndslice::selection::ReifySlice;
30use serde::Deserialize;
31use serde::Serialize;
32
33use crate::CommActor;
34use crate::actor_mesh::CastError;
35use crate::actor_mesh::actor_mesh_cast;
36use crate::actor_mesh::cast_to_sliced_mesh;
37use crate::v1::Name;
38
39#[macro_export]
40macro_rules! mesh_id {
41    ($proc_mesh:ident) => {
42        $crate::reference::ProcMeshId(stringify!($proc_mesh).to_string(), "0".into())
43    };
44    ($proc_mesh:ident . $actor_mesh:ident) => {
45        $crate::reference::ActorMeshId::V0(
46            $crate::reference::ProcMeshId(stringify!($proc_mesh).to_string()),
47            stringify!($proc_mesh).to_string(),
48        )
49    };
50}
51
52#[derive(
53    Debug,
54    Serialize,
55    Deserialize,
56    Clone,
57    PartialEq,
58    Eq,
59    PartialOrd,
60    Hash,
61    Ord,
62    Named
63)]
64pub struct ProcMeshId(pub String);
65
66/// Actor Mesh ID.  Enum with different versions.
67#[derive(
68    Debug,
69    Serialize,
70    Deserialize,
71    Clone,
72    PartialEq,
73    Eq,
74    PartialOrd,
75    Hash,
76    Ord,
77    Named,
78    AttrValue
79)]
80pub enum ActorMeshId {
81    /// V0: Tuple of the ProcMesh ID and actor name.
82    V0(ProcMeshId, String),
83    /// V1: Name-based actor mesh ID.
84    V1(Name),
85}
86
87impl fmt::Display for ActorMeshId {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        match self {
90            ActorMeshId::V0(proc_mesh_id, actor_name) => {
91                write!(f, "v0:{},{}", proc_mesh_id.0, actor_name)
92            }
93            ActorMeshId::V1(name) => write!(f, "{}", name),
94        }
95    }
96}
97
98impl FromStr for ActorMeshId {
99    type Err = anyhow::Error;
100
101    #[allow(clippy::manual_strip)]
102    fn from_str(s: &str) -> Result<Self, Self::Err> {
103        if s.starts_with("v0:") {
104            let parts: Vec<_> = s[3..].split(',').collect();
105            if parts.len() != 2 {
106                return Err(anyhow::anyhow!("invalid v0 actor mesh id: {}", s));
107            }
108            let proc_mesh_id = parts[0];
109            let actor_name = parts[1];
110            Ok(ActorMeshId::V0(
111                ProcMeshId(proc_mesh_id.to_string()),
112                actor_name.to_string(),
113            ))
114        } else {
115            Ok(ActorMeshId::V1(Name::from_str(s)?))
116        }
117    }
118}
119
120/// Types references to Actor Meshes.
121#[derive(Debug, Serialize, Deserialize, PartialEq)]
122pub struct ActorMeshRef<A: Referable> {
123    pub(crate) mesh_id: ActorMeshId,
124    /// The shape of the root mesh.
125    root: Shape,
126    /// If some, it mean this mesh ref points to a sliced mesh, and this field
127    /// is this sliced mesh's shape. If None, it means this mesh ref points to
128    /// the root mesh.
129    sliced: Option<Shape>,
130    /// The reference to the comm actor of the underlying Proc Mesh.
131    comm_actor_ref: ActorRef<CommActor>,
132    phantom: PhantomData<A>,
133}
134
135impl<A: Referable> ActorMeshRef<A> {
136    /// The caller guarantees that the provided mesh ID is also a valid,
137    /// typed reference.  This is usually invoked to provide a guarantee
138    /// that an externally-provided mesh ID (e.g., through a command
139    /// line argument) is a valid reference.
140    pub fn attest(mesh_id: ActorMeshId, root: Shape, comm_actor_ref: ActorRef<CommActor>) -> Self {
141        Self {
142            mesh_id,
143            root,
144            sliced: None,
145            comm_actor_ref,
146            phantom: PhantomData,
147        }
148    }
149
150    /// The Actor Mesh ID corresponding with this reference.
151    pub fn mesh_id(&self) -> &ActorMeshId {
152        &self.mesh_id
153    }
154
155    /// Shape of the Actor Mesh.
156    pub fn shape(&self) -> &Shape {
157        match &self.sliced {
158            Some(s) => s,
159            None => &self.root,
160        }
161    }
162
163    /// Cast an [`M`]-typed message to the ranks selected by `sel`
164    /// in this ActorMesh.
165    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
166    pub fn cast<M>(
167        &self,
168        cx: &impl context::Actor,
169        selection: Selection,
170        message: M,
171    ) -> Result<(), CastError>
172    where
173        A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
174        M: Castable + RemoteMessage,
175    {
176        match &self.sliced {
177            Some(sliced_shape) => cast_to_sliced_mesh::<A, M>(
178                cx,
179                self.mesh_id.clone(),
180                &self.comm_actor_ref,
181                &selection,
182                message,
183                sliced_shape,
184                &self.root,
185            ),
186            None => actor_mesh_cast::<A, M>(
187                cx,
188                self.mesh_id.clone(),
189                &self.comm_actor_ref,
190                selection,
191                &self.root,
192                &self.root,
193                message,
194            ),
195        }
196    }
197
198    pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
199        let sliced = self.shape().select(label, range)?;
200        Ok(Self {
201            mesh_id: self.mesh_id.clone(),
202            root: self.root.clone(),
203            sliced: Some(sliced),
204            comm_actor_ref: self.comm_actor_ref.clone(),
205            phantom: PhantomData,
206        })
207    }
208
209    pub fn new_with_shape(&self, new_shape: Shape) -> anyhow::Result<Self> {
210        let base_slice = self.shape().slice();
211        base_slice.reify_slice(new_shape.slice()).map_err(|e| {
212            anyhow::anyhow!(
213                "failed to reify the new shape into the base shape; this \
214                normally means the new shape is not a valid slice of the base \
215                error is: {e:?}"
216            )
217        })?;
218
219        Ok(Self {
220            mesh_id: self.mesh_id.clone(),
221            root: self.root.clone(),
222            sliced: Some(new_shape),
223            comm_actor_ref: self.comm_actor_ref.clone(),
224            phantom: PhantomData,
225        })
226    }
227}
228
229impl<A: Referable> Clone for ActorMeshRef<A> {
230    fn clone(&self) -> Self {
231        Self {
232            mesh_id: self.mesh_id.clone(),
233            root: self.root.clone(),
234            sliced: self.sliced.clone(),
235            comm_actor_ref: self.comm_actor_ref.clone(),
236            phantom: PhantomData,
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use async_trait::async_trait;
244    use hyperactor::Actor;
245    use hyperactor::Bind;
246    use hyperactor::Context;
247    use hyperactor::Handler;
248    use hyperactor::PortRef;
249    use hyperactor::Unbind;
250    use hyperactor::channel::ChannelTransport;
251    use hyperactor_mesh_macros::sel;
252    use ndslice::Extent;
253    use ndslice::extent;
254
255    use super::*;
256    use crate::Mesh;
257    use crate::ProcMesh;
258    use crate::RootActorMesh;
259    use crate::actor_mesh::ActorMesh;
260    use crate::alloc::AllocSpec;
261    use crate::alloc::Allocator;
262    use crate::alloc::LocalAllocator;
263
264    fn extent() -> Extent {
265        extent!(replica = 4)
266    }
267
268    #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
269    struct MeshPingPongMessage(
270        /*ttl:*/ u64,
271        ActorMeshRef<MeshPingPongActor>,
272        /*completed port:*/ #[binding(include)] PortRef<bool>,
273    );
274
275    #[derive(Debug, Clone)]
276    #[hyperactor::export(
277        spawn = true,
278        handlers = [MeshPingPongMessage { cast = true }],
279    )]
280    struct MeshPingPongActor {
281        mesh_ref: ActorMeshRef<MeshPingPongActor>,
282    }
283
284    #[derive(Debug, Serialize, Deserialize, Named, Clone)]
285    struct MeshPingPongActorParams {
286        mesh_id: ActorMeshId,
287        shape: Shape,
288        comm_actor_ref: ActorRef<CommActor>,
289    }
290
291    #[async_trait]
292    impl Actor for MeshPingPongActor {
293        type Params = MeshPingPongActorParams;
294
295        async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
296            Ok(Self {
297                mesh_ref: ActorMeshRef::attest(params.mesh_id, params.shape, params.comm_actor_ref),
298            })
299        }
300    }
301
302    #[async_trait]
303    impl Handler<MeshPingPongMessage> for MeshPingPongActor {
304        async fn handle(
305            &mut self,
306            cx: &Context<Self>,
307            MeshPingPongMessage(ttl, sender_mesh, done_tx): MeshPingPongMessage,
308        ) -> Result<(), anyhow::Error> {
309            if ttl == 0 {
310                done_tx.send(cx, true)?;
311                return Ok(());
312            }
313            let msg = MeshPingPongMessage(ttl - 1, self.mesh_ref.clone(), done_tx);
314            sender_mesh.cast(cx, sel!(?), msg)?;
315            Ok(())
316        }
317    }
318
319    #[tokio::test]
320    async fn test_inter_mesh_ping_pong() {
321        let alloc_ping = LocalAllocator
322            .allocate(AllocSpec {
323                extent: extent(),
324                constraints: Default::default(),
325                proc_name: None,
326                transport: ChannelTransport::Local,
327            })
328            .await
329            .unwrap();
330        let alloc_pong = LocalAllocator
331            .allocate(AllocSpec {
332                extent: extent(),
333                constraints: Default::default(),
334                proc_name: None,
335                transport: ChannelTransport::Local,
336            })
337            .await
338            .unwrap();
339        let ping_proc_mesh = ProcMesh::allocate(alloc_ping).await.unwrap();
340        let ping_mesh: RootActorMesh<MeshPingPongActor> = ping_proc_mesh
341            .spawn(
342                "ping",
343                &MeshPingPongActorParams {
344                    mesh_id: ActorMeshId::V0(
345                        ProcMeshId(ping_proc_mesh.world_id().to_string()),
346                        "ping".to_string(),
347                    ),
348                    shape: ping_proc_mesh.shape().clone(),
349                    comm_actor_ref: ping_proc_mesh.comm_actor().clone(),
350                },
351            )
352            .await
353            .unwrap();
354        assert_eq!(ping_proc_mesh.shape(), ping_mesh.shape());
355
356        let pong_proc_mesh = ProcMesh::allocate(alloc_pong).await.unwrap();
357        let pong_mesh: RootActorMesh<MeshPingPongActor> = pong_proc_mesh
358            .spawn(
359                "pong",
360                &MeshPingPongActorParams {
361                    mesh_id: ActorMeshId::V0(
362                        ProcMeshId(pong_proc_mesh.world_id().to_string()),
363                        "pong".to_string(),
364                    ),
365                    shape: pong_proc_mesh.shape().clone(),
366                    comm_actor_ref: pong_proc_mesh.comm_actor().clone(),
367                },
368            )
369            .await
370            .unwrap();
371
372        let ping_mesh_ref: ActorMeshRef<MeshPingPongActor> = ping_mesh.bind();
373        let pong_mesh_ref: ActorMeshRef<MeshPingPongActor> = pong_mesh.bind();
374
375        let (done_tx, mut done_rx) = ping_proc_mesh.client().open_port::<bool>();
376        ping_mesh_ref
377            .cast(
378                ping_proc_mesh.client(),
379                sel!(?),
380                MeshPingPongMessage(10, pong_mesh_ref, done_tx.bind()),
381            )
382            .unwrap();
383
384        assert!(done_rx.recv().await.unwrap());
385    }
386
387    #[test]
388    fn test_actor_mesh_id_roundtrip() {
389        let mesh_ids = &[
390            ActorMeshId::V0(
391                ProcMeshId("proc_mesh".to_string()),
392                "actor_mesh".to_string(),
393            ),
394            ActorMeshId::V1(Name::new("testing")),
395        ];
396
397        for mesh_id in mesh_ids {
398            assert_eq!(
399                mesh_id,
400                &mesh_id.to_string().parse::<ActorMeshId>().unwrap()
401            );
402        }
403    }
404}