1use crate::comm::multicast::CAST_ORIGINATING_SENDER;
10use crate::reference::ActorMeshId;
11use crate::resource;
12pub mod multicast;
13
14use std::cmp::Ordering;
15use std::collections::HashMap;
16use std::fmt::Debug;
17
18use anyhow::Result;
19use async_trait::async_trait;
20use hyperactor::Actor;
21use hyperactor::ActorId;
22use hyperactor::ActorRef;
23use hyperactor::Context;
24use hyperactor::Handler;
25use hyperactor::Instance;
26use hyperactor::Named;
27use hyperactor::PortRef;
28use hyperactor::WorldId;
29use hyperactor::data::Serialized;
30use hyperactor::mailbox::DeliveryError;
31use hyperactor::mailbox::MailboxSender;
32use hyperactor::mailbox::Undeliverable;
33use hyperactor::mailbox::UndeliverableMailboxSender;
34use hyperactor::mailbox::UndeliverableMessageError;
35use hyperactor::mailbox::monitored_return_handle;
36use hyperactor::reference::UnboundPort;
37use ndslice::selection::routing::RoutingFrame;
38use serde::Deserialize;
39use serde::Serialize;
40
41use crate::comm::multicast::CastMessage;
42use crate::comm::multicast::CastMessageEnvelope;
43use crate::comm::multicast::ForwardMessage;
44use crate::comm::multicast::set_cast_info_on_headers;
45
46#[derive(Debug, Clone, Serialize, Deserialize, Named, Default)]
48pub struct CommActorParams {}
49
50#[derive(Debug)]
52struct Buffered {
53 seq: usize,
55 deliver_here: bool,
57 next_steps: HashMap<usize, Vec<RoutingFrame>>,
59 message: CastMessageEnvelope,
61}
62
63#[derive(Debug, Default)]
66struct ReceiveState {
67 seq: usize,
69 buffer: HashMap<usize, Buffered>,
72 last_seqs: HashMap<usize, usize>,
74}
75
76#[derive(Debug)]
79#[hyperactor::export(
80 spawn = true,
81 handlers = [
82 CommActorMode,
83 CastMessage,
84 ForwardMessage,
85 ],
86)]
87pub struct CommActor {
88 send_seq: HashMap<(ActorMeshId, ActorId), usize>,
90 recv_state: HashMap<(ActorMeshId, ActorId), ReceiveState>,
92
93 mode: CommActorMode,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, Named)]
102pub enum CommActorMode {
103 Mesh(usize, HashMap<usize, ActorRef<CommActor>>),
106
107 Implicit,
110
111 ImplicitWithWorldId(WorldId),
117}
118
119impl Default for CommActorMode {
120 fn default() -> Self {
121 Self::Implicit
122 }
123}
124
125impl CommActorMode {
126 fn peer_for_rank(&self, self_id: &ActorId, rank: usize) -> Result<ActorRef<CommActor>> {
129 match self {
130 Self::Mesh(_self_rank, peers) => peers
131 .get(&rank)
132 .cloned()
133 .ok_or_else(|| anyhow::anyhow!("no peer for rank {}", rank)),
134 Self::Implicit => {
135 let world_id = self_id
136 .proc_id()
137 .world_id()
138 .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc"))?;
139 let proc_id = world_id.proc_id(rank);
140 let actor_id = ActorId::root(proc_id, self_id.name().to_string());
141 Ok(ActorRef::<CommActor>::attest(actor_id))
142 }
143 Self::ImplicitWithWorldId(world_id) => {
144 let proc_id = world_id.proc_id(rank);
145 let actor_id = ActorId::root(proc_id, self_id.name().to_string());
146 Ok(ActorRef::<CommActor>::attest(actor_id))
147 }
148 }
149 }
150
151 fn self_rank(&self, self_id: &ActorId) -> Result<usize> {
153 match self {
154 Self::Mesh(rank, _) => Ok(*rank),
155 Self::Implicit | Self::ImplicitWithWorldId(_) => self_id
156 .proc_id()
157 .rank()
158 .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc")),
159 }
160 }
161}
162
163#[async_trait]
164impl Actor for CommActor {
165 type Params = CommActorParams;
166
167 async fn new(_params: Self::Params) -> Result<Self> {
168 Ok(Self {
169 send_seq: HashMap::new(),
170 recv_state: HashMap::new(),
171 mode: Default::default(),
172 })
173 }
174
175 async fn handle_undeliverable_message(
177 &mut self,
178 cx: &Instance<Self>,
179 undelivered: hyperactor::mailbox::Undeliverable<hyperactor::mailbox::MessageEnvelope>,
180 ) -> Result<(), anyhow::Error> {
181 let Undeliverable(mut message_envelope) = undelivered;
182
183 if let Ok(ForwardMessage { message, .. }) =
185 message_envelope.deserialized::<ForwardMessage>()
186 {
187 let sender = message.sender();
188 let return_port = PortRef::attest_message_port(sender);
189 message_envelope.set_error(DeliveryError::Multicast(format!(
190 "comm actor {} failed to forward the cast message; return to \
191 its original sender's port {}",
192 cx.self_id(),
193 return_port.port_id(),
194 )));
195 return_port
196 .send(cx, Undeliverable(message_envelope.clone()))
197 .map_err(|err| {
198 let error = DeliveryError::BrokenLink(format!(
199 "error occured when returning ForwardMessage to the original \
200 sender's port {}; error is: {}",
201 return_port.port_id(),
202 err,
203 ));
204 message_envelope.set_error(error);
205 UndeliverableMessageError::return_failure(&message_envelope)
206 })?;
207 return Ok(());
208 }
209
210 if let Some(sender) = message_envelope.headers().get(CAST_ORIGINATING_SENDER) {
212 let return_port = PortRef::attest_message_port(sender);
213 message_envelope.set_error(DeliveryError::Multicast(format!(
214 "comm actor {} failed to deliver the cast message to the dest \
215 actor; return to its original sender's port {}",
216 cx.self_id(),
217 return_port.port_id(),
218 )));
219 return_port
220 .send(cx, Undeliverable(message_envelope.clone()))
221 .map_err(|err| {
222 let error = DeliveryError::BrokenLink(format!(
223 "error occured when returning cast message to the original \
224 sender's port {}; error is: {}",
225 return_port.port_id(),
226 err,
227 ));
228 message_envelope.set_error(error);
229 UndeliverableMessageError::return_failure(&message_envelope)
230 })?;
231 return Ok(());
232 }
233
234 UndeliverableMailboxSender
236 .post(message_envelope, monitored_return_handle());
237 Ok(())
238 }
239}
240
241impl CommActor {
242 fn forward(
244 cx: &Instance<Self>,
245 mode: &CommActorMode,
246 rank: usize,
247 message: ForwardMessage,
248 ) -> Result<()> {
249 let child = mode.peer_for_rank(cx.self_id(), rank)?;
250 child.send(cx, message)?;
251 Ok(())
252 }
253
254 fn handle_message(
255 cx: &Context<Self>,
256 mode: &CommActorMode,
257 deliver_here: bool,
258 next_steps: HashMap<usize, Vec<RoutingFrame>>,
259 sender: ActorId,
260 mut message: CastMessageEnvelope,
261 seq: usize,
262 last_seqs: &mut HashMap<usize, usize>,
263 ) -> Result<()> {
264 message.data_mut().visit_mut::<UnboundPort>(
268 |UnboundPort(port_id, reducer_spec, reducer_opts)| {
269 let split = port_id.split(cx, reducer_spec.clone(), reducer_opts.clone())?;
270
271 #[cfg(test)]
272 tests::collect_split_port(port_id, &split, deliver_here);
273
274 *port_id = split;
275 Ok(())
276 },
277 )?;
278
279 if deliver_here {
281 let rank_on_root_mesh = mode.self_rank(cx.self_id())?;
282 let cast_rank = message.relative_rank(rank_on_root_mesh)?;
283 message
285 .data_mut()
286 .visit_mut::<resource::Rank>(|resource::Rank(rank)| {
287 *rank = Some(cast_rank);
288 Ok(())
289 })?;
290 let cast_shape = message.shape();
291 let point = cast_shape
292 .extent()
293 .point_of_rank(cast_rank)
294 .expect("rank out of bounds");
295 let mut headers = cx.headers().clone();
296 set_cast_info_on_headers(&mut headers, point, message.sender().clone());
297 cx.post(
298 cx.self_id()
299 .proc_id()
300 .actor_id(message.dest_port().actor_name(), 0)
301 .port_id(message.dest_port().port()),
302 headers,
303 Serialized::serialize(message.data())?,
304 );
305 }
306
307 next_steps
309 .into_iter()
310 .map(|(peer, dests)| {
311 let last_seq = last_seqs.entry(peer).or_default();
312 Self::forward(
313 cx,
314 mode,
315 peer,
316 ForwardMessage {
317 dests,
318 sender: sender.clone(),
319 message: message.clone(),
320 seq,
321 last_seq: *last_seq,
322 },
323 )?;
324 *last_seq = seq;
325 Ok(())
326 })
327 .collect::<Result<Vec<_>>>()?;
328
329 Ok(())
330 }
331}
332
333#[async_trait]
334impl Handler<CommActorMode> for CommActor {
335 async fn handle(&mut self, _cx: &Context<Self>, mode: CommActorMode) -> Result<()> {
336 self.mode = mode;
337 Ok(())
338 }
339}
340
341#[async_trait]
343impl Handler<CastMessage> for CommActor {
344 async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessage) -> Result<()> {
345 let slice = cast_message.dest.slice.clone();
347 let selection = cast_message.dest.selection.clone();
348 let frame = RoutingFrame::root(selection, slice);
349 let rank = frame.slice.location(&frame.here)?;
350 let seq = self
351 .send_seq
352 .entry(cast_message.message.stream_key())
353 .or_default();
354 let last_seq = *seq;
355 *seq += 1;
356 Self::forward(
357 cx,
358 &self.mode,
359 rank,
360 ForwardMessage {
361 dests: vec![frame],
362 sender: cx.self_id().clone(),
363 message: cast_message.message,
364 seq: *seq,
365 last_seq,
366 },
367 )?;
368 Ok(())
369 }
370}
371
372#[async_trait]
373impl Handler<ForwardMessage> for CommActor {
374 async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessage) -> Result<()> {
375 let ForwardMessage {
376 sender,
377 dests,
378 message,
379 seq,
380 last_seq,
381 } = fwd_message;
382
383 let rank = self.mode.self_rank(cx.self_id())?;
385 let (deliver_here, next_steps) =
386 ndslice::selection::routing::resolve_routing(rank, dests, &mut |_| {
387 panic!("Choice encountered in CommActor routing")
388 })?;
389
390 let recv_state = self.recv_state.entry(message.stream_key()).or_default();
391 match recv_state.seq.cmp(&last_seq) {
392 Ordering::Equal => {
394 Self::handle_message(
396 cx,
397 &self.mode,
398 deliver_here,
399 next_steps,
400 sender.clone(),
401 message,
402 seq,
403 &mut recv_state.last_seqs,
404 )?;
405 recv_state.seq = seq;
406
407 while let Some(Buffered {
410 seq,
411 deliver_here,
412 next_steps,
413 message,
414 }) = recv_state.buffer.remove(&recv_state.seq)
415 {
416 Self::handle_message(
417 cx,
418 &self.mode,
419 deliver_here,
420 next_steps,
421 sender.clone(),
422 message,
423 seq,
424 &mut recv_state.last_seqs,
425 )?;
426 recv_state.seq = seq;
427 }
428 }
429 Ordering::Less => {
432 tracing::warn!(
433 "buffering out-of-order message with seq {} (last {}), expected {}: {:?}",
434 seq,
435 last_seq,
436 recv_state.seq,
437 message
438 );
439 recv_state.buffer.insert(
440 last_seq,
441 Buffered {
442 seq,
443 deliver_here,
444 next_steps,
445 message,
446 },
447 );
448 }
449 Ordering::Greater => {
451 tracing::warn!("received duplicate message with seq {}: {:?}", seq, message);
452 }
453 }
454
455 Ok(())
456 }
457}
458
459pub mod test_utils {
460 use anyhow::Result;
461 use async_trait::async_trait;
462 use hyperactor::Actor;
463 use hyperactor::ActorId;
464 use hyperactor::Bind;
465 use hyperactor::Context;
466 use hyperactor::Handler;
467 use hyperactor::Named;
468 use hyperactor::PortRef;
469 use hyperactor::Unbind;
470 use serde::Deserialize;
471 use serde::Serialize;
472
473 use super::*;
474
475 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
476 pub struct MyReply {
477 pub sender: ActorId,
478 pub value: u64,
479 }
480
481 #[derive(Debug, Named, Serialize, Deserialize, PartialEq, Clone, Bind, Unbind)]
482 pub enum TestMessage {
483 Forward(String),
484 CastAndReply {
485 arg: String,
486 reply_to0: PortRef<String>,
490 #[binding(include)]
491 reply_to1: PortRef<u64>,
492 #[binding(include)]
493 reply_to2: PortRef<MyReply>,
494 },
495 }
496
497 #[derive(Debug)]
498 #[hyperactor::export(
499 spawn = true,
500 handlers = [
501 TestMessage { cast = true },
502 ],
503 )]
504 pub struct TestActor {
505 forward_port: PortRef<TestMessage>,
508 }
509
510 #[derive(Debug, Clone, Named, Serialize, Deserialize)]
511 pub struct TestActorParams {
512 pub forward_port: PortRef<TestMessage>,
513 }
514
515 #[async_trait]
516 impl Actor for TestActor {
517 type Params = TestActorParams;
518
519 async fn new(params: Self::Params) -> Result<Self> {
520 let Self::Params { forward_port } = params;
521 Ok(Self { forward_port })
522 }
523 }
524
525 #[async_trait]
526 impl Handler<TestMessage> for TestActor {
527 async fn handle(&mut self, cx: &Context<Self>, msg: TestMessage) -> anyhow::Result<()> {
528 self.forward_port.send(cx, msg)?;
529 Ok(())
530 }
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use std::collections::BTreeMap;
537 use std::collections::HashSet;
538 use std::fmt::Display;
539 use std::hash::Hash;
540 use std::ops::Deref;
541 use std::ops::DerefMut;
542 use std::sync::Arc;
543 use std::sync::Mutex;
544 use std::sync::OnceLock;
545
546 use hyperactor::PortId;
547 use hyperactor::PortRef;
548 use hyperactor::accum;
549 use hyperactor::accum::Accumulator;
550 use hyperactor::accum::ReducerSpec;
551 use hyperactor::channel::ChannelTransport;
552 use hyperactor::clock::Clock;
553 use hyperactor::clock::RealClock;
554 use hyperactor::config;
555 use hyperactor::context::Mailbox;
556 use hyperactor::mailbox::PortReceiver;
557 use hyperactor::mailbox::open_port;
558 use hyperactor::reference::Index;
559 use hyperactor_mesh_macros::sel;
560 use maplit::btreemap;
561 use maplit::hashmap;
562 use ndslice::Extent;
563 use ndslice::Selection;
564 use ndslice::ViewExt as _;
565 use ndslice::extent;
566 use ndslice::selection::test_utils::collect_commactor_routing_tree;
567 use test_utils::*;
568 use timed_test::async_timed_test;
569 use tokio::time::Duration;
570
571 use super::*;
572 use crate::ProcMesh;
573 use crate::actor_mesh::ActorMesh;
574 use crate::actor_mesh::RootActorMesh;
575 use crate::alloc::AllocSpec;
576 use crate::alloc::Allocator;
577 use crate::alloc::LocalAllocator;
578 use crate::proc_mesh::SharedSpawnable;
579 use crate::v1;
580
581 struct Edge<T> {
582 from: T,
583 to: T,
584 is_leaf: bool,
585 }
586
587 impl<T> From<(T, T, bool)> for Edge<T> {
588 fn from((from, to, is_leaf): (T, T, bool)) -> Self {
589 Self { from, to, is_leaf }
590 }
591 }
592
593 static SPLIT_PORT_TREE: OnceLock<Mutex<Vec<Edge<PortId>>>> = OnceLock::new();
596
597 pub(crate) fn collect_split_port(original: &PortId, split: &PortId, deliver_here: bool) {
600 let mutex = SPLIT_PORT_TREE.get_or_init(|| Mutex::new(vec![]));
601 let mut tree = mutex.lock().unwrap();
602
603 tree.deref_mut().push(Edge {
604 from: original.clone(),
605 to: split.clone(),
606 is_leaf: deliver_here,
607 });
608 }
609
610 fn clear_collected_tree() {
614 if let Some(tree) = SPLIT_PORT_TREE.get() {
615 let mut tree = tree.lock().unwrap();
616 tree.clear();
617 }
618 }
619
620 #[derive(PartialEq)]
624 struct PathToLeaves<T>(BTreeMap<T, Vec<T>>);
625
626 impl<T: Display> Debug for PathToLeaves<T> {
628 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
629 fn vec_to_string<T: Display>(v: &[T]) -> String {
630 v.iter()
631 .map(ToString::to_string)
632 .collect::<Vec<String>>()
633 .join(", ")
634 }
635
636 for (src, path) in &self.0 {
637 write!(f, "{} -> {}\n", src, vec_to_string(path))?;
638 }
639 Ok(())
640 }
641 }
642
643 fn build_paths<T: Clone + Eq + Hash + Ord>(edges: &[Edge<T>]) -> PathToLeaves<T> {
644 let mut child_parent_map = HashMap::new();
645 let mut all_nodes = HashSet::new();
646 let mut parents = HashSet::new();
647 let mut children = HashSet::new();
648 let mut dests = HashSet::new();
649
650 for Edge { from, to, is_leaf } in edges {
652 child_parent_map.insert(to.clone(), from.clone());
653 all_nodes.insert(from.clone());
654 all_nodes.insert(to.clone());
655 parents.insert(from.clone());
656 children.insert(to.clone());
657 if *is_leaf {
658 dests.insert(to.clone());
659 }
660 }
661
662 let mut result = BTreeMap::new();
664 for dest in dests {
665 let mut path = vec![dest.clone()];
666 let mut current = dest.clone();
667 while let Some(parent) = child_parent_map.get(¤t) {
668 path.push(parent.clone());
669 current = parent.clone();
670 }
671 path.reverse();
672 result.insert(dest, path);
673 }
674
675 PathToLeaves(result)
676 }
677
678 #[test]
679 fn test_build_paths() {
680 let edges: Vec<_> = [
687 (0, 1, false),
688 (1, 2, true),
689 (1, 3, true),
690 (0, 4, true),
691 (4, 5, true),
692 ]
693 .into_iter()
694 .map(|(from, to, is_leaf)| Edge { from, to, is_leaf })
695 .collect();
696
697 let paths = build_paths(&edges);
698
699 let expected = btreemap! {
700 2 => vec![0, 1, 2],
701 3 => vec![0, 1, 3],
702 4 => vec![0, 4],
703 5 => vec![0, 4, 5],
704 };
705
706 assert_eq!(paths.0, expected);
707 }
708
709 fn get_ranks(paths: PathToLeaves<PortId>, client_reply: &PortId) -> PathToLeaves<Index> {
729 let ranks = paths
730 .0
731 .into_iter()
732 .map(|(dst, mut path)| {
733 let first = path.remove(0);
734 assert_eq!(&first, client_reply);
736 assert!(dst.actor_id().name().contains("comm"));
739 let actor_path = path
740 .into_iter()
741 .map(|p| {
742 assert!(p.actor_id().name().contains("comm"));
743 p.actor_id().rank()
744 })
745 .collect();
746 (dst.into_actor_id().rank(), actor_path)
747 })
748 .collect();
749 PathToLeaves(ranks)
750 }
751
752 struct MeshSetup {
753 actor_mesh: RootActorMesh<'static, TestActor>,
754 reply1_rx: PortReceiver<u64>,
755 reply2_rx: PortReceiver<MyReply>,
756 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
757 }
758
759 struct NoneAccumulator;
760
761 impl Accumulator for NoneAccumulator {
762 type State = u64;
763 type Update = u64;
764
765 fn accumulate(
766 &self,
767 _state: &mut Self::State,
768 _update: Self::Update,
769 ) -> anyhow::Result<()> {
770 unimplemented!()
771 }
772
773 fn reducer_spec(&self) -> Option<ReducerSpec> {
774 unimplemented!()
775 }
776 }
777
778 fn verify_split_port_paths(
780 selection: &Selection,
781 extent: &Extent,
782 reply_port_ref1: &PortRef<u64>,
783 reply_port_ref2: &PortRef<MyReply>,
784 ) {
785 let sel_paths = PathToLeaves(
787 collect_commactor_routing_tree(selection, &extent.to_slice())
788 .delivered
789 .into_iter()
790 .collect(),
791 );
792
793 let (reply1_paths, reply2_paths) = {
795 let tree = SPLIT_PORT_TREE.get().unwrap();
796 let edges = tree.lock().unwrap();
797 let (reply1, reply2): (BTreeMap<_, _>, BTreeMap<_, _>) = build_paths(&edges)
798 .0
799 .into_iter()
800 .partition(|(_dst, path)| &path[0] == reply_port_ref1.port_id());
801 (
802 get_ranks(PathToLeaves(reply1), reply_port_ref1.port_id()),
803 get_ranks(PathToLeaves(reply2), reply_port_ref2.port_id()),
804 )
805 };
806
807 assert_eq!(sel_paths, reply1_paths);
809 assert_eq!(sel_paths, reply2_paths);
810 }
811
812 async fn setup_mesh<A>(accum: Option<A>) -> MeshSetup
813 where
814 A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
815 {
816 let extent = extent!(replica = 4, host = 4, gpu = 4);
817 let alloc = LocalAllocator
818 .allocate(AllocSpec {
819 extent: extent.clone(),
820 constraints: Default::default(),
821 proc_name: None,
822 transport: ChannelTransport::Local,
823 })
824 .await
825 .unwrap();
826
827 let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap());
828 let dest_actor_name = "dest_actor";
829 let (tx, mut rx) = hyperactor::mailbox::open_port(proc_mesh.client());
830 let params = TestActorParams {
831 forward_port: tx.bind(),
832 };
833 let actor_mesh = proc_mesh
834 .clone()
835 .spawn::<TestActor>(dest_actor_name, ¶ms)
836 .await
837 .unwrap();
838
839 let (reply_port_handle0, _) = open_port::<String>(proc_mesh.client());
840 let reply_port_ref0 = reply_port_handle0.bind();
841 let (reply_port_handle1, reply1_rx) = match accum {
842 Some(a) => proc_mesh.client().mailbox().open_accum_port(a),
843 None => open_port(proc_mesh.client()),
844 };
845 let reply_port_ref1 = reply_port_handle1.bind();
846 let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(proc_mesh.client());
847 let reply_port_ref2 = reply_port_handle2.bind();
848 let message = TestMessage::CastAndReply {
849 arg: "abc".to_string(),
850 reply_to0: reply_port_ref0.clone(),
851 reply_to1: reply_port_ref1.clone(),
852 reply_to2: reply_port_ref2.clone(),
853 };
854
855 let selection = sel!(*);
856 clear_collected_tree();
857 actor_mesh
858 .cast(proc_mesh.client(), selection.clone(), message)
859 .unwrap();
860
861 let mut reply_tos = vec![];
862 for _ in extent.points() {
863 let msg = rx.recv().await.expect("missing");
864 match msg {
865 TestMessage::CastAndReply {
866 arg,
867 reply_to0,
868 reply_to1,
869 reply_to2,
870 } => {
871 assert_eq!(arg, "abc");
872 assert_eq!(reply_to0, reply_port_ref0);
875 assert_ne!(reply_to1, reply_port_ref1);
877 assert_eq!(reply_to1.port_id().actor_id().name(), "comm");
878 assert_ne!(reply_to2, reply_port_ref2);
879 assert_eq!(reply_to2.port_id().actor_id().name(), "comm");
880 reply_tos.push((reply_to1, reply_to2));
881 }
882 _ => {
883 panic!("unexpected message: {:?}", msg);
884 }
885 }
886 }
887
888 verify_split_port_paths(&selection, &extent, &reply_port_ref1, &reply_port_ref2);
889
890 MeshSetup {
891 actor_mesh,
892 reply1_rx,
893 reply2_rx,
894 reply_tos,
895 }
896 }
897
898 async fn execute_cast_and_reply(
899 ranks: Vec<ActorRef<TestActor>>,
900 instance: &Instance<()>,
901 mut reply1_rx: PortReceiver<u64>,
902 mut reply2_rx: PortReceiver<MyReply>,
903 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
904 ) {
905 {
907 for (dest_actor, (reply_to1, reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
908 let rank = dest_actor.actor_id().rank() as u64;
909 reply_to1.send(instance, rank).unwrap();
910 let my_reply = MyReply {
911 sender: dest_actor.actor_id().clone(),
912 value: rank,
913 };
914 reply_to2.send(instance, my_reply.clone()).unwrap();
915
916 assert_eq!(reply1_rx.recv().await.unwrap(), rank);
917 assert_eq!(reply2_rx.recv().await.unwrap(), my_reply);
918 }
919 }
920
921 tracing::info!("the 1st updates from all dest actors were receivered by client");
922
923 {
927 let n = 100;
928 let mut expected2: HashMap<usize, Vec<MyReply>> = hashmap! {};
929 for (dest_actor, (_reply_to1, reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
930 let rank = dest_actor.actor_id().rank();
931 let mut sent2 = vec![];
932 for i in 0..n {
933 let value = (rank * 100 + i) as u64;
934 let my_reply = MyReply {
935 sender: dest_actor.actor_id().clone(),
936 value,
937 };
938 reply_to2.send(instance, my_reply.clone()).unwrap();
939 sent2.push(my_reply);
940 }
941 assert!(
942 expected2.insert(rank, sent2).is_none(),
943 "duplicate rank {rank} in map"
944 );
945 }
946
947 let mut received2: HashMap<usize, Vec<MyReply>> = hashmap! {};
948
949 for _ in 0..(n * ranks.len()) {
950 let my_reply = reply2_rx.recv().await.unwrap();
951 received2
952 .entry(my_reply.sender.rank())
953 .or_default()
954 .push(my_reply);
955 }
956 assert_eq!(received2, expected2);
957 }
958 }
959
960 #[async_timed_test(timeout_secs = 30)]
961 async fn test_cast_and_reply() {
962 let MeshSetup {
963 actor_mesh,
964 reply1_rx,
965 reply2_rx,
966 reply_tos,
967 ..
968 } = setup_mesh::<NoneAccumulator>(None).await;
969 let proc_mesh_client = actor_mesh.proc_mesh().client();
970
971 let ranks = actor_mesh.ranks.clone();
972 execute_cast_and_reply(ranks, proc_mesh_client, reply1_rx, reply2_rx, reply_tos).await;
973 }
974
975 async fn wait_for_with_timeout(
976 receiver: &mut PortReceiver<u64>,
977 expected: u64,
978 dur: Duration,
979 ) -> anyhow::Result<()> {
980 RealClock
982 .timeout(dur, async {
983 loop {
984 let msg = receiver.recv().await.unwrap();
985 if msg == expected {
986 break;
987 }
988 }
989 })
990 .await?;
991 Ok(())
992 }
993
994 async fn execute_cast_and_accum(
995 ranks: Vec<ActorRef<TestActor>>,
996 instance: &Instance<()>,
997 mut reply1_rx: PortReceiver<u64>,
998 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
999 ) {
1000 let mut sum = 0;
1004 let n = 100;
1005 for (dest_actor, (reply_to1, _reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
1006 let rank = dest_actor.actor_id().rank();
1007 for i in 0..n {
1008 let value = (rank + i) as u64;
1009 reply_to1.send(instance, value).unwrap();
1010 sum += value;
1011 }
1012 }
1013 wait_for_with_timeout(&mut reply1_rx, sum, Duration::from_secs(2))
1014 .await
1015 .unwrap();
1016 RealClock.sleep(Duration::from_secs(2)).await;
1018 let msg = reply1_rx.try_recv().unwrap();
1019 assert_eq!(msg, None);
1020 }
1021
1022 #[async_timed_test(timeout_secs = 30)]
1023 async fn test_cast_and_accum() {
1024 let config = config::global::lock();
1025 let _guard1 = config.override_key(config::SPLIT_MAX_BUFFER_SIZE, 1);
1027
1028 let MeshSetup {
1029 actor_mesh,
1030 reply1_rx,
1031 reply_tos,
1032 ..
1033 } = setup_mesh(Some(accum::sum::<u64>())).await;
1034 let proc_mesh_client = actor_mesh.proc_mesh().client();
1035 let ranks = actor_mesh.ranks.clone();
1036 execute_cast_and_accum(ranks, proc_mesh_client, reply1_rx, reply_tos).await;
1037 }
1038
1039 struct MeshSetupV1 {
1040 instance: &'static Instance<()>,
1041 actor_mesh_ref: v1::ActorMeshRef<TestActor>,
1042 reply1_rx: PortReceiver<u64>,
1043 reply2_rx: PortReceiver<MyReply>,
1044 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1045 }
1046
1047 async fn setup_mesh_v1<A>(accum: Option<A>) -> MeshSetupV1
1048 where
1049 A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
1050 {
1051 let instance = v1::testing::instance().await;
1052
1053 let extent = extent!(replica = 4, host = 4, gpu = 4);
1054 let alloc = LocalAllocator
1055 .allocate(AllocSpec {
1056 extent: extent.clone(),
1057 constraints: Default::default(),
1058 proc_name: None,
1059 transport: ChannelTransport::Local,
1060 })
1061 .await
1062 .unwrap();
1063
1064 let proc_mesh = v1::ProcMesh::allocate(instance, Box::new(alloc), "test.local")
1065 .await
1066 .unwrap();
1067
1068 let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1069 let params = TestActorParams {
1070 forward_port: tx.bind(),
1071 };
1072 let actor_mesh = proc_mesh.spawn(&instance, "test", ¶ms).await.unwrap();
1073 let actor_mesh_ref = actor_mesh.deref().clone();
1074
1075 let (reply_port_handle0, _) = open_port::<String>(instance);
1076 let reply_port_ref0 = reply_port_handle0.bind();
1077 let (reply_port_handle1, reply1_rx) = match accum {
1078 Some(a) => instance.mailbox().open_accum_port(a),
1079 None => open_port(instance),
1080 };
1081 let reply_port_ref1 = reply_port_handle1.bind();
1082 let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(instance);
1083 let reply_port_ref2 = reply_port_handle2.bind();
1084 let message = TestMessage::CastAndReply {
1085 arg: "abc".to_string(),
1086 reply_to0: reply_port_ref0.clone(),
1087 reply_to1: reply_port_ref1.clone(),
1088 reply_to2: reply_port_ref2.clone(),
1089 };
1090
1091 clear_collected_tree();
1092 actor_mesh_ref.cast(instance, message).unwrap();
1093
1094 let mut reply_tos = vec![];
1095 for _ in extent.points() {
1096 let msg = rx.recv().await.expect("missing");
1097 match msg {
1098 TestMessage::CastAndReply {
1099 arg,
1100 reply_to0,
1101 reply_to1,
1102 reply_to2,
1103 } => {
1104 assert_eq!(arg, "abc");
1105 assert_eq!(reply_to0, reply_port_ref0);
1108 assert_ne!(reply_to1, reply_port_ref1);
1110 assert!(reply_to1.port_id().actor_id().name().contains("comm"));
1111 assert_ne!(reply_to2, reply_port_ref2);
1112 assert!(reply_to2.port_id().actor_id().name().contains("comm"));
1113 reply_tos.push((reply_to1, reply_to2));
1114 }
1115 _ => {
1116 panic!("unexpected message: {:?}", msg);
1117 }
1118 }
1119 }
1120
1121 let selection = sel!(*);
1123 verify_split_port_paths(&selection, &extent, &reply_port_ref1, &reply_port_ref2);
1124
1125 MeshSetupV1 {
1126 instance,
1127 actor_mesh_ref,
1128 reply1_rx,
1129 reply2_rx,
1130 reply_tos,
1131 }
1132 }
1133
1134 #[async_timed_test(timeout_secs = 30)]
1135 async fn test_cast_and_reply_v1() {
1136 let MeshSetupV1 {
1137 instance,
1138 actor_mesh_ref,
1139 reply1_rx,
1140 reply2_rx,
1141 reply_tos,
1142 ..
1143 } = setup_mesh_v1::<NoneAccumulator>(None).await;
1144
1145 let ranks = actor_mesh_ref.values().collect::<Vec<_>>();
1146 execute_cast_and_reply(ranks, instance, reply1_rx, reply2_rx, reply_tos).await;
1147 }
1148
1149 #[async_timed_test(timeout_secs = 30)]
1150 async fn test_cast_and_accum_v1() {
1151 let config = config::global::lock();
1152 let _guard1 = config.override_key(config::SPLIT_MAX_BUFFER_SIZE, 1);
1154
1155 let MeshSetupV1 {
1156 instance,
1157 actor_mesh_ref,
1158 reply1_rx,
1159 reply_tos,
1160 ..
1161 } = setup_mesh_v1(Some(accum::sum::<u64>())).await;
1162
1163 let ranks = actor_mesh_ref.values().collect::<Vec<_>>();
1164 execute_cast_and_accum(ranks, instance, reply1_rx, reply_tos).await;
1165 }
1166}