1use std::cmp::Ord;
10use std::cmp::PartialOrd;
11use std::hash::Hash;
12use std::marker::PhantomData;
13
14use hyperactor::ActorRef;
15use hyperactor::Named;
16use hyperactor::RemoteHandles;
17use hyperactor::RemoteMessage;
18use hyperactor::actor::RemoteActor;
19use hyperactor::cap;
20use hyperactor::message::Castable;
21use hyperactor::message::IndexedErasedUnbound;
22use ndslice::Range;
23use ndslice::Selection;
24use ndslice::Shape;
25use ndslice::ShapeError;
26use ndslice::selection::ReifySlice;
27use serde::Deserialize;
28use serde::Serialize;
29
30use crate::CommActor;
31use crate::actor_mesh::CastError;
32use crate::actor_mesh::actor_mesh_cast;
33use crate::actor_mesh::cast_to_sliced_mesh;
34
35#[macro_export]
36macro_rules! mesh_id {
37 ($proc_mesh:ident) => {
38 $crate::reference::ProcMeshId(stringify!($proc_mesh).to_string(), "0".into())
39 };
40 ($proc_mesh:ident . $actor_mesh:ident) => {
41 $crate::reference::ActorMeshId(
42 $crate::reference::ProcMeshId(stringify!($proc_mesh).to_string()),
43 stringify!($proc_mesh).to_string(),
44 )
45 };
46}
47
48#[derive(
49 Debug,
50 Serialize,
51 Deserialize,
52 Clone,
53 PartialEq,
54 Eq,
55 PartialOrd,
56 Hash,
57 Ord,
58 Named
59)]
60pub struct ProcMeshId(pub String);
61
62#[derive(
64 Debug,
65 Serialize,
66 Deserialize,
67 Clone,
68 PartialEq,
69 Eq,
70 PartialOrd,
71 Hash,
72 Ord,
73 Named
74)]
75pub struct ActorMeshId(pub ProcMeshId, pub String);
76
77#[derive(Debug, Serialize, Deserialize, PartialEq)]
79pub struct ActorMeshRef<A: RemoteActor> {
80 pub(crate) mesh_id: ActorMeshId,
81 root: Shape,
83 sliced: Option<Shape>,
87 comm_actor_ref: ActorRef<CommActor>,
89 phantom: PhantomData<A>,
90}
91
92impl<A: RemoteActor> ActorMeshRef<A> {
93 pub fn attest(mesh_id: ActorMeshId, root: Shape, comm_actor_ref: ActorRef<CommActor>) -> Self {
98 Self {
99 mesh_id,
100 root,
101 sliced: None,
102 comm_actor_ref,
103 phantom: PhantomData,
104 }
105 }
106
107 pub fn mesh_id(&self) -> &ActorMeshId {
109 &self.mesh_id
110 }
111
112 pub fn shape(&self) -> &Shape {
114 match &self.sliced {
115 Some(s) => s,
116 None => &self.root,
117 }
118 }
119
120 #[allow(clippy::result_large_err)] pub fn cast<M>(
124 &self,
125 caps: &(impl cap::CanSend + cap::CanOpenPort),
126 selection: Selection,
127 message: M,
128 ) -> Result<(), CastError>
129 where
130 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
131 M: Castable + RemoteMessage,
132 {
133 match &self.sliced {
134 Some(sliced_shape) => cast_to_sliced_mesh::<A, M>(
135 caps,
136 self.mesh_id.clone(),
137 &self.comm_actor_ref,
138 &selection,
139 message,
140 sliced_shape,
141 &self.root,
142 ),
143 None => actor_mesh_cast::<A, M>(
144 caps,
145 self.mesh_id.clone(),
146 &self.comm_actor_ref,
147 selection,
148 &self.root,
149 &self.root,
150 message,
151 ),
152 }
153 }
154
155 pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
156 let sliced = self.shape().select(label, range)?;
157 Ok(Self {
158 mesh_id: self.mesh_id.clone(),
159 root: self.root.clone(),
160 sliced: Some(sliced),
161 comm_actor_ref: self.comm_actor_ref.clone(),
162 phantom: PhantomData,
163 })
164 }
165
166 pub fn new_with_shape(&self, new_shape: Shape) -> anyhow::Result<Self> {
167 let base_slice = self.shape().slice();
168 base_slice.reify_slice(new_shape.slice()).map_err(|e| {
169 anyhow::anyhow!(
170 "failed to reify the new shape into the base shape; this \
171 normally means the new shape is not a valid slice of the base \
172 error is: {e:?}"
173 )
174 })?;
175
176 Ok(Self {
177 mesh_id: self.mesh_id.clone(),
178 root: self.root.clone(),
179 sliced: Some(new_shape),
180 comm_actor_ref: self.comm_actor_ref.clone(),
181 phantom: PhantomData,
182 })
183 }
184}
185
186impl<A: RemoteActor> Clone for ActorMeshRef<A> {
187 fn clone(&self) -> Self {
188 Self {
189 mesh_id: self.mesh_id.clone(),
190 root: self.root.clone(),
191 sliced: self.sliced.clone(),
192 comm_actor_ref: self.comm_actor_ref.clone(),
193 phantom: PhantomData,
194 }
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use async_trait::async_trait;
201 use hyperactor::Actor;
202 use hyperactor::Bind;
203 use hyperactor::Context;
204 use hyperactor::Handler;
205 use hyperactor::PortRef;
206 use hyperactor::Unbind;
207 use hyperactor_mesh_macros::sel;
208 use ndslice::Extent;
209 use ndslice::extent;
210
211 use super::*;
212 use crate::Mesh;
213 use crate::ProcMesh;
214 use crate::RootActorMesh;
215 use crate::actor_mesh::ActorMesh;
216 use crate::alloc::AllocSpec;
217 use crate::alloc::Allocator;
218 use crate::alloc::LocalAllocator;
219
220 fn extent() -> Extent {
221 extent!(replica = 4)
222 }
223
224 #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
225 struct MeshPingPongMessage(
226 u64,
227 ActorMeshRef<MeshPingPongActor>,
228 #[binding(include)] PortRef<bool>,
229 );
230
231 #[derive(Debug, Clone)]
232 #[hyperactor::export(
233 spawn = true,
234 handlers = [MeshPingPongMessage { cast = true }],
235 )]
236 struct MeshPingPongActor {
237 mesh_ref: ActorMeshRef<MeshPingPongActor>,
238 }
239
240 #[derive(Debug, Serialize, Deserialize, Named, Clone)]
241 struct MeshPingPongActorParams {
242 mesh_id: ActorMeshId,
243 shape: Shape,
244 comm_actor_ref: ActorRef<CommActor>,
245 }
246
247 #[async_trait]
248 impl Actor for MeshPingPongActor {
249 type Params = MeshPingPongActorParams;
250
251 async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
252 Ok(Self {
253 mesh_ref: ActorMeshRef::attest(params.mesh_id, params.shape, params.comm_actor_ref),
254 })
255 }
256 }
257
258 #[async_trait]
259 impl Handler<MeshPingPongMessage> for MeshPingPongActor {
260 async fn handle(
261 &mut self,
262 cx: &Context<Self>,
263 MeshPingPongMessage(ttl, sender_mesh, done_tx): MeshPingPongMessage,
264 ) -> Result<(), anyhow::Error> {
265 if ttl == 0 {
266 done_tx.send(cx, true)?;
267 return Ok(());
268 }
269 let msg = MeshPingPongMessage(ttl - 1, self.mesh_ref.clone(), done_tx);
270 sender_mesh.cast(cx, sel!(?), msg)?;
271 Ok(())
272 }
273 }
274
275 #[tokio::test]
276 async fn test_inter_mesh_ping_pong() {
277 let alloc_ping = LocalAllocator
278 .allocate(AllocSpec {
279 extent: extent(),
280 constraints: Default::default(),
281 })
282 .await
283 .unwrap();
284 let alloc_pong = LocalAllocator
285 .allocate(AllocSpec {
286 extent: extent(),
287 constraints: Default::default(),
288 })
289 .await
290 .unwrap();
291 let ping_proc_mesh = ProcMesh::allocate(alloc_ping).await.unwrap();
292 let ping_mesh: RootActorMesh<MeshPingPongActor> = ping_proc_mesh
293 .spawn(
294 "ping",
295 &MeshPingPongActorParams {
296 mesh_id: ActorMeshId(
297 ProcMeshId(ping_proc_mesh.world_id().to_string()),
298 "ping".to_string(),
299 ),
300 shape: ping_proc_mesh.shape().clone(),
301 comm_actor_ref: ping_proc_mesh.comm_actor().clone(),
302 },
303 )
304 .await
305 .unwrap();
306 assert_eq!(ping_proc_mesh.shape(), ping_mesh.shape());
307
308 let pong_proc_mesh = ProcMesh::allocate(alloc_pong).await.unwrap();
309 let pong_mesh: RootActorMesh<MeshPingPongActor> = pong_proc_mesh
310 .spawn(
311 "pong",
312 &MeshPingPongActorParams {
313 mesh_id: ActorMeshId(
314 ProcMeshId(pong_proc_mesh.world_id().to_string()),
315 "pong".to_string(),
316 ),
317 shape: pong_proc_mesh.shape().clone(),
318 comm_actor_ref: pong_proc_mesh.comm_actor().clone(),
319 },
320 )
321 .await
322 .unwrap();
323
324 let ping_mesh_ref: ActorMeshRef<MeshPingPongActor> = ping_mesh.bind();
325 let pong_mesh_ref: ActorMeshRef<MeshPingPongActor> = pong_mesh.bind();
326
327 let (done_tx, mut done_rx) = ping_proc_mesh.client().open_port::<bool>();
328 ping_mesh_ref
329 .cast(
330 ping_proc_mesh.client(),
331 sel!(?),
332 MeshPingPongMessage(10, pong_mesh_ref, done_tx.bind()),
333 )
334 .unwrap();
335
336 assert!(done_rx.recv().await.unwrap());
337 }
338}