1use std::cmp::Ord;
10use std::cmp::PartialOrd;
11use std::fmt;
12use std::hash::Hash;
13use std::marker::PhantomData;
14use std::str::FromStr;
15
16use hyperactor::ActorRef;
17use hyperactor::AttrValue;
18use hyperactor::Named;
19use hyperactor::RemoteHandles;
20use hyperactor::RemoteMessage;
21use hyperactor::actor::Referable;
22use hyperactor::context;
23use hyperactor::message::Castable;
24use hyperactor::message::IndexedErasedUnbound;
25use ndslice::Range;
26use ndslice::Selection;
27use ndslice::Shape;
28use ndslice::ShapeError;
29use ndslice::selection::ReifySlice;
30use serde::Deserialize;
31use serde::Serialize;
32
33use crate::CommActor;
34use crate::actor_mesh::CastError;
35use crate::actor_mesh::actor_mesh_cast;
36use crate::actor_mesh::cast_to_sliced_mesh;
37use crate::v1::Name;
38
39#[macro_export]
40macro_rules! mesh_id {
41 ($proc_mesh:ident) => {
42 $crate::reference::ProcMeshId(stringify!($proc_mesh).to_string(), "0".into())
43 };
44 ($proc_mesh:ident . $actor_mesh:ident) => {
45 $crate::reference::ActorMeshId::V0(
46 $crate::reference::ProcMeshId(stringify!($proc_mesh).to_string()),
47 stringify!($proc_mesh).to_string(),
48 )
49 };
50}
51
52#[derive(
53 Debug,
54 Serialize,
55 Deserialize,
56 Clone,
57 PartialEq,
58 Eq,
59 PartialOrd,
60 Hash,
61 Ord,
62 Named
63)]
64pub struct ProcMeshId(pub String);
65
66#[derive(
68 Debug,
69 Serialize,
70 Deserialize,
71 Clone,
72 PartialEq,
73 Eq,
74 PartialOrd,
75 Hash,
76 Ord,
77 Named,
78 AttrValue
79)]
80pub enum ActorMeshId {
81 V0(ProcMeshId, String),
83 V1(Name),
85}
86
87impl fmt::Display for ActorMeshId {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 match self {
90 ActorMeshId::V0(proc_mesh_id, actor_name) => {
91 write!(f, "v0:{},{}", proc_mesh_id.0, actor_name)
92 }
93 ActorMeshId::V1(name) => write!(f, "{}", name),
94 }
95 }
96}
97
98impl FromStr for ActorMeshId {
99 type Err = anyhow::Error;
100
101 #[allow(clippy::manual_strip)]
102 fn from_str(s: &str) -> Result<Self, Self::Err> {
103 if s.starts_with("v0:") {
104 let parts: Vec<_> = s[3..].split(',').collect();
105 if parts.len() != 2 {
106 return Err(anyhow::anyhow!("invalid v0 actor mesh id: {}", s));
107 }
108 let proc_mesh_id = parts[0];
109 let actor_name = parts[1];
110 Ok(ActorMeshId::V0(
111 ProcMeshId(proc_mesh_id.to_string()),
112 actor_name.to_string(),
113 ))
114 } else {
115 Ok(ActorMeshId::V1(Name::from_str(s)?))
116 }
117 }
118}
119
120#[derive(Debug, Serialize, Deserialize, PartialEq)]
122pub struct ActorMeshRef<A: Referable> {
123 pub(crate) mesh_id: ActorMeshId,
124 root: Shape,
126 sliced: Option<Shape>,
130 comm_actor_ref: ActorRef<CommActor>,
132 phantom: PhantomData<A>,
133}
134
135impl<A: Referable> ActorMeshRef<A> {
136 pub fn attest(mesh_id: ActorMeshId, root: Shape, comm_actor_ref: ActorRef<CommActor>) -> Self {
141 Self {
142 mesh_id,
143 root,
144 sliced: None,
145 comm_actor_ref,
146 phantom: PhantomData,
147 }
148 }
149
150 pub fn mesh_id(&self) -> &ActorMeshId {
152 &self.mesh_id
153 }
154
155 pub fn shape(&self) -> &Shape {
157 match &self.sliced {
158 Some(s) => s,
159 None => &self.root,
160 }
161 }
162
163 #[allow(clippy::result_large_err)] pub fn cast<M>(
167 &self,
168 cx: &impl context::Actor,
169 selection: Selection,
170 message: M,
171 ) -> Result<(), CastError>
172 where
173 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
174 M: Castable + RemoteMessage,
175 {
176 match &self.sliced {
177 Some(sliced_shape) => cast_to_sliced_mesh::<A, M>(
178 cx,
179 self.mesh_id.clone(),
180 &self.comm_actor_ref,
181 &selection,
182 message,
183 sliced_shape,
184 &self.root,
185 ),
186 None => actor_mesh_cast::<A, M>(
187 cx,
188 self.mesh_id.clone(),
189 &self.comm_actor_ref,
190 selection,
191 &self.root,
192 &self.root,
193 message,
194 ),
195 }
196 }
197
198 pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
199 let sliced = self.shape().select(label, range)?;
200 Ok(Self {
201 mesh_id: self.mesh_id.clone(),
202 root: self.root.clone(),
203 sliced: Some(sliced),
204 comm_actor_ref: self.comm_actor_ref.clone(),
205 phantom: PhantomData,
206 })
207 }
208
209 pub fn new_with_shape(&self, new_shape: Shape) -> anyhow::Result<Self> {
210 let base_slice = self.shape().slice();
211 base_slice.reify_slice(new_shape.slice()).map_err(|e| {
212 anyhow::anyhow!(
213 "failed to reify the new shape into the base shape; this \
214 normally means the new shape is not a valid slice of the base \
215 error is: {e:?}"
216 )
217 })?;
218
219 Ok(Self {
220 mesh_id: self.mesh_id.clone(),
221 root: self.root.clone(),
222 sliced: Some(new_shape),
223 comm_actor_ref: self.comm_actor_ref.clone(),
224 phantom: PhantomData,
225 })
226 }
227}
228
229impl<A: Referable> Clone for ActorMeshRef<A> {
230 fn clone(&self) -> Self {
231 Self {
232 mesh_id: self.mesh_id.clone(),
233 root: self.root.clone(),
234 sliced: self.sliced.clone(),
235 comm_actor_ref: self.comm_actor_ref.clone(),
236 phantom: PhantomData,
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use async_trait::async_trait;
244 use hyperactor::Actor;
245 use hyperactor::Bind;
246 use hyperactor::Context;
247 use hyperactor::Handler;
248 use hyperactor::PortRef;
249 use hyperactor::Unbind;
250 use hyperactor::channel::ChannelTransport;
251 use hyperactor_mesh_macros::sel;
252 use ndslice::Extent;
253 use ndslice::extent;
254
255 use super::*;
256 use crate::Mesh;
257 use crate::ProcMesh;
258 use crate::RootActorMesh;
259 use crate::actor_mesh::ActorMesh;
260 use crate::alloc::AllocSpec;
261 use crate::alloc::Allocator;
262 use crate::alloc::LocalAllocator;
263
264 fn extent() -> Extent {
265 extent!(replica = 4)
266 }
267
268 #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
269 struct MeshPingPongMessage(
270 u64,
271 ActorMeshRef<MeshPingPongActor>,
272 #[binding(include)] PortRef<bool>,
273 );
274
275 #[derive(Debug, Clone)]
276 #[hyperactor::export(
277 spawn = true,
278 handlers = [MeshPingPongMessage { cast = true }],
279 )]
280 struct MeshPingPongActor {
281 mesh_ref: ActorMeshRef<MeshPingPongActor>,
282 }
283
284 #[derive(Debug, Serialize, Deserialize, Named, Clone)]
285 struct MeshPingPongActorParams {
286 mesh_id: ActorMeshId,
287 shape: Shape,
288 comm_actor_ref: ActorRef<CommActor>,
289 }
290
291 #[async_trait]
292 impl Actor for MeshPingPongActor {
293 type Params = MeshPingPongActorParams;
294
295 async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
296 Ok(Self {
297 mesh_ref: ActorMeshRef::attest(params.mesh_id, params.shape, params.comm_actor_ref),
298 })
299 }
300 }
301
302 #[async_trait]
303 impl Handler<MeshPingPongMessage> for MeshPingPongActor {
304 async fn handle(
305 &mut self,
306 cx: &Context<Self>,
307 MeshPingPongMessage(ttl, sender_mesh, done_tx): MeshPingPongMessage,
308 ) -> Result<(), anyhow::Error> {
309 if ttl == 0 {
310 done_tx.send(cx, true)?;
311 return Ok(());
312 }
313 let msg = MeshPingPongMessage(ttl - 1, self.mesh_ref.clone(), done_tx);
314 sender_mesh.cast(cx, sel!(?), msg)?;
315 Ok(())
316 }
317 }
318
319 #[tokio::test]
320 async fn test_inter_mesh_ping_pong() {
321 let alloc_ping = LocalAllocator
322 .allocate(AllocSpec {
323 extent: extent(),
324 constraints: Default::default(),
325 proc_name: None,
326 transport: ChannelTransport::Local,
327 })
328 .await
329 .unwrap();
330 let alloc_pong = LocalAllocator
331 .allocate(AllocSpec {
332 extent: extent(),
333 constraints: Default::default(),
334 proc_name: None,
335 transport: ChannelTransport::Local,
336 })
337 .await
338 .unwrap();
339 let ping_proc_mesh = ProcMesh::allocate(alloc_ping).await.unwrap();
340 let ping_mesh: RootActorMesh<MeshPingPongActor> = ping_proc_mesh
341 .spawn(
342 "ping",
343 &MeshPingPongActorParams {
344 mesh_id: ActorMeshId::V0(
345 ProcMeshId(ping_proc_mesh.world_id().to_string()),
346 "ping".to_string(),
347 ),
348 shape: ping_proc_mesh.shape().clone(),
349 comm_actor_ref: ping_proc_mesh.comm_actor().clone(),
350 },
351 )
352 .await
353 .unwrap();
354 assert_eq!(ping_proc_mesh.shape(), ping_mesh.shape());
355
356 let pong_proc_mesh = ProcMesh::allocate(alloc_pong).await.unwrap();
357 let pong_mesh: RootActorMesh<MeshPingPongActor> = pong_proc_mesh
358 .spawn(
359 "pong",
360 &MeshPingPongActorParams {
361 mesh_id: ActorMeshId::V0(
362 ProcMeshId(pong_proc_mesh.world_id().to_string()),
363 "pong".to_string(),
364 ),
365 shape: pong_proc_mesh.shape().clone(),
366 comm_actor_ref: pong_proc_mesh.comm_actor().clone(),
367 },
368 )
369 .await
370 .unwrap();
371
372 let ping_mesh_ref: ActorMeshRef<MeshPingPongActor> = ping_mesh.bind();
373 let pong_mesh_ref: ActorMeshRef<MeshPingPongActor> = pong_mesh.bind();
374
375 let (done_tx, mut done_rx) = ping_proc_mesh.client().open_port::<bool>();
376 ping_mesh_ref
377 .cast(
378 ping_proc_mesh.client(),
379 sel!(?),
380 MeshPingPongMessage(10, pong_mesh_ref, done_tx.bind()),
381 )
382 .unwrap();
383
384 assert!(done_rx.recv().await.unwrap());
385 }
386
387 #[test]
388 fn test_actor_mesh_id_roundtrip() {
389 let mesh_ids = &[
390 ActorMeshId::V0(
391 ProcMeshId("proc_mesh".to_string()),
392 "actor_mesh".to_string(),
393 ),
394 ActorMeshId::V1(Name::new("testing")),
395 ];
396
397 for mesh_id in mesh_ids {
398 assert_eq!(
399 mesh_id,
400 &mesh_id.to_string().parse::<ActorMeshId>().unwrap()
401 );
402 }
403 }
404}