1use std::fmt;
10use std::hash::Hash;
11use std::hash::Hasher;
12use std::marker::PhantomData;
13use std::ops::Deref;
14use std::sync::OnceLock as OnceCell;
15
16use hyperactor::Actor;
17use hyperactor::ActorRef;
18use hyperactor::RemoteHandles;
19use hyperactor::RemoteMessage;
20use hyperactor::actor::Referable;
21use hyperactor::attrs::Attrs;
22use hyperactor::context;
23use hyperactor::message::Castable;
24use hyperactor::message::IndexedErasedUnbound;
25use hyperactor::message::Unbound;
26use hyperactor_mesh_macros::sel;
27use ndslice::Selection;
28use ndslice::ViewExt as _;
29use ndslice::view;
30use ndslice::view::Region;
31use ndslice::view::View;
32use serde::Deserialize;
33use serde::Deserializer;
34use serde::Serialize;
35use serde::Serializer;
36
37use crate::CommActor;
38use crate::actor_mesh as v0_actor_mesh;
39use crate::comm::multicast;
40use crate::proc_mesh::mesh_agent::ActorState;
41use crate::reference::ActorMeshId;
42use crate::resource;
43use crate::v1;
44use crate::v1::Error;
45use crate::v1::Name;
46use crate::v1::ProcMeshRef;
47use crate::v1::ValueMesh;
48
49#[derive(Debug)]
54pub struct ActorMesh<A: Referable> {
55 proc_mesh: ProcMeshRef,
56 name: Name,
57 current_ref: ActorMeshRef<A>,
58}
59
60impl<A: Referable> ActorMesh<A> {
63 pub(crate) fn new(proc_mesh: ProcMeshRef, name: Name) -> Self {
64 let current_ref =
65 ActorMeshRef::with_page_size(name.clone(), proc_mesh.clone(), DEFAULT_PAGE);
66
67 Self {
68 proc_mesh,
69 name,
70 current_ref,
71 }
72 }
73}
74
75impl<A: Referable> Deref for ActorMesh<A> {
76 type Target = ActorMeshRef<A>;
77
78 fn deref(&self) -> &Self::Target {
79 &self.current_ref
80 }
81}
82
83impl<A: Referable> Clone for ActorMesh<A> {
86 fn clone(&self) -> Self {
87 Self {
88 proc_mesh: self.proc_mesh.clone(),
89 name: self.name.clone(),
90 current_ref: self.current_ref.clone(),
91 }
92 }
93}
94
95const DEFAULT_PAGE: usize = 1024;
99
100struct Page<A: Referable> {
102 slots: Box<[OnceCell<ActorRef<A>>]>,
103}
104
105impl<A: Referable> Page<A> {
106 fn new(len: usize) -> Self {
107 let mut v = Vec::with_capacity(len);
108 for _ in 0..len {
109 v.push(OnceCell::new());
110 }
111 Self {
112 slots: v.into_boxed_slice(),
113 }
114 }
115}
116
117pub struct ActorMeshRef<A: Referable> {
119 proc_mesh: ProcMeshRef,
120 name: Name,
121
122 pages: OnceCell<Vec<OnceCell<Box<Page<A>>>>>,
132 page_size: usize,
134
135 _phantom: PhantomData<A>,
136}
137
138impl<A: Actor + Referable> ActorMeshRef<A> {
139 pub fn cast<M>(&self, cx: &impl context::Actor, message: M) -> v1::Result<()>
141 where
142 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
143 M: Castable + RemoteMessage + Clone, {
145 if let Some(root_comm_actor) = self.proc_mesh.root_comm_actor() {
146 self.cast_v0(cx, message, root_comm_actor)
147 } else {
148 for (point, actor) in self.iter() {
149 let create_rank = point.rank();
150 let mut headers = Attrs::new();
151 headers.set(
152 multicast::CAST_ORIGINATING_SENDER,
153 cx.instance().self_id().clone(),
154 );
155 headers.set(multicast::CAST_POINT, point);
156
157 let mut unbound = Unbound::try_from_message(message.clone())
160 .map_err(|e| Error::CastingError(self.name.clone(), e))?;
161 unbound
162 .visit_mut::<resource::Rank>(|resource::Rank(rank)| {
163 *rank = Some(create_rank);
164 Ok(())
165 })
166 .map_err(|e| Error::CastingError(self.name.clone(), e))?;
167 let rebound_message = unbound
168 .bind()
169 .map_err(|e| Error::CastingError(self.name.clone(), e))?;
170 actor
171 .send_with_headers(cx, headers, rebound_message)
172 .map_err(|e| Error::SendingError(actor.actor_id().clone(), Box::new(e)))?;
173 }
174 Ok(())
175 }
176 }
177
178 fn cast_v0<M>(
179 &self,
180 cx: &impl context::Actor,
181 message: M,
182 root_comm_actor: &ActorRef<CommActor>,
183 ) -> v1::Result<()>
184 where
185 A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
186 M: Castable + RemoteMessage + Clone, {
188 let cast_mesh_shape = view::Ranked::region(self).into();
189 let actor_mesh_id = ActorMeshId::V1(self.name.clone());
190 match &self.proc_mesh.root_region {
191 Some(root_region) => {
192 let root_mesh_shape = root_region.into();
193 v0_actor_mesh::cast_to_sliced_mesh::<A, M>(
194 cx,
195 actor_mesh_id,
196 root_comm_actor,
197 &sel!(*),
198 message,
199 &cast_mesh_shape,
200 &root_mesh_shape,
201 )
202 .map_err(|e| Error::CastingError(self.name.clone(), e.into()))
203 }
204 None => v0_actor_mesh::actor_mesh_cast::<A, M>(
205 cx,
206 actor_mesh_id,
207 root_comm_actor,
208 sel!(*),
209 &cast_mesh_shape,
210 &cast_mesh_shape,
211 message,
212 )
213 .map_err(|e| Error::CastingError(self.name.clone(), e.into())),
214 }
215 }
216
217 pub async fn actor_states(
218 &self,
219 cx: &impl context::Actor,
220 ) -> v1::Result<ValueMesh<resource::State<ActorState>>> {
221 self.proc_mesh.actor_states(cx, self.name.clone()).await
222 }
223}
224
225impl<A: Referable> ActorMeshRef<A> {
226 pub(crate) fn new(name: Name, proc_mesh: ProcMeshRef) -> Self {
227 Self::with_page_size(name, proc_mesh, DEFAULT_PAGE)
228 }
229
230 pub(crate) fn with_page_size(name: Name, proc_mesh: ProcMeshRef, page_size: usize) -> Self {
231 Self {
232 proc_mesh,
233 name,
234 pages: OnceCell::new(),
235 page_size: page_size.max(1),
236 _phantom: PhantomData,
237 }
238 }
239
240 #[inline]
241 fn len(&self) -> usize {
242 view::Ranked::region(&self.proc_mesh).num_ranks()
243 }
244
245 fn ensure_pages(&self) -> &Vec<OnceCell<Box<Page<A>>>> {
246 let n = self.len().div_ceil(self.page_size); self.pages
248 .get_or_init(|| (0..n).map(|_| OnceCell::new()).collect())
249 }
250
251 fn materialize(&self, rank: usize) -> Option<&ActorRef<A>> {
252 let len = self.len();
253 if rank >= len {
254 return None;
255 }
256 let p = self.page_size;
257 let page_ix = rank / p;
258 let local_ix = rank % p;
259
260 let pages = self.ensure_pages();
261 let page = pages[page_ix].get_or_init(|| {
262 let base = page_ix * p;
264 let remaining = len - base;
265 let page_len = remaining.min(p);
266 Box::new(Page::<A>::new(page_len))
267 });
268
269 Some(page.slots[local_ix].get_or_init(|| {
270 debug_assert!(rank < self.len(), "rank must be within [0, len)");
278 debug_assert!(
279 self.proc_mesh.get(rank).is_some(),
280 "proc_mesh must be dense/aligned with this view"
281 );
282 let proc_ref = self.proc_mesh.get(rank).expect("rank in-bounds");
283 proc_ref.attest(&self.name)
284 }))
285 }
286}
287
288impl<A: Referable> Clone for ActorMeshRef<A> {
289 fn clone(&self) -> Self {
290 Self {
291 proc_mesh: self.proc_mesh.clone(),
292 name: self.name.clone(),
293 pages: OnceCell::new(), page_size: self.page_size,
295 _phantom: PhantomData,
296 }
297 }
298}
299
300impl<A: Referable> PartialEq for ActorMeshRef<A> {
301 fn eq(&self, other: &Self) -> bool {
302 self.proc_mesh == other.proc_mesh && self.name == other.name
303 }
304}
305impl<A: Referable> Eq for ActorMeshRef<A> {}
306
307impl<A: Referable> Hash for ActorMeshRef<A> {
308 fn hash<H: Hasher>(&self, state: &mut H) {
309 self.proc_mesh.hash(state);
310 self.name.hash(state);
311 }
312}
313
314impl<A: Referable> fmt::Debug for ActorMeshRef<A> {
315 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316 f.debug_struct("ActorMeshRef")
317 .field("proc_mesh", &self.proc_mesh)
318 .field("name", &self.name)
319 .field("page_size", &self.page_size)
320 .finish_non_exhaustive() }
322}
323
324impl<A: Referable> Serialize for ActorMeshRef<A> {
326 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
327 where
328 S: Serializer,
329 {
330 (&self.proc_mesh, &self.name).serialize(serializer)
332 }
333}
334
335impl<'de, A: Referable> Deserialize<'de> for ActorMeshRef<A> {
337 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
338 where
339 D: Deserializer<'de>,
340 {
341 let (proc_mesh, name) = <(ProcMeshRef, Name)>::deserialize(deserializer)?;
342 Ok(ActorMeshRef::with_page_size(name, proc_mesh, DEFAULT_PAGE))
343 }
344}
345
346impl<A: Referable> view::Ranked for ActorMeshRef<A> {
347 type Item = ActorRef<A>;
348
349 #[inline]
350 fn region(&self) -> &Region {
351 view::Ranked::region(&self.proc_mesh)
352 }
353
354 #[inline]
355 fn get(&self, rank: usize) -> Option<&Self::Item> {
356 self.materialize(rank)
357 }
358}
359
360impl<A: Referable> view::RankedSliceable for ActorMeshRef<A> {
361 fn sliced(&self, region: Region) -> Self {
362 debug_assert!(region.is_subset(view::Ranked::region(self)));
363 let proc_mesh = self.proc_mesh.subset(region).unwrap();
364 Self::with_page_size(self.name.clone(), proc_mesh, self.page_size)
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use std::assert_matches::assert_matches;
371 use std::collections::HashSet;
372
373 use hyperactor::actor::ActorStatus;
374 use hyperactor::clock::Clock;
375 use hyperactor::clock::RealClock;
376 use hyperactor::context::Mailbox as _;
377 use hyperactor::mailbox;
378 use ndslice::Extent;
379 use ndslice::ViewExt;
380 use ndslice::extent;
381 use ndslice::view::Ranked;
382 use timed_test::async_timed_test;
383 use tokio::time::Duration;
384
385 use super::ActorMesh;
386 use crate::proc_mesh::mesh_agent::ActorState;
387 use crate::resource;
388 use crate::v1::ActorMeshRef;
389 use crate::v1::Name;
390 use crate::v1::ProcMesh;
391 use crate::v1::testactor;
392 use crate::v1::testing;
393
394 #[tokio::test]
395 async fn test_actor_mesh_ref_lazy_materialization() {
396 let instance = testing::instance().await;
398 let extent = extent!(replicas = 3, hosts = 2); let pm: ProcMesh = testing::proc_meshes(instance, extent.clone())
402 .await
403 .into_iter()
404 .next()
405 .expect("at least one proc mesh");
406 let am: ActorMesh<testactor::TestActor> = pm.spawn(instance, "test", &()).await.unwrap();
407
408 let page_size = 2;
412 let amr: ActorMeshRef<testactor::TestActor> =
413 ActorMeshRef::with_page_size(am.name.clone(), pm.clone(), page_size);
414 assert_eq!(amr.extent(), extent);
415 assert_eq!(amr.region().num_ranks(), 6);
416
417 let p0_a = amr.get(0).expect("rank 0 exists") as *const _;
419 let p0_b = amr.get(0).expect("rank 0 exists") as *const _;
420 assert_eq!(p0_a, p0_b, "same rank should return same cached pointer");
421
422 let p1_a = amr.get(1).expect("rank 1 exists") as *const _;
424 let p1_b = amr.get(1).expect("rank 1 exists") as *const _;
425 assert_eq!(p1_a, p1_b, "same rank should return same cached pointer");
426 assert_ne!(p0_a, p1_a, "different ranks have different cache slots");
429
430 let p2_a = amr.get(2).expect("rank 2 exists") as *const _;
432 let p2_b = amr.get(2).expect("rank 2 exists") as *const _;
433 assert_eq!(p2_a, p2_b, "same rank should return same cached pointer");
434 assert_ne!(p0_a, p2_a, "different pages have different cache slots");
435
436 let amr_clone = amr.clone();
438 let orig_id_0 = amr.get(0).unwrap().actor_id().clone();
439 let clone_id_0 = amr_clone.get(0).unwrap().actor_id().clone();
440 assert_eq!(orig_id_0, clone_id_0, "clone preserves identity");
441 let p0_clone = amr_clone.get(0).unwrap() as *const _;
442 assert_ne!(
443 p0_a, p0_clone,
444 "cloned ActorMeshRef has a fresh cache (different pointer)"
445 );
446
447 let sliced = amr.range("replicas", 1..).expect("slice should be valid"); assert_eq!(sliced.region().num_ranks(), 4);
451 let sp0_a = sliced.get(0).unwrap() as *const _;
453 let sp0_b = sliced.get(0).unwrap() as *const _;
454 assert_eq!(sp0_a, sp0_b, "sliced view has its own cache slot per rank");
455 let sp2 = sliced.get(2).unwrap() as *const _;
458 assert_ne!(sp0_a, sp2, "sliced view crosses its own page boundary");
459
460 let mut set = HashSet::new();
463 set.insert(amr.clone());
464 set.insert(amr.clone());
465 assert_eq!(set.len(), 1, "cache state must not affect Hash/Eq");
466
467 let (port, mut rx) = mailbox::open_port(instance);
470 amr.get(0)
473 .expect("rank 0 exists")
474 .send(instance, testactor::GetActorId(port.bind()))
475 .expect("send to rank 0 should succeed");
476 amr.get(3)
477 .expect("rank 3 exists")
478 .send(instance, testactor::GetActorId(port.bind()))
479 .expect("send to rank 3 should succeed");
480 let id_a = RealClock
481 .timeout(Duration::from_secs(3), rx.recv())
482 .await
483 .expect("timed out waiting for first reply")
484 .expect("channel closed before first reply");
485 let id_b = RealClock
486 .timeout(Duration::from_secs(3), rx.recv())
487 .await
488 .expect("timed out waiting for second reply")
489 .expect("channel closed before second reply");
490 assert_ne!(id_a, id_b, "two different ranks responded");
491 }
492
493 #[async_timed_test(timeout_secs = 30)]
494 async fn test_actor_states() {
495 hyperactor_telemetry::initialize_logging_for_test();
496
497 let instance = testing::instance().await;
498 let (supervision_port, mut supervision_receiver) =
500 instance.open_port::<resource::State<ActorState>>();
501 let supervisor = supervision_port.bind();
502 let num_replicas = 4;
503 let meshes = testing::proc_meshes(instance, extent!(replicas = num_replicas)).await;
504 let proc_mesh = &meshes[1];
505 let child_name = Name::new("child");
506
507 let actor_mesh = proc_mesh
508 .spawn_with_name::<testactor::TestActor>(instance, child_name.clone(), &())
509 .await
510 .unwrap();
511
512 actor_mesh
513 .cast(
514 instance,
515 testactor::CauseSupervisionEvent(testactor::SupervisionEventType::Panic),
516 )
517 .unwrap();
518
519 #[allow(clippy::disallowed_methods)]
523 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
524
525 let supervision_task = tokio::spawn(async move {
529 match actor_mesh.actor_states(&instance).await {
530 Ok(events) => {
531 for state in events.values() {
532 supervisor.send(instance, state.clone()).unwrap();
533 }
534 }
535 Err(e) => {
536 println!("error: {:?}", e);
537 }
538 };
539 });
540 supervision_task.await.unwrap();
542
543 for _ in 0..num_replicas {
544 let state = supervision_receiver.recv().await.unwrap();
545 if let resource::Status::Failed(s) = state.status {
546 assert!(s.contains("supervision events"));
547 } else {
548 panic!("Not failed: {:?}", state.status);
549 }
550 if let Some(ref inner) = state.state {
551 assert!(!inner.supervision_events.is_empty());
552 for event in &inner.supervision_events {
553 println!("receiving event: {:?}", event);
554 assert_eq!(event.actor_id.name(), format!("{}", child_name.clone()));
555 assert_matches!(event.actor_status, ActorStatus::Failed(_));
556 }
557 }
558 }
559 }
560
561 #[async_timed_test(timeout_secs = 30)]
562 async fn test_cast() {
563 let config = hyperactor::config::global::lock();
564 let _guard = config.override_key(crate::bootstrap::MESH_BOOTSTRAP_ENABLE_PDEATHSIG, false);
565
566 let instance = testing::instance().await;
567 let host_mesh = testing::host_mesh(extent!(host = 4)).await;
568 let proc_mesh = host_mesh
569 .spawn(instance, "test", Extent::unity())
570 .await
571 .unwrap();
572 let actor_mesh = proc_mesh
573 .spawn::<testactor::TestActor>(instance, "test", &())
574 .await
575 .unwrap();
576
577 let (cast_info, mut cast_info_rx) = instance.mailbox().open_port();
578 actor_mesh
579 .cast(
580 instance,
581 testactor::GetCastInfo {
582 cast_info: cast_info.bind(),
583 },
584 )
585 .unwrap();
586
587 let mut point_to_actor: HashSet<_> = actor_mesh.iter().collect();
588 while !point_to_actor.is_empty() {
589 let (point, origin_actor_ref, sender_actor_id) = cast_info_rx.recv().await.unwrap();
590 let key = (point, origin_actor_ref);
591 assert!(
592 point_to_actor.remove(&key),
593 "key {:?} not present or removed twice",
594 key
595 );
596 assert_eq!(&sender_actor_id, instance.self_id());
597 }
598
599 let _ = host_mesh.shutdown(&instance).await;
600 }
601}