1use crate::comm::multicast::CAST_ORIGINATING_SENDER;
10use crate::reference::ActorMeshId;
11use crate::resource;
12pub mod multicast;
13
14use std::cmp::Ordering;
15use std::collections::HashMap;
16use std::fmt::Debug;
17
18use anyhow::Result;
19use async_trait::async_trait;
20use hyperactor::Actor;
21use hyperactor::ActorId;
22use hyperactor::ActorRef;
23use hyperactor::Context;
24use hyperactor::Handler;
25use hyperactor::Instance;
26use hyperactor::Named;
27use hyperactor::PortRef;
28use hyperactor::WorldId;
29use hyperactor::data::Serialized;
30use hyperactor::mailbox::DeliveryError;
31use hyperactor::mailbox::MailboxSender;
32use hyperactor::mailbox::Undeliverable;
33use hyperactor::mailbox::UndeliverableMailboxSender;
34use hyperactor::mailbox::UndeliverableMessageError;
35use hyperactor::mailbox::monitored_return_handle;
36use hyperactor::reference::UnboundPort;
37use ndslice::selection::routing::RoutingFrame;
38use serde::Deserialize;
39use serde::Serialize;
40
41use crate::comm::multicast::CastMessage;
42use crate::comm::multicast::CastMessageEnvelope;
43use crate::comm::multicast::ForwardMessage;
44use crate::comm::multicast::set_cast_info_on_headers;
45
46#[derive(Debug, Clone, Serialize, Deserialize, Named, Default)]
48pub struct CommActorParams {}
49
50#[derive(Debug)]
52struct Buffered {
53 seq: usize,
55 deliver_here: bool,
57 next_steps: HashMap<usize, Vec<RoutingFrame>>,
59 message: CastMessageEnvelope,
61}
62
63#[derive(Debug, Default)]
66struct ReceiveState {
67 seq: usize,
69 buffer: HashMap<usize, Buffered>,
72 last_seqs: HashMap<usize, usize>,
74}
75
76#[derive(Debug)]
79#[hyperactor::export(
80 spawn = true,
81 handlers = [
82 CommActorMode,
83 CastMessage,
84 ForwardMessage,
85 ],
86)]
87pub struct CommActor {
88 send_seq: HashMap<(ActorMeshId, ActorId), usize>,
90 recv_state: HashMap<(ActorMeshId, ActorId), ReceiveState>,
92
93 mode: CommActorMode,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, Named)]
102pub enum CommActorMode {
103 Mesh(usize, HashMap<usize, ActorRef<CommActor>>),
106
107 Implicit,
110
111 ImplicitWithWorldId(WorldId),
117}
118
119impl Default for CommActorMode {
120 fn default() -> Self {
121 Self::Implicit
122 }
123}
124
125impl CommActorMode {
126 fn peer_for_rank(&self, self_id: &ActorId, rank: usize) -> Result<ActorRef<CommActor>> {
129 match self {
130 Self::Mesh(_self_rank, peers) => peers
131 .get(&rank)
132 .cloned()
133 .ok_or_else(|| anyhow::anyhow!("no peer for rank {}", rank)),
134 Self::Implicit => {
135 let world_id = self_id
136 .proc_id()
137 .world_id()
138 .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc"))?;
139 let proc_id = world_id.proc_id(rank);
140 let actor_id = ActorId::root(proc_id, self_id.name().to_string());
141 Ok(ActorRef::<CommActor>::attest(actor_id))
142 }
143 Self::ImplicitWithWorldId(world_id) => {
144 let proc_id = world_id.proc_id(rank);
145 let actor_id = ActorId::root(proc_id, self_id.name().to_string());
146 Ok(ActorRef::<CommActor>::attest(actor_id))
147 }
148 }
149 }
150
151 fn self_rank(&self, self_id: &ActorId) -> Result<usize> {
153 match self {
154 Self::Mesh(rank, _) => Ok(*rank),
155 Self::Implicit | Self::ImplicitWithWorldId(_) => self_id
156 .proc_id()
157 .rank()
158 .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc")),
159 }
160 }
161}
162
163#[async_trait]
164impl Actor for CommActor {
165 type Params = CommActorParams;
166
167 async fn new(_params: Self::Params) -> Result<Self> {
168 Ok(Self {
169 send_seq: HashMap::new(),
170 recv_state: HashMap::new(),
171 mode: Default::default(),
172 })
173 }
174
175 async fn handle_undeliverable_message(
177 &mut self,
178 cx: &Instance<Self>,
179 undelivered: hyperactor::mailbox::Undeliverable<hyperactor::mailbox::MessageEnvelope>,
180 ) -> Result<(), anyhow::Error> {
181 let Undeliverable(mut message_envelope) = undelivered;
182
183 if let Ok(ForwardMessage { message, .. }) =
185 message_envelope.deserialized::<ForwardMessage>()
186 {
187 let sender = message.sender();
188 let return_port = PortRef::attest_message_port(sender);
189 message_envelope.set_error(DeliveryError::Multicast(format!(
190 "comm actor {} failed to forward the cast message; return to \
191 its original sender's port {}",
192 cx.self_id(),
193 return_port.port_id(),
194 )));
195 return_port
196 .send(cx, Undeliverable(message_envelope.clone()))
197 .map_err(|err| {
198 let error = DeliveryError::BrokenLink(format!(
199 "error occured when returning ForwardMessage to the original \
200 sender's port {}; error is: {}",
201 return_port.port_id(),
202 err,
203 ));
204 message_envelope.set_error(error);
205 UndeliverableMessageError::ReturnFailure {
206 envelope: message_envelope,
207 }
208 })?;
209 return Ok(());
210 }
211
212 if let Some(sender) = message_envelope.headers().get(CAST_ORIGINATING_SENDER) {
214 let return_port = PortRef::attest_message_port(sender);
215 message_envelope.set_error(DeliveryError::Multicast(format!(
216 "comm actor {} failed to deliver the cast message to the dest \
217 actor; return to its original sender's port {}",
218 cx.self_id(),
219 return_port.port_id(),
220 )));
221 return_port
222 .send(cx, Undeliverable(message_envelope.clone()))
223 .map_err(|err| {
224 let error = DeliveryError::BrokenLink(format!(
225 "error occured when returning cast message to the original \
226 sender's port {}; error is: {}",
227 return_port.port_id(),
228 err,
229 ));
230 message_envelope.set_error(error);
231 UndeliverableMessageError::ReturnFailure {
232 envelope: message_envelope,
233 }
234 })?;
235 return Ok(());
236 }
237
238 UndeliverableMailboxSender
240 .post(message_envelope, monitored_return_handle());
241 Ok(())
242 }
243}
244
245impl CommActor {
246 fn forward(
248 cx: &Instance<Self>,
249 mode: &CommActorMode,
250 rank: usize,
251 message: ForwardMessage,
252 ) -> Result<()> {
253 let child = mode.peer_for_rank(cx.self_id(), rank)?;
254 child.send(cx, message)?;
255 Ok(())
256 }
257
258 fn handle_message(
259 cx: &Context<Self>,
260 mode: &CommActorMode,
261 deliver_here: bool,
262 next_steps: HashMap<usize, Vec<RoutingFrame>>,
263 sender: ActorId,
264 mut message: CastMessageEnvelope,
265 seq: usize,
266 last_seqs: &mut HashMap<usize, usize>,
267 ) -> Result<()> {
268 message.data_mut().visit_mut::<UnboundPort>(
272 |UnboundPort(port_id, reducer_spec, reducer_opts, return_undeliverable)| {
273 let split = port_id.split(
274 cx,
275 reducer_spec.clone(),
276 reducer_opts.clone(),
277 *return_undeliverable,
278 )?;
279
280 #[cfg(test)]
281 tests::collect_split_port(port_id, &split, deliver_here);
282
283 *port_id = split;
284 Ok(())
285 },
286 )?;
287
288 if deliver_here {
290 let rank_on_root_mesh = mode.self_rank(cx.self_id())?;
291 let cast_rank = message.relative_rank(rank_on_root_mesh)?;
292 message
294 .data_mut()
295 .visit_mut::<resource::Rank>(|resource::Rank(rank)| {
296 *rank = Some(cast_rank);
297 Ok(())
298 })?;
299 let cast_shape = message.shape();
300 let point = cast_shape
301 .extent()
302 .point_of_rank(cast_rank)
303 .expect("rank out of bounds");
304 let mut headers = cx.headers().clone();
305 set_cast_info_on_headers(&mut headers, point, message.sender().clone());
306 cx.post(
307 cx.self_id()
308 .proc_id()
309 .actor_id(message.dest_port().actor_name(), 0)
310 .port_id(message.dest_port().port()),
311 headers,
312 Serialized::serialize(message.data())?,
313 );
314 }
315
316 next_steps
318 .into_iter()
319 .map(|(peer, dests)| {
320 let last_seq = last_seqs.entry(peer).or_default();
321 Self::forward(
322 cx,
323 mode,
324 peer,
325 ForwardMessage {
326 dests,
327 sender: sender.clone(),
328 message: message.clone(),
329 seq,
330 last_seq: *last_seq,
331 },
332 )?;
333 *last_seq = seq;
334 Ok(())
335 })
336 .collect::<Result<Vec<_>>>()?;
337
338 Ok(())
339 }
340}
341
342#[async_trait]
343impl Handler<CommActorMode> for CommActor {
344 async fn handle(&mut self, _cx: &Context<Self>, mode: CommActorMode) -> Result<()> {
345 self.mode = mode;
346 Ok(())
347 }
348}
349
350#[async_trait]
352impl Handler<CastMessage> for CommActor {
353 async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessage) -> Result<()> {
354 let slice = cast_message.dest.slice.clone();
356 let selection = cast_message.dest.selection.clone();
357 let frame = RoutingFrame::root(selection, slice);
358 let rank = frame.slice.location(&frame.here)?;
359 let seq = self
360 .send_seq
361 .entry(cast_message.message.stream_key())
362 .or_default();
363 let last_seq = *seq;
364 *seq += 1;
365 Self::forward(
366 cx,
367 &self.mode,
368 rank,
369 ForwardMessage {
370 dests: vec![frame],
371 sender: cx.self_id().clone(),
372 message: cast_message.message,
373 seq: *seq,
374 last_seq,
375 },
376 )?;
377 Ok(())
378 }
379}
380
381#[async_trait]
382impl Handler<ForwardMessage> for CommActor {
383 async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessage) -> Result<()> {
384 let ForwardMessage {
385 sender,
386 dests,
387 message,
388 seq,
389 last_seq,
390 } = fwd_message;
391
392 let rank = self.mode.self_rank(cx.self_id())?;
394 let (deliver_here, next_steps) =
395 ndslice::selection::routing::resolve_routing(rank, dests, &mut |_| {
396 panic!("Choice encountered in CommActor routing")
397 })?;
398
399 let recv_state = self.recv_state.entry(message.stream_key()).or_default();
400 match recv_state.seq.cmp(&last_seq) {
401 Ordering::Equal => {
403 Self::handle_message(
405 cx,
406 &self.mode,
407 deliver_here,
408 next_steps,
409 sender.clone(),
410 message,
411 seq,
412 &mut recv_state.last_seqs,
413 )?;
414 recv_state.seq = seq;
415
416 while let Some(Buffered {
419 seq,
420 deliver_here,
421 next_steps,
422 message,
423 }) = recv_state.buffer.remove(&recv_state.seq)
424 {
425 Self::handle_message(
426 cx,
427 &self.mode,
428 deliver_here,
429 next_steps,
430 sender.clone(),
431 message,
432 seq,
433 &mut recv_state.last_seqs,
434 )?;
435 recv_state.seq = seq;
436 }
437 }
438 Ordering::Less => {
441 tracing::warn!(
442 "buffering out-of-order message with seq {} (last {}), expected {}: {:?}",
443 seq,
444 last_seq,
445 recv_state.seq,
446 message
447 );
448 recv_state.buffer.insert(
449 last_seq,
450 Buffered {
451 seq,
452 deliver_here,
453 next_steps,
454 message,
455 },
456 );
457 }
458 Ordering::Greater => {
460 tracing::warn!("received duplicate message with seq {}: {:?}", seq, message);
461 }
462 }
463
464 Ok(())
465 }
466}
467
468pub mod test_utils {
469 use anyhow::Result;
470 use async_trait::async_trait;
471 use hyperactor::Actor;
472 use hyperactor::ActorId;
473 use hyperactor::Bind;
474 use hyperactor::Context;
475 use hyperactor::Handler;
476 use hyperactor::Named;
477 use hyperactor::PortRef;
478 use hyperactor::Unbind;
479 use serde::Deserialize;
480 use serde::Serialize;
481
482 use super::*;
483
484 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
485 pub struct MyReply {
486 pub sender: ActorId,
487 pub value: u64,
488 }
489
490 #[derive(Debug, Named, Serialize, Deserialize, PartialEq, Clone, Bind, Unbind)]
491 pub enum TestMessage {
492 Forward(String),
493 CastAndReply {
494 arg: String,
495 reply_to0: PortRef<String>,
499 #[binding(include)]
500 reply_to1: PortRef<u64>,
501 #[binding(include)]
502 reply_to2: PortRef<MyReply>,
503 },
504 }
505
506 #[derive(Debug)]
507 #[hyperactor::export(
508 spawn = true,
509 handlers = [
510 TestMessage { cast = true },
511 ],
512 )]
513 pub struct TestActor {
514 forward_port: PortRef<TestMessage>,
517 }
518
519 #[derive(Debug, Clone, Named, Serialize, Deserialize)]
520 pub struct TestActorParams {
521 pub forward_port: PortRef<TestMessage>,
522 }
523
524 #[async_trait]
525 impl Actor for TestActor {
526 type Params = TestActorParams;
527
528 async fn new(params: Self::Params) -> Result<Self> {
529 let Self::Params { forward_port } = params;
530 Ok(Self { forward_port })
531 }
532 }
533
534 #[async_trait]
535 impl Handler<TestMessage> for TestActor {
536 async fn handle(&mut self, cx: &Context<Self>, msg: TestMessage) -> anyhow::Result<()> {
537 self.forward_port.send(cx, msg)?;
538 Ok(())
539 }
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use std::collections::BTreeMap;
546 use std::collections::HashSet;
547 use std::fmt::Display;
548 use std::hash::Hash;
549 use std::ops::Deref;
550 use std::ops::DerefMut;
551 use std::sync::Arc;
552 use std::sync::Mutex;
553 use std::sync::OnceLock;
554
555 use hyperactor::PortId;
556 use hyperactor::PortRef;
557 use hyperactor::accum;
558 use hyperactor::accum::Accumulator;
559 use hyperactor::accum::ReducerSpec;
560 use hyperactor::channel::ChannelTransport;
561 use hyperactor::clock::Clock;
562 use hyperactor::clock::RealClock;
563 use hyperactor::config;
564 use hyperactor::context::Mailbox;
565 use hyperactor::mailbox::PortReceiver;
566 use hyperactor::mailbox::open_port;
567 use hyperactor::reference::Index;
568 use hyperactor_mesh_macros::sel;
569 use maplit::btreemap;
570 use maplit::hashmap;
571 use ndslice::Extent;
572 use ndslice::Selection;
573 use ndslice::ViewExt as _;
574 use ndslice::extent;
575 use ndslice::selection::test_utils::collect_commactor_routing_tree;
576 use test_utils::*;
577 use timed_test::async_timed_test;
578 use tokio::time::Duration;
579
580 use super::*;
581 use crate::ProcMesh;
582 use crate::actor_mesh::ActorMesh;
583 use crate::actor_mesh::RootActorMesh;
584 use crate::alloc::AllocSpec;
585 use crate::alloc::Allocator;
586 use crate::alloc::LocalAllocator;
587 use crate::proc_mesh::SharedSpawnable;
588 use crate::v1;
589
590 struct Edge<T> {
591 from: T,
592 to: T,
593 is_leaf: bool,
594 }
595
596 impl<T> From<(T, T, bool)> for Edge<T> {
597 fn from((from, to, is_leaf): (T, T, bool)) -> Self {
598 Self { from, to, is_leaf }
599 }
600 }
601
602 static SPLIT_PORT_TREE: OnceLock<Mutex<Vec<Edge<PortId>>>> = OnceLock::new();
605
606 pub(crate) fn collect_split_port(original: &PortId, split: &PortId, deliver_here: bool) {
609 let mutex = SPLIT_PORT_TREE.get_or_init(|| Mutex::new(vec![]));
610 let mut tree = mutex.lock().unwrap();
611
612 tree.deref_mut().push(Edge {
613 from: original.clone(),
614 to: split.clone(),
615 is_leaf: deliver_here,
616 });
617 }
618
619 fn clear_collected_tree() {
623 if let Some(tree) = SPLIT_PORT_TREE.get() {
624 let mut tree = tree.lock().unwrap();
625 tree.clear();
626 }
627 }
628
629 #[derive(PartialEq)]
633 struct PathToLeaves<T>(BTreeMap<T, Vec<T>>);
634
635 impl<T: Display> Debug for PathToLeaves<T> {
637 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
638 fn vec_to_string<T: Display>(v: &[T]) -> String {
639 v.iter()
640 .map(ToString::to_string)
641 .collect::<Vec<String>>()
642 .join(", ")
643 }
644
645 for (src, path) in &self.0 {
646 write!(f, "{} -> {}\n", src, vec_to_string(path))?;
647 }
648 Ok(())
649 }
650 }
651
652 fn build_paths<T: Clone + Eq + Hash + Ord>(edges: &[Edge<T>]) -> PathToLeaves<T> {
653 let mut child_parent_map = HashMap::new();
654 let mut all_nodes = HashSet::new();
655 let mut parents = HashSet::new();
656 let mut children = HashSet::new();
657 let mut dests = HashSet::new();
658
659 for Edge { from, to, is_leaf } in edges {
661 child_parent_map.insert(to.clone(), from.clone());
662 all_nodes.insert(from.clone());
663 all_nodes.insert(to.clone());
664 parents.insert(from.clone());
665 children.insert(to.clone());
666 if *is_leaf {
667 dests.insert(to.clone());
668 }
669 }
670
671 let mut result = BTreeMap::new();
673 for dest in dests {
674 let mut path = vec![dest.clone()];
675 let mut current = dest.clone();
676 while let Some(parent) = child_parent_map.get(¤t) {
677 path.push(parent.clone());
678 current = parent.clone();
679 }
680 path.reverse();
681 result.insert(dest, path);
682 }
683
684 PathToLeaves(result)
685 }
686
687 #[test]
688 fn test_build_paths() {
689 let edges: Vec<_> = [
696 (0, 1, false),
697 (1, 2, true),
698 (1, 3, true),
699 (0, 4, true),
700 (4, 5, true),
701 ]
702 .into_iter()
703 .map(|(from, to, is_leaf)| Edge { from, to, is_leaf })
704 .collect();
705
706 let paths = build_paths(&edges);
707
708 let expected = btreemap! {
709 2 => vec![0, 1, 2],
710 3 => vec![0, 1, 3],
711 4 => vec![0, 4],
712 5 => vec![0, 4, 5],
713 };
714
715 assert_eq!(paths.0, expected);
716 }
717
718 fn get_ranks(paths: PathToLeaves<PortId>, client_reply: &PortId) -> PathToLeaves<Index> {
738 let ranks = paths
739 .0
740 .into_iter()
741 .map(|(dst, mut path)| {
742 let first = path.remove(0);
743 assert_eq!(&first, client_reply);
745 assert!(dst.actor_id().name().contains("comm"));
748 let actor_path = path
749 .into_iter()
750 .map(|p| {
751 assert!(p.actor_id().name().contains("comm"));
752 p.actor_id().rank()
753 })
754 .collect();
755 (dst.into_actor_id().rank(), actor_path)
756 })
757 .collect();
758 PathToLeaves(ranks)
759 }
760
761 struct MeshSetup {
762 actor_mesh: RootActorMesh<'static, TestActor>,
763 reply1_rx: PortReceiver<u64>,
764 reply2_rx: PortReceiver<MyReply>,
765 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
766 }
767
768 struct NoneAccumulator;
769
770 impl Accumulator for NoneAccumulator {
771 type State = u64;
772 type Update = u64;
773
774 fn accumulate(
775 &self,
776 _state: &mut Self::State,
777 _update: Self::Update,
778 ) -> anyhow::Result<()> {
779 unimplemented!()
780 }
781
782 fn reducer_spec(&self) -> Option<ReducerSpec> {
783 unimplemented!()
784 }
785 }
786
787 fn verify_split_port_paths(
789 selection: &Selection,
790 extent: &Extent,
791 reply_port_ref1: &PortRef<u64>,
792 reply_port_ref2: &PortRef<MyReply>,
793 ) {
794 let sel_paths = PathToLeaves(
796 collect_commactor_routing_tree(selection, &extent.to_slice())
797 .delivered
798 .into_iter()
799 .collect(),
800 );
801
802 let (reply1_paths, reply2_paths) = {
804 let tree = SPLIT_PORT_TREE.get().unwrap();
805 let edges = tree.lock().unwrap();
806 let (reply1, reply2): (BTreeMap<_, _>, BTreeMap<_, _>) = build_paths(&edges)
807 .0
808 .into_iter()
809 .partition(|(_dst, path)| &path[0] == reply_port_ref1.port_id());
810 (
811 get_ranks(PathToLeaves(reply1), reply_port_ref1.port_id()),
812 get_ranks(PathToLeaves(reply2), reply_port_ref2.port_id()),
813 )
814 };
815
816 assert_eq!(sel_paths, reply1_paths);
818 assert_eq!(sel_paths, reply2_paths);
819 }
820
821 async fn setup_mesh<A>(accum: Option<A>) -> MeshSetup
822 where
823 A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
824 {
825 let extent = extent!(replica = 4, host = 4, gpu = 4);
826 let alloc = LocalAllocator
827 .allocate(AllocSpec {
828 extent: extent.clone(),
829 constraints: Default::default(),
830 proc_name: None,
831 transport: ChannelTransport::Local,
832 proc_allocation_mode: Default::default(),
833 })
834 .await
835 .unwrap();
836
837 let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap());
838 let dest_actor_name = "dest_actor";
839 let (tx, mut rx) = hyperactor::mailbox::open_port(proc_mesh.client());
840 let params = TestActorParams {
841 forward_port: tx.bind(),
842 };
843 let instance = crate::v1::testing::instance().await;
844 let actor_mesh = Arc::clone(&proc_mesh)
845 .spawn::<TestActor>(&instance, dest_actor_name, ¶ms)
846 .await
847 .unwrap();
848
849 let (reply_port_handle0, _) = open_port::<String>(proc_mesh.client());
850 let reply_port_ref0 = reply_port_handle0.bind();
851 let (reply_port_handle1, reply1_rx) = match accum {
852 Some(a) => proc_mesh.client().mailbox().open_accum_port(a),
853 None => open_port(proc_mesh.client()),
854 };
855 let reply_port_ref1 = reply_port_handle1.bind();
856 let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(proc_mesh.client());
857 let reply_port_ref2 = reply_port_handle2.bind();
858 let message = TestMessage::CastAndReply {
859 arg: "abc".to_string(),
860 reply_to0: reply_port_ref0.clone(),
861 reply_to1: reply_port_ref1.clone(),
862 reply_to2: reply_port_ref2.clone(),
863 };
864
865 let selection = sel!(*);
866 clear_collected_tree();
867 actor_mesh
868 .cast(proc_mesh.client(), selection.clone(), message)
869 .unwrap();
870
871 let mut reply_tos = vec![];
872 for _ in extent.points() {
873 let msg = rx.recv().await.expect("missing");
874 match msg {
875 TestMessage::CastAndReply {
876 arg,
877 reply_to0,
878 reply_to1,
879 reply_to2,
880 } => {
881 assert_eq!(arg, "abc");
882 assert_eq!(reply_to0, reply_port_ref0);
885 assert_ne!(reply_to1, reply_port_ref1);
887 assert_eq!(reply_to1.port_id().actor_id().name(), "comm");
888 assert_ne!(reply_to2, reply_port_ref2);
889 assert_eq!(reply_to2.port_id().actor_id().name(), "comm");
890 reply_tos.push((reply_to1, reply_to2));
891 }
892 _ => {
893 panic!("unexpected message: {:?}", msg);
894 }
895 }
896 }
897
898 verify_split_port_paths(&selection, &extent, &reply_port_ref1, &reply_port_ref2);
899
900 MeshSetup {
901 actor_mesh,
902 reply1_rx,
903 reply2_rx,
904 reply_tos,
905 }
906 }
907
908 async fn execute_cast_and_reply(
909 ranks: Vec<ActorRef<TestActor>>,
910 instance: &Instance<()>,
911 mut reply1_rx: PortReceiver<u64>,
912 mut reply2_rx: PortReceiver<MyReply>,
913 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
914 ) {
915 {
917 for (dest_actor, (reply_to1, reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
918 let rank = dest_actor.actor_id().rank() as u64;
919 reply_to1.send(instance, rank).unwrap();
920 let my_reply = MyReply {
921 sender: dest_actor.actor_id().clone(),
922 value: rank,
923 };
924 reply_to2.send(instance, my_reply.clone()).unwrap();
925
926 assert_eq!(reply1_rx.recv().await.unwrap(), rank);
927 assert_eq!(reply2_rx.recv().await.unwrap(), my_reply);
928 }
929 }
930
931 tracing::info!("the 1st updates from all dest actors were receivered by client");
932
933 {
937 let n = 100;
938 let mut expected2: HashMap<usize, Vec<MyReply>> = hashmap! {};
939 for (dest_actor, (_reply_to1, reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
940 let rank = dest_actor.actor_id().rank();
941 let mut sent2 = vec![];
942 for i in 0..n {
943 let value = (rank * 100 + i) as u64;
944 let my_reply = MyReply {
945 sender: dest_actor.actor_id().clone(),
946 value,
947 };
948 reply_to2.send(instance, my_reply.clone()).unwrap();
949 sent2.push(my_reply);
950 }
951 assert!(
952 expected2.insert(rank, sent2).is_none(),
953 "duplicate rank {rank} in map"
954 );
955 }
956
957 let mut received2: HashMap<usize, Vec<MyReply>> = hashmap! {};
958
959 for _ in 0..(n * ranks.len()) {
960 let my_reply = reply2_rx.recv().await.unwrap();
961 received2
962 .entry(my_reply.sender.rank())
963 .or_default()
964 .push(my_reply);
965 }
966 assert_eq!(received2, expected2);
967 }
968 }
969
970 #[async_timed_test(timeout_secs = 30)]
971 async fn test_cast_and_reply() {
972 let MeshSetup {
973 actor_mesh,
974 reply1_rx,
975 reply2_rx,
976 reply_tos,
977 ..
978 } = setup_mesh::<NoneAccumulator>(None).await;
979 let proc_mesh_client = actor_mesh.proc_mesh().client();
980
981 let ranks = actor_mesh.ranks().clone();
982 execute_cast_and_reply(ranks, proc_mesh_client, reply1_rx, reply2_rx, reply_tos).await;
983 }
984
985 async fn wait_for_with_timeout(
986 receiver: &mut PortReceiver<u64>,
987 expected: u64,
988 dur: Duration,
989 ) -> anyhow::Result<()> {
990 RealClock
992 .timeout(dur, async {
993 loop {
994 let msg = receiver.recv().await.unwrap();
995 if msg == expected {
996 break;
997 }
998 }
999 })
1000 .await?;
1001 Ok(())
1002 }
1003
1004 async fn execute_cast_and_accum(
1005 ranks: Vec<ActorRef<TestActor>>,
1006 instance: &Instance<()>,
1007 mut reply1_rx: PortReceiver<u64>,
1008 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1009 ) {
1010 let mut sum = 0;
1014 let n = 100;
1015 for (dest_actor, (reply_to1, _reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
1016 let rank = dest_actor.actor_id().rank();
1017 for i in 0..n {
1018 let value = (rank + i) as u64;
1019 reply_to1.send(instance, value).unwrap();
1020 sum += value;
1021 }
1022 }
1023 wait_for_with_timeout(&mut reply1_rx, sum, Duration::from_secs(2))
1024 .await
1025 .unwrap();
1026 RealClock.sleep(Duration::from_secs(2)).await;
1028 let msg = reply1_rx.try_recv().unwrap();
1029 assert_eq!(msg, None);
1030 }
1031
1032 #[async_timed_test(timeout_secs = 30)]
1033 async fn test_cast_and_accum() {
1034 let config = config::global::lock();
1035 let _guard1 = config.override_key(config::SPLIT_MAX_BUFFER_SIZE, 1);
1037
1038 let MeshSetup {
1039 actor_mesh,
1040 reply1_rx,
1041 reply_tos,
1042 ..
1043 } = setup_mesh(Some(accum::sum::<u64>())).await;
1044 let proc_mesh_client = actor_mesh.proc_mesh().client();
1045 let ranks = actor_mesh.ranks().clone();
1046 execute_cast_and_accum(ranks, proc_mesh_client, reply1_rx, reply_tos).await;
1047 }
1048
1049 struct MeshSetupV1 {
1050 instance: &'static Instance<()>,
1051 actor_mesh_ref: v1::ActorMeshRef<TestActor>,
1052 reply1_rx: PortReceiver<u64>,
1053 reply2_rx: PortReceiver<MyReply>,
1054 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1055 }
1056
1057 async fn setup_mesh_v1<A>(accum: Option<A>) -> MeshSetupV1
1058 where
1059 A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
1060 {
1061 let instance = v1::testing::instance().await;
1062
1063 let extent = extent!(replica = 4, host = 4, gpu = 4);
1064 let alloc = LocalAllocator
1065 .allocate(AllocSpec {
1066 extent: extent.clone(),
1067 constraints: Default::default(),
1068 proc_name: None,
1069 transport: ChannelTransport::Local,
1070 proc_allocation_mode: Default::default(),
1071 })
1072 .await
1073 .unwrap();
1074
1075 let proc_mesh = v1::ProcMesh::allocate(instance, Box::new(alloc), "test.local")
1076 .await
1077 .unwrap();
1078
1079 let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1080 let params = TestActorParams {
1081 forward_port: tx.bind(),
1082 };
1083 let actor_mesh = proc_mesh.spawn(&instance, "test", ¶ms).await.unwrap();
1084 let actor_mesh_ref = actor_mesh.deref().clone();
1085
1086 let (reply_port_handle0, _) = open_port::<String>(instance);
1087 let reply_port_ref0 = reply_port_handle0.bind();
1088 let (reply_port_handle1, reply1_rx) = match accum {
1089 Some(a) => instance.mailbox().open_accum_port(a),
1090 None => open_port(instance),
1091 };
1092 let reply_port_ref1 = reply_port_handle1.bind();
1093 let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(instance);
1094 let reply_port_ref2 = reply_port_handle2.bind();
1095 let message = TestMessage::CastAndReply {
1096 arg: "abc".to_string(),
1097 reply_to0: reply_port_ref0.clone(),
1098 reply_to1: reply_port_ref1.clone(),
1099 reply_to2: reply_port_ref2.clone(),
1100 };
1101
1102 clear_collected_tree();
1103 actor_mesh_ref.cast(instance, message).unwrap();
1104
1105 let mut reply_tos = vec![];
1106 for _ in extent.points() {
1107 let msg = rx.recv().await.expect("missing");
1108 match msg {
1109 TestMessage::CastAndReply {
1110 arg,
1111 reply_to0,
1112 reply_to1,
1113 reply_to2,
1114 } => {
1115 assert_eq!(arg, "abc");
1116 assert_eq!(reply_to0, reply_port_ref0);
1119 assert_ne!(reply_to1, reply_port_ref1);
1121 assert!(reply_to1.port_id().actor_id().name().contains("comm"));
1122 assert_ne!(reply_to2, reply_port_ref2);
1123 assert!(reply_to2.port_id().actor_id().name().contains("comm"));
1124 reply_tos.push((reply_to1, reply_to2));
1125 }
1126 _ => {
1127 panic!("unexpected message: {:?}", msg);
1128 }
1129 }
1130 }
1131
1132 let selection = sel!(*);
1134 verify_split_port_paths(&selection, &extent, &reply_port_ref1, &reply_port_ref2);
1135
1136 MeshSetupV1 {
1137 instance,
1138 actor_mesh_ref,
1139 reply1_rx,
1140 reply2_rx,
1141 reply_tos,
1142 }
1143 }
1144
1145 #[async_timed_test(timeout_secs = 30)]
1146 async fn test_cast_and_reply_v1() {
1147 let MeshSetupV1 {
1148 instance,
1149 actor_mesh_ref,
1150 reply1_rx,
1151 reply2_rx,
1152 reply_tos,
1153 ..
1154 } = setup_mesh_v1::<NoneAccumulator>(None).await;
1155
1156 let ranks = actor_mesh_ref.values().collect::<Vec<_>>();
1157 execute_cast_and_reply(ranks, instance, reply1_rx, reply2_rx, reply_tos).await;
1158 }
1159
1160 #[async_timed_test(timeout_secs = 30)]
1161 async fn test_cast_and_accum_v1() {
1162 let config = config::global::lock();
1163 let _guard1 = config.override_key(config::SPLIT_MAX_BUFFER_SIZE, 1);
1165
1166 let MeshSetupV1 {
1167 instance,
1168 actor_mesh_ref,
1169 reply1_rx,
1170 reply_tos,
1171 ..
1172 } = setup_mesh_v1(Some(accum::sum::<u64>())).await;
1173
1174 let ranks = actor_mesh_ref.values().collect::<Vec<_>>();
1175 execute_cast_and_accum(ranks, instance, reply1_rx, reply_tos).await;
1176 }
1177}