1#![allow(dead_code)] use std::collections::BTreeSet;
12use std::ops::Deref;
13
14use async_trait::async_trait;
15use hyperactor::Actor;
16use hyperactor::ActorRef;
17use hyperactor::Bind;
18use hyperactor::GangId;
19use hyperactor::GangRef;
20use hyperactor::Message;
21use hyperactor::Named;
22use hyperactor::PortHandle;
23use hyperactor::RemoteHandles;
24use hyperactor::RemoteMessage;
25use hyperactor::Unbind;
26use hyperactor::WorldId;
27use hyperactor::actor::Referable;
28use hyperactor::attrs::Attrs;
29use hyperactor::attrs::declare_attrs;
30use hyperactor::config;
31use hyperactor::context;
32use hyperactor::mailbox::MailboxSenderError;
33use hyperactor::mailbox::PortReceiver;
34use hyperactor::message::Castable;
35use hyperactor::message::IndexedErasedUnbound;
36use hyperactor::supervision::ActorSupervisionEvent;
37use ndslice::Range;
38use ndslice::Selection;
39use ndslice::Shape;
40use ndslice::ShapeError;
41use ndslice::SliceError;
42use ndslice::reshape::Limit;
43use ndslice::reshape::ReshapeError;
44use ndslice::reshape::ReshapeSliceExt;
45use ndslice::reshape::reshape_selection;
46use ndslice::selection;
47use ndslice::selection::EvalOpts;
48use ndslice::selection::ReifySlice;
49use ndslice::selection::normal;
50use serde::Deserialize;
51use serde::Serialize;
52use serde_multipart::Part;
53use tokio::sync::mpsc;
54
55use crate::CommActor;
56use crate::Mesh;
57use crate::comm::multicast::CastMessage;
58use crate::comm::multicast::CastMessageEnvelope;
59use crate::comm::multicast::Uslice;
60use crate::config::MAX_CAST_DIMENSION_SIZE;
61use crate::metrics;
62use crate::proc_mesh::ProcMesh;
63use crate::reference::ActorMeshId;
64use crate::reference::ActorMeshRef;
65use crate::reference::ProcMeshId;
66
67declare_attrs! {
68 pub attr CAST_ACTOR_MESH_ID: ActorMeshId;
72}
73
74#[allow(clippy::result_large_err)] pub(crate) fn actor_mesh_cast<A, M>(
78 cx: &impl context::Actor,
79 actor_mesh_id: ActorMeshId,
80 comm_actor_ref: &ActorRef<CommActor>,
81 selection_of_root: Selection,
82 root_mesh_shape: &Shape,
83 cast_mesh_shape: &Shape,
84 message: M,
85) -> Result<(), CastError>
86where
87 A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
88 M: Castable + RemoteMessage,
89{
90 let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
91 "message_type" => M::typename(),
92 "message_variant" => message.arm().unwrap_or_default(),
93 ));
94
95 let message = CastMessageEnvelope::new::<A, M>(
96 actor_mesh_id.clone(),
97 cx.mailbox().actor_id().clone(),
98 cast_mesh_shape.clone(),
99 message,
100 )?;
101
102 let slice_of_root = root_mesh_shape.slice();
120
121 let max_cast_dimension_size = config::global::get(MAX_CAST_DIMENSION_SIZE);
122
123 let slice_of_cast = slice_of_root.reshape_with_limit(Limit::from(max_cast_dimension_size));
124
125 let selection_of_cast =
126 reshape_selection(selection_of_root, root_mesh_shape.slice(), &slice_of_cast)?;
127
128 let cast_message = CastMessage {
129 dest: Uslice {
130 slice: slice_of_cast,
131 selection: selection_of_cast,
132 },
133 message,
134 };
135
136 let mut headers = Attrs::new();
137 headers.set(CAST_ACTOR_MESH_ID, actor_mesh_id);
138
139 comm_actor_ref
140 .port()
141 .send_with_headers(cx, headers, cast_message)?;
142
143 Ok(())
144}
145
146#[allow(clippy::result_large_err)] pub(crate) fn cast_to_sliced_mesh<A, M>(
148 cx: &impl context::Actor,
149 actor_mesh_id: ActorMeshId,
150 comm_actor_ref: &ActorRef<CommActor>,
151 sel_of_sliced: &Selection,
152 message: M,
153 sliced_shape: &Shape,
154 root_mesh_shape: &Shape,
155) -> Result<(), CastError>
156where
157 A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
158 M: Castable + RemoteMessage,
159{
160 let root_slice = root_mesh_shape.slice();
161
162 let sel_of_root = if selection::normalize(sel_of_sliced) == normal::NormalizedSelection::True {
164 root_slice.reify_slice(sliced_shape.slice())?
166 } else {
167 let ranks = sel_of_sliced
169 .eval(&EvalOpts::strict(), sliced_shape.slice())?
170 .collect::<BTreeSet<_>>();
171 Selection::of_ranks(root_slice, &ranks)?
172 };
173
174 actor_mesh_cast::<A, M>(
176 cx,
177 actor_mesh_id,
178 comm_actor_ref,
179 sel_of_root,
180 root_mesh_shape,
181 sliced_shape,
182 message,
183 )
184}
185
186#[async_trait]
188pub trait ActorMesh: Mesh<Id = ActorMeshId> {
189 type Actor: Referable;
191
192 #[allow(clippy::result_large_err)] fn cast<M>(
196 &self,
197 cx: &impl context::Actor,
198 selection: Selection,
199 message: M,
200 ) -> Result<(), CastError>
201 where
202 Self::Actor: RemoteHandles<IndexedErasedUnbound<M>>,
203 M: Castable + RemoteMessage,
204 {
205 actor_mesh_cast::<Self::Actor, M>(
206 cx, self.id(), self.proc_mesh().comm_actor(), selection, self.shape(), self.shape(), message, )
214 }
215
216 fn proc_mesh(&self) -> &ProcMesh;
218
219 fn name(&self) -> &str;
221
222 fn world_id(&self) -> &WorldId {
223 self.proc_mesh().world_id()
224 }
225
226 fn iter_actor_refs(&self) -> impl Iterator<Item = ActorRef<Self::Actor>> {
228 let gang: GangRef<Self::Actor> =
229 GangId(self.proc_mesh().world_id().clone(), self.name().to_string()).into();
230 self.shape().slice().iter().map(move |rank| gang.rank(rank))
231 }
232
233 async fn stop(&self) -> Result<(), anyhow::Error> {
234 self.proc_mesh().stop_actor_by_name(self.name()).await
235 }
236
237 fn bind(&self) -> ActorMeshRef<Self::Actor> {
239 ActorMeshRef::attest(
240 ActorMeshId::V0(
241 ProcMeshId(self.world_id().to_string()),
242 self.name().to_string(),
243 ),
244 self.shape().clone(),
245 self.proc_mesh().comm_actor().clone(),
246 )
247 }
248}
249
250enum ProcMeshRef<'a> {
254 Shared(Box<dyn Deref<Target = ProcMesh> + Sync + Send>),
256 Borrowed(&'a ProcMesh),
259}
260
261impl Deref for ProcMeshRef<'_> {
262 type Target = ProcMesh;
263
264 fn deref(&self) -> &Self::Target {
265 match self {
266 Self::Shared(p) => p,
267 Self::Borrowed(p) => p, }
269 }
270}
271
272pub struct RootActorMesh<'a, A: Referable> {
279 proc_mesh: ProcMeshRef<'a>,
280 name: String,
281 pub(crate) ranks: Vec<ActorRef<A>>, actor_supervision_rx: Option<mpsc::UnboundedReceiver<ActorSupervisionEvent>>,
285}
286
287impl<'a, A: Referable> RootActorMesh<'a, A> {
288 pub(crate) fn new(
289 proc_mesh: &'a ProcMesh,
290 name: String,
291 actor_supervision_rx: mpsc::UnboundedReceiver<ActorSupervisionEvent>,
292 ranks: Vec<ActorRef<A>>,
293 ) -> Self {
294 Self {
295 proc_mesh: ProcMeshRef::Borrowed(proc_mesh),
296 name,
297 ranks,
298 actor_supervision_rx: Some(actor_supervision_rx),
299 }
300 }
301
302 pub(crate) fn new_shared<D: Deref<Target = ProcMesh> + Send + Sync + 'static>(
303 proc_mesh: D,
304 name: String,
305 actor_supervision_rx: mpsc::UnboundedReceiver<ActorSupervisionEvent>,
306 ranks: Vec<ActorRef<A>>,
307 ) -> Self {
308 Self {
309 proc_mesh: ProcMeshRef::Shared(Box::new(proc_mesh)),
310 name,
311 ranks,
312 actor_supervision_rx: Some(actor_supervision_rx),
313 }
314 }
315
316 pub fn open_port<M: Message>(&self) -> (PortHandle<M>, PortReceiver<M>) {
318 self.proc_mesh.client().open_port()
319 }
320
321 pub fn events(&mut self) -> Option<ActorSupervisionEvents> {
324 self.actor_supervision_rx
325 .take()
326 .map(|actor_supervision_rx| ActorSupervisionEvents {
327 actor_supervision_rx,
328 mesh_id: self.id(),
329 })
330 }
331}
332
333pub struct ActorSupervisionEvents {
335 actor_supervision_rx: mpsc::UnboundedReceiver<ActorSupervisionEvent>,
337 mesh_id: ActorMeshId,
339}
340
341impl ActorSupervisionEvents {
342 pub async fn next(&mut self) -> Option<ActorSupervisionEvent> {
343 let result = self.actor_supervision_rx.recv().await;
344 if result.is_none() {
345 tracing::info!(
346 "supervision stream for actor mesh {:?} was closed!",
347 self.mesh_id
348 );
349 }
350 result
351 }
352}
353
354#[async_trait]
355impl<'a, A: Referable> Mesh for RootActorMesh<'a, A> {
356 type Node = ActorRef<A>;
357 type Id = ActorMeshId;
358 type Sliced<'b>
359 = SlicedActorMesh<'b, A>
360 where
361 'a: 'b;
362
363 fn shape(&self) -> &Shape {
364 self.proc_mesh.shape()
365 }
366
367 fn select<R: Into<Range>>(
368 &self,
369 label: &str,
370 range: R,
371 ) -> Result<Self::Sliced<'_>, ShapeError> {
372 Ok(SlicedActorMesh(self, self.shape().select(label, range)?))
373 }
374
375 fn get(&self, rank: usize) -> Option<ActorRef<A>> {
376 self.ranks.get(rank).cloned()
377 }
378
379 fn id(&self) -> Self::Id {
380 ActorMeshId::V0(self.proc_mesh.id(), self.name.clone())
381 }
382}
383
384impl<A: Referable> ActorMesh for RootActorMesh<'_, A> {
385 type Actor = A;
386
387 fn proc_mesh(&self) -> &ProcMesh {
388 &self.proc_mesh
389 }
390
391 fn name(&self) -> &str {
392 &self.name
393 }
394}
395
396pub struct SlicedActorMesh<'a, A: Referable>(&'a RootActorMesh<'a, A>, Shape);
397
398impl<'a, A: Referable> SlicedActorMesh<'a, A> {
399 pub fn new(actor_mesh: &'a RootActorMesh<'a, A>, shape: Shape) -> Self {
400 Self(actor_mesh, shape)
401 }
402
403 pub fn shape(&self) -> &Shape {
404 &self.1
405 }
406}
407
408#[async_trait]
409impl<A: Referable> Mesh for SlicedActorMesh<'_, A> {
410 type Node = ActorRef<A>;
411 type Id = ActorMeshId;
412 type Sliced<'b>
413 = SlicedActorMesh<'b, A>
414 where
415 Self: 'b;
416
417 fn shape(&self) -> &Shape {
418 &self.1
419 }
420
421 fn select<R: Into<Range>>(
422 &self,
423 label: &str,
424 range: R,
425 ) -> Result<Self::Sliced<'_>, ShapeError> {
426 Ok(Self(self.0, self.1.select(label, range)?))
427 }
428
429 fn get(&self, _index: usize) -> Option<ActorRef<A>> {
430 unimplemented!()
431 }
432
433 fn id(&self) -> Self::Id {
434 self.0.id()
435 }
436}
437
438impl<A: Referable> ActorMesh for SlicedActorMesh<'_, A> {
439 type Actor = A;
440
441 fn proc_mesh(&self) -> &ProcMesh {
442 &self.0.proc_mesh
443 }
444
445 fn name(&self) -> &str {
446 &self.0.name
447 }
448
449 #[allow(clippy::result_large_err)] fn cast<M>(&self, cx: &impl context::Actor, sel: Selection, message: M) -> Result<(), CastError>
451 where
452 Self::Actor: RemoteHandles<IndexedErasedUnbound<M>>,
453 M: Castable + RemoteMessage,
454 {
455 cast_to_sliced_mesh::<A, M>(
456 cx,
457 self.id(),
458 self.proc_mesh().comm_actor(),
459 &sel,
460 message,
461 self.shape(),
462 self.0.shape(),
463 )
464 }
465}
466
467#[derive(Debug, thiserror::Error)]
469pub enum CastError {
470 #[error("invalid selection {0}: {1}")]
471 InvalidSelection(Selection, ShapeError),
472
473 #[error("send on rank {0}: {1}")]
474 MailboxSenderError(usize, MailboxSenderError),
475
476 #[error(transparent)]
477 RootMailboxSenderError(#[from] MailboxSenderError),
478
479 #[error(transparent)]
480 ShapeError(#[from] ShapeError),
481
482 #[error(transparent)]
483 SliceError(#[from] SliceError),
484
485 #[error(transparent)]
486 SerializationError(#[from] bincode::Error),
487
488 #[error(transparent)]
489 Other(#[from] anyhow::Error),
490
491 #[error(transparent)]
492 ReshapeError(#[from] ReshapeError),
493}
494
495pub(crate) mod test_util {
498 use std::collections::VecDeque;
499 use std::fmt;
500 use std::fmt::Debug;
501 use std::sync::Arc;
502
503 use anyhow::ensure;
504 use hyperactor::Context;
505 use hyperactor::Handler;
506 use hyperactor::PortRef;
507 use ndslice::extent;
508
509 use super::*;
510 use crate::comm::multicast::CastInfo;
511
512 #[derive(Debug, Default, Actor)]
517 #[hyperactor::export(
518 spawn = true,
519 handlers = [
520 Echo { cast = true },
521 Payload { cast = true },
522 GetRank { cast = true },
523 Error { cast = true },
524 Relay,
525 ],
526 )]
527 pub struct TestActor;
528
529 #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
540 pub struct GetRank(pub bool, #[binding(include)] pub PortRef<usize>);
541
542 #[async_trait]
543 impl Handler<GetRank> for TestActor {
544 async fn handle(
545 &mut self,
546 cx: &Context<Self>,
547 GetRank(ok, reply): GetRank,
548 ) -> Result<(), anyhow::Error> {
549 let point = cx.cast_point();
550 reply.send(cx, point.rank())?;
551 anyhow::ensure!(ok, "intentional error!"); Ok(())
553 }
554 }
555
556 #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
557 pub struct Echo(pub String, #[binding(include)] pub PortRef<String>);
558
559 #[async_trait]
560 impl Handler<Echo> for TestActor {
561 async fn handle(&mut self, cx: &Context<Self>, message: Echo) -> Result<(), anyhow::Error> {
562 let Echo(message, reply_port) = message;
563 reply_port.send(cx, message)?;
564 Ok(())
565 }
566 }
567
568 #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
569 pub struct Payload {
570 pub part: Part,
571 #[binding(include)]
572 pub reply_port: PortRef<()>,
573 }
574
575 #[async_trait]
576 impl Handler<Payload> for TestActor {
577 async fn handle(
578 &mut self,
579 cx: &Context<Self>,
580 message: Payload,
581 ) -> Result<(), anyhow::Error> {
582 let Payload { reply_port, .. } = message;
583 reply_port.send(cx, ())?;
584 Ok(())
585 }
586 }
587
588 #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)]
589 pub struct Error(pub String);
590
591 #[async_trait]
592 impl Handler<Error> for TestActor {
593 async fn handle(
594 &mut self,
595 _cx: &Context<Self>,
596 Error(error): Error,
597 ) -> Result<(), anyhow::Error> {
598 Err(anyhow::anyhow!("{}", error))
599 }
600 }
601
602 #[derive(Debug, Serialize, Deserialize, Named, Clone)]
603 pub struct Relay(pub usize, pub VecDeque<PortRef<Relay>>);
604
605 #[async_trait]
606 impl Handler<Relay> for TestActor {
607 async fn handle(
608 &mut self,
609 cx: &Context<Self>,
610 Relay(count, mut hops): Relay,
611 ) -> Result<(), anyhow::Error> {
612 ensure!(!hops.is_empty(), "relay must have at least one hop");
613 let next = hops.pop_front().unwrap();
614 next.send(cx, Relay(count + 1, hops))?;
615 Ok(())
616 }
617 }
618
619 #[hyperactor::export(
622 spawn = true,
623 handlers = [
624 Echo,
625 ],
626 )]
627 pub struct ProxyActor {
628 proc_mesh: Arc<ProcMesh>,
629 actor_mesh: RootActorMesh<'static, TestActor>,
630 }
631
632 impl fmt::Debug for ProxyActor {
633 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
634 f.debug_struct("ProxyActor")
635 .field("proc_mesh", &"...")
636 .field("actor_mesh", &"...")
637 .finish()
638 }
639 }
640
641 #[async_trait]
642 impl Actor for ProxyActor {
643 type Params = ();
644
645 async fn new(_params: Self::Params) -> Result<Self, anyhow::Error> {
646 use std::sync::Arc;
648
649 use hyperactor::channel::ChannelTransport;
650
651 use crate::alloc::AllocSpec;
652 use crate::alloc::Allocator;
653 use crate::alloc::LocalAllocator;
654
655 let mut allocator = LocalAllocator;
656 let alloc = allocator
657 .allocate(AllocSpec {
658 extent: extent! { replica = 1 },
659 constraints: Default::default(),
660 proc_name: None,
661 transport: ChannelTransport::Local,
662 })
663 .await
664 .unwrap();
665 let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap());
666 let leaked: &'static Arc<ProcMesh> = Box::leak(Box::new(proc_mesh));
667 let actor_mesh: RootActorMesh<'static, TestActor> =
668 leaked.spawn("echo", &()).await.unwrap();
669 Ok(Self {
670 proc_mesh: Arc::clone(leaked),
671 actor_mesh,
672 })
673 }
674 }
675
676 #[async_trait]
677 impl Handler<Echo> for ProxyActor {
678 async fn handle(&mut self, cx: &Context<Self>, message: Echo) -> Result<(), anyhow::Error> {
679 if std::env::var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK").is_err() {
680 let actor = self.actor_mesh.get(0).unwrap();
683
684 let (tx, mut rx) = cx.open_port();
687
688 actor.send(cx, Echo(message.0, tx.bind()))?;
689 message.1.send(cx, rx.recv().await.unwrap())?;
690
691 Ok(())
692 } else {
693 let actor: ActorRef<_> = self.actor_mesh.get(0).unwrap();
696 let (tx, mut rx) = cx.open_port::<String>();
697 actor.send(cx, Echo(message.0, tx.bind()))?;
698
699 use tokio::time::Duration;
700 use tokio::time::timeout;
701 #[allow(clippy::disallowed_methods)]
702 if let Ok(_) = timeout(Duration::from_secs(1), rx.recv()).await {
703 message
704 .1
705 .send(cx, "the impossible happened".to_owned())
706 .unwrap()
707 }
708
709 Ok(())
710 }
711 }
712 }
713}
714
715#[cfg(test)]
716mod tests {
717 use std::sync::Arc;
718
719 use hyperactor::ActorId;
720 use hyperactor::PortRef;
721 use hyperactor::ProcId;
722 use hyperactor::WorldId;
723 use hyperactor::attrs::Attrs;
724 use hyperactor::data::Encoding;
725 use timed_test::async_timed_test;
726
727 use super::*;
728 use crate::proc_mesh::ProcEvent;
729
730 #[macro_export]
732 macro_rules! actor_mesh_test_suite {
733 ($allocator:expr) => {
734 use std::assert_matches::assert_matches;
735
736 use ndslice::extent;
737 use $crate::alloc::AllocSpec;
738 use $crate::alloc::Allocator;
739 use $crate::assign::Ranks;
740 use $crate::sel_from_shape;
741 use $crate::sel;
742 use $crate::comm::multicast::set_cast_info_on_headers;
743 use $crate::proc_mesh::SharedSpawnable;
744 use std::collections::VecDeque;
745 use hyperactor::data::Serialized;
746 use $crate::proc_mesh::default_transport;
747
748 use super::*;
749 use super::test_util::*;
750
751 #[tokio::test]
752 async fn test_proxy_mesh() {
753 use super::test_util::*;
754 use $crate::alloc::AllocSpec;
755 use $crate::alloc::Allocator;
756
757 use ndslice::extent;
758
759 let alloc = $allocator
760 .allocate(AllocSpec {
761 extent: extent! { replica = 1 },
762 constraints: Default::default(),
763 proc_name: None,
764 transport: default_transport()
765 })
766 .await
767 .unwrap();
768 let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
769 let actor_mesh: RootActorMesh<'_, ProxyActor> = proc_mesh.spawn("proxy", &()).await.unwrap();
770 let proxy_actor = actor_mesh.get(0).unwrap();
771 let (tx, mut rx) = actor_mesh.open_port::<String>();
772 proxy_actor.send(proc_mesh.client(), Echo("hello!".to_owned(), tx.bind())).unwrap();
773
774 #[allow(clippy::disallowed_methods)]
775 match tokio::time::timeout(tokio::time::Duration::from_secs(3), rx.recv()).await {
776 Ok(msg) => assert_eq!(&msg.unwrap(), "hello!"),
777 Err(_) => assert!(false),
778 }
779 }
780
781 #[tokio::test]
782 async fn test_basic() {
783 let alloc = $allocator
784 .allocate(AllocSpec {
785 extent: extent!(replica = 4),
786 constraints: Default::default(),
787 proc_name: None,
788 transport: default_transport()
789 })
790 .await
791 .unwrap();
792
793 let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
794 let actor_mesh: RootActorMesh<TestActor> = proc_mesh.spawn("echo", &()).await.unwrap();
795 let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
796 actor_mesh
797 .cast(proc_mesh.client(), sel!(*), Echo("Hello".to_string(), reply_handle.bind()))
798 .unwrap();
799 for _ in 0..4 {
800 assert_eq!(&reply_receiver.recv().await.unwrap(), "Hello");
801 }
802 }
803
804 #[tokio::test]
805 async fn test_ping_pong() {
806 use hyperactor::test_utils::pingpong::PingPongActor;
807 use hyperactor::test_utils::pingpong::PingPongMessage;
808 use hyperactor::test_utils::pingpong::PingPongActorParams;
809
810 let alloc = $allocator
811 .allocate(AllocSpec {
812 extent: extent!(replica = 2),
813 constraints: Default::default(),
814 proc_name: None,
815 transport: default_transport(),
816 })
817 .await
818 .unwrap();
819 let mesh = ProcMesh::allocate(alloc).await.unwrap();
820
821 let (undeliverable_msg_tx, _) = mesh.client().open_port();
822 let ping_pong_actor_params = PingPongActorParams::new(Some(undeliverable_msg_tx.bind()), None);
823 let actor_mesh: RootActorMesh<PingPongActor> = mesh
824 .spawn::<PingPongActor>("ping-pong", &ping_pong_actor_params)
825 .await
826 .unwrap();
827
828 let ping: ActorRef<PingPongActor> = actor_mesh.get(0).unwrap();
829 let pong: ActorRef<PingPongActor> = actor_mesh.get(1).unwrap();
830 let (done_tx, done_rx) = mesh.client().open_once_port();
831 ping.send(mesh.client(), PingPongMessage(4, pong.clone(), done_tx.bind())).unwrap();
832
833 assert!(done_rx.recv().await.unwrap());
834 }
835
836 #[tokio::test]
837 async fn test_pingpong_full_mesh() {
838 use hyperactor::test_utils::pingpong::PingPongActor;
839 use hyperactor::test_utils::pingpong::PingPongActorParams;
840 use hyperactor::test_utils::pingpong::PingPongMessage;
841
842 use futures::future::join_all;
843
844 const X: usize = 3;
845 const Y: usize = 3;
846 const Z: usize = 3;
847 let alloc = $allocator
848 .allocate(AllocSpec {
849 extent: extent!(x = X, y = Y, z = Z),
850 constraints: Default::default(),
851 proc_name: None,
852 transport: default_transport(),
853 })
854 .await
855 .unwrap();
856
857 let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
858 let (undeliverable_tx, _undeliverable_rx) = proc_mesh.client().open_port();
859 let params = PingPongActorParams::new(Some(undeliverable_tx.bind()), None);
860 let actor_mesh = proc_mesh.spawn::<PingPongActor>("pingpong", ¶ms).await.unwrap();
861 let slice = actor_mesh.shape().slice();
862
863 let mut futures = Vec::new();
864 for rank in slice.iter() {
865 let actor = actor_mesh.get(rank).unwrap();
866 let coords = (&slice.coordinates(rank).unwrap()[..]).try_into().unwrap();
867 let sizes = (&slice.sizes())[..].try_into().unwrap();
868 let neighbors = ndslice::utils::stencil::moore_neighbors::<3>();
869 for neighbor_coords in ndslice::utils::apply_stencil(&coords, sizes, &neighbors) {
870 if let Ok(neighbor_rank) = slice.location(&neighbor_coords) {
871 let neighbor = actor_mesh.get(neighbor_rank).unwrap();
872 let (done_tx, done_rx) = proc_mesh.client().open_once_port();
873 actor
874 .send(
875 proc_mesh.client(),
876 PingPongMessage(4, neighbor.clone(), done_tx.bind()),
877 )
878 .unwrap();
879 futures.push(done_rx.recv());
880 }
881 }
882 }
883 let results = join_all(futures).await;
884 assert_eq!(results.len(), 316); for result in results {
886 assert_eq!(result.unwrap(), true);
887 }
888 }
889
890 #[tokio::test]
891 async fn test_cast() {
892 let alloc = $allocator
893 .allocate(AllocSpec {
894 extent: extent!(replica = 2, host = 2, gpu = 8),
895 constraints: Default::default(),
896 proc_name: None,
897 transport: default_transport(),
898 })
899 .await
900 .unwrap();
901
902 let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
903 let actor_mesh: RootActorMesh<TestActor> = proc_mesh.spawn("echo", &()).await.unwrap();
904 let dont_simulate_error = true;
905 let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
906 actor_mesh
907 .cast(proc_mesh.client(), sel!(*), GetRank(dont_simulate_error, reply_handle.bind()))
908 .unwrap();
909 let mut ranks = Ranks::new(actor_mesh.shape().slice().len());
910 while !ranks.is_full() {
911 let rank = reply_receiver.recv().await.unwrap();
912 assert!(ranks.insert(rank, rank).is_none(), "duplicate rank {rank}");
913 }
914 let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
916 actor_mesh
917 .cast(
918 proc_mesh.client(),
919 sel_from_shape!(actor_mesh.shape(), replica = 0, host = 0),
920 GetRank(dont_simulate_error, reply_handle.bind()),
921 )
922 .unwrap();
923 let mut ranks = Ranks::new(8);
924 while !ranks.is_full() {
925 let rank = reply_receiver.recv().await.unwrap();
926 assert!(ranks.insert(rank, rank).is_none(), "duplicate rank {rank}");
927 }
928 }
929
930 #[tokio::test]
931 async fn test_inter_actor_comms() {
932 let alloc = $allocator
933 .allocate(AllocSpec {
934 extent: extent!(replica = 2, host = 2, gpu = 8),
938 constraints: Default::default(),
939 proc_name: None,
940 transport: default_transport(),
941 })
942 .await
943 .unwrap();
944
945 let proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
946 let actor_mesh: RootActorMesh<TestActor> = proc_mesh.spawn("echo", &()).await.unwrap();
947
948 let mut hops: VecDeque<_> = actor_mesh.iter().map(|actor| actor.port()).collect();
950 let (handle, mut rx) = proc_mesh.client().open_port();
951 hops.push_back(handle.bind());
952 hops.pop_front()
953 .unwrap()
954 .send(proc_mesh.client(), Relay(0, hops))
955 .unwrap();
956 assert_matches!(
957 rx.recv().await.unwrap(),
958 Relay(count, hops)
959 if count == actor_mesh.shape().slice().len()
960 && hops.is_empty());
961 }
962
963 #[tokio::test]
964 async fn test_inter_proc_mesh_comms() {
965 let mut meshes = Vec::new();
966 for _ in 0..2 {
967 let alloc = $allocator
968 .allocate(AllocSpec {
969 extent: extent!(replica = 1),
970 constraints: Default::default(),
971 proc_name: None,
972 transport: default_transport(),
973 })
974 .await
975 .unwrap();
976
977 let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap());
978 let proc_mesh_clone = Arc::clone(&proc_mesh);
979 let actor_mesh : RootActorMesh<TestActor> = proc_mesh_clone.spawn("echo", &()).await.unwrap();
980 meshes.push((proc_mesh, actor_mesh));
981 }
982
983 let mut hops: VecDeque<_> = meshes
984 .iter()
985 .flat_map(|(_proc_mesh, actor_mesh)| actor_mesh.iter())
986 .map(|actor| actor.port())
987 .collect();
988 let num_hops = hops.len();
989
990 let client = meshes[0].0.client();
991 let (handle, mut rx) = client.open_port();
992 hops.push_back(handle.bind());
993 hops.pop_front()
994 .unwrap()
995 .send(client, Relay(0, hops))
996 .unwrap();
997 assert_matches!(
998 rx.recv().await.unwrap(),
999 Relay(count, hops)
1000 if count == num_hops
1001 && hops.is_empty());
1002 }
1003
1004 #[async_timed_test(timeout_secs = 60)]
1005 async fn test_actor_mesh_cast() {
1006 use $crate::sel;
1010 use $crate::comm::test_utils::TestActor as CastTestActor;
1011 use $crate::comm::test_utils::TestActorParams as CastTestActorParams;
1012 use $crate::comm::test_utils::TestMessage as CastTestMessage;
1013
1014 let extent = extent!(replica = 4, host = 4, gpu = 4);
1015 let num_actors = extent.len();
1016 let alloc = $allocator
1017 .allocate(AllocSpec {
1018 extent,
1019 constraints: Default::default(),
1020 proc_name: None,
1021 transport: default_transport(),
1022 })
1023 .await
1024 .unwrap();
1025
1026 let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1027
1028 let (tx, mut rx) = hyperactor::mailbox::open_port(proc_mesh.client());
1029 let params = CastTestActorParams{ forward_port: tx.bind() };
1030 let actor_mesh: RootActorMesh<CastTestActor> = proc_mesh.spawn("actor", ¶ms).await.unwrap();
1031
1032 actor_mesh.cast(proc_mesh.client(), sel!(*), CastTestMessage::Forward("abc".to_string())).unwrap();
1033
1034 for _ in 0..num_actors {
1035 assert_eq!(rx.recv().await.unwrap(), CastTestMessage::Forward("abc".to_string()));
1036 }
1037
1038 proc_mesh.events().unwrap().into_alloc().stop_and_wait().await.unwrap();
1043 }
1044
1045 #[tokio::test]
1046 async fn test_delivery_failure() {
1047 let alloc = $allocator
1048 .allocate(AllocSpec {
1049 extent: extent!(replica = 1 ),
1050 constraints: Default::default(),
1051 proc_name: None,
1052 transport: default_transport(),
1053 })
1054 .await
1055 .unwrap();
1056
1057 let name = alloc.name().to_string();
1058 let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
1059 let mut events = mesh.events().unwrap();
1060
1061 let unmonitored_reply_to = mesh.client().open_port::<usize>().0.bind();
1063 let bad_actor = ActorRef::<TestActor>::attest(ActorId(ProcId::Ranked(WorldId(name.clone()), 0), "foo".into(), 0));
1064 bad_actor.send(mesh.client(), GetRank(true, unmonitored_reply_to)).unwrap();
1065
1066 assert_matches!(
1068 events.next().await.unwrap(),
1069 ProcEvent::Crashed(0, reason) if reason.contains("failed: message not delivered")
1070 );
1071
1072 }
1074
1075 #[tokio::test]
1076 async fn test_send_with_headers() {
1077 let extent = extent!(replica = 3);
1078 let alloc = $allocator
1079 .allocate(AllocSpec {
1080 extent: extent.clone(),
1081 constraints: Default::default(),
1082 proc_name: None,
1083 transport: default_transport(),
1084 })
1085 .await
1086 .unwrap();
1087
1088 let mesh = ProcMesh::allocate(alloc).await.unwrap();
1089 let (reply_port_handle, mut reply_port_receiver) = mesh.client().open_port::<usize>();
1090 let reply_port = reply_port_handle.bind();
1091
1092 let actor_mesh: RootActorMesh<TestActor> = mesh.spawn("test", &()).await.unwrap();
1093 let actor_ref = actor_mesh.get(0).unwrap();
1094 let mut headers = Attrs::new();
1095 set_cast_info_on_headers(&mut headers, extent.point_of_rank(0).unwrap(), mesh.client().self_id().clone());
1096 actor_ref.send_with_headers(mesh.client(), headers.clone(), GetRank(true, reply_port.clone())).unwrap();
1097 assert_eq!(0, reply_port_receiver.recv().await.unwrap());
1098
1099 set_cast_info_on_headers(&mut headers, extent.point_of_rank(1).unwrap(), mesh.client().self_id().clone());
1100 actor_ref.port()
1101 .send_with_headers(mesh.client(), headers.clone(), GetRank(true, reply_port.clone()))
1102 .unwrap();
1103 assert_eq!(1, reply_port_receiver.recv().await.unwrap());
1104
1105 set_cast_info_on_headers(&mut headers, extent.point_of_rank(2).unwrap(), mesh.client().self_id().clone());
1106 actor_ref.actor_id()
1107 .port_id(GetRank::port())
1108 .send_with_headers(
1109 mesh.client(),
1110 Serialized::serialize(&GetRank(true, reply_port)).unwrap(),
1111 headers
1112 );
1113 assert_eq!(2, reply_port_receiver.recv().await.unwrap());
1114 }
1116 }
1117 }
1118
1119 mod local {
1120 use hyperactor::channel::ChannelTransport;
1121
1122 use crate::alloc::local::LocalAllocator;
1123
1124 actor_mesh_test_suite!(LocalAllocator);
1125
1126 #[tokio::test]
1127 async fn test_send_failure() {
1128 hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default());
1129
1130 use hyperactor::test_utils::pingpong::PingPongActor;
1131 use hyperactor::test_utils::pingpong::PingPongActorParams;
1132 use hyperactor::test_utils::pingpong::PingPongMessage;
1133
1134 use crate::alloc::ProcStopReason;
1135 use crate::proc_mesh::ProcEvent;
1136
1137 let config = hyperactor::config::global::lock();
1138 let _guard = config.override_key(
1139 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1140 tokio::time::Duration::from_secs(1),
1141 );
1142
1143 let alloc = LocalAllocator
1144 .allocate(AllocSpec {
1145 extent: extent!(replica = 2),
1146 constraints: Default::default(),
1147 proc_name: None,
1148 transport: ChannelTransport::Local,
1149 })
1150 .await
1151 .unwrap();
1152 let monkey = alloc.chaos_monkey();
1153 let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
1154 let mut events = mesh.events().unwrap();
1155
1156 let ping_pong_actor_params = PingPongActorParams::new(
1157 Some(PortRef::attest_message_port(mesh.client().self_id())),
1158 None,
1159 );
1160 let actor_mesh: RootActorMesh<PingPongActor> = mesh
1161 .spawn::<PingPongActor>("ping-pong", &ping_pong_actor_params)
1162 .await
1163 .unwrap();
1164
1165 let ping: ActorRef<PingPongActor> = actor_mesh.get(0).unwrap();
1166 let pong: ActorRef<PingPongActor> = actor_mesh.get(1).unwrap();
1167
1168 monkey(0, ProcStopReason::Killed(0, false));
1170 assert_matches!(
1171 events.next().await.unwrap(),
1172 ProcEvent::Stopped(0, ProcStopReason::Killed(0, false))
1173 );
1174
1175 let (unmonitored_done_tx, _) = mesh.client().open_once_port();
1178 ping.send(
1179 mesh.client(),
1180 PingPongMessage(1, pong.clone(), unmonitored_done_tx.bind()),
1181 )
1182 .unwrap();
1183
1184 assert_matches!(
1186 events.next().await.unwrap(),
1187 ProcEvent::Crashed(0, reason) if reason.contains("failed: message not delivered")
1188 );
1189
1190 let (unmonitored_done_tx, _) = mesh.client().open_once_port();
1193 pong.send(
1194 mesh.client(),
1195 PingPongMessage(1, ping.clone(), unmonitored_done_tx.bind()),
1196 )
1197 .unwrap();
1198
1199 assert_matches!(
1201 events.next().await.unwrap(),
1202 ProcEvent::Crashed(0, reason) if reason.contains("failed: message not delivered")
1203 );
1204 }
1205
1206 #[tokio::test]
1207 async fn test_cast_failure() {
1208 use crate::alloc::ProcStopReason;
1209 use crate::proc_mesh::ProcEvent;
1210 use crate::sel;
1211
1212 let alloc = LocalAllocator
1213 .allocate(AllocSpec {
1214 extent: extent!(replica = 1),
1215 constraints: Default::default(),
1216 proc_name: None,
1217 transport: ChannelTransport::Local,
1218 })
1219 .await
1220 .unwrap();
1221
1222 let stop = alloc.stopper();
1223 let mut mesh = ProcMesh::allocate(alloc).await.unwrap();
1224 let mut events = mesh.events().unwrap();
1225
1226 let actor_mesh = mesh
1227 .spawn::<TestActor>("reply-then-fail", &())
1228 .await
1229 .unwrap();
1230
1231 let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
1234 actor_mesh
1235 .cast(mesh.client(), sel!(*), GetRank(false, reply_handle.bind()))
1236 .unwrap();
1237 let rank = reply_receiver.recv().await.unwrap();
1238 assert_eq!(rank, 0);
1239
1240 assert_matches!(
1242 events.next().await.unwrap(),
1243 ProcEvent::Crashed(0, reason) if reason.contains("intentional error!")
1244 );
1245
1246 let (reply_handle, _) = actor_mesh.open_port();
1248 actor_mesh
1249 .cast(mesh.client(), sel!(*), GetRank(false, reply_handle.bind()))
1250 .unwrap();
1251
1252 assert_matches!(
1254 events.next().await.unwrap(),
1255 ProcEvent::Crashed(0, reason) if reason.contains("failed: message not delivered")
1256 );
1257
1258 stop();
1260 assert_matches!(
1261 events.next().await.unwrap(),
1262 ProcEvent::Stopped(0, ProcStopReason::Stopped),
1263 );
1264 assert!(events.next().await.is_none());
1265 }
1266
1267 #[tracing_test::traced_test]
1268 #[tokio::test]
1269 async fn test_stop_actor_mesh() {
1270 use hyperactor::test_utils::pingpong::PingPongActor;
1271 use hyperactor::test_utils::pingpong::PingPongActorParams;
1272 use hyperactor::test_utils::pingpong::PingPongMessage;
1273
1274 let config = hyperactor::config::global::lock();
1275 let _guard = config.override_key(
1276 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1277 tokio::time::Duration::from_secs(1),
1278 );
1279
1280 let alloc = LocalAllocator
1281 .allocate(AllocSpec {
1282 extent: extent!(replica = 2),
1283 constraints: Default::default(),
1284 proc_name: None,
1285 transport: ChannelTransport::Local,
1286 })
1287 .await
1288 .unwrap();
1289 let mesh = ProcMesh::allocate(alloc).await.unwrap();
1290
1291 let ping_pong_actor_params = PingPongActorParams::new(
1292 Some(PortRef::attest_message_port(mesh.client().self_id())),
1293 None,
1294 );
1295 let mesh_one: RootActorMesh<PingPongActor> = mesh
1296 .spawn::<PingPongActor>("mesh_one", &ping_pong_actor_params)
1297 .await
1298 .unwrap();
1299
1300 let mesh_two: RootActorMesh<PingPongActor> = mesh
1301 .spawn::<PingPongActor>("mesh_two", &ping_pong_actor_params)
1302 .await
1303 .unwrap();
1304
1305 mesh_two.stop().await.unwrap();
1306
1307 let ping_two: ActorRef<PingPongActor> = mesh_two.get(0).unwrap();
1308 let pong_two: ActorRef<PingPongActor> = mesh_two.get(1).unwrap();
1309
1310 assert!(logs_contain(&format!(
1311 "stopped actor {}",
1312 ping_two.actor_id()
1313 )));
1314 assert!(logs_contain(&format!(
1315 "stopped actor {}",
1316 pong_two.actor_id()
1317 )));
1318
1319 let ping_one: ActorRef<PingPongActor> = mesh_one.get(0).unwrap();
1321 let pong_one: ActorRef<PingPongActor> = mesh_one.get(1).unwrap();
1322 let (done_tx, done_rx) = mesh.client().open_once_port();
1323 pong_one
1324 .send(
1325 mesh.client(),
1326 PingPongMessage(1, ping_one.clone(), done_tx.bind()),
1327 )
1328 .unwrap();
1329 assert!(done_rx.recv().await.is_ok());
1330 }
1331 } mod process {
1334
1335 use bytes::Bytes;
1336 use hyperactor::PortId;
1337 use hyperactor::channel::ChannelTransport;
1338 use hyperactor::clock::Clock;
1339 use hyperactor::clock::RealClock;
1340 use hyperactor::mailbox::MessageEnvelope;
1341 use rand::Rng;
1342 use tokio::process::Command;
1343
1344 use crate::alloc::process::ProcessAllocator;
1345
1346 fn process_allocator() -> ProcessAllocator {
1347 ProcessAllocator::new(Command::new(crate::testresource::get(
1348 "monarch/hyperactor_mesh/bootstrap",
1349 )))
1350 }
1351
1352 #[cfg(fbcode_build)] actor_mesh_test_suite!(process_allocator());
1354
1355 #[cfg(fbcode_build)]
1358 #[async_timed_test(timeout_secs = 30)]
1360 async fn test_oversized_frames() {
1361 #[derive(Debug, Serialize, Deserialize, PartialEq)]
1363 enum Frame<M> {
1364 Init(u64),
1365 Message(u64, M),
1366 }
1367 fn frame_length(src: &ActorId, dst: &PortId, pay: &Payload) -> usize {
1369 let serialized = Serialized::serialize(pay).unwrap();
1370 let mut headers = Attrs::new();
1371 hyperactor::mailbox::headers::set_send_timestamp(&mut headers);
1372 let envelope = MessageEnvelope::new(src.clone(), dst.clone(), serialized, headers);
1373 let frame = Frame::Message(0u64, envelope);
1374 let message = serde_multipart::serialize_illegal_bincode(&frame).unwrap();
1375 message.frame_len()
1376 }
1377
1378 let config = hyperactor::config::global::lock();
1380 let _guard2 =
1382 config.override_key(hyperactor::config::CODEC_MAX_FRAME_LENGTH, 1024usize);
1383 unsafe {
1386 std::env::set_var("HYPERACTOR_CODEC_MAX_FRAME_LENGTH", "1024");
1387 };
1388 let _guard3 =
1389 config.override_key(hyperactor::config::DEFAULT_ENCODING, Encoding::Bincode);
1390 let _guard4 = config.override_key(hyperactor::config::CHANNEL_MULTIPART, false);
1391
1392 let alloc = process_allocator()
1393 .allocate(AllocSpec {
1394 extent: extent!(replica = 1),
1395 constraints: Default::default(),
1396 proc_name: None,
1397 transport: ChannelTransport::Unix,
1398 })
1399 .await
1400 .unwrap();
1401 let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1402 let mut proc_events = proc_mesh.events().unwrap();
1403 let actor_mesh: RootActorMesh<TestActor> =
1404 proc_mesh.spawn("ingest", &()).await.unwrap();
1405 let (reply_handle, mut reply_receiver) = actor_mesh.open_port();
1406 let dest = actor_mesh.get(0).unwrap();
1407
1408 let payload = Payload {
1410 part: Part::from(Bytes::from(vec![0u8; 698])),
1411 reply_port: reply_handle.bind(),
1412 };
1413 let frame_len = frame_length(
1414 proc_mesh.client().self_id(),
1415 dest.port::<Payload>().port_id(),
1416 &payload,
1417 );
1418 assert_eq!(frame_len, 1024);
1419
1420 dest.send(proc_mesh.client(), payload).unwrap();
1422 #[allow(clippy::disallowed_methods)]
1423 let result = RealClock
1424 .timeout(Duration::from_secs(2), reply_receiver.recv())
1425 .await;
1426 assert!(result.is_ok(), "Operation should not time out");
1427
1428 let payload = Payload {
1430 part: Part::from(Bytes::from(vec![0u8; 699])),
1431 reply_port: reply_handle.bind(),
1432 };
1433 let frame_len = frame_length(
1434 proc_mesh.client().self_id(),
1435 dest.port::<Payload>().port_id(),
1436 &payload,
1437 );
1438 assert_eq!(frame_len, 1025); if rand::thread_rng().gen_bool(0.5) {
1443 dest.send(proc_mesh.client(), payload).unwrap();
1444 } else {
1445 actor_mesh
1446 .cast(proc_mesh.client(), sel!(*), payload)
1447 .unwrap();
1448 }
1449
1450 {
1453 let event = proc_events.next().await.unwrap();
1454 assert_matches!(
1455 event,
1456 ProcEvent::Crashed(_, _),
1457 "Should have received crash event"
1458 );
1459 }
1460 }
1461
1462 #[cfg(fbcode_build)]
1466 #[tokio::test]
1467 async fn test_router_undeliverable_return() {
1468 use ndslice::extent;
1471
1472 use super::test_util::*;
1473 use crate::alloc::AllocSpec;
1474 use crate::alloc::Allocator;
1475
1476 let alloc = process_allocator()
1477 .allocate(AllocSpec {
1478 extent: extent! { replica = 1 },
1479 constraints: Default::default(),
1480 proc_name: None,
1481 transport: ChannelTransport::Unix,
1482 })
1483 .await
1484 .unwrap();
1485
1486 unsafe { std::env::set_var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK", "1") };
1488
1489 let mut proc_mesh = ProcMesh::allocate(alloc).await.unwrap();
1490 let mut proc_events = proc_mesh.events().unwrap();
1491 let mut actor_mesh: RootActorMesh<'_, ProxyActor> =
1492 { proc_mesh.spawn("proxy", &()).await.unwrap() };
1493 let mut actor_events = actor_mesh.events().unwrap();
1494
1495 let proxy_actor = actor_mesh.get(0).unwrap();
1496 let (tx, mut rx) = actor_mesh.open_port::<String>();
1497 proxy_actor
1498 .send(proc_mesh.client(), Echo("hello!".to_owned(), tx.bind()))
1499 .unwrap();
1500
1501 #[allow(clippy::disallowed_methods)]
1502 match tokio::time::timeout(tokio::time::Duration::from_secs(3), rx.recv()).await {
1503 Ok(_) => panic!("the impossible happened"),
1504 Err(_) => {
1505 assert_matches!(
1506 proc_events.next().await.unwrap(),
1507 ProcEvent::Crashed(0, reason) if reason.contains("undeliverable")
1508 );
1509 assert_eq!(
1510 actor_events.next().await.unwrap().actor_id.name(),
1511 &actor_mesh.name
1512 );
1513 }
1514 }
1515
1516 unsafe { std::env::remove_var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK") };
1518 }
1519 }
1520
1521 mod sim {
1522 use crate::alloc::sim::SimAllocator;
1523
1524 actor_mesh_test_suite!(SimAllocator::new_and_start_simnet());
1525 }
1526
1527 mod reshape_cast {
1528 use async_trait::async_trait;
1529 use hyperactor::Actor;
1530 use hyperactor::Context;
1531 use hyperactor::Handler;
1532 use hyperactor::channel::ChannelAddr;
1533 use hyperactor::channel::ChannelTransport;
1534 use hyperactor::channel::ChannelTx;
1535 use hyperactor::channel::Rx;
1536 use hyperactor::channel::Tx;
1537 use hyperactor::channel::dial;
1538 use hyperactor::channel::serve;
1539 use hyperactor::clock::Clock;
1540 use hyperactor::clock::RealClock;
1541 use ndslice::Selection;
1542
1543 use crate::Mesh;
1544 use crate::ProcMesh;
1545 use crate::RootActorMesh;
1546 use crate::actor_mesh::ActorMesh;
1547 use crate::alloc::AllocSpec;
1548 use crate::alloc::Allocator;
1549 use crate::alloc::LocalAllocator;
1550 use crate::config::MAX_CAST_DIMENSION_SIZE;
1551
1552 #[derive(Debug)]
1553 #[hyperactor::export(
1554 spawn = true,
1555 handlers = [() { cast = true }],
1556 )]
1557 struct EchoActor(ChannelTx<usize>);
1558
1559 #[async_trait]
1560 impl Actor for EchoActor {
1561 type Params = ChannelAddr;
1562
1563 async fn new(params: ChannelAddr) -> Result<Self, anyhow::Error> {
1564 Ok(Self(dial::<usize>(params)?))
1565 }
1566 }
1567
1568 #[async_trait]
1569 impl Handler<()> for EchoActor {
1570 async fn handle(
1571 &mut self,
1572 cx: &Context<Self>,
1573 _message: (),
1574 ) -> Result<(), anyhow::Error> {
1575 let Self(port) = self;
1576 port.post(cx.self_id().rank());
1577 Ok(())
1578 }
1579 }
1580
1581 async fn validate_cast<A>(
1582 actor_mesh: &A,
1583 caps: &impl hyperactor::context::Actor,
1584 addr: ChannelAddr,
1585 selection: Selection,
1586 ) where
1587 A: ActorMesh<Actor = EchoActor>,
1588 {
1589 let config = hyperactor::config::global::lock();
1590 let _guard = config.override_key(MAX_CAST_DIMENSION_SIZE, 2);
1591
1592 let (_, mut rx) = serve::<usize>(addr).unwrap();
1593
1594 let expected_ranks = selection
1595 .eval(
1596 &ndslice::selection::EvalOpts::strict(),
1597 actor_mesh.shape().slice(),
1598 )
1599 .unwrap()
1600 .collect::<std::collections::BTreeSet<_>>();
1601
1602 actor_mesh.cast(caps, selection, ()).unwrap();
1603
1604 let mut received = std::collections::BTreeSet::new();
1605
1606 for _ in 0..(expected_ranks.len()) {
1607 received.insert(
1608 RealClock
1609 .timeout(tokio::time::Duration::from_secs(1), rx.recv())
1610 .await
1611 .unwrap()
1612 .unwrap(),
1613 );
1614 }
1615
1616 assert_eq!(received, expected_ranks);
1617 }
1618
1619 use ndslice::strategy::gen_extent;
1620 use ndslice::strategy::gen_selection;
1621 use proptest::prelude::*;
1622 use proptest::test_runner::TestRunner;
1623
1624 fn make_tokio_runtime() -> tokio::runtime::Runtime {
1625 tokio::runtime::Builder::new_multi_thread()
1626 .enable_all()
1627 .worker_threads(2)
1628 .build()
1629 .unwrap()
1630 }
1631
1632 proptest! {
1633 #![proptest_config(ProptestConfig {
1634 cases: 8, ..ProptestConfig::default()
1635 })]
1636 #[test]
1637 fn test_reshaped_actor_mesh_cast(extent in gen_extent(1..=4, 8)) {
1638 let runtime = make_tokio_runtime();
1639 let alloc = runtime.block_on(LocalAllocator
1640 .allocate(AllocSpec {
1641 extent,
1642 constraints: Default::default(),
1643 proc_name: None,
1644 transport: ChannelTransport::Local
1645 }))
1646 .unwrap();
1647 let proc_mesh = runtime.block_on(ProcMesh::allocate(alloc)).unwrap();
1648
1649 let addr = ChannelAddr::any(ChannelTransport::Unix);
1650
1651 let actor_mesh: RootActorMesh<EchoActor> =
1652 runtime.block_on(proc_mesh.spawn("echo", &addr)).unwrap();
1653
1654 let mut runner = TestRunner::default();
1655 let selection = gen_selection(4, actor_mesh.shape().slice().sizes().to_vec(), 0)
1656 .new_tree(&mut runner)
1657 .unwrap()
1658 .current();
1659
1660 runtime.block_on(validate_cast(&actor_mesh, actor_mesh.proc_mesh().client(), addr, selection));
1661 }
1662 }
1663
1664 proptest! {
1665 #![proptest_config(ProptestConfig {
1666 cases: 8, ..ProptestConfig::default()
1667 })]
1668 #[test]
1669 fn test_reshaped_actor_mesh_slice_cast(extent in gen_extent(1..=4, 8)) {
1670 let runtime = make_tokio_runtime();
1671 let alloc = runtime.block_on(LocalAllocator
1672 .allocate(AllocSpec {
1673 extent: extent.clone(),
1674 constraints: Default::default(),
1675 proc_name: None,
1676 transport: ChannelTransport::Local
1677 }))
1678 .unwrap();
1679 let proc_mesh = runtime.block_on(ProcMesh::allocate(alloc)).unwrap();
1680
1681 let addr = ChannelAddr::any(ChannelTransport::Unix);
1682
1683 let actor_mesh: RootActorMesh<EchoActor> =
1684 runtime.block_on(proc_mesh.spawn("echo", &addr)).unwrap();
1685
1686
1687 let first_label = extent.labels().first().unwrap();
1688 let slice = actor_mesh.select(first_label, 0..extent.size(first_label).unwrap()).unwrap();
1689
1690 let slice = if extent.len() >= 2 {
1692 let label = &extent.labels()[1];
1693 let size = extent.size(label).unwrap();
1694 let start = if size > 1 { 1 } else { 0 };
1695 let end = (if size > 1 { size - 1 } else { 1 }).max(start + 1);
1696 slice.select(label, start..end).unwrap()
1697 } else {
1698 slice
1699 };
1700
1701 let slice = if extent.len() >= 3 {
1702 let label = &extent.labels()[2];
1703 let size = extent.size(label).unwrap();
1704 let start = if size > 1 { 1 } else { 0 };
1705 let end = (if size > 1 { size - 1 } else { 1 }).max(start + 1);
1706 slice.select(label, start..end).unwrap()
1707 } else {
1708 slice
1709 };
1710
1711 let slice = if extent.len() >= 4 {
1712 let label = &extent.labels()[3];
1713 let size = extent.size(label).unwrap();
1714 let start = if size > 1 { 1 } else { 0 };
1715 let end = (if size > 1 { size - 1 } else { 1 }).max(start + 1);
1716 slice.select(label, start..end).unwrap()
1717 } else {
1718 slice
1719 };
1720
1721
1722 let mut runner = TestRunner::default();
1723 let selection = gen_selection(4, slice.shape().slice().sizes().to_vec(), 0)
1724 .new_tree(&mut runner)
1725 .unwrap()
1726 .current();
1727
1728 runtime.block_on(validate_cast(
1729 &slice,
1730 actor_mesh.proc_mesh().client(),
1731 addr,
1732 selection
1733 ));
1734 }
1735 }
1736
1737 proptest! {
1738 #![proptest_config(ProptestConfig {
1739 cases: 8, ..ProptestConfig::default()
1740 })]
1741 #[test]
1742 fn test_reshaped_actor_mesh_cast_with_selection(extent in gen_extent(1..=4, 8)) {
1743 let runtime = make_tokio_runtime();
1744 let alloc = runtime.block_on(LocalAllocator
1745 .allocate(AllocSpec {
1746 extent,
1747 constraints: Default::default(),
1748 proc_name: None,
1749 transport: ChannelTransport::Local
1750 }))
1751 .unwrap();
1752 let proc_mesh = runtime.block_on(ProcMesh::allocate(alloc)).unwrap();
1753
1754 let addr = ChannelAddr::any(ChannelTransport::Unix);
1755
1756 let actor_mesh: RootActorMesh<EchoActor> =
1757 runtime.block_on(proc_mesh.spawn("echo", &addr)).unwrap();
1758
1759 let mut runner = TestRunner::default();
1760 let selection = gen_selection(4, actor_mesh.shape().slice().sizes().to_vec(), 0)
1761 .new_tree(&mut runner)
1762 .unwrap()
1763 .current();
1764
1765 runtime.block_on(validate_cast(
1766 &actor_mesh,
1767 actor_mesh.proc_mesh().client(),
1768 addr,
1769 selection
1770 ));
1771 }
1772 }
1773 }
1774}