1use crate::comm::multicast::CAST_ORIGINATING_SENDER;
10use crate::reference::ActorMeshId;
11pub mod multicast;
12
13use std::cmp::Ordering;
14use std::collections::HashMap;
15use std::fmt::Debug;
16
17use anyhow::Result;
18use async_trait::async_trait;
19use hyperactor::Actor;
20use hyperactor::ActorId;
21use hyperactor::ActorRef;
22use hyperactor::Context;
23use hyperactor::Handler;
24use hyperactor::Instance;
25use hyperactor::Named;
26use hyperactor::PortRef;
27use hyperactor::WorldId;
28use hyperactor::data::Serialized;
29use hyperactor::mailbox::DeliveryError;
30use hyperactor::mailbox::MailboxSender;
31use hyperactor::mailbox::Undeliverable;
32use hyperactor::mailbox::UndeliverableMailboxSender;
33use hyperactor::mailbox::UndeliverableMessageError;
34use hyperactor::mailbox::monitored_return_handle;
35use hyperactor::reference::UnboundPort;
36use ndslice::selection::routing::RoutingFrame;
37use serde::Deserialize;
38use serde::Serialize;
39
40use crate::comm::multicast::CastMessage;
41use crate::comm::multicast::CastMessageEnvelope;
42use crate::comm::multicast::ForwardMessage;
43use crate::comm::multicast::set_cast_info_on_headers;
44
45#[derive(Debug, Clone, Serialize, Deserialize, Named, Default)]
47pub struct CommActorParams {}
48
49#[derive(Debug)]
51struct Buffered {
52 seq: usize,
54 deliver_here: bool,
56 next_steps: HashMap<usize, Vec<RoutingFrame>>,
58 message: CastMessageEnvelope,
60}
61
62#[derive(Debug, Default)]
65struct ReceiveState {
66 seq: usize,
68 buffer: HashMap<usize, Buffered>,
71 last_seqs: HashMap<usize, usize>,
73}
74
75#[derive(Debug)]
78#[hyperactor::export(
79 spawn = true,
80 handlers = [
81 CommActorMode,
82 CastMessage,
83 ForwardMessage,
84 ],
85)]
86pub struct CommActor {
87 send_seq: HashMap<(ActorMeshId, ActorId), usize>,
89 recv_state: HashMap<(ActorMeshId, ActorId), ReceiveState>,
91
92 mode: CommActorMode,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize, Named)]
101pub enum CommActorMode {
102 Mesh(usize, HashMap<usize, ActorRef<CommActor>>),
105
106 Implicit,
109
110 ImplicitWithWorldId(WorldId),
116}
117
118impl Default for CommActorMode {
119 fn default() -> Self {
120 Self::Implicit
121 }
122}
123
124impl CommActorMode {
125 fn peer_for_rank(&self, self_id: &ActorId, rank: usize) -> Result<ActorRef<CommActor>> {
128 match self {
129 Self::Mesh(_self_rank, peers) => peers
130 .get(&rank)
131 .cloned()
132 .ok_or_else(|| anyhow::anyhow!("no peer for rank {}", rank)),
133 Self::Implicit => {
134 let world_id = self_id
135 .proc_id()
136 .world_id()
137 .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc"))?;
138 let proc_id = world_id.proc_id(rank);
139 let actor_id = ActorId::root(proc_id, self_id.name().to_string());
140 Ok(ActorRef::<CommActor>::attest(actor_id))
141 }
142 Self::ImplicitWithWorldId(world_id) => {
143 let proc_id = world_id.proc_id(rank);
144 let actor_id = ActorId::root(proc_id, self_id.name().to_string());
145 Ok(ActorRef::<CommActor>::attest(actor_id))
146 }
147 }
148 }
149
150 fn self_rank(&self, self_id: &ActorId) -> Result<usize> {
152 match self {
153 Self::Mesh(rank, _) => Ok(*rank),
154 Self::Implicit | Self::ImplicitWithWorldId(_) => self_id
155 .proc_id()
156 .rank()
157 .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc")),
158 }
159 }
160}
161
162#[async_trait]
163impl Actor for CommActor {
164 type Params = CommActorParams;
165
166 async fn new(_params: Self::Params) -> Result<Self> {
167 Ok(Self {
168 send_seq: HashMap::new(),
169 recv_state: HashMap::new(),
170 mode: Default::default(),
171 })
172 }
173
174 async fn handle_undeliverable_message(
176 &mut self,
177 cx: &Instance<Self>,
178 undelivered: hyperactor::mailbox::Undeliverable<hyperactor::mailbox::MessageEnvelope>,
179 ) -> Result<(), anyhow::Error> {
180 let Undeliverable(mut message_envelope) = undelivered;
181
182 if let Ok(ForwardMessage { message, .. }) =
184 message_envelope.deserialized::<ForwardMessage>()
185 {
186 let sender = message.sender();
187 let return_port = PortRef::attest_message_port(sender);
188 return_port
189 .send(cx, Undeliverable(message_envelope.clone()))
190 .map_err(|err| {
191 message_envelope
192 .try_set_error(DeliveryError::BrokenLink(format!("send failure: {err}")));
193 UndeliverableMessageError::return_failure(&message_envelope)
194 })?;
195 return Ok(());
196 }
197
198 if let Some(sender) = message_envelope.headers().get(CAST_ORIGINATING_SENDER) {
200 let return_port = PortRef::attest_message_port(sender);
201 return_port
202 .send(cx, Undeliverable(message_envelope.clone()))
203 .map_err(|err| {
204 message_envelope
205 .try_set_error(DeliveryError::BrokenLink(format!("send failure: {err}")));
206 UndeliverableMessageError::return_failure(&message_envelope)
207 })?;
208 return Ok(());
209 }
210
211 UndeliverableMailboxSender
213 .post(message_envelope, monitored_return_handle());
214 Ok(())
215 }
216}
217
218impl CommActor {
219 fn forward(
221 cx: &Instance<Self>,
222 mode: &CommActorMode,
223 rank: usize,
224 message: ForwardMessage,
225 ) -> Result<()> {
226 let child = mode.peer_for_rank(cx.self_id(), rank)?;
227 child.send(cx, message)?;
228 Ok(())
229 }
230
231 fn handle_message(
232 cx: &Context<Self>,
233 mode: &CommActorMode,
234 deliver_here: bool,
235 next_steps: HashMap<usize, Vec<RoutingFrame>>,
236 sender: ActorId,
237 mut message: CastMessageEnvelope,
238 seq: usize,
239 last_seqs: &mut HashMap<usize, usize>,
240 ) -> Result<()> {
241 message
245 .data_mut()
246 .visit_mut::<UnboundPort>(|UnboundPort(port_id, reducer_spec)| {
247 let split = port_id.split(cx, reducer_spec.clone())?;
248
249 #[cfg(test)]
250 tests::collect_split_port(port_id, &split, deliver_here);
251
252 *port_id = split;
253 Ok(())
254 })?;
255
256 if deliver_here {
258 let rank_on_root_mesh = mode.self_rank(cx.self_id())?;
259 let cast_rank = message.relative_rank(rank_on_root_mesh)?;
260 let cast_shape = message.shape();
261 let mut headers = cx.headers().clone();
262 set_cast_info_on_headers(
263 &mut headers,
264 cast_rank,
265 cast_shape.clone(),
266 message.sender().clone(),
267 );
268 cx.post(
269 cx.self_id()
270 .proc_id()
271 .actor_id(message.dest_port().actor_name(), 0)
272 .port_id(message.dest_port().port()),
273 headers,
274 Serialized::serialize(message.data())?,
275 );
276 }
277
278 next_steps
280 .into_iter()
281 .map(|(peer, dests)| {
282 let last_seq = last_seqs.entry(peer).or_default();
283 Self::forward(
284 cx,
285 mode,
286 peer,
287 ForwardMessage {
288 dests,
289 sender: sender.clone(),
290 message: message.clone(),
291 seq,
292 last_seq: *last_seq,
293 },
294 )?;
295 *last_seq = seq;
296 Ok(())
297 })
298 .collect::<Result<Vec<_>>>()?;
299
300 Ok(())
301 }
302}
303
304#[async_trait]
305impl Handler<CommActorMode> for CommActor {
306 async fn handle(&mut self, _cx: &Context<Self>, mode: CommActorMode) -> Result<()> {
307 self.mode = mode;
308 Ok(())
309 }
310}
311
312#[async_trait]
314impl Handler<CastMessage> for CommActor {
315 async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessage) -> Result<()> {
316 let slice = cast_message.dest.slice.clone();
318 let selection = cast_message.dest.selection.clone();
319 let frame = RoutingFrame::root(selection, slice);
320 let rank = frame.slice.location(&frame.here)?;
321 let seq = self
322 .send_seq
323 .entry(cast_message.message.stream_key())
324 .or_default();
325 let last_seq = *seq;
326 *seq += 1;
327 Self::forward(
328 cx,
329 &self.mode,
330 rank,
331 ForwardMessage {
332 dests: vec![frame],
333 sender: cx.self_id().clone(),
334 message: cast_message.message,
335 seq: *seq,
336 last_seq,
337 },
338 )?;
339 Ok(())
340 }
341}
342
343#[async_trait]
344impl Handler<ForwardMessage> for CommActor {
345 async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessage) -> Result<()> {
346 let ForwardMessage {
347 sender,
348 dests,
349 message,
350 seq,
351 last_seq,
352 } = fwd_message;
353
354 let rank = self.mode.self_rank(cx.self_id())?;
356 let (deliver_here, next_steps) =
357 ndslice::selection::routing::resolve_routing(rank, dests, &mut |_| {
358 panic!("Choice encountered in CommActor routing")
359 })?;
360
361 let recv_state = self.recv_state.entry(message.stream_key()).or_default();
362 match recv_state.seq.cmp(&last_seq) {
363 Ordering::Equal => {
365 Self::handle_message(
367 cx,
368 &self.mode,
369 deliver_here,
370 next_steps,
371 sender.clone(),
372 message,
373 seq,
374 &mut recv_state.last_seqs,
375 )?;
376 recv_state.seq = seq;
377
378 while let Some(Buffered {
381 seq,
382 deliver_here,
383 next_steps,
384 message,
385 }) = recv_state.buffer.remove(&recv_state.seq)
386 {
387 Self::handle_message(
388 cx,
389 &self.mode,
390 deliver_here,
391 next_steps,
392 sender.clone(),
393 message,
394 seq,
395 &mut recv_state.last_seqs,
396 )?;
397 recv_state.seq = seq;
398 }
399 }
400 Ordering::Less => {
403 tracing::warn!(
404 "buffering out-of-order message with seq {} (last {}), expected {}: {:?}",
405 seq,
406 last_seq,
407 recv_state.seq,
408 message
409 );
410 recv_state.buffer.insert(
411 last_seq,
412 Buffered {
413 seq,
414 deliver_here,
415 next_steps,
416 message,
417 },
418 );
419 }
420 Ordering::Greater => {
422 tracing::warn!("received duplicate message with seq {}: {:?}", seq, message);
423 }
424 }
425
426 Ok(())
427 }
428}
429
430pub mod test_utils {
431 use anyhow::Result;
432 use async_trait::async_trait;
433 use hyperactor::Actor;
434 use hyperactor::ActorId;
435 use hyperactor::Bind;
436 use hyperactor::Context;
437 use hyperactor::Handler;
438 use hyperactor::Named;
439 use hyperactor::PortRef;
440 use hyperactor::Unbind;
441 use serde::Deserialize;
442 use serde::Serialize;
443
444 use super::*;
445
446 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
447 pub struct MyReply {
448 pub sender: ActorId,
449 pub value: u64,
450 }
451
452 #[derive(Debug, Named, Serialize, Deserialize, PartialEq, Clone, Bind, Unbind)]
453 pub enum TestMessage {
454 Forward(String),
455 CastAndReply {
456 arg: String,
457 reply_to0: PortRef<String>,
461 #[binding(include)]
462 reply_to1: PortRef<u64>,
463 #[binding(include)]
464 reply_to2: PortRef<MyReply>,
465 },
466 }
467
468 #[derive(Debug)]
469 #[hyperactor::export(
470 spawn = true,
471 handlers = [
472 TestMessage { cast = true },
473 ],
474 )]
475 pub struct TestActor {
476 forward_port: PortRef<TestMessage>,
479 }
480
481 #[derive(Debug, Clone, Named, Serialize, Deserialize)]
482 pub struct TestActorParams {
483 pub forward_port: PortRef<TestMessage>,
484 }
485
486 #[async_trait]
487 impl Actor for TestActor {
488 type Params = TestActorParams;
489
490 async fn new(params: Self::Params) -> Result<Self> {
491 let Self::Params { forward_port } = params;
492 Ok(Self { forward_port })
493 }
494 }
495
496 #[async_trait]
497 impl Handler<TestMessage> for TestActor {
498 async fn handle(&mut self, cx: &Context<Self>, msg: TestMessage) -> anyhow::Result<()> {
499 self.forward_port.send(cx, msg)?;
500 Ok(())
501 }
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use std::collections::BTreeMap;
508 use std::collections::HashSet;
509 use std::fmt::Display;
510 use std::hash::Hash;
511 use std::ops::DerefMut;
512 use std::sync::Arc;
513 use std::sync::Mutex;
514 use std::sync::OnceLock;
515
516 use hyperactor::PortId;
517 use hyperactor::PortRef;
518 use hyperactor::accum;
519 use hyperactor::accum::Accumulator;
520 use hyperactor::accum::ReducerSpec;
521 use hyperactor::clock::Clock;
522 use hyperactor::clock::RealClock;
523 use hyperactor::config;
524 use hyperactor::mailbox::PortReceiver;
525 use hyperactor::mailbox::open_port;
526 use hyperactor::reference::Index;
527 use hyperactor_mesh_macros::sel;
528 use maplit::btreemap;
529 use maplit::hashmap;
530 use ndslice::Selection;
531 use ndslice::extent;
532 use ndslice::selection::test_utils::collect_commactor_routing_tree;
533 use test_utils::*;
534 use timed_test::async_timed_test;
535 use tokio::time::Duration;
536
537 use super::*;
538 use crate::ProcMesh;
539 use crate::actor_mesh::ActorMesh;
540 use crate::actor_mesh::RootActorMesh;
541 use crate::alloc::AllocSpec;
542 use crate::alloc::Allocator;
543 use crate::alloc::LocalAllocator;
544 use crate::proc_mesh::SharedSpawnable;
545
546 struct Edge<T> {
547 from: T,
548 to: T,
549 is_leaf: bool,
550 }
551
552 impl<T> From<(T, T, bool)> for Edge<T> {
553 fn from((from, to, is_leaf): (T, T, bool)) -> Self {
554 Self { from, to, is_leaf }
555 }
556 }
557
558 static SPLIT_PORT_TREE: OnceLock<Mutex<Vec<Edge<PortId>>>> = OnceLock::new();
561
562 pub(crate) fn collect_split_port(original: &PortId, split: &PortId, deliver_here: bool) {
565 let mutex = SPLIT_PORT_TREE.get_or_init(|| Mutex::new(vec![]));
566 let mut tree = mutex.lock().unwrap();
567
568 tree.deref_mut().push(Edge {
569 from: original.clone(),
570 to: split.clone(),
571 is_leaf: deliver_here,
572 });
573 }
574
575 #[derive(PartialEq)]
579 struct PathToLeaves<T>(BTreeMap<T, Vec<T>>);
580
581 impl<T: Display> Debug for PathToLeaves<T> {
583 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
584 fn vec_to_string<T: Display>(v: &[T]) -> String {
585 v.iter()
586 .map(ToString::to_string)
587 .collect::<Vec<String>>()
588 .join(", ")
589 }
590
591 for (src, path) in &self.0 {
592 write!(f, "{} -> {}\n", src, vec_to_string(path))?;
593 }
594 Ok(())
595 }
596 }
597
598 fn build_paths<T: Clone + Eq + Hash + Ord>(edges: &[Edge<T>]) -> PathToLeaves<T> {
599 let mut child_parent_map = HashMap::new();
600 let mut all_nodes = HashSet::new();
601 let mut parents = HashSet::new();
602 let mut children = HashSet::new();
603 let mut dests = HashSet::new();
604
605 for Edge { from, to, is_leaf } in edges {
607 child_parent_map.insert(to.clone(), from.clone());
608 all_nodes.insert(from.clone());
609 all_nodes.insert(to.clone());
610 parents.insert(from.clone());
611 children.insert(to.clone());
612 if *is_leaf {
613 dests.insert(to.clone());
614 }
615 }
616
617 let mut result = BTreeMap::new();
619 for dest in dests {
620 let mut path = vec![dest.clone()];
621 let mut current = dest.clone();
622 while let Some(parent) = child_parent_map.get(¤t) {
623 path.push(parent.clone());
624 current = parent.clone();
625 }
626 path.reverse();
627 result.insert(dest, path);
628 }
629
630 PathToLeaves(result)
631 }
632
633 #[test]
634 fn test_build_paths() {
635 let edges: Vec<_> = [
642 (0, 1, false),
643 (1, 2, true),
644 (1, 3, true),
645 (0, 4, true),
646 (4, 5, true),
647 ]
648 .into_iter()
649 .map(|(from, to, is_leaf)| Edge { from, to, is_leaf })
650 .collect();
651
652 let paths = build_paths(&edges);
653
654 let expected = btreemap! {
655 2 => vec![0, 1, 2],
656 3 => vec![0, 1, 3],
657 4 => vec![0, 4],
658 5 => vec![0, 4, 5],
659 };
660
661 assert_eq!(paths.0, expected);
662 }
663
664 fn get_ranks(paths: PathToLeaves<PortId>, client_reply: &PortId) -> PathToLeaves<Index> {
684 let ranks = paths
685 .0
686 .into_iter()
687 .map(|(dst, mut path)| {
688 let first = path.remove(0);
689 assert_eq!(&first, client_reply);
691 assert_eq!(dst.actor_id().name(), "comm");
694 let actor_path = path
695 .into_iter()
696 .map(|p| {
697 assert_eq!(p.actor_id().name(), "comm");
698 p.actor_id().rank()
699 })
700 .collect();
701 (dst.into_actor_id().rank(), actor_path)
702 })
703 .collect();
704 PathToLeaves(ranks)
705 }
706
707 struct MeshSetup {
708 actor_mesh: RootActorMesh<'static, TestActor>,
709 reply1_rx: PortReceiver<u64>,
710 reply2_rx: PortReceiver<MyReply>,
711 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
712 }
713
714 struct NoneAccumulator;
715
716 impl Accumulator for NoneAccumulator {
717 type State = u64;
718 type Update = u64;
719
720 fn accumulate(
721 &self,
722 _state: &mut Self::State,
723 _update: Self::Update,
724 ) -> anyhow::Result<()> {
725 unimplemented!()
726 }
727
728 fn reducer_spec(&self) -> Option<ReducerSpec> {
729 unimplemented!()
730 }
731 }
732
733 async fn setup_mesh<A>(accum: Option<A>) -> MeshSetup
734 where
735 A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
736 {
737 let extent = extent!(replica = 4, host = 4, gpu = 4);
738 let alloc = LocalAllocator
739 .allocate(AllocSpec {
740 extent: extent.clone(),
741 constraints: Default::default(),
742 })
743 .await
744 .unwrap();
745
746 let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap());
747 let dest_actor_name = "dest_actor";
748 let (tx, mut rx) = hyperactor::mailbox::open_port(proc_mesh.client());
749 let params = TestActorParams {
750 forward_port: tx.bind(),
751 };
752 let actor_mesh = proc_mesh
753 .clone()
754 .spawn::<TestActor>(dest_actor_name, ¶ms)
755 .await
756 .unwrap();
757
758 let (reply_port_handle0, _) = open_port::<String>(proc_mesh.client());
759 let reply_port_ref0 = reply_port_handle0.bind();
760 let (reply_port_handle1, reply1_rx) = match accum {
761 Some(a) => proc_mesh.client().open_accum_port(a),
762 None => open_port(proc_mesh.client()),
763 };
764 let reply_port_ref1 = reply_port_handle1.bind();
765 let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(proc_mesh.client());
766 let reply_port_ref2 = reply_port_handle2.bind();
767 let message = TestMessage::CastAndReply {
768 arg: "abc".to_string(),
769 reply_to0: reply_port_ref0.clone(),
770 reply_to1: reply_port_ref1.clone(),
771 reply_to2: reply_port_ref2.clone(),
772 };
773
774 let selection = sel!(*);
775 actor_mesh
776 .cast(proc_mesh.client(), selection.clone(), message)
777 .unwrap();
778
779 let mut reply_tos = vec![];
780 for _ in extent.points() {
781 let msg = rx.recv().await.expect("missing");
782 match msg {
783 TestMessage::CastAndReply {
784 arg,
785 reply_to0,
786 reply_to1,
787 reply_to2,
788 } => {
789 assert_eq!(arg, "abc");
790 assert_eq!(reply_to0, reply_port_ref0);
793 assert_ne!(reply_to1, reply_port_ref1);
795 assert_eq!(reply_to1.port_id().actor_id().name(), "comm");
796 assert_ne!(reply_to2, reply_port_ref2);
797 assert_eq!(reply_to2.port_id().actor_id().name(), "comm");
798 reply_tos.push((reply_to1, reply_to2));
799 }
800 _ => {
801 panic!("unexpected message: {:?}", msg);
802 }
803 }
804 }
805
806 {
808 let sel_paths = PathToLeaves(
810 collect_commactor_routing_tree(&selection, &extent.to_slice())
811 .delivered
812 .into_iter()
813 .collect(),
814 );
815
816 let (reply1_paths, reply2_paths) = {
818 let tree = SPLIT_PORT_TREE.get().unwrap();
819 let edges = tree.lock().unwrap();
820 let (reply1, reply2): (BTreeMap<_, _>, BTreeMap<_, _>) = build_paths(&edges)
821 .0
822 .into_iter()
823 .partition(|(_dst, path)| &path[0] == reply_port_ref1.port_id());
824 (
825 get_ranks(PathToLeaves(reply1), reply_port_ref1.port_id()),
826 get_ranks(PathToLeaves(reply2), reply_port_ref2.port_id()),
827 )
828 };
829
830 assert_eq!(sel_paths, reply1_paths);
832 assert_eq!(sel_paths, reply2_paths);
833 }
834
835 MeshSetup {
836 actor_mesh,
837 reply1_rx,
838 reply2_rx,
839 reply_tos,
840 }
841 }
842
843 #[async_timed_test(timeout_secs = 30)]
844 async fn test_cast_and_reply() {
845 let MeshSetup {
846 actor_mesh,
847 mut reply1_rx,
848 mut reply2_rx,
849 reply_tos,
850 ..
851 } = setup_mesh::<NoneAccumulator>(None).await;
852 let proc_mesh_client = actor_mesh.proc_mesh().client();
853
854 {
856 for (dest_actor, (reply_to1, reply_to2)) in
857 actor_mesh.ranks.iter().zip(reply_tos.iter())
858 {
859 let rank = dest_actor.actor_id().rank() as u64;
860 reply_to1.send(proc_mesh_client, rank).unwrap();
861 let my_reply = MyReply {
862 sender: dest_actor.actor_id().clone(),
863 value: rank,
864 };
865 reply_to2.send(proc_mesh_client, my_reply.clone()).unwrap();
866
867 assert_eq!(reply1_rx.recv().await.unwrap(), rank);
868 assert_eq!(reply2_rx.recv().await.unwrap(), my_reply);
869 }
870 }
871
872 tracing::info!("the 1st updates from all dest actors were receivered by client");
873
874 {
878 let n = 100;
879 let mut expected2: HashMap<usize, Vec<MyReply>> = hashmap! {};
880 for (dest_actor, (_reply_to1, reply_to2)) in
881 actor_mesh.ranks.iter().zip(reply_tos.iter())
882 {
883 let rank = dest_actor.actor_id().rank();
884 let mut sent2 = vec![];
885 for i in 0..n {
886 let value = (rank * 100 + i) as u64;
887 let my_reply = MyReply {
888 sender: dest_actor.actor_id().clone(),
889 value,
890 };
891 reply_to2.send(proc_mesh_client, my_reply.clone()).unwrap();
892 sent2.push(my_reply);
893 }
894 assert!(
895 expected2.insert(rank, sent2).is_none(),
896 "duplicate rank {rank} in map"
897 );
898 }
899
900 let mut received2: HashMap<usize, Vec<MyReply>> = hashmap! {};
901
902 for _ in 0..(n * actor_mesh.ranks.len()) {
903 let my_reply = reply2_rx.recv().await.unwrap();
904 received2
905 .entry(my_reply.sender.rank())
906 .or_default()
907 .push(my_reply);
908 }
909 assert_eq!(received2, expected2);
910 }
911 }
912
913 async fn wait_for_with_timeout(
914 receiver: &mut PortReceiver<u64>,
915 expected: u64,
916 dur: Duration,
917 ) -> anyhow::Result<()> {
918 RealClock
920 .timeout(dur, async {
921 loop {
922 let msg = receiver.recv().await.unwrap();
923 if msg == expected {
924 break;
925 }
926 }
927 })
928 .await?;
929 Ok(())
930 }
931
932 #[async_timed_test(timeout_secs = 30)]
933 async fn test_cast_and_accum() -> Result<()> {
934 let config = config::global::lock();
935 let _guard1 = config.override_key(config::SPLIT_MAX_BUFFER_SIZE, 1);
937
938 let MeshSetup {
939 actor_mesh,
940 mut reply1_rx,
941 reply_tos,
942 ..
943 } = setup_mesh(Some(accum::sum::<u64>())).await;
944 let proc_mesh_client = actor_mesh.proc_mesh().client();
945
946 {
950 let mut sum = 0;
951 let n = 100;
952 for (dest_actor, (reply_to1, _reply_to2)) in
953 actor_mesh.ranks.iter().zip(reply_tos.iter())
954 {
955 let rank = dest_actor.actor_id().rank();
956 for i in 0..n {
957 let value = (rank + i) as u64;
958 reply_to1.send(proc_mesh_client, value).unwrap();
959 sum += value;
960 }
961 }
962 wait_for_with_timeout(&mut reply1_rx, sum, Duration::from_secs(2))
963 .await
964 .unwrap();
965 RealClock.sleep(Duration::from_secs(2)).await;
967 let msg = reply1_rx.try_recv().unwrap();
968 assert_eq!(msg, None);
969 }
970 Ok(())
971 }
972}