1use crate::casting::CAST_ACTOR_MESH_ID;
10use crate::comm::multicast::CAST_ORIGINATING_SENDER;
11use crate::comm::multicast::CastEnvelope;
12use crate::comm::multicast::CastMessageV1;
13use crate::comm::multicast::ForwardMessageV1;
14use crate::reference::ActorMeshId;
15use crate::resource;
16pub mod multicast;
17
18use std::cmp::Ordering;
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22use anyhow::Result;
23use async_trait::async_trait;
24use hyperactor::Actor;
25use hyperactor::Context;
26use hyperactor::Handler;
27use hyperactor::Instance;
28use hyperactor::RemoteMessage;
29use hyperactor::accum::ReducerMode;
30use hyperactor::mailbox::DeliveryError;
31use hyperactor::mailbox::MailboxSender;
32use hyperactor::mailbox::Undeliverable;
33use hyperactor::mailbox::UndeliverableMailboxSender;
34use hyperactor::mailbox::UndeliverableMessageError;
35use hyperactor::mailbox::monitored_return_handle;
36use hyperactor::message::ErasedUnbound;
37use hyperactor::ordering::SEQ_INFO;
38use hyperactor::ordering::SeqInfo;
39use hyperactor::reference as hyperactor_reference;
40use hyperactor_config::CONFIG;
41use hyperactor_config::ConfigAttr;
42use hyperactor_config::Flattrs;
43use hyperactor_config::attrs::declare_attrs;
44use hyperactor_mesh_macros::sel;
45use ndslice::Point;
46use ndslice::Selection;
47use ndslice::View;
48use ndslice::selection::routing::RoutingFrame;
49use serde::Deserialize;
50use serde::Serialize;
51use typeuri::Named;
52
53use crate::comm::multicast::CastMessage;
54use crate::comm::multicast::CastMessageEnvelope;
55use crate::comm::multicast::ForwardMessage;
56use crate::comm::multicast::set_cast_info_on_headers;
57
58declare_attrs! {
59 @meta(CONFIG = ConfigAttr::new(
61 Some("HYPERACTOR_MESH_ENABLE_NATIVE_V1_CASTING".to_string()),
62 Some("enable_native_v1_casting".to_string()),
63 ))
64 pub attr ENABLE_NATIVE_V1_CASTING: bool = false;
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize, Named, Default)]
69pub struct CommActorParams {}
70wirevalue::register_type!(CommActorParams);
71
72#[derive(Debug)]
74struct Buffered {
75 seq: usize,
77 deliver_here: bool,
79 next_steps: HashMap<usize, Vec<RoutingFrame>>,
81 message: CastMessageEnvelope,
83}
84
85#[derive(Debug, Default)]
88struct ReceiveState {
89 seq: usize,
91 buffer: HashMap<usize, Buffered>,
94 last_seqs: HashMap<usize, usize>,
96}
97
98#[derive(Debug, Default)]
101#[hyperactor::export(
102 spawn = true,
103 handlers = [
104 CommMeshConfig,
105 CastMessage,
106 ForwardMessage,
107 CastMessageV1,
108 ForwardMessageV1,
109 ],
110)]
111pub struct CommActor {
112 send_seq: HashMap<(ActorMeshId, hyperactor_reference::ActorId), usize>,
114 recv_state: HashMap<(ActorMeshId, hyperactor_reference::ActorId), ReceiveState>,
116
117 mesh_config: Option<CommMeshConfig>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize, Named)]
123pub struct CommMeshConfig {
124 rank: usize,
126 peers: HashMap<usize, hyperactor_reference::ActorRef<CommActor>>,
128}
129wirevalue::register_type!(CommMeshConfig);
130
131impl CommMeshConfig {
132 pub fn new(
134 rank: usize,
135 peers: HashMap<usize, hyperactor_reference::ActorRef<CommActor>>,
136 ) -> Self {
137 Self { rank, peers }
138 }
139
140 fn peer_for_rank(&self, rank: usize) -> Result<hyperactor_reference::ActorRef<CommActor>> {
142 self.peers
143 .get(&rank)
144 .cloned()
145 .ok_or_else(|| anyhow::anyhow!("no peer for rank {}", rank))
146 }
147
148 fn self_rank(&self) -> usize {
150 self.rank
151 }
152}
153
154#[async_trait]
155impl Actor for CommActor {
156 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
157 this.set_system();
158 Ok(())
159 }
160
161 async fn handle_undeliverable_message(
163 &mut self,
164 cx: &Instance<Self>,
165 undelivered: hyperactor::mailbox::Undeliverable<hyperactor::mailbox::MessageEnvelope>,
166 ) -> Result<(), anyhow::Error> {
167 let Undeliverable(mut message_envelope) = undelivered;
168
169 if let Ok(ForwardMessage { message, .. }) =
171 message_envelope.deserialized::<ForwardMessage>()
172 {
173 let sender = message.sender();
174 let return_port = hyperactor_reference::PortRef::attest_message_port(sender);
175 message_envelope.set_error(DeliveryError::Multicast(format!(
176 "comm actor {} failed to forward the cast message; returning to origin {}",
177 cx.self_id(),
178 return_port.port_id(),
179 )));
180
181 message_envelope.set_header(CAST_ORIGINATING_SENDER, sender.clone());
184
185 return_port
186 .send(cx, Undeliverable(message_envelope.clone()))
187 .map_err(|err| {
188 let error = DeliveryError::BrokenLink(format!(
189 "error occured when returning ForwardMessage to the original \
190 sender's port {}; error is: {}",
191 return_port.port_id(),
192 err,
193 ));
194 message_envelope.set_error(error);
195 UndeliverableMessageError::ReturnFailure {
196 envelope: message_envelope,
197 }
198 })?;
199 return Ok(());
200 }
201
202 if let Some(sender) = message_envelope.headers().get(CAST_ORIGINATING_SENDER) {
204 let return_port = hyperactor_reference::PortRef::attest_message_port(&sender);
205 message_envelope.set_error(DeliveryError::Multicast(format!(
206 "comm actor {} failed to deliver the cast message to the dest \
207 actor; returning to origin {}",
208 cx.self_id(),
209 return_port.port_id(),
210 )));
211 return_port
212 .send(cx, Undeliverable(message_envelope.clone()))
213 .map_err(|err| {
214 let error = DeliveryError::BrokenLink(format!(
215 "error occured when returning cast message to the origin \
216 sender {}; error is: {}",
217 return_port.port_id(),
218 err,
219 ));
220 message_envelope.set_error(error);
221 UndeliverableMessageError::ReturnFailure {
222 envelope: message_envelope,
223 }
224 })?;
225 return Ok(());
226 }
227
228 UndeliverableMailboxSender
230 .post(message_envelope, monitored_return_handle());
231 Ok(())
232 }
233}
234
235impl CommActor {
236 fn forward<M: RemoteMessage>(
238 cx: &Context<Self>,
239 config: &CommMeshConfig,
240 rank: usize,
241 message: M,
242 ) -> Result<()>
243 where
244 CommActor: hyperactor::RemoteHandles<M>,
245 {
246 let child = config.peer_for_rank(rank)?;
247 if let Some(cast_actor_mesh_id) = cx.headers().get(CAST_ACTOR_MESH_ID) {
249 let mut headers = Flattrs::new();
250 headers.set(CAST_ACTOR_MESH_ID, cast_actor_mesh_id);
251 child.send_with_headers(cx, headers, message)?;
252 } else {
253 child.send(cx, message)?;
254 }
255 Ok(())
256 }
257
258 fn handle_message(
259 cx: &Context<Self>,
260 config: &CommMeshConfig,
261 deliver_here: bool,
262 next_steps: HashMap<usize, Vec<RoutingFrame>>,
263 sender: hyperactor_reference::ActorId,
264 mut message: CastMessageEnvelope,
265 seq: usize,
266 last_seqs: &mut HashMap<usize, usize>,
267 ) -> Result<()> {
268 split_ports(cx, message.data_mut(), deliver_here, &next_steps)?;
269
270 if deliver_here {
272 let headers = message.headers().clone();
276 Self::deliver_to_dest(cx, headers, &mut message, config)?;
277 }
278
279 next_steps
281 .into_iter()
282 .map(|(peer, dests)| {
283 let last_seq = last_seqs.entry(peer).or_default();
284 Self::forward(
285 cx,
286 config,
287 peer,
288 ForwardMessage {
289 dests,
290 sender: sender.clone(),
291 message: message.clone(),
292 seq,
293 last_seq: *last_seq,
294 },
295 )?;
296 *last_seq = seq;
297 Ok(())
298 })
299 .collect::<Result<Vec<_>>>()?;
300
301 Ok(())
302 }
303
304 fn deliver_to_dest<M: CastEnvelope>(
305 cx: &Context<Self>,
306 mut headers: Flattrs,
307 message: &mut M,
308 config: &CommMeshConfig,
309 ) -> anyhow::Result<()> {
310 let cast_point = message.cast_point(config)?;
311 replace_with_self_ranks(&cast_point, message.data_mut())?;
313
314 set_cast_info_on_headers(&mut headers, cast_point, message.sender().clone());
315 cx.post_with_external_seq_info(
316 cx.self_id()
317 .proc_id()
318 .actor_id(message.dest_port().actor_name(), 0)
319 .port_id(message.dest_port().port()),
320 headers,
321 wirevalue::Any::serialize(message.data())?,
322 );
323
324 Ok(())
325 }
326}
327
328fn split_ports(
332 cx: &Context<CommActor>,
333 data: &mut ErasedUnbound,
334 deliver_here: bool,
335 next_steps: &HashMap<usize, Vec<RoutingFrame>>,
336) -> anyhow::Result<()> {
337 data.visit_mut::<hyperactor_reference::UnboundPort>(
341 |hyperactor_reference::UnboundPort(port_id, reducer_spec, return_undeliverable, kind)| {
342 let reducer_mode = match kind {
343 hyperactor_reference::UnboundPortKind::Streaming(opts) => {
344 ReducerMode::Streaming(opts.clone().unwrap_or_default())
345 }
346 hyperactor_reference::UnboundPortKind::Once if reducer_spec.is_none() => {
347 return Ok(());
354 }
355 hyperactor_reference::UnboundPortKind::Once => {
356 let peer_count = next_steps.len() + if deliver_here { 1 } else { 0 };
360 ReducerMode::Once(peer_count)
361 }
362 };
363
364 let split = port_id.split(
365 cx,
366 reducer_spec.clone(),
367 reducer_mode,
368 *return_undeliverable,
369 )?;
370
371 #[cfg(test)]
372 tests::collect_split_port(port_id, &split, deliver_here);
373
374 *port_id = split;
375 Ok(())
376 },
377 )
378}
379
380fn replace_with_self_ranks(cast_point: &Point, data: &mut ErasedUnbound) -> anyhow::Result<()> {
381 data.visit_mut::<resource::Rank>(|resource::Rank(rank)| {
382 *rank = Some(cast_point.rank());
383 Ok(())
384 })
385}
386
387#[async_trait]
388impl Handler<CommMeshConfig> for CommActor {
389 async fn handle(&mut self, _cx: &Context<Self>, config: CommMeshConfig) -> Result<()> {
390 self.mesh_config = Some(config);
391 Ok(())
392 }
393}
394
395#[async_trait]
397impl Handler<CastMessage> for CommActor {
398 #[tracing::instrument(level = "debug", skip_all)]
399 async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessage) -> Result<()> {
400 let slice = cast_message.dest.slice.clone();
402 let selection = cast_message.dest.selection.clone();
403 let frame = RoutingFrame::root(selection, slice);
404 let rank = frame.slice.location(&frame.here)?;
405 let seq = self
406 .send_seq
407 .entry(cast_message.message.stream_key())
408 .or_default();
409 let last_seq = *seq;
410 *seq += 1;
411
412 let fwd_message = ForwardMessage {
413 dests: vec![frame],
414 sender: cx.self_id().clone(),
415 message: cast_message.message,
416 seq: *seq,
417 last_seq,
418 };
419
420 let config = self
421 .mesh_config
422 .as_ref()
423 .ok_or_else(|| anyhow::anyhow!("CommMeshConfig has not been set yet"))?;
424
425 if config.self_rank() == rank {
428 Handler::<ForwardMessage>::handle(self, cx, fwd_message).await?;
429 } else {
430 Self::forward(cx, config, rank, fwd_message)?;
431 }
432 Ok(())
433 }
434}
435
436#[async_trait]
437impl Handler<ForwardMessage> for CommActor {
438 #[tracing::instrument(level = "debug", skip_all)]
439 async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessage) -> Result<()> {
440 let ForwardMessage {
441 sender,
442 dests,
443 message,
444 seq,
445 last_seq,
446 } = fwd_message;
447
448 let config = self
449 .mesh_config
450 .as_ref()
451 .ok_or_else(|| anyhow::anyhow!("CommMeshConfig has not been set yet"))?;
452
453 let rank = config.self_rank();
455 let (deliver_here, next_steps) =
456 ndslice::selection::routing::resolve_routing(rank, dests, &mut |_| {
457 panic!("Choice encountered in CommActor routing")
458 })?;
459
460 let recv_state = self.recv_state.entry(message.stream_key()).or_default();
461 match recv_state.seq.cmp(&last_seq) {
462 Ordering::Equal => {
464 Self::handle_message(
466 cx,
467 config,
468 deliver_here,
469 next_steps,
470 sender.clone(),
471 message,
472 seq,
473 &mut recv_state.last_seqs,
474 )?;
475 recv_state.seq = seq;
476
477 while let Some(Buffered {
480 seq,
481 deliver_here,
482 next_steps,
483 message,
484 }) = recv_state.buffer.remove(&recv_state.seq)
485 {
486 Self::handle_message(
487 cx,
488 config,
489 deliver_here,
490 next_steps,
491 sender.clone(),
492 message,
493 seq,
494 &mut recv_state.last_seqs,
495 )?;
496 recv_state.seq = seq;
497 }
498 }
499 Ordering::Less => {
502 tracing::warn!(
503 "buffering out-of-order message with seq {} (last {}), expected {}: {:?}",
504 seq,
505 last_seq,
506 recv_state.seq,
507 message
508 );
509 recv_state.buffer.insert(
510 last_seq,
511 Buffered {
512 seq,
513 deliver_here,
514 next_steps,
515 message,
516 },
517 );
518 }
519 Ordering::Greater => {
521 tracing::warn!("received duplicate message with seq {}: {:?}", seq, message);
522 }
523 }
524
525 Ok(())
526 }
527}
528
529#[async_trait]
530impl Handler<CastMessageV1> for CommActor {
531 async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessageV1) -> Result<()> {
532 let slice = cast_message.dest_region.slice().clone();
533 let frame = RoutingFrame::root(sel!(*), slice);
534 let forward_message = ForwardMessageV1 {
535 dests: vec![frame],
536 message: cast_message,
537 };
538 self.handle(cx, forward_message).await
539 }
540}
541
542#[async_trait]
543impl Handler<ForwardMessageV1> for CommActor {
544 async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessageV1) -> Result<()> {
545 let ForwardMessageV1 { dests, mut message } = fwd_message;
546 let config = self
547 .mesh_config
548 .as_ref()
549 .ok_or_else(|| anyhow::anyhow!("CommMeshConfig has not been set yet"))?;
550 let rank_on_root_mesh = config.self_rank();
552 let (deliver_here, next_steps) =
553 ndslice::selection::routing::resolve_routing(rank_on_root_mesh, dests, &mut |_| {
554 panic!("choice encountered in CommActor routing")
555 })?;
556
557 split_ports(cx, &mut message.data, deliver_here, &next_steps)?;
558
559 if deliver_here {
561 let mut headers = message.headers().clone();
562 let seq = message
563 .seqs
564 .get(message.cast_point(config)?.rank())
565 .expect("mismatched seqs and dest_region");
566 headers.set(
567 SEQ_INFO,
568 SeqInfo::Session {
569 session_id: message.session_id,
570 seq,
571 },
572 );
573 Self::deliver_to_dest(cx, headers, &mut message, config)?;
574 }
575
576 for (peer_rank_on_root_mesh, dests) in next_steps {
578 let forward_message = ForwardMessageV1 {
579 dests,
580 message: message.clone(),
581 };
582 Self::forward(cx, config, peer_rank_on_root_mesh, forward_message)?;
583 }
584
585 Ok(())
586 }
587}
588
589pub mod test_utils {
590 use anyhow::Result;
591 use async_trait::async_trait;
592 use hyperactor::Actor;
593 use hyperactor::Bind;
594 use hyperactor::Context;
595 use hyperactor::Handler;
596 use hyperactor::Unbind;
597 use hyperactor::reference as hyperactor_reference;
598 use serde::Deserialize;
599 use serde::Serialize;
600 use typeuri::Named;
601
602 use super::*;
603
604 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
605 pub struct MyReply {
606 pub sender: hyperactor_reference::ActorId,
607 pub value: u64,
608 }
609
610 #[derive(Debug, Named, Serialize, Deserialize, PartialEq, Clone, Bind, Unbind)]
611 pub enum TestMessage {
612 Forward(String),
613 CastAndReply {
614 arg: String,
615 reply_to0: hyperactor_reference::PortRef<String>,
619 #[binding(include)]
620 reply_to1: hyperactor_reference::PortRef<u64>,
621 #[binding(include)]
622 reply_to2: hyperactor_reference::PortRef<MyReply>,
623 },
624 CastAndReplyOnce {
625 arg: String,
626 #[binding(include)]
627 reply_to: hyperactor::reference::OncePortRef<u64>,
628 },
629 }
630
631 #[derive(Debug)]
632 #[hyperactor::export(
633 spawn = true,
634 handlers = [
635 TestMessage { cast = true },
636 ],
637 )]
638 pub struct TestActor {
639 forward_port: hyperactor_reference::PortRef<TestMessage>,
642 }
643
644 #[derive(Debug, Clone, Named, Serialize, Deserialize)]
645 pub struct TestActorParams {
646 pub forward_port: hyperactor_reference::PortRef<TestMessage>,
647 }
648
649 #[async_trait]
650 impl Actor for TestActor {}
651
652 #[async_trait]
653 impl hyperactor::RemoteSpawn for TestActor {
654 type Params = TestActorParams;
655
656 async fn new(params: Self::Params, _environment: Flattrs) -> Result<Self> {
657 let Self::Params { forward_port } = params;
658 Ok(Self { forward_port })
659 }
660 }
661
662 #[async_trait]
663 impl Handler<TestMessage> for TestActor {
664 async fn handle(&mut self, cx: &Context<Self>, msg: TestMessage) -> anyhow::Result<()> {
665 self.forward_port.send(cx, msg)?;
666 Ok(())
667 }
668 }
669}
670
671#[cfg(test)]
672mod tests {
673 use std::collections::BTreeMap;
674 use std::collections::HashSet;
675 use std::fmt::Display;
676 use std::hash::Hash;
677 use std::ops::Deref;
678 use std::ops::DerefMut;
679 use std::sync::Mutex;
680 use std::sync::OnceLock;
681
682 use hyperactor::accum;
683 use hyperactor::accum::Accumulator;
684 use hyperactor::accum::ReducerSpec;
685 use hyperactor::context;
686 use hyperactor::context::Mailbox;
687 use hyperactor::mailbox::PortReceiver;
688 use hyperactor::mailbox::open_port;
689 use hyperactor::reference as hyperactor_reference;
690 use hyperactor_config;
691 use hyperactor_mesh_macros::sel;
692 use maplit::btreemap;
693 use maplit::hashmap;
694 use ndslice::Extent;
695 use ndslice::Selection;
696 use ndslice::ViewExt as _;
697 use ndslice::extent;
698 use ndslice::selection::test_utils::collect_commactor_routing_tree;
699 use test_utils::*;
700 use timed_test::async_timed_test;
701 use tokio::time::Duration;
702
703 use super::*;
704 use crate::host_mesh::HostMesh;
705 use crate::test_utils::local_host_mesh;
706 use crate::testing;
707
708 fn lookup_rank(
710 actor_id: &hyperactor::reference::ActorId,
711 rank_lookup: &HashMap<hyperactor_reference::ProcId, usize>,
712 ) -> usize {
713 let proc_id = actor_id.proc_id();
714 *rank_lookup
715 .get(proc_id)
716 .unwrap_or_else(|| panic!("proc rank not found for {}", proc_id))
717 }
718
719 struct Edge<T> {
720 from: T,
721 to: T,
722 is_leaf: bool,
723 }
724
725 impl<T> From<(T, T, bool)> for Edge<T> {
726 fn from((from, to, is_leaf): (T, T, bool)) -> Self {
727 Self { from, to, is_leaf }
728 }
729 }
730
731 static SPLIT_PORT_TREE: OnceLock<Mutex<Vec<Edge<hyperactor_reference::PortId>>>> =
734 OnceLock::new();
735
736 pub(crate) fn collect_split_port(
739 original: &hyperactor_reference::PortId,
740 split: &hyperactor_reference::PortId,
741 deliver_here: bool,
742 ) {
743 let mutex = SPLIT_PORT_TREE.get_or_init(|| Mutex::new(vec![]));
744 let mut tree = mutex.lock().unwrap();
745
746 tree.deref_mut().push(Edge {
747 from: original.clone(),
748 to: split.clone(),
749 is_leaf: deliver_here,
750 });
751 }
752
753 fn clear_collected_tree() {
757 if let Some(tree) = SPLIT_PORT_TREE.get() {
758 let mut tree = tree.lock().unwrap();
759 tree.clear();
760 }
761 }
762
763 #[derive(PartialEq)]
767 struct PathToLeaves<T>(BTreeMap<T, Vec<T>>);
768
769 impl<T: Display> Debug for PathToLeaves<T> {
771 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
772 fn vec_to_string<T: Display>(v: &[T]) -> String {
773 v.iter()
774 .map(ToString::to_string)
775 .collect::<Vec<String>>()
776 .join(", ")
777 }
778
779 for (src, path) in &self.0 {
780 write!(f, "{} -> {}\n", src, vec_to_string(path))?;
781 }
782 Ok(())
783 }
784 }
785
786 fn build_paths<T: Clone + Eq + Hash + Ord>(edges: &[Edge<T>]) -> PathToLeaves<T> {
787 let mut child_parent_map = HashMap::new();
788 let mut all_nodes = HashSet::new();
789 let mut parents = HashSet::new();
790 let mut children = HashSet::new();
791 let mut dests = HashSet::new();
792
793 for Edge { from, to, is_leaf } in edges {
795 child_parent_map.insert(to.clone(), from.clone());
796 all_nodes.insert(from.clone());
797 all_nodes.insert(to.clone());
798 parents.insert(from.clone());
799 children.insert(to.clone());
800 if *is_leaf {
801 dests.insert(to.clone());
802 }
803 }
804
805 let mut result = BTreeMap::new();
807 for dest in dests {
808 let mut path = vec![dest.clone()];
809 let mut current = dest.clone();
810 while let Some(parent) = child_parent_map.get(¤t) {
811 path.push(parent.clone());
812 current = parent.clone();
813 }
814 path.reverse();
815 result.insert(dest, path);
816 }
817
818 PathToLeaves(result)
819 }
820
821 #[test]
822 fn test_build_paths() {
823 let edges: Vec<_> = [
830 (0, 1, false),
831 (1, 2, true),
832 (1, 3, true),
833 (0, 4, true),
834 (4, 5, true),
835 ]
836 .into_iter()
837 .map(|(from, to, is_leaf)| Edge { from, to, is_leaf })
838 .collect();
839
840 let paths = build_paths(&edges);
841
842 let expected = btreemap! {
843 2 => vec![0, 1, 2],
844 3 => vec![0, 1, 3],
845 4 => vec![0, 4],
846 5 => vec![0, 4, 5],
847 };
848
849 assert_eq!(paths.0, expected);
850 }
851
852 fn get_ranks(
872 paths: PathToLeaves<hyperactor_reference::PortId>,
873 client_reply: &hyperactor_reference::PortId,
874 rank_lookup: &HashMap<hyperactor_reference::ProcId, usize>,
875 ) -> PathToLeaves<hyperactor_reference::Index> {
876 let ranks = paths
877 .0
878 .into_iter()
879 .map(|(dst, mut path)| {
880 let first = path.remove(0);
881 assert_eq!(&first, client_reply);
883 assert!(dst.actor_id().name().contains("comm"));
886 let actor_path = path
887 .into_iter()
888 .map(|p| {
889 assert!(p.actor_id().name().contains("comm"));
890 lookup_rank(p.actor_id(), rank_lookup)
891 })
892 .collect();
893 (lookup_rank(dst.actor_id(), rank_lookup), actor_path)
894 })
895 .collect();
896 PathToLeaves(ranks)
897 }
898
899 struct NoneAccumulator;
900
901 impl Accumulator for NoneAccumulator {
902 type State = u64;
903 type Update = u64;
904
905 fn accumulate(
906 &self,
907 _state: &mut Self::State,
908 _update: Self::Update,
909 ) -> anyhow::Result<()> {
910 unimplemented!()
911 }
912
913 fn reducer_spec(&self) -> Option<ReducerSpec> {
914 unimplemented!()
915 }
916 }
917
918 fn verify_split_port_paths(
920 selection: &Selection,
921 extent: &Extent,
922 reply_port_ref1: &hyperactor_reference::PortRef<u64>,
923 reply_port_ref2: &hyperactor_reference::PortRef<MyReply>,
924 rank_lookup: &HashMap<hyperactor_reference::ProcId, usize>,
925 ) {
926 let sel_paths = PathToLeaves(
928 collect_commactor_routing_tree(selection, &extent.to_slice())
929 .delivered
930 .into_iter()
931 .collect(),
932 );
933
934 let (reply1_paths, reply2_paths) = {
936 let tree = SPLIT_PORT_TREE
937 .get()
938 .expect("not initialized; are Hosts in the same process as SPLIT_PORT_TREE?");
939 let edges = tree.lock().unwrap();
940 let (reply1, reply2): (BTreeMap<_, _>, BTreeMap<_, _>) = build_paths(&edges)
941 .0
942 .into_iter()
943 .partition(|(_dst, path)| &path[0] == reply_port_ref1.port_id());
944 (
945 get_ranks(PathToLeaves(reply1), reply_port_ref1.port_id(), rank_lookup),
946 get_ranks(PathToLeaves(reply2), reply_port_ref2.port_id(), rank_lookup),
947 )
948 };
949
950 assert_eq!(sel_paths, reply1_paths);
952 assert_eq!(sel_paths, reply2_paths);
953 }
954
955 async fn execute_cast_and_reply(
956 ranks: Vec<hyperactor_reference::ActorRef<TestActor>>,
957 instance: &impl context::Actor,
958 mut reply1_rx: PortReceiver<u64>,
959 mut reply2_rx: PortReceiver<MyReply>,
960 reply_tos: Vec<(
961 hyperactor_reference::PortRef<u64>,
962 hyperactor_reference::PortRef<MyReply>,
963 )>,
964 ) {
965 {
967 for (rank, (dest_actor, (reply_to1, reply_to2))) in
968 ranks.iter().zip(reply_tos.iter()).enumerate()
969 {
970 let rank_u64 = rank as u64;
971 reply_to1.send(instance, rank_u64).unwrap();
972 let my_reply = MyReply {
973 sender: dest_actor.actor_id().clone(),
974 value: rank_u64,
975 };
976 reply_to2.send(instance, my_reply.clone()).unwrap();
977
978 assert_eq!(reply1_rx.recv().await.unwrap(), rank_u64);
979 assert_eq!(reply2_rx.recv().await.unwrap(), my_reply);
980 }
981 }
982
983 tracing::info!("the 1st updates from all dest actors were receivered by client");
984
985 {
989 let n = 100;
990 let mut expected2: HashMap<hyperactor_reference::ActorId, Vec<MyReply>> = hashmap! {};
991 for (i, (dest_actor, (_reply_to1, reply_to2))) in
992 ranks.iter().zip(reply_tos.iter()).enumerate()
993 {
994 let mut sent2 = vec![];
995 for j in 0..n {
996 let value = (i * 100 + j) as u64;
997 let my_reply = MyReply {
998 sender: dest_actor.actor_id().clone(),
999 value,
1000 };
1001 reply_to2.send(instance, my_reply.clone()).unwrap();
1002 sent2.push(my_reply);
1003 }
1004 assert!(
1005 expected2
1006 .insert(dest_actor.actor_id().clone(), sent2)
1007 .is_none(),
1008 "duplicate actor_id {} in map",
1009 dest_actor.actor_id()
1010 );
1011 }
1012
1013 let mut received2: HashMap<hyperactor_reference::ActorId, Vec<MyReply>> = hashmap! {};
1014
1015 for _ in 0..(n * ranks.len()) {
1016 let my_reply = reply2_rx.recv().await.unwrap();
1017 received2
1018 .entry(my_reply.sender.clone())
1019 .or_default()
1020 .push(my_reply);
1021 }
1022 assert_eq!(received2, expected2);
1023 }
1024 }
1025
1026 async fn wait_for_with_timeout(
1027 receiver: &mut PortReceiver<u64>,
1028 expected: u64,
1029 dur: Duration,
1030 ) -> anyhow::Result<()> {
1031 tokio::time::timeout(dur, async {
1033 loop {
1034 let msg = receiver.recv().await.unwrap();
1035 if msg == expected {
1036 break;
1037 }
1038 }
1039 })
1040 .await?;
1041 Ok(())
1042 }
1043
1044 async fn execute_cast_and_accum(
1045 ranks: Vec<hyperactor_reference::ActorRef<TestActor>>,
1046 instance: &impl context::Actor,
1047 mut reply1_rx: PortReceiver<u64>,
1048 reply_tos: Vec<(
1049 hyperactor_reference::PortRef<u64>,
1050 hyperactor_reference::PortRef<MyReply>,
1051 )>,
1052 ) {
1053 let mut sum = 0;
1057 let n = 100;
1058 for (i, (_dest_actor, (reply_to1, _reply_to2))) in
1059 ranks.iter().zip(reply_tos.iter()).enumerate()
1060 {
1061 for j in 0..n {
1062 let value = (i + j) as u64;
1063 reply_to1.send(instance, value).unwrap();
1064 sum += value;
1065 }
1066 }
1067 wait_for_with_timeout(&mut reply1_rx, sum, Duration::from_secs(8))
1068 .await
1069 .unwrap();
1070 tokio::time::sleep(Duration::from_secs(2)).await;
1072 let msg = reply1_rx.try_recv().unwrap();
1073 assert_eq!(msg, None);
1074 }
1075
1076 struct MeshSetupV1 {
1077 instance: &'static Instance<testing::TestRootClient>,
1078 actor_mesh_ref: crate::ActorMeshRef<TestActor>,
1079 reply1_rx: PortReceiver<u64>,
1080 reply2_rx: PortReceiver<MyReply>,
1081 reply_tos: Vec<(
1082 hyperactor_reference::PortRef<u64>,
1083 hyperactor_reference::PortRef<MyReply>,
1084 )>,
1085 host_mesh: HostMesh,
1087 }
1088
1089 async fn setup_mesh_v1<A>(accum: Option<A>) -> MeshSetupV1
1090 where
1091 A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
1092 {
1093 let instance = crate::testing::instance();
1094 let host_mesh = local_host_mesh(8).await;
1097 let proc_mesh = host_mesh
1098 .spawn(instance, "test", extent!(gpu = 8))
1099 .await
1100 .unwrap();
1101
1102 let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1103 let params = TestActorParams {
1104 forward_port: tx.bind(),
1105 };
1106 let actor_name = crate::Name::new("test").expect("valid test name");
1107 let actor_mesh = proc_mesh
1111 .spawn_with_name(&instance, actor_name, ¶ms, None, true)
1112 .await
1113 .unwrap();
1114 let actor_mesh_ref: crate::ActorMeshRef<TestActor> = actor_mesh.deref().clone();
1115
1116 let (reply_port_handle0, _) = open_port::<String>(instance);
1117 let reply_port_ref0 = reply_port_handle0.bind();
1118 let (reply_port_handle1, reply1_rx) = match accum {
1119 Some(a) => instance.mailbox().open_accum_port(a),
1120 None => open_port(instance),
1121 };
1122 let reply_port_ref1 = reply_port_handle1.bind();
1123 let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(instance);
1124 let reply_port_ref2 = reply_port_handle2.bind();
1125 let message = TestMessage::CastAndReply {
1126 arg: "abc".to_string(),
1127 reply_to0: reply_port_ref0.clone(),
1128 reply_to1: reply_port_ref1.clone(),
1129 reply_to2: reply_port_ref2.clone(),
1130 };
1131
1132 clear_collected_tree();
1133 actor_mesh_ref.cast(instance, message).unwrap();
1134
1135 let mut reply_tos = vec![];
1136 for _ in proc_mesh.extent().points() {
1137 let msg = rx.recv().await.expect("missing");
1138 match msg {
1139 TestMessage::CastAndReply {
1140 arg,
1141 reply_to0,
1142 reply_to1,
1143 reply_to2,
1144 } => {
1145 assert_eq!(arg, "abc");
1146 assert_eq!(reply_to0, reply_port_ref0);
1149 assert_ne!(reply_to1, reply_port_ref1);
1151 assert!(reply_to1.port_id().actor_id().name().contains("comm"));
1152 assert_ne!(reply_to2, reply_port_ref2);
1153 assert!(reply_to2.port_id().actor_id().name().contains("comm"));
1154 reply_tos.push((reply_to1, reply_to2));
1155 }
1156 _ => {
1157 panic!("unexpected message: {:?}", msg);
1158 }
1159 }
1160 }
1161
1162 let rank_lookup = proc_mesh
1165 .ranks()
1166 .iter()
1167 .enumerate()
1168 .map(|(i, r)| (r.proc_id().clone(), i))
1169 .collect::<HashMap<hyperactor_reference::ProcId, usize>>();
1170
1171 let selection = sel!(*);
1173 verify_split_port_paths(
1174 &selection,
1175 &proc_mesh.extent(),
1176 &reply_port_ref1,
1177 &reply_port_ref2,
1178 &rank_lookup,
1179 );
1180
1181 MeshSetupV1 {
1182 instance,
1183 actor_mesh_ref,
1184 reply1_rx,
1185 reply2_rx,
1186 reply_tos,
1187 host_mesh,
1188 }
1189 }
1190
1191 async fn execute_cast_and_reply_v1() {
1192 let mut setup = setup_mesh_v1::<NoneAccumulator>(None).await;
1193
1194 let ranks = setup.actor_mesh_ref.values().collect::<Vec<_>>();
1195 execute_cast_and_reply(
1196 ranks,
1197 setup.instance,
1198 setup.reply1_rx,
1199 setup.reply2_rx,
1200 setup.reply_tos,
1201 )
1202 .await;
1203
1204 let _ = setup.host_mesh.shutdown(setup.instance).await;
1205 }
1206
1207 #[async_timed_test(timeout_secs = 60)]
1208 async fn test_cast_and_reply_v1_retrofit() {
1209 let config = hyperactor_config::global::lock();
1210 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false);
1211 let _guard2 = config.override_key(
1212 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1213 false,
1214 );
1215 execute_cast_and_reply_v1().await
1216 }
1217
1218 #[async_timed_test(timeout_secs = 60)]
1219 async fn test_cast_and_reply_v1_native() {
1220 let config = hyperactor_config::global::lock();
1221 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1222 let _guard2 = config.override_key(
1223 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1224 true,
1225 );
1226 execute_cast_and_reply_v1().await
1227 }
1228
1229 async fn execute_cast_and_accum_v1(config: &hyperactor_config::global::ConfigLock) {
1230 let _guard1 = config.override_key(hyperactor::config::SPLIT_MAX_BUFFER_SIZE, 1);
1232
1233 let mut setup = setup_mesh_v1(Some(accum::sum::<u64>())).await;
1234
1235 let ranks = setup.actor_mesh_ref.values().collect::<Vec<_>>();
1236 execute_cast_and_accum(ranks, setup.instance, setup.reply1_rx, setup.reply_tos).await;
1237
1238 let _ = setup.host_mesh.shutdown(setup.instance).await;
1239 }
1240
1241 #[async_timed_test(timeout_secs = 60)]
1242 async fn test_cast_and_accum_v1_retrofit() {
1243 let config = hyperactor_config::global::lock();
1244 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false);
1245 let _guard2 = config.override_key(
1246 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1247 false,
1248 );
1249 execute_cast_and_accum_v1(&config).await
1250 }
1251
1252 #[async_timed_test(timeout_secs = 60)]
1253 async fn test_cast_and_accum_v1_native() {
1254 let config = hyperactor_config::global::lock();
1255 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1256 let _guard2 = config.override_key(
1257 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1258 true,
1259 );
1260 execute_cast_and_accum_v1(&config).await
1261 }
1262
1263 struct OncePortMeshSetupV1 {
1264 instance: &'static Instance<testing::TestRootClient>,
1265 reply_rx: hyperactor::mailbox::OncePortReceiver<u64>,
1266 reply_tos: Vec<hyperactor::reference::OncePortRef<u64>>,
1267 _reply_port_ref: hyperactor::reference::OncePortRef<u64>,
1268 host_mesh: HostMesh,
1269 }
1270
1271 async fn setup_once_port_mesh<A>(reducer: Option<A>) -> OncePortMeshSetupV1
1272 where
1273 A: Accumulator<State = u64, Update = u64> + Send + Sync + 'static,
1274 {
1275 let instance = crate::testing::instance();
1276 let host_mesh = local_host_mesh(8).await;
1279 let proc_mesh = host_mesh
1280 .spawn(instance, "test", extent!(gpu = 8))
1281 .await
1282 .unwrap();
1283
1284 let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1285 let params = TestActorParams {
1286 forward_port: tx.bind(),
1287 };
1288 let actor_name = crate::Name::new("test").expect("valid test name");
1289 let actor_mesh: crate::ActorMesh<TestActor> = proc_mesh
1291 .spawn_with_name(&instance, actor_name, ¶ms, None, true)
1292 .await
1293 .unwrap();
1294 let actor_mesh_ref = actor_mesh.deref().clone();
1295
1296 let has_reducer = reducer.is_some();
1297 let (reply_port_handle, reply_rx) = match reducer {
1298 Some(reducer) => instance.mailbox().open_reduce_port(reducer),
1299 None => instance.mailbox().open_once_port::<u64>(),
1300 };
1301 let reply_port_ref = reply_port_handle.bind();
1302
1303 let message = TestMessage::CastAndReplyOnce {
1304 arg: "abc".to_string(),
1305 reply_to: reply_port_ref.clone(),
1306 };
1307
1308 clear_collected_tree();
1309 actor_mesh_ref.cast(instance, message).unwrap();
1310
1311 let mut reply_tos = vec![];
1312 for _ in proc_mesh.extent().points() {
1313 let msg = rx.recv().await.expect("missing");
1314 match msg {
1315 TestMessage::CastAndReplyOnce { arg, reply_to } => {
1316 assert_eq!(arg, "abc");
1317 if has_reducer {
1318 assert_ne!(reply_to, reply_port_ref);
1320 assert!(reply_to.port_id().actor_id().name().contains("comm"));
1321 } else {
1322 assert_eq!(reply_to, reply_port_ref);
1324 }
1325 reply_tos.push(reply_to);
1326 }
1327 _ => {
1328 panic!("unexpected message: {:?}", msg);
1329 }
1330 }
1331 }
1332
1333 OncePortMeshSetupV1 {
1334 instance,
1335 reply_rx,
1336 reply_tos,
1337 _reply_port_ref: reply_port_ref,
1338 host_mesh,
1339 }
1340 }
1341
1342 #[async_timed_test(timeout_secs = 60)]
1343 async fn test_cast_and_reply_once_v1() {
1344 let mut setup = setup_once_port_mesh::<NoneAccumulator>(None).await;
1348
1349 let num_replies = setup.reply_tos.len();
1352 for (i, reply_to) in setup.reply_tos.into_iter().enumerate() {
1353 reply_to.send(setup.instance, i as u64).unwrap();
1354 }
1355
1356 let result = setup.reply_rx.recv().await.unwrap();
1358 assert!(result < num_replies as u64);
1360
1361 let _ = setup.host_mesh.shutdown(setup.instance).await;
1362 }
1363
1364 #[async_timed_test(timeout_secs = 60)]
1365 async fn test_cast_and_accum_once_v1() {
1366 let mut setup = setup_once_port_mesh(Some(accum::sum::<u64>())).await;
1370
1371 let mut expected_sum = 0u64;
1373 for (i, reply_to) in setup.reply_tos.into_iter().enumerate() {
1374 reply_to.send(setup.instance, i as u64).unwrap();
1375 expected_sum += i as u64;
1376 }
1377
1378 let result = setup.reply_rx.recv().await.unwrap();
1380 assert_eq!(result, expected_sum);
1381
1382 let _ = setup.host_mesh.shutdown(setup.instance).await;
1383 }
1384}