1use crate::comm::multicast::CAST_ORIGINATING_SENDER;
10use crate::reference::ActorMeshId;
11use crate::resource;
12pub mod multicast;
13
14use std::cmp::Ordering;
15use std::collections::HashMap;
16use std::fmt::Debug;
17
18use anyhow::Result;
19use async_trait::async_trait;
20use hyperactor::Actor;
21use hyperactor::ActorId;
22use hyperactor::ActorRef;
23use hyperactor::Context;
24use hyperactor::Handler;
25use hyperactor::Instance;
26use hyperactor::PortRef;
27use hyperactor::WorldId;
28use hyperactor::mailbox::DeliveryError;
29use hyperactor::mailbox::MailboxSender;
30use hyperactor::mailbox::Undeliverable;
31use hyperactor::mailbox::UndeliverableMailboxSender;
32use hyperactor::mailbox::UndeliverableMessageError;
33use hyperactor::mailbox::monitored_return_handle;
34use hyperactor::reference::UnboundPort;
35use ndslice::selection::routing::RoutingFrame;
36use serde::Deserialize;
37use serde::Serialize;
38use typeuri::Named;
39
40use crate::comm::multicast::CastMessage;
41use crate::comm::multicast::CastMessageEnvelope;
42use crate::comm::multicast::ForwardMessage;
43use crate::comm::multicast::set_cast_info_on_headers;
44
45#[derive(Debug, Clone, Serialize, Deserialize, Named, Default)]
47pub struct CommActorParams {}
48wirevalue::register_type!(CommActorParams);
49
50#[derive(Debug)]
52struct Buffered {
53 seq: usize,
55 deliver_here: bool,
57 next_steps: HashMap<usize, Vec<RoutingFrame>>,
59 message: CastMessageEnvelope,
61}
62
63#[derive(Debug, Default)]
66struct ReceiveState {
67 seq: usize,
69 buffer: HashMap<usize, Buffered>,
72 last_seqs: HashMap<usize, usize>,
74}
75
76#[derive(Debug, Default)]
79#[hyperactor::export(
80 spawn = true,
81 handlers = [
82 CommActorMode,
83 CastMessage,
84 ForwardMessage,
85 ],
86)]
87pub struct CommActor {
88 send_seq: HashMap<(ActorMeshId, ActorId), usize>,
90 recv_state: HashMap<(ActorMeshId, ActorId), ReceiveState>,
92
93 mode: CommActorMode,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, Named)]
102#[derive(Default)]
103pub enum CommActorMode {
104 Mesh(usize, HashMap<usize, ActorRef<CommActor>>),
107
108 #[default]
111 Implicit,
112
113 ImplicitWithWorldId(WorldId),
119}
120wirevalue::register_type!(CommActorMode);
121
122impl CommActorMode {
123 fn peer_for_rank(&self, self_id: &ActorId, rank: usize) -> Result<ActorRef<CommActor>> {
126 match self {
127 Self::Mesh(_self_rank, peers) => peers
128 .get(&rank)
129 .cloned()
130 .ok_or_else(|| anyhow::anyhow!("no peer for rank {}", rank)),
131 Self::Implicit => {
132 let world_id = self_id
133 .proc_id()
134 .world_id()
135 .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc"))?;
136 let proc_id = world_id.proc_id(rank);
137 let actor_id = ActorId::root(proc_id, self_id.name().to_string());
138 Ok(ActorRef::<CommActor>::attest(actor_id))
139 }
140 Self::ImplicitWithWorldId(world_id) => {
141 let proc_id = world_id.proc_id(rank);
142 let actor_id = ActorId::root(proc_id, self_id.name().to_string());
143 Ok(ActorRef::<CommActor>::attest(actor_id))
144 }
145 }
146 }
147
148 fn self_rank(&self, self_id: &ActorId) -> Result<usize> {
150 match self {
151 Self::Mesh(rank, _) => Ok(*rank),
152 Self::Implicit | Self::ImplicitWithWorldId(_) => self_id
153 .proc_id()
154 .rank()
155 .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc")),
156 }
157 }
158}
159
160#[async_trait]
161impl Actor for CommActor {
162 async fn handle_undeliverable_message(
164 &mut self,
165 cx: &Instance<Self>,
166 undelivered: hyperactor::mailbox::Undeliverable<hyperactor::mailbox::MessageEnvelope>,
167 ) -> Result<(), anyhow::Error> {
168 let Undeliverable(mut message_envelope) = undelivered;
169
170 if let Ok(ForwardMessage { message, .. }) =
172 message_envelope.deserialized::<ForwardMessage>()
173 {
174 let sender = message.sender();
175 let return_port = PortRef::attest_message_port(sender);
176 message_envelope.set_error(DeliveryError::Multicast(format!(
177 "comm actor {} failed to forward the cast message; return to \
178 its original sender's port {}",
179 cx.self_id(),
180 return_port.port_id(),
181 )));
182 return_port
183 .send(cx, Undeliverable(message_envelope.clone()))
184 .map_err(|err| {
185 let error = DeliveryError::BrokenLink(format!(
186 "error occured when returning ForwardMessage to the original \
187 sender's port {}; error is: {}",
188 return_port.port_id(),
189 err,
190 ));
191 message_envelope.set_error(error);
192 UndeliverableMessageError::ReturnFailure {
193 envelope: message_envelope,
194 }
195 })?;
196 return Ok(());
197 }
198
199 if let Some(sender) = message_envelope.headers().get(CAST_ORIGINATING_SENDER) {
201 let return_port = PortRef::attest_message_port(sender);
202 message_envelope.set_error(DeliveryError::Multicast(format!(
203 "comm actor {} failed to deliver the cast message to the dest \
204 actor; return to its original sender's port {}",
205 cx.self_id(),
206 return_port.port_id(),
207 )));
208 return_port
209 .send(cx, Undeliverable(message_envelope.clone()))
210 .map_err(|err| {
211 let error = DeliveryError::BrokenLink(format!(
212 "error occured when returning cast message to the original \
213 sender's port {}; error is: {}",
214 return_port.port_id(),
215 err,
216 ));
217 message_envelope.set_error(error);
218 UndeliverableMessageError::ReturnFailure {
219 envelope: message_envelope,
220 }
221 })?;
222 return Ok(());
223 }
224
225 UndeliverableMailboxSender
227 .post(message_envelope, monitored_return_handle());
228 Ok(())
229 }
230}
231
232impl CommActor {
233 fn forward(
235 cx: &Instance<Self>,
236 mode: &CommActorMode,
237 rank: usize,
238 message: ForwardMessage,
239 ) -> Result<()> {
240 let child = mode.peer_for_rank(cx.self_id(), rank)?;
241 child.send(cx, message)?;
242 Ok(())
243 }
244
245 fn handle_message(
246 cx: &Context<Self>,
247 mode: &CommActorMode,
248 deliver_here: bool,
249 next_steps: HashMap<usize, Vec<RoutingFrame>>,
250 sender: ActorId,
251 mut message: CastMessageEnvelope,
252 seq: usize,
253 last_seqs: &mut HashMap<usize, usize>,
254 ) -> Result<()> {
255 message.data_mut().visit_mut::<UnboundPort>(
259 |UnboundPort(port_id, reducer_spec, reducer_opts, return_undeliverable)| {
260 let split = port_id.split(
261 cx,
262 reducer_spec.clone(),
263 reducer_opts.clone(),
264 *return_undeliverable,
265 )?;
266
267 #[cfg(test)]
268 tests::collect_split_port(port_id, &split, deliver_here);
269
270 *port_id = split;
271 Ok(())
272 },
273 )?;
274
275 if deliver_here {
277 let rank_on_root_mesh = mode.self_rank(cx.self_id())?;
278 let cast_rank = message.relative_rank(rank_on_root_mesh)?;
279 message
281 .data_mut()
282 .visit_mut::<resource::Rank>(|resource::Rank(rank)| {
283 *rank = Some(cast_rank);
284 Ok(())
285 })?;
286 let cast_shape = message.shape();
287 let point = cast_shape
288 .extent()
289 .point_of_rank(cast_rank)
290 .expect("rank out of bounds");
291 let mut headers = cx.headers().clone();
292 set_cast_info_on_headers(&mut headers, point, message.sender().clone());
293 cx.post(
294 cx.self_id()
295 .proc_id()
296 .actor_id(message.dest_port().actor_name(), 0)
297 .port_id(message.dest_port().port()),
298 headers,
299 wirevalue::Any::serialize(message.data())?,
300 );
301 }
302
303 next_steps
305 .into_iter()
306 .map(|(peer, dests)| {
307 let last_seq = last_seqs.entry(peer).or_default();
308 Self::forward(
309 cx,
310 mode,
311 peer,
312 ForwardMessage {
313 dests,
314 sender: sender.clone(),
315 message: message.clone(),
316 seq,
317 last_seq: *last_seq,
318 },
319 )?;
320 *last_seq = seq;
321 Ok(())
322 })
323 .collect::<Result<Vec<_>>>()?;
324
325 Ok(())
326 }
327}
328
329#[async_trait]
330impl Handler<CommActorMode> for CommActor {
331 async fn handle(&mut self, _cx: &Context<Self>, mode: CommActorMode) -> Result<()> {
332 self.mode = mode;
333 Ok(())
334 }
335}
336
337#[async_trait]
339impl Handler<CastMessage> for CommActor {
340 #[hyperactor::instrument]
341 async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessage) -> Result<()> {
342 let slice = cast_message.dest.slice.clone();
344 let selection = cast_message.dest.selection.clone();
345 let frame = RoutingFrame::root(selection, slice);
346 let rank = frame.slice.location(&frame.here)?;
347 let seq = self
348 .send_seq
349 .entry(cast_message.message.stream_key())
350 .or_default();
351 let last_seq = *seq;
352 *seq += 1;
353
354 let fwd_message = ForwardMessage {
355 dests: vec![frame],
356 sender: cx.self_id().clone(),
357 message: cast_message.message,
358 seq: *seq,
359 last_seq,
360 };
361
362 if self
365 .mode
366 .self_rank(cx.self_id())
367 .is_ok_and(|self_rank| self_rank == rank)
368 {
369 Handler::<ForwardMessage>::handle(self, cx, fwd_message).await?;
370 } else {
371 Self::forward(cx, &self.mode, rank, fwd_message)?;
372 }
373 Ok(())
374 }
375}
376
377#[async_trait]
378impl Handler<ForwardMessage> for CommActor {
379 #[hyperactor::instrument]
380 async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessage) -> Result<()> {
381 let ForwardMessage {
382 sender,
383 dests,
384 message,
385 seq,
386 last_seq,
387 } = fwd_message;
388
389 let rank = self.mode.self_rank(cx.self_id())?;
391 let (deliver_here, next_steps) =
392 ndslice::selection::routing::resolve_routing(rank, dests, &mut |_| {
393 panic!("Choice encountered in CommActor routing")
394 })?;
395
396 let recv_state = self.recv_state.entry(message.stream_key()).or_default();
397 match recv_state.seq.cmp(&last_seq) {
398 Ordering::Equal => {
400 Self::handle_message(
402 cx,
403 &self.mode,
404 deliver_here,
405 next_steps,
406 sender.clone(),
407 message,
408 seq,
409 &mut recv_state.last_seqs,
410 )?;
411 recv_state.seq = seq;
412
413 while let Some(Buffered {
416 seq,
417 deliver_here,
418 next_steps,
419 message,
420 }) = recv_state.buffer.remove(&recv_state.seq)
421 {
422 Self::handle_message(
423 cx,
424 &self.mode,
425 deliver_here,
426 next_steps,
427 sender.clone(),
428 message,
429 seq,
430 &mut recv_state.last_seqs,
431 )?;
432 recv_state.seq = seq;
433 }
434 }
435 Ordering::Less => {
438 tracing::warn!(
439 "buffering out-of-order message with seq {} (last {}), expected {}: {:?}",
440 seq,
441 last_seq,
442 recv_state.seq,
443 message
444 );
445 recv_state.buffer.insert(
446 last_seq,
447 Buffered {
448 seq,
449 deliver_here,
450 next_steps,
451 message,
452 },
453 );
454 }
455 Ordering::Greater => {
457 tracing::warn!("received duplicate message with seq {}: {:?}", seq, message);
458 }
459 }
460
461 Ok(())
462 }
463}
464
465pub mod test_utils {
466 use anyhow::Result;
467 use async_trait::async_trait;
468 use hyperactor::Actor;
469 use hyperactor::ActorId;
470 use hyperactor::Bind;
471 use hyperactor::Context;
472 use hyperactor::Handler;
473 use hyperactor::PortRef;
474 use hyperactor::Unbind;
475 use serde::Deserialize;
476 use serde::Serialize;
477 use typeuri::Named;
478
479 use super::*;
480
481 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
482 pub struct MyReply {
483 pub sender: ActorId,
484 pub value: u64,
485 }
486
487 #[derive(Debug, Named, Serialize, Deserialize, PartialEq, Clone, Bind, Unbind)]
488 pub enum TestMessage {
489 Forward(String),
490 CastAndReply {
491 arg: String,
492 reply_to0: PortRef<String>,
496 #[binding(include)]
497 reply_to1: PortRef<u64>,
498 #[binding(include)]
499 reply_to2: PortRef<MyReply>,
500 },
501 }
502
503 #[derive(Debug)]
504 #[hyperactor::export(
505 spawn = true,
506 handlers = [
507 TestMessage { cast = true },
508 ],
509 )]
510 pub struct TestActor {
511 forward_port: PortRef<TestMessage>,
514 }
515
516 #[derive(Debug, Clone, Named, Serialize, Deserialize)]
517 pub struct TestActorParams {
518 pub forward_port: PortRef<TestMessage>,
519 }
520
521 #[async_trait]
522 impl Actor for TestActor {}
523
524 #[async_trait]
525 impl hyperactor::RemoteSpawn for TestActor {
526 type Params = TestActorParams;
527
528 async fn new(params: Self::Params) -> Result<Self> {
529 let Self::Params { forward_port } = params;
530 Ok(Self { forward_port })
531 }
532 }
533
534 #[async_trait]
535 impl Handler<TestMessage> for TestActor {
536 async fn handle(&mut self, cx: &Context<Self>, msg: TestMessage) -> anyhow::Result<()> {
537 self.forward_port.send(cx, msg)?;
538 Ok(())
539 }
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use std::collections::BTreeMap;
546 use std::collections::HashSet;
547 use std::fmt::Display;
548 use std::hash::Hash;
549 use std::ops::Deref;
550 use std::ops::DerefMut;
551 use std::sync::Arc;
552 use std::sync::Mutex;
553 use std::sync::OnceLock;
554
555 use hyperactor::PortId;
556 use hyperactor::PortRef;
557 use hyperactor::accum;
558 use hyperactor::accum::Accumulator;
559 use hyperactor::accum::ReducerSpec;
560 use hyperactor::channel::ChannelTransport;
561 use hyperactor::clock::Clock;
562 use hyperactor::clock::RealClock;
563 use hyperactor::context;
564 use hyperactor::context::Mailbox;
565 use hyperactor::mailbox::PortReceiver;
566 use hyperactor::mailbox::open_port;
567 use hyperactor::reference::Index;
568 use hyperactor_config;
569 use hyperactor_mesh_macros::sel;
570 use maplit::btreemap;
571 use maplit::hashmap;
572 use ndslice::Extent;
573 use ndslice::Selection;
574 use ndslice::ViewExt as _;
575 use ndslice::extent;
576 use ndslice::selection::test_utils::collect_commactor_routing_tree;
577 use test_utils::*;
578 use timed_test::async_timed_test;
579 use tokio::time::Duration;
580
581 use super::*;
582 use crate::ProcMesh;
583 use crate::actor_mesh::ActorMesh;
584 use crate::actor_mesh::RootActorMesh;
585 use crate::alloc::AllocSpec;
586 use crate::alloc::Allocator;
587 use crate::alloc::LocalAllocator;
588 use crate::proc_mesh::SharedSpawnable;
589 use crate::v1;
590 use crate::v1::testing;
591
592 struct Edge<T> {
593 from: T,
594 to: T,
595 is_leaf: bool,
596 }
597
598 impl<T> From<(T, T, bool)> for Edge<T> {
599 fn from((from, to, is_leaf): (T, T, bool)) -> Self {
600 Self { from, to, is_leaf }
601 }
602 }
603
604 static SPLIT_PORT_TREE: OnceLock<Mutex<Vec<Edge<PortId>>>> = OnceLock::new();
607
608 pub(crate) fn collect_split_port(original: &PortId, split: &PortId, deliver_here: bool) {
611 let mutex = SPLIT_PORT_TREE.get_or_init(|| Mutex::new(vec![]));
612 let mut tree = mutex.lock().unwrap();
613
614 tree.deref_mut().push(Edge {
615 from: original.clone(),
616 to: split.clone(),
617 is_leaf: deliver_here,
618 });
619 }
620
621 fn clear_collected_tree() {
625 if let Some(tree) = SPLIT_PORT_TREE.get() {
626 let mut tree = tree.lock().unwrap();
627 tree.clear();
628 }
629 }
630
631 #[derive(PartialEq)]
635 struct PathToLeaves<T>(BTreeMap<T, Vec<T>>);
636
637 impl<T: Display> Debug for PathToLeaves<T> {
639 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
640 fn vec_to_string<T: Display>(v: &[T]) -> String {
641 v.iter()
642 .map(ToString::to_string)
643 .collect::<Vec<String>>()
644 .join(", ")
645 }
646
647 for (src, path) in &self.0 {
648 write!(f, "{} -> {}\n", src, vec_to_string(path))?;
649 }
650 Ok(())
651 }
652 }
653
654 fn build_paths<T: Clone + Eq + Hash + Ord>(edges: &[Edge<T>]) -> PathToLeaves<T> {
655 let mut child_parent_map = HashMap::new();
656 let mut all_nodes = HashSet::new();
657 let mut parents = HashSet::new();
658 let mut children = HashSet::new();
659 let mut dests = HashSet::new();
660
661 for Edge { from, to, is_leaf } in edges {
663 child_parent_map.insert(to.clone(), from.clone());
664 all_nodes.insert(from.clone());
665 all_nodes.insert(to.clone());
666 parents.insert(from.clone());
667 children.insert(to.clone());
668 if *is_leaf {
669 dests.insert(to.clone());
670 }
671 }
672
673 let mut result = BTreeMap::new();
675 for dest in dests {
676 let mut path = vec![dest.clone()];
677 let mut current = dest.clone();
678 while let Some(parent) = child_parent_map.get(¤t) {
679 path.push(parent.clone());
680 current = parent.clone();
681 }
682 path.reverse();
683 result.insert(dest, path);
684 }
685
686 PathToLeaves(result)
687 }
688
689 #[test]
690 fn test_build_paths() {
691 let edges: Vec<_> = [
698 (0, 1, false),
699 (1, 2, true),
700 (1, 3, true),
701 (0, 4, true),
702 (4, 5, true),
703 ]
704 .into_iter()
705 .map(|(from, to, is_leaf)| Edge { from, to, is_leaf })
706 .collect();
707
708 let paths = build_paths(&edges);
709
710 let expected = btreemap! {
711 2 => vec![0, 1, 2],
712 3 => vec![0, 1, 3],
713 4 => vec![0, 4],
714 5 => vec![0, 4, 5],
715 };
716
717 assert_eq!(paths.0, expected);
718 }
719
720 fn get_ranks(paths: PathToLeaves<PortId>, client_reply: &PortId) -> PathToLeaves<Index> {
740 let ranks = paths
741 .0
742 .into_iter()
743 .map(|(dst, mut path)| {
744 let first = path.remove(0);
745 assert_eq!(&first, client_reply);
747 assert!(dst.actor_id().name().contains("comm"));
750 let actor_path = path
751 .into_iter()
752 .map(|p| {
753 assert!(p.actor_id().name().contains("comm"));
754 p.actor_id().rank()
755 })
756 .collect();
757 (dst.into_actor_id().rank(), actor_path)
758 })
759 .collect();
760 PathToLeaves(ranks)
761 }
762
763 struct MeshSetup {
764 actor_mesh: RootActorMesh<'static, TestActor>,
765 reply1_rx: PortReceiver<u64>,
766 reply2_rx: PortReceiver<MyReply>,
767 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
768 }
769
770 struct NoneAccumulator;
771
772 impl Accumulator for NoneAccumulator {
773 type State = u64;
774 type Update = u64;
775
776 fn accumulate(
777 &self,
778 _state: &mut Self::State,
779 _update: Self::Update,
780 ) -> anyhow::Result<()> {
781 unimplemented!()
782 }
783
784 fn reducer_spec(&self) -> Option<ReducerSpec> {
785 unimplemented!()
786 }
787 }
788
789 fn verify_split_port_paths(
791 selection: &Selection,
792 extent: &Extent,
793 reply_port_ref1: &PortRef<u64>,
794 reply_port_ref2: &PortRef<MyReply>,
795 ) {
796 let sel_paths = PathToLeaves(
798 collect_commactor_routing_tree(selection, &extent.to_slice())
799 .delivered
800 .into_iter()
801 .collect(),
802 );
803
804 let (reply1_paths, reply2_paths) = {
806 let tree = SPLIT_PORT_TREE.get().unwrap();
807 let edges = tree.lock().unwrap();
808 let (reply1, reply2): (BTreeMap<_, _>, BTreeMap<_, _>) = build_paths(&edges)
809 .0
810 .into_iter()
811 .partition(|(_dst, path)| &path[0] == reply_port_ref1.port_id());
812 (
813 get_ranks(PathToLeaves(reply1), reply_port_ref1.port_id()),
814 get_ranks(PathToLeaves(reply2), reply_port_ref2.port_id()),
815 )
816 };
817
818 assert_eq!(sel_paths, reply1_paths);
820 assert_eq!(sel_paths, reply2_paths);
821 }
822
823 async fn setup_mesh<A>(accum: Option<A>) -> MeshSetup
824 where
825 A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
826 {
827 let extent = extent!(replica = 4, host = 4, gpu = 4);
828 let alloc = LocalAllocator
829 .allocate(AllocSpec {
830 extent: extent.clone(),
831 constraints: Default::default(),
832 proc_name: None,
833 transport: ChannelTransport::Local,
834 proc_allocation_mode: Default::default(),
835 })
836 .await
837 .unwrap();
838
839 let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap());
840 let dest_actor_name = "dest_actor";
841 let (tx, mut rx) = hyperactor::mailbox::open_port(proc_mesh.client());
842 let params = TestActorParams {
843 forward_port: tx.bind(),
844 };
845 let instance = crate::v1::testing::instance();
846 let actor_mesh: RootActorMesh<TestActor> = Arc::clone(&proc_mesh)
847 .spawn(&instance, dest_actor_name, ¶ms)
848 .await
849 .unwrap();
850
851 let (reply_port_handle0, _) = open_port::<String>(proc_mesh.client());
852 let reply_port_ref0 = reply_port_handle0.bind();
853 let (reply_port_handle1, reply1_rx) = match accum {
854 Some(a) => proc_mesh.client().mailbox().open_accum_port(a),
855 None => open_port(proc_mesh.client()),
856 };
857 let reply_port_ref1 = reply_port_handle1.bind();
858 let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(proc_mesh.client());
859 let reply_port_ref2 = reply_port_handle2.bind();
860 let message = TestMessage::CastAndReply {
861 arg: "abc".to_string(),
862 reply_to0: reply_port_ref0.clone(),
863 reply_to1: reply_port_ref1.clone(),
864 reply_to2: reply_port_ref2.clone(),
865 };
866
867 let selection = sel!(*);
868 clear_collected_tree();
869 actor_mesh
870 .cast(proc_mesh.client(), selection.clone(), message)
871 .unwrap();
872
873 let mut reply_tos = vec![];
874 for _ in extent.points() {
875 let msg = rx.recv().await.expect("missing");
876 match msg {
877 TestMessage::CastAndReply {
878 arg,
879 reply_to0,
880 reply_to1,
881 reply_to2,
882 } => {
883 assert_eq!(arg, "abc");
884 assert_eq!(reply_to0, reply_port_ref0);
887 assert_ne!(reply_to1, reply_port_ref1);
889 assert_eq!(reply_to1.port_id().actor_id().name(), "comm");
890 assert_ne!(reply_to2, reply_port_ref2);
891 assert_eq!(reply_to2.port_id().actor_id().name(), "comm");
892 reply_tos.push((reply_to1, reply_to2));
893 }
894 _ => {
895 panic!("unexpected message: {:?}", msg);
896 }
897 }
898 }
899
900 verify_split_port_paths(&selection, &extent, &reply_port_ref1, &reply_port_ref2);
901
902 MeshSetup {
903 actor_mesh,
904 reply1_rx,
905 reply2_rx,
906 reply_tos,
907 }
908 }
909
910 async fn execute_cast_and_reply(
911 ranks: Vec<ActorRef<TestActor>>,
912 instance: &impl context::Actor,
913 mut reply1_rx: PortReceiver<u64>,
914 mut reply2_rx: PortReceiver<MyReply>,
915 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
916 ) {
917 {
919 for (dest_actor, (reply_to1, reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
920 let rank = dest_actor.actor_id().rank() as u64;
921 reply_to1.send(instance, rank).unwrap();
922 let my_reply = MyReply {
923 sender: dest_actor.actor_id().clone(),
924 value: rank,
925 };
926 reply_to2.send(instance, my_reply.clone()).unwrap();
927
928 assert_eq!(reply1_rx.recv().await.unwrap(), rank);
929 assert_eq!(reply2_rx.recv().await.unwrap(), my_reply);
930 }
931 }
932
933 tracing::info!("the 1st updates from all dest actors were receivered by client");
934
935 {
939 let n = 100;
940 let mut expected2: HashMap<usize, Vec<MyReply>> = hashmap! {};
941 for (dest_actor, (_reply_to1, reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
942 let rank = dest_actor.actor_id().rank();
943 let mut sent2 = vec![];
944 for i in 0..n {
945 let value = (rank * 100 + i) as u64;
946 let my_reply = MyReply {
947 sender: dest_actor.actor_id().clone(),
948 value,
949 };
950 reply_to2.send(instance, my_reply.clone()).unwrap();
951 sent2.push(my_reply);
952 }
953 assert!(
954 expected2.insert(rank, sent2).is_none(),
955 "duplicate rank {rank} in map"
956 );
957 }
958
959 let mut received2: HashMap<usize, Vec<MyReply>> = hashmap! {};
960
961 for _ in 0..(n * ranks.len()) {
962 let my_reply = reply2_rx.recv().await.unwrap();
963 received2
964 .entry(my_reply.sender.rank())
965 .or_default()
966 .push(my_reply);
967 }
968 assert_eq!(received2, expected2);
969 }
970 }
971
972 #[async_timed_test(timeout_secs = 30)]
973 async fn test_cast_and_reply() {
974 let MeshSetup {
975 actor_mesh,
976 reply1_rx,
977 reply2_rx,
978 reply_tos,
979 ..
980 } = setup_mesh::<NoneAccumulator>(None).await;
981 let proc_mesh_client = actor_mesh.proc_mesh().client();
982
983 let ranks = actor_mesh.ranks().clone();
984 execute_cast_and_reply(ranks, proc_mesh_client, reply1_rx, reply2_rx, reply_tos).await;
985 }
986
987 async fn wait_for_with_timeout(
988 receiver: &mut PortReceiver<u64>,
989 expected: u64,
990 dur: Duration,
991 ) -> anyhow::Result<()> {
992 RealClock
994 .timeout(dur, async {
995 loop {
996 let msg = receiver.recv().await.unwrap();
997 if msg == expected {
998 break;
999 }
1000 }
1001 })
1002 .await?;
1003 Ok(())
1004 }
1005
1006 async fn execute_cast_and_accum(
1007 ranks: Vec<ActorRef<TestActor>>,
1008 instance: &impl context::Actor,
1009 mut reply1_rx: PortReceiver<u64>,
1010 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1011 ) {
1012 let mut sum = 0;
1016 let n = 100;
1017 for (dest_actor, (reply_to1, _reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
1018 let rank = dest_actor.actor_id().rank();
1019 for i in 0..n {
1020 let value = (rank + i) as u64;
1021 reply_to1.send(instance, value).unwrap();
1022 sum += value;
1023 }
1024 }
1025 wait_for_with_timeout(&mut reply1_rx, sum, Duration::from_secs(2))
1026 .await
1027 .unwrap();
1028 RealClock.sleep(Duration::from_secs(2)).await;
1030 let msg = reply1_rx.try_recv().unwrap();
1031 assert_eq!(msg, None);
1032 }
1033
1034 #[async_timed_test(timeout_secs = 30)]
1035 async fn test_cast_and_accum() {
1036 let config = hyperactor_config::global::lock();
1037 let _guard1 = config.override_key(hyperactor::config::SPLIT_MAX_BUFFER_SIZE, 1);
1039
1040 let MeshSetup {
1041 actor_mesh,
1042 reply1_rx,
1043 reply_tos,
1044 ..
1045 } = setup_mesh(Some(accum::sum::<u64>())).await;
1046 let proc_mesh_client = actor_mesh.proc_mesh().client();
1047 let ranks = actor_mesh.ranks().clone();
1048 execute_cast_and_accum(ranks, proc_mesh_client, reply1_rx, reply_tos).await;
1049 }
1050
1051 struct MeshSetupV1 {
1052 instance: &'static Instance<testing::TestRootClient>,
1053 actor_mesh_ref: v1::ActorMeshRef<TestActor>,
1054 reply1_rx: PortReceiver<u64>,
1055 reply2_rx: PortReceiver<MyReply>,
1056 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1057 }
1058
1059 async fn setup_mesh_v1<A>(accum: Option<A>) -> MeshSetupV1
1060 where
1061 A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
1062 {
1063 let instance = v1::testing::instance();
1064
1065 let extent = extent!(replica = 4, host = 4, gpu = 4);
1066 let alloc = LocalAllocator
1067 .allocate(AllocSpec {
1068 extent: extent.clone(),
1069 constraints: Default::default(),
1070 proc_name: None,
1071 transport: ChannelTransport::Local,
1072 proc_allocation_mode: Default::default(),
1073 })
1074 .await
1075 .unwrap();
1076
1077 let proc_mesh = v1::ProcMesh::allocate(instance, Box::new(alloc), "test_local")
1078 .await
1079 .unwrap();
1080
1081 let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1082 let params = TestActorParams {
1083 forward_port: tx.bind(),
1084 };
1085 let actor_name = v1::Name::new("test").expect("valid test name");
1086 let actor_mesh = proc_mesh
1090 .spawn_with_name(&instance, actor_name, ¶ms, None, true)
1091 .await
1092 .unwrap();
1093 let actor_mesh_ref = actor_mesh.deref().clone();
1094
1095 let (reply_port_handle0, _) = open_port::<String>(instance);
1096 let reply_port_ref0 = reply_port_handle0.bind();
1097 let (reply_port_handle1, reply1_rx) = match accum {
1098 Some(a) => instance.mailbox().open_accum_port(a),
1099 None => open_port(instance),
1100 };
1101 let reply_port_ref1 = reply_port_handle1.bind();
1102 let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(instance);
1103 let reply_port_ref2 = reply_port_handle2.bind();
1104 let message = TestMessage::CastAndReply {
1105 arg: "abc".to_string(),
1106 reply_to0: reply_port_ref0.clone(),
1107 reply_to1: reply_port_ref1.clone(),
1108 reply_to2: reply_port_ref2.clone(),
1109 };
1110
1111 clear_collected_tree();
1112 actor_mesh_ref.cast(instance, message).unwrap();
1113
1114 let mut reply_tos = vec![];
1115 for _ in extent.points() {
1116 let msg = rx.recv().await.expect("missing");
1117 match msg {
1118 TestMessage::CastAndReply {
1119 arg,
1120 reply_to0,
1121 reply_to1,
1122 reply_to2,
1123 } => {
1124 assert_eq!(arg, "abc");
1125 assert_eq!(reply_to0, reply_port_ref0);
1128 assert_ne!(reply_to1, reply_port_ref1);
1130 assert!(reply_to1.port_id().actor_id().name().contains("comm"));
1131 assert_ne!(reply_to2, reply_port_ref2);
1132 assert!(reply_to2.port_id().actor_id().name().contains("comm"));
1133 reply_tos.push((reply_to1, reply_to2));
1134 }
1135 _ => {
1136 panic!("unexpected message: {:?}", msg);
1137 }
1138 }
1139 }
1140
1141 let selection = sel!(*);
1143 verify_split_port_paths(&selection, &extent, &reply_port_ref1, &reply_port_ref2);
1144
1145 MeshSetupV1 {
1146 instance,
1147 actor_mesh_ref,
1148 reply1_rx,
1149 reply2_rx,
1150 reply_tos,
1151 }
1152 }
1153
1154 #[async_timed_test(timeout_secs = 30)]
1155 async fn test_cast_and_reply_v1() {
1156 let MeshSetupV1 {
1157 instance,
1158 actor_mesh_ref,
1159 reply1_rx,
1160 reply2_rx,
1161 reply_tos,
1162 ..
1163 } = setup_mesh_v1::<NoneAccumulator>(None).await;
1164
1165 let ranks = actor_mesh_ref.values().collect::<Vec<_>>();
1166 execute_cast_and_reply(ranks, instance, reply1_rx, reply2_rx, reply_tos).await;
1167 }
1168
1169 #[async_timed_test(timeout_secs = 30)]
1170 async fn test_cast_and_accum_v1() {
1171 let config = hyperactor_config::global::lock();
1172 let _guard1 = config.override_key(hyperactor::config::SPLIT_MAX_BUFFER_SIZE, 1);
1174
1175 let MeshSetupV1 {
1176 instance,
1177 actor_mesh_ref,
1178 reply1_rx,
1179 reply_tos,
1180 ..
1181 } = setup_mesh_v1(Some(accum::sum::<u64>())).await;
1182
1183 let ranks = actor_mesh_ref.values().collect::<Vec<_>>();
1184 execute_cast_and_accum(ranks, instance, reply1_rx, reply_tos).await;
1185 }
1186}