1use crate::casting::CAST_ACTOR_MESH_ID;
10use crate::comm::multicast::CAST_ORIGINATING_SENDER;
11use crate::comm::multicast::CastEnvelope;
12use crate::comm::multicast::CastMessageV1;
13use crate::comm::multicast::ForwardMessageV1;
14use crate::mesh_id::ActorMeshId;
15use crate::resource;
16pub mod multicast;
17
18use std::cmp::Ordering;
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22use anyhow::Result;
23use async_trait::async_trait;
24use hyperactor::Actor;
25use hyperactor::ActorAddr;
26use hyperactor::ActorRef;
27use hyperactor::Context;
28use hyperactor::Endpoint as _;
29use hyperactor::Handler;
30use hyperactor::Instance;
31use hyperactor::PortRef;
32use hyperactor::RemoteEndpoint as _;
33use hyperactor::RemoteMessage;
34use hyperactor::UnboundPort;
35use hyperactor::UnboundPortKind;
36use hyperactor::accum::ReducerMode;
37use hyperactor::mailbox::DeliveryError;
38use hyperactor::mailbox::MailboxSender;
39use hyperactor::mailbox::Undeliverable;
40use hyperactor::mailbox::UndeliverableMailboxSender;
41use hyperactor::mailbox::UndeliverableMessageError;
42use hyperactor::mailbox::monitored_return_handle;
43use hyperactor::message::ErasedUnbound;
44use hyperactor::ordering::SEQ_INFO;
45use hyperactor::ordering::SeqInfo;
46use hyperactor_config::CONFIG;
47use hyperactor_config::ConfigAttr;
48use hyperactor_config::Flattrs;
49use hyperactor_config::attrs::declare_attrs;
50use hyperactor_mesh_macros::sel;
51use ndslice::Point;
52use ndslice::Selection;
53use ndslice::View;
54use ndslice::selection::routing::RoutingFrame;
55use serde::Deserialize;
56use serde::Serialize;
57use typeuri::Named;
58
59use crate::comm::multicast::CastMessage;
60use crate::comm::multicast::CastMessageEnvelope;
61use crate::comm::multicast::ForwardMessage;
62use crate::comm::multicast::set_cast_info_on_headers;
63
64declare_attrs! {
65 @meta(CONFIG = ConfigAttr::new(
67 Some("HYPERACTOR_MESH_ENABLE_NATIVE_V1_CASTING".to_string()),
68 Some("enable_native_v1_casting".to_string()),
69 ))
70 pub attr ENABLE_NATIVE_V1_CASTING: bool = true;
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, Named, Default)]
75pub struct CommActorParams {}
76wirevalue::register_type!(CommActorParams);
77
78#[derive(Debug)]
80struct Buffered {
81 seq: usize,
83 deliver_here: bool,
85 next_steps: HashMap<usize, Vec<RoutingFrame>>,
87 message: CastMessageEnvelope,
89}
90
91#[derive(Debug, Default)]
94struct ReceiveState {
95 seq: usize,
97 buffer: HashMap<usize, Buffered>,
100 last_seqs: HashMap<usize, usize>,
102}
103
104#[derive(Debug, Default)]
107#[hyperactor::export(
108 CommMeshConfig,
109 CastMessage,
110 ForwardMessage,
111 CastMessageV1,
112 ForwardMessageV1
113)]
114#[hyperactor::spawnable]
115pub struct CommActor {
116 send_seq: HashMap<(ActorMeshId, ActorAddr), usize>,
118 recv_state: HashMap<(ActorMeshId, ActorAddr), ReceiveState>,
120
121 mesh_config: MeshConfigState,
123}
124
125#[derive(Debug)]
126enum PendingMessage {
127 Cast(CastMessage),
128 Forward(ForwardMessage),
129 ForwardV1(ForwardMessageV1),
130}
131
132#[derive(Debug)]
133enum MeshConfigState {
134 NotConfigured(Vec<PendingMessage>),
136 Configured(CommMeshConfig),
138}
139
140impl Default for MeshConfigState {
141 fn default() -> Self {
142 MeshConfigState::NotConfigured(Vec::new())
143 }
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize, Named)]
148pub struct CommMeshConfig {
149 rank: usize,
151 peers: HashMap<usize, ActorRef<CommActor>>,
153}
154wirevalue::register_type!(CommMeshConfig);
155
156impl CommMeshConfig {
157 pub fn new(rank: usize, peers: HashMap<usize, ActorRef<CommActor>>) -> Self {
159 Self { rank, peers }
160 }
161
162 fn peer_for_rank(&self, rank: usize) -> Result<ActorRef<CommActor>> {
164 self.peers
165 .get(&rank)
166 .cloned()
167 .ok_or_else(|| anyhow::anyhow!("no peer for rank {}", rank))
168 }
169
170 fn self_rank(&self) -> usize {
172 self.rank
173 }
174}
175
176#[async_trait]
177impl Actor for CommActor {
178 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
179 this.set_system();
180 Ok(())
181 }
182
183 async fn handle_undeliverable_message(
185 &mut self,
186 cx: &Instance<Self>,
187 undelivered: hyperactor::mailbox::Undeliverable<hyperactor::mailbox::MessageEnvelope>,
188 ) -> Result<(), anyhow::Error> {
189 let mut message_envelope = match undelivered {
190 Undeliverable::Message(message_envelope) => message_envelope,
191 Undeliverable::Lost(lost) => {
192 anyhow::bail!(UndeliverableMessageError::Lost { lost });
193 }
194 };
195
196 if let Ok(ForwardMessage { message, .. }) =
198 message_envelope.deserialized::<ForwardMessage>()
199 {
200 let sender = message.sender();
201 let return_port = PortRef::attest_handler_port(sender);
202 message_envelope.set_error(DeliveryError::Multicast(format!(
203 "comm actor {} failed to forward the cast message; returning to origin {}",
204 cx.self_addr(),
205 return_port.port_addr(),
206 )));
207
208 message_envelope.set_header(CAST_ORIGINATING_SENDER, sender.clone());
211
212 return_port.post(cx, Undeliverable::Message(message_envelope.clone()));
213 return Ok(());
214 }
215
216 if let Some(sender) = message_envelope.headers().get(CAST_ORIGINATING_SENDER) {
218 let return_port = PortRef::attest_handler_port(&sender);
219 message_envelope.set_error(DeliveryError::Multicast(format!(
220 "comm actor {} failed to deliver the cast message to the dest \
221 actor; returning to origin {}",
222 cx.self_addr(),
223 return_port.port_addr(),
224 )));
225 return_port.post(cx, Undeliverable::Message(message_envelope.clone()));
226 return Ok(());
227 }
228
229 UndeliverableMailboxSender
231 .post(message_envelope, monitored_return_handle());
232 Ok(())
233 }
234}
235
236impl CommActor {
237 fn forward<M: RemoteMessage>(
239 cx: &Context<Self>,
240 config: &CommMeshConfig,
241 rank: usize,
242 message: M,
243 ) -> Result<()>
244 where
245 CommActor: hyperactor::RemoteHandles<M>,
246 {
247 let child = config.peer_for_rank(rank)?;
248 if let Some(cast_actor_mesh_id) = cx.headers().get(CAST_ACTOR_MESH_ID) {
250 let mut headers = Flattrs::new();
251 headers.set(CAST_ACTOR_MESH_ID, cast_actor_mesh_id);
252 child.post_with_headers(cx, headers, message);
253 } else {
254 child.post(cx, message);
255 }
256 Ok(())
257 }
258
259 fn handle_message(
260 cx: &Context<Self>,
261 config: &CommMeshConfig,
262 deliver_here: bool,
263 next_steps: HashMap<usize, Vec<RoutingFrame>>,
264 sender: ActorAddr,
265 mut message: CastMessageEnvelope,
266 seq: usize,
267 last_seqs: &mut HashMap<usize, usize>,
268 ) -> Result<()> {
269 split_ports(cx, message.data_mut(), deliver_here, &next_steps)?;
270
271 if deliver_here {
273 let headers = message.headers().clone();
277 Self::deliver_to_dest(cx, headers, &mut message, config)?;
278 }
279
280 next_steps
282 .into_iter()
283 .map(|(peer, dests)| {
284 let last_seq = last_seqs.entry(peer).or_default();
285 Self::forward(
286 cx,
287 config,
288 peer,
289 ForwardMessage {
290 dests,
291 sender: sender.clone(),
292 message: message.clone(),
293 seq,
294 last_seq: *last_seq,
295 },
296 )?;
297 *last_seq = seq;
298 Ok(())
299 })
300 .collect::<Result<Vec<_>>>()?;
301
302 Ok(())
303 }
304
305 fn deliver_to_dest<M: CastEnvelope>(
306 cx: &Context<Self>,
307 mut headers: Flattrs,
308 message: &mut M,
309 config: &CommMeshConfig,
310 ) -> anyhow::Result<()> {
311 let cast_point = message.cast_point(config)?;
312 replace_with_self_ranks(&cast_point, message.data_mut())?;
314
315 set_cast_info_on_headers(&mut headers, cast_point, message.sender().clone());
316 cx.post_with_external_seq_info(
317 cx.self_addr()
318 .proc_addr()
319 .actor_addr_uid(message.dest_port().actor_uid().clone())
320 .port_addr(hyperactor::port::Port::from(message.dest_port().port())),
321 headers,
322 wirevalue::Any::serialize(message.data())?,
323 );
324
325 Ok(())
326 }
327}
328
329fn split_ports(
333 cx: &Context<CommActor>,
334 data: &mut ErasedUnbound,
335 deliver_here: bool,
336 next_steps: &HashMap<usize, Vec<RoutingFrame>>,
337) -> anyhow::Result<()> {
338 data.visit_mut::<UnboundPort>(
342 |UnboundPort(port_id, reducer_spec, return_undeliverable, kind, unsplit)| {
343 if *unsplit {
344 return Ok(());
345 }
346 let reducer_mode = match kind {
347 UnboundPortKind::Streaming(opts) => {
348 ReducerMode::Streaming(opts.clone().unwrap_or_default())
349 }
350 UnboundPortKind::Once if reducer_spec.is_none() => {
351 return Ok(());
358 }
359 UnboundPortKind::Once => {
360 let peer_count = next_steps.len() + if deliver_here { 1 } else { 0 };
364 ReducerMode::Once(peer_count)
365 }
366 };
367
368 let split = port_id.clone().split(
369 cx,
370 reducer_spec.clone(),
371 reducer_mode,
372 *return_undeliverable,
373 )?;
374
375 #[cfg(test)]
376 tests::collect_split_port(&port_id.clone(), &split, deliver_here);
377
378 *port_id = split;
379 Ok(())
380 },
381 )
382}
383
384fn replace_with_self_ranks(cast_point: &Point, data: &mut ErasedUnbound) -> anyhow::Result<()> {
385 data.visit_mut::<resource::Rank>(|resource::Rank(rank)| {
386 *rank = Some(cast_point.rank());
387 Ok(())
388 })
389}
390
391#[async_trait]
392impl Handler<CommMeshConfig> for CommActor {
393 async fn handle(&mut self, cx: &Context<Self>, config: CommMeshConfig) -> Result<()> {
394 let pending =
395 match std::mem::replace(&mut self.mesh_config, MeshConfigState::Configured(config)) {
396 MeshConfigState::NotConfigured(pending) => pending,
397 MeshConfigState::Configured(_) => Vec::new(),
398 };
399 if !pending.is_empty() {
400 tracing::info!(
401 count = pending.len(),
402 "replaying buffered pre-config messages"
403 );
404 }
405 for msg in pending {
406 match msg {
407 PendingMessage::Cast(m) => self.handle(cx, m).await?,
408 PendingMessage::Forward(m) => self.handle(cx, m).await?,
409 PendingMessage::ForwardV1(m) => self.handle(cx, m).await?,
410 }
411 }
412 Ok(())
413 }
414}
415
416#[async_trait]
418impl Handler<CastMessage> for CommActor {
419 #[tracing::instrument(level = "debug", skip_all)]
420 async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessage) -> Result<()> {
421 let config = match &mut self.mesh_config {
422 MeshConfigState::NotConfigured(pending) => {
423 pending.push(PendingMessage::Cast(cast_message));
424 return Ok(());
425 }
426 MeshConfigState::Configured(config) => config,
427 };
428 let slice = cast_message.dest.slice.clone();
430 let selection = cast_message.dest.selection.clone();
431 let frame = RoutingFrame::root(selection, slice);
432 let rank = frame.slice.location(&frame.here)?;
433 let seq = self
434 .send_seq
435 .entry(cast_message.message.stream_key())
436 .or_default();
437 let last_seq = *seq;
438 *seq += 1;
439
440 let fwd_message = ForwardMessage {
441 dests: vec![frame],
442 sender: cx.self_addr().clone(),
443 message: cast_message.message,
444 seq: *seq,
445 last_seq,
446 };
447
448 if config.self_rank() == rank {
451 Handler::<ForwardMessage>::handle(self, cx, fwd_message).await?;
452 } else {
453 Self::forward(cx, config, rank, fwd_message)?;
454 }
455 Ok(())
456 }
457}
458
459#[async_trait]
460impl Handler<ForwardMessage> for CommActor {
461 #[tracing::instrument(level = "debug", skip_all)]
462 async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessage) -> Result<()> {
463 let config = match &mut self.mesh_config {
464 MeshConfigState::NotConfigured(pending) => {
465 pending.push(PendingMessage::Forward(fwd_message));
466 return Ok(());
467 }
468 MeshConfigState::Configured(config) => config,
469 };
470
471 let ForwardMessage {
472 sender,
473 dests,
474 message,
475 seq,
476 last_seq,
477 } = fwd_message;
478
479 let rank = config.self_rank();
481 let (deliver_here, next_steps) =
482 ndslice::selection::routing::resolve_routing(rank, dests, &mut |_| {
483 panic!("Choice encountered in CommActor routing")
484 })?;
485
486 let recv_state = self.recv_state.entry(message.stream_key()).or_default();
487 match recv_state.seq.cmp(&last_seq) {
488 Ordering::Equal => {
490 Self::handle_message(
492 cx,
493 config,
494 deliver_here,
495 next_steps,
496 sender.clone(),
497 message,
498 seq,
499 &mut recv_state.last_seqs,
500 )?;
501 recv_state.seq = seq;
502
503 while let Some(Buffered {
506 seq,
507 deliver_here,
508 next_steps,
509 message,
510 }) = recv_state.buffer.remove(&recv_state.seq)
511 {
512 Self::handle_message(
513 cx,
514 config,
515 deliver_here,
516 next_steps,
517 sender.clone(),
518 message,
519 seq,
520 &mut recv_state.last_seqs,
521 )?;
522 recv_state.seq = seq;
523 }
524 }
525 Ordering::Less => {
528 tracing::warn!(
529 "buffering out-of-order message with seq {} (last {}), expected {}: {:?}",
530 seq,
531 last_seq,
532 recv_state.seq,
533 message
534 );
535 recv_state.buffer.insert(
536 last_seq,
537 Buffered {
538 seq,
539 deliver_here,
540 next_steps,
541 message,
542 },
543 );
544 }
545 Ordering::Greater => {
547 tracing::warn!("received duplicate message with seq {}: {:?}", seq, message);
548 }
549 }
550
551 Ok(())
552 }
553}
554
555#[async_trait]
556impl Handler<CastMessageV1> for CommActor {
557 async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessageV1) -> Result<()> {
558 let slice = cast_message.dest_region.slice().clone();
559 let frame = RoutingFrame::root(sel!(*), slice);
560 let forward_message = ForwardMessageV1 {
561 dests: vec![frame],
562 message: cast_message,
563 };
564 self.handle(cx, forward_message).await
565 }
566}
567
568#[async_trait]
569impl Handler<ForwardMessageV1> for CommActor {
570 async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessageV1) -> Result<()> {
571 let config = match &mut self.mesh_config {
572 MeshConfigState::NotConfigured(pending) => {
573 pending.push(PendingMessage::ForwardV1(fwd_message));
574 return Ok(());
575 }
576 MeshConfigState::Configured(config) => config,
577 };
578
579 let ForwardMessageV1 { dests, mut message } = fwd_message;
580 let rank_on_root_mesh = config.self_rank();
582 let (deliver_here, next_steps) =
583 ndslice::selection::routing::resolve_routing(rank_on_root_mesh, dests, &mut |_| {
584 panic!("choice encountered in CommActor routing")
585 })?;
586
587 split_ports(cx, &mut message.data, deliver_here, &next_steps)?;
588
589 if deliver_here {
591 let mut headers = message.headers().clone();
592 let seq = message
593 .seqs
594 .get(message.cast_point(config)?.rank())
595 .expect("mismatched seqs and dest_region");
596 headers.set(
597 SEQ_INFO,
598 SeqInfo::Session {
599 session_id: message.session_id,
600 seq,
601 },
602 );
603 Self::deliver_to_dest(cx, headers, &mut message, config)?;
604 }
605
606 for (peer_rank_on_root_mesh, dests) in next_steps {
608 let forward_message = ForwardMessageV1 {
609 dests,
610 message: message.clone(),
611 };
612 Self::forward(cx, config, peer_rank_on_root_mesh, forward_message)?;
613 }
614
615 Ok(())
616 }
617}
618
619pub mod test_utils {
620 use anyhow::Result;
621 use async_trait::async_trait;
622 use hyperactor::Actor;
623 use hyperactor::ActorAddr;
624 use hyperactor::Bind;
625 use hyperactor::Context;
626 use hyperactor::Handler;
627 use hyperactor::PortRef;
628 use hyperactor::Unbind;
629 use serde::Deserialize;
630 use serde::Serialize;
631 use typeuri::Named;
632
633 use super::*;
634
635 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
636 pub struct MyReply {
637 pub sender: ActorAddr,
638 pub value: u64,
639 }
640
641 #[derive(Debug, Named, Serialize, Deserialize, PartialEq, Clone, Bind, Unbind)]
642 #[expect(
643 clippy::large_enum_variant,
644 reason = "test fixture; CastAndReply carries #[binding(include)] PortRefs whose Bind/Unbind derive interaction with Box<T> needs verification — separate diff"
645 )]
646 pub enum TestMessage {
647 Forward(String),
648 CastAndReply {
649 arg: String,
650 reply_to0: PortRef<String>,
654 #[binding(include)]
655 reply_to1: PortRef<u64>,
656 #[binding(include)]
657 reply_to2: PortRef<MyReply>,
658 },
659 CastAndReplyOnce {
660 arg: String,
661 #[binding(include)]
662 reply_to: hyperactor::OncePortRef<u64>,
663 },
664 CastWithUnsplitPort {
665 #[binding(include)]
666 reply_to: PortRef<u64>,
667 },
668 }
669
670 #[derive(Debug)]
671 #[hyperactor::export(TestMessage { cast = true })]
672 #[hyperactor::spawnable]
673 pub struct TestActor {
674 forward_port: PortRef<TestMessage>,
677 }
678
679 #[derive(Debug, Clone, Named, Serialize, Deserialize)]
680 pub struct TestActorParams {
681 pub forward_port: PortRef<TestMessage>,
682 }
683
684 #[async_trait]
685 impl Actor for TestActor {}
686
687 #[async_trait]
688 impl hyperactor::RemoteSpawn for TestActor {
689 type Params = TestActorParams;
690
691 async fn new(params: Self::Params, _environment: Flattrs) -> Result<Self> {
692 let Self::Params { forward_port } = params;
693 Ok(Self { forward_port })
694 }
695 }
696
697 #[async_trait]
698 impl Handler<TestMessage> for TestActor {
699 async fn handle(&mut self, cx: &Context<Self>, msg: TestMessage) -> anyhow::Result<()> {
700 if let TestMessage::CastWithUnsplitPort { ref reply_to } = msg {
703 reply_to.post(cx, 42);
704 }
705 self.forward_port.post(cx, msg);
706 Ok(())
707 }
708 }
709}
710
711#[cfg(test)]
712mod tests {
713 use std::collections::BTreeMap;
714 use std::collections::HashSet;
715 use std::fmt::Display;
716 use std::hash::Hash;
717 use std::ops::Deref;
718 use std::ops::DerefMut;
719 use std::sync::Mutex;
720 use std::sync::OnceLock;
721
722 use hyperactor::accum;
723
724 async fn buffering_fixture(
729 proc_name: &str,
730 ) -> (
731 Instance<()>,
732 hyperactor::mailbox::PortReceiver<TestMessage>,
733 hyperactor::ActorHandle<CommActor>,
734 crate::mesh_id::ActorMeshId,
735 (
737 hyperactor::ActorHandle<()>,
738 hyperactor::ActorHandle<TestActor>,
739 ActorRef<TestActor>,
740 ),
741 ) {
742 use hyperactor::Proc;
743 use hyperactor::RemoteSpawn;
744 use hyperactor::channel::ChannelTransport;
745 use hyperactor::id::Label;
746
747 let proc = Proc::direct(ChannelTransport::Unix.any(), proc_name.to_string()).unwrap();
748 let (client, client_handle) = proc.client("client").unwrap();
749
750 let actor_mesh_id = crate::mesh_id::ActorMeshId::instance(Label::new("test").unwrap());
751
752 let (tx, rx) = open_port(&client);
753 let forward_port = tx.bind();
754 let test_actor = TestActor::new(TestActorParams { forward_port }, Default::default())
755 .await
756 .unwrap();
757 let test_handle = proc
758 .spawn_with_uid(actor_mesh_id.uid().clone(), test_actor)
759 .unwrap();
760 let test_ref: ActorRef<TestActor> = test_handle.bind::<TestActor>();
761
762 let comm_handle = proc.spawn("comm", CommActor::default()).unwrap();
763
764 (
765 client,
766 rx,
767 comm_handle,
768 actor_mesh_id,
769 (client_handle, test_handle, test_ref),
770 )
771 }
772
773 fn send_config(client: &Instance<()>, comm_handle: &hyperactor::ActorHandle<CommActor>) {
775 let comm_ref = comm_handle.bind::<CommActor>();
776 let mut peers = HashMap::new();
777 peers.insert(0, comm_ref);
778 comm_handle.post(client, CommMeshConfig::new(0, peers));
779 }
780
781 async fn assert_buffered_and_replayed<M: hyperactor::Message>(
784 proc_name: &str,
785 mut make_msg: impl FnMut(&Instance<()>, &crate::mesh_id::ActorMeshId, &str) -> M,
786 ) where
787 CommActor: hyperactor::Handler<M>,
788 {
789 let (client, mut rx, comm_handle, actor_mesh_id, _guards) =
790 buffering_fixture(proc_name).await;
791
792 comm_handle.post(&client, make_msg(&client, &actor_mesh_id, "buffered"));
793 send_config(&client, &comm_handle);
794 comm_handle.post(&client, make_msg(&client, &actor_mesh_id, "direct"));
795
796 assert_eq!(
797 rx.recv().await.unwrap(),
798 TestMessage::Forward("buffered".to_string()),
799 );
800 assert_eq!(
801 rx.recv().await.unwrap(),
802 TestMessage::Forward("direct".to_string()),
803 );
804 comm_handle.drain_and_stop("test done").ok();
805 }
806
807 #[async_timed_test(timeout_secs = 1)]
808 async fn cast_before_config_is_buffered_and_replayed() {
809 use ndslice::Slice;
810
811 assert_buffered_and_replayed("test_cast", |client, actor_mesh_id, payload| {
812 let actor_mesh_id = actor_mesh_id.clone();
813 let slice = Slice::new_row_major(vec![1]);
814 let shape = ndslice::Shape::new(vec!["rank".to_string()], slice.clone()).unwrap();
815 let envelope = multicast::CastMessageEnvelope::new::<TestActor, TestMessage>(
816 actor_mesh_id,
817 client.self_addr().clone(),
818 shape,
819 hyperactor_config::Flattrs::new(),
820 TestMessage::Forward(payload.to_string()),
821 )
822 .unwrap();
823 multicast::CastMessage {
824 dest: multicast::Uslice {
825 slice,
826 selection: sel!(*),
827 },
828 message: envelope,
829 }
830 })
831 .await;
832 }
833
834 #[async_timed_test(timeout_secs = 1)]
835 async fn forward_before_config_is_buffered_and_replayed() {
836 use ndslice::Slice;
837 use ndslice::selection::routing::RoutingFrame;
838
839 let mut next_seq: usize = 0;
840 assert_buffered_and_replayed("test_fwd", move |client, actor_mesh_id, payload| {
841 let actor_mesh_id = actor_mesh_id.clone();
842 let slice = Slice::new_row_major(vec![1]);
843 let shape = ndslice::Shape::new(vec!["rank".to_string()], slice.clone()).unwrap();
844 let envelope = multicast::CastMessageEnvelope::new::<TestActor, TestMessage>(
845 actor_mesh_id,
846 client.self_addr().clone(),
847 shape,
848 hyperactor_config::Flattrs::new(),
849 TestMessage::Forward(payload.to_string()),
850 )
851 .unwrap();
852 let frame = RoutingFrame::root(sel!(*), slice);
853 let last_seq = next_seq;
854 next_seq += 1;
855 multicast::ForwardMessage {
856 sender: client.self_addr().clone(),
857 dests: vec![frame],
858 seq: next_seq,
859 last_seq,
860 message: envelope,
861 }
862 })
863 .await;
864 }
865
866 #[async_timed_test(timeout_secs = 1)]
867 async fn forward_v1_before_config_is_buffered_and_replayed() {
868 use ndslice::Region;
869 use ndslice::Slice;
870 use ndslice::selection::routing::RoutingFrame;
871
872 assert_buffered_and_replayed("test_fwd_v1", |client, actor_mesh_id, payload| {
873 let slice = Slice::new_row_major(vec![1]);
874 let region = Region::new(vec!["rank".to_string()], slice.clone());
875 let cast_msg = multicast::CastMessageV1::new::<TestActor, TestMessage>(
876 client.self_addr().clone(),
877 actor_mesh_id,
878 region.clone(),
879 hyperactor_config::Flattrs::new(),
880 TestMessage::Forward(payload.to_string()),
881 uuid::Uuid::new_v4(),
882 crate::ValueMesh::from_single(region, 1u64),
883 )
884 .unwrap();
885 let frame = RoutingFrame::root(sel!(*), slice);
886 multicast::ForwardMessageV1 {
887 dests: vec![frame],
888 message: cast_msg,
889 }
890 })
891 .await;
892 }
893
894 use hyperactor::ActorAddr;
895 use hyperactor::ActorRef;
896 use hyperactor::Index;
897 use hyperactor::OncePortRef;
898 use hyperactor::PortAddr;
899 use hyperactor::PortRef;
900 use hyperactor::ProcAddr;
901 use hyperactor::accum::Accumulator;
902 use hyperactor::accum::ReducerSpec;
903 use hyperactor::context;
904 use hyperactor::context::Mailbox;
905 use hyperactor::mailbox::PortReceiver;
906 use hyperactor::mailbox::open_port;
907 use hyperactor_config;
908 use hyperactor_mesh_macros::sel;
909 use maplit::btreemap;
910 use maplit::hashmap;
911 use ndslice::Extent;
912 use ndslice::Selection;
913 use ndslice::ViewExt as _;
914 use ndslice::extent;
915 use ndslice::selection::test_utils::collect_commactor_routing_tree;
916 use test_utils::*;
917 use timed_test::async_timed_test;
918 use tokio::time::Duration;
919
920 use super::*;
921 use crate::ActorMesh;
922 use crate::host_mesh::HostMesh;
923 use crate::test_utils::local_host_mesh;
924 use crate::testing;
925
926 fn lookup_rank(actor_id: &ActorAddr, rank_lookup: &HashMap<ProcAddr, usize>) -> usize {
928 let proc_id = actor_id.proc_addr();
929 *rank_lookup
930 .get(&proc_id)
931 .unwrap_or_else(|| panic!("proc rank not found for {}", proc_id))
932 }
933
934 struct Edge<T> {
935 from: T,
936 to: T,
937 is_leaf: bool,
938 }
939
940 impl<T> From<(T, T, bool)> for Edge<T> {
941 fn from((from, to, is_leaf): (T, T, bool)) -> Self {
942 Self { from, to, is_leaf }
943 }
944 }
945
946 static SPLIT_PORT_TREE: OnceLock<Mutex<Vec<Edge<PortAddr>>>> = OnceLock::new();
949
950 pub(crate) fn collect_split_port(original: &PortAddr, split: &PortAddr, deliver_here: bool) {
953 let mutex = SPLIT_PORT_TREE.get_or_init(|| Mutex::new(vec![]));
954 let mut tree = mutex.lock().unwrap();
955
956 tree.deref_mut().push(Edge {
957 from: original.clone(),
958 to: split.clone(),
959 is_leaf: deliver_here,
960 });
961 }
962
963 fn clear_collected_tree() {
967 if let Some(tree) = SPLIT_PORT_TREE.get() {
968 let mut tree: std::sync::MutexGuard<'_, Vec<Edge<PortAddr>>> = tree.lock().unwrap();
969 tree.clear();
970 }
971 }
972
973 #[derive(PartialEq)]
977 struct PathToLeaves<T>(BTreeMap<T, Vec<T>>);
978
979 impl<T: Display> Debug for PathToLeaves<T> {
981 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
982 fn vec_to_string<T: Display>(v: &[T]) -> String {
983 v.iter()
984 .map(ToString::to_string)
985 .collect::<Vec<String>>()
986 .join(", ")
987 }
988
989 for (src, path) in &self.0 {
990 writeln!(f, "{} -> {}", src, vec_to_string(path))?;
991 }
992 Ok(())
993 }
994 }
995
996 fn build_paths<T: Clone + Eq + Hash + Ord>(edges: &[Edge<T>]) -> PathToLeaves<T> {
997 let mut child_parent_map = HashMap::new();
998 let mut all_nodes = HashSet::new();
999 let mut parents = HashSet::new();
1000 let mut children = HashSet::new();
1001 let mut dests = HashSet::new();
1002
1003 for Edge { from, to, is_leaf } in edges {
1005 child_parent_map.insert(to.clone(), from.clone());
1006 all_nodes.insert(from.clone());
1007 all_nodes.insert(to.clone());
1008 parents.insert(from.clone());
1009 children.insert(to.clone());
1010 if *is_leaf {
1011 dests.insert(to.clone());
1012 }
1013 }
1014
1015 let mut result = BTreeMap::new();
1017 for dest in dests {
1018 let mut path = vec![dest.clone()];
1019 let mut current = dest.clone();
1020 while let Some(parent) = child_parent_map.get(¤t) {
1021 path.push(parent.clone());
1022 current = parent.clone();
1023 }
1024 path.reverse();
1025 result.insert(dest, path);
1026 }
1027
1028 PathToLeaves(result)
1029 }
1030
1031 #[test]
1032 fn test_build_paths() {
1033 let edges: Vec<_> = [
1040 (0, 1, false),
1041 (1, 2, true),
1042 (1, 3, true),
1043 (0, 4, true),
1044 (4, 5, true),
1045 ]
1046 .into_iter()
1047 .map(|(from, to, is_leaf)| Edge { from, to, is_leaf })
1048 .collect();
1049
1050 let paths = build_paths(&edges);
1051
1052 let expected = btreemap! {
1053 2 => vec![0, 1, 2],
1054 3 => vec![0, 1, 3],
1055 4 => vec![0, 4],
1056 5 => vec![0, 4, 5],
1057 };
1058
1059 assert_eq!(paths.0, expected);
1060 }
1061
1062 fn get_ranks(
1082 paths: PathToLeaves<PortAddr>,
1083 client_reply: &PortAddr,
1084 rank_lookup: &HashMap<ProcAddr, usize>,
1085 ) -> PathToLeaves<Index> {
1086 let ranks = paths
1087 .0
1088 .into_iter()
1089 .map(|(dst, mut path): (PortAddr, Vec<PortAddr>)| {
1090 let first = path.remove(0);
1091 assert_eq!(&first, client_reply);
1093 assert!(dst.actor_addr().label().unwrap().as_str().contains("comm"));
1096 let actor_path = path
1097 .into_iter()
1098 .map(|p: PortAddr| {
1099 assert!(p.actor_addr().label().unwrap().as_str().contains("comm"));
1100 let actor_ref = p.actor_addr();
1101 lookup_rank(&actor_ref, rank_lookup)
1102 })
1103 .collect();
1104 let dst_actor_ref = dst.actor_addr();
1105 (lookup_rank(&dst_actor_ref, rank_lookup), actor_path)
1106 })
1107 .collect();
1108 PathToLeaves(ranks)
1109 }
1110
1111 struct NoneAccumulator;
1112
1113 impl Accumulator for NoneAccumulator {
1114 type State = u64;
1115 type Update = u64;
1116
1117 fn accumulate(
1118 &self,
1119 _state: &mut Self::State,
1120 _update: Self::Update,
1121 ) -> anyhow::Result<()> {
1122 unimplemented!()
1123 }
1124
1125 fn reducer_spec(&self) -> Option<ReducerSpec> {
1126 unimplemented!()
1127 }
1128 }
1129
1130 fn verify_split_port_paths(
1132 selection: &Selection,
1133 extent: &Extent,
1134 reply_port_ref1: &PortRef<u64>,
1135 reply_port_ref2: &PortRef<MyReply>,
1136 rank_lookup: &HashMap<ProcAddr, usize>,
1137 ) {
1138 let sel_paths = PathToLeaves(
1140 collect_commactor_routing_tree(selection, &extent.to_slice())
1141 .delivered
1142 .into_iter()
1143 .collect(),
1144 );
1145
1146 let (reply1_paths, reply2_paths) = {
1148 let tree = SPLIT_PORT_TREE
1149 .get()
1150 .expect("not initialized; are Hosts in the same process as SPLIT_PORT_TREE?");
1151 let edges = tree.lock().unwrap();
1152 let (reply1, reply2): (BTreeMap<_, _>, BTreeMap<_, _>) = build_paths(&edges)
1153 .0
1154 .into_iter()
1155 .partition(|(_dst, path)| path[0] == *reply_port_ref1.port_addr());
1156 (
1157 get_ranks(
1158 PathToLeaves(reply1),
1159 reply_port_ref1.port_addr(),
1160 rank_lookup,
1161 ),
1162 get_ranks(
1163 PathToLeaves(reply2),
1164 reply_port_ref2.port_addr(),
1165 rank_lookup,
1166 ),
1167 )
1168 };
1169
1170 assert_eq!(sel_paths, reply1_paths);
1172 assert_eq!(sel_paths, reply2_paths);
1173 }
1174
1175 async fn execute_cast_and_reply(
1176 ranks: Vec<ActorRef<TestActor>>,
1177 instance: &impl context::Actor,
1178 mut reply1_rx: PortReceiver<u64>,
1179 mut reply2_rx: PortReceiver<MyReply>,
1180 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1181 ) {
1182 {
1184 for (rank, (dest_actor, (reply_to1, reply_to2))) in
1185 ranks.iter().zip(reply_tos.iter()).enumerate()
1186 {
1187 let rank_u64 = rank as u64;
1188 reply_to1.post(instance, rank_u64);
1189 let my_reply = MyReply {
1190 sender: dest_actor.actor_addr().clone(),
1191 value: rank_u64,
1192 };
1193 reply_to2.post(instance, my_reply.clone());
1194
1195 assert_eq!(reply1_rx.recv().await.unwrap(), rank_u64);
1196 assert_eq!(reply2_rx.recv().await.unwrap(), my_reply);
1197 }
1198 }
1199
1200 tracing::info!("the 1st updates from all dest actors were receivered by client");
1201
1202 {
1206 let n = 100;
1207 let mut expected2: HashMap<ActorAddr, Vec<MyReply>> = hashmap! {};
1208 for (i, (dest_actor, (_reply_to1, reply_to2))) in
1209 ranks.iter().zip(reply_tos.iter()).enumerate()
1210 {
1211 let mut sent2 = vec![];
1212 for j in 0..n {
1213 let value = (i * 100 + j) as u64;
1214 let my_reply = MyReply {
1215 sender: dest_actor.actor_addr().clone(),
1216 value,
1217 };
1218 reply_to2.post(instance, my_reply.clone());
1219 sent2.push(my_reply);
1220 }
1221 assert!(
1222 expected2
1223 .insert(dest_actor.actor_addr().clone(), sent2)
1224 .is_none(),
1225 "duplicate actor_id {} in map",
1226 dest_actor.actor_addr()
1227 );
1228 }
1229
1230 let mut received2: HashMap<ActorAddr, Vec<MyReply>> = hashmap! {};
1231
1232 for _ in 0..(n * ranks.len()) {
1233 let my_reply = reply2_rx.recv().await.unwrap();
1234 received2
1235 .entry(my_reply.sender.clone())
1236 .or_default()
1237 .push(my_reply);
1238 }
1239 assert_eq!(received2, expected2);
1240 }
1241 }
1242
1243 async fn wait_for_with_timeout(
1244 receiver: &mut PortReceiver<u64>,
1245 expected: u64,
1246 dur: Duration,
1247 ) -> anyhow::Result<()> {
1248 tokio::time::timeout(dur, async {
1250 loop {
1251 let msg = receiver.recv().await.unwrap();
1252 if msg == expected {
1253 break;
1254 }
1255 }
1256 })
1257 .await?;
1258 Ok(())
1259 }
1260
1261 async fn execute_cast_and_accum(
1262 ranks: Vec<ActorRef<TestActor>>,
1263 instance: &impl context::Actor,
1264 mut reply1_rx: PortReceiver<u64>,
1265 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1266 ) {
1267 let mut sum = 0;
1271 let n = 100;
1272 for (i, (_dest_actor, (reply_to1, _reply_to2))) in
1273 ranks.iter().zip(reply_tos.iter()).enumerate()
1274 {
1275 for j in 0..n {
1276 let value = (i + j) as u64;
1277 reply_to1.post(instance, value);
1278 sum += value;
1279 }
1280 }
1281 wait_for_with_timeout(&mut reply1_rx, sum, Duration::from_secs(8))
1282 .await
1283 .unwrap();
1284 tokio::time::sleep(Duration::from_secs(2)).await;
1286 let msg = reply1_rx.try_recv().unwrap();
1287 assert_eq!(msg, None);
1288 }
1289
1290 struct MeshSetupV1 {
1291 instance: &'static Instance<testing::TestRootClient>,
1292 actor_mesh_ref: crate::ActorMeshRef<TestActor>,
1293 reply1_rx: PortReceiver<u64>,
1294 reply2_rx: PortReceiver<MyReply>,
1295 reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1296 host_mesh: HostMesh,
1298 }
1299
1300 async fn setup_mesh_v1<A>(accum: Option<A>) -> MeshSetupV1
1301 where
1302 A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
1303 {
1304 let instance = crate::testing::instance();
1305 let host_mesh = local_host_mesh(8).await;
1308 let proc_mesh = host_mesh
1309 .spawn(instance, "test", extent!(gpu = 8), None, None)
1310 .await
1311 .unwrap();
1312
1313 let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1314 let params = TestActorParams {
1315 forward_port: tx.bind(),
1316 };
1317 let actor_name =
1318 crate::mesh_id::ActorMeshId::instance(hyperactor::id::Label::new("test").unwrap());
1319 let actor_mesh = proc_mesh
1323 .spawn_with_name(&instance, actor_name, ¶ms, None, true)
1324 .await
1325 .unwrap();
1326 let actor_mesh_ref: crate::ActorMeshRef<TestActor> = actor_mesh.deref().clone();
1327
1328 let (reply_port_handle0, _) = open_port::<String>(instance);
1329 let reply_port_ref0 = reply_port_handle0.bind();
1330 let (reply_port_handle1, reply1_rx) = match accum {
1331 Some(a) => instance.mailbox().open_accum_port(a),
1332 None => open_port(instance),
1333 };
1334 let reply_port_ref1 = reply_port_handle1.bind();
1335 let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(instance);
1336 let reply_port_ref2 = reply_port_handle2.bind();
1337 let message = TestMessage::CastAndReply {
1338 arg: "abc".to_string(),
1339 reply_to0: reply_port_ref0.clone(),
1340 reply_to1: reply_port_ref1.clone(),
1341 reply_to2: reply_port_ref2.clone(),
1342 };
1343
1344 clear_collected_tree();
1345 actor_mesh_ref.cast(instance, message).unwrap();
1346
1347 let mut reply_tos = vec![];
1348 for _ in proc_mesh.extent().points() {
1349 let msg = rx.recv().await.expect("missing");
1350 match msg {
1351 TestMessage::CastAndReply {
1352 arg,
1353 reply_to0,
1354 reply_to1,
1355 reply_to2,
1356 } => {
1357 assert_eq!(arg, "abc");
1358 assert_eq!(reply_to0, reply_port_ref0);
1361 assert_ne!(reply_to1, reply_port_ref1);
1363 assert_ne!(reply_to2, reply_port_ref2);
1364 let p2p_threshold = hyperactor_config::global::get(
1365 crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD,
1366 );
1367 if p2p_threshold == 0 {
1368 assert!(
1370 reply_to1
1371 .port_addr()
1372 .actor_id()
1373 .label()
1374 .unwrap()
1375 .as_str()
1376 .contains("comm")
1377 );
1378 assert!(
1379 reply_to2
1380 .port_addr()
1381 .actor_id()
1382 .label()
1383 .unwrap()
1384 .as_str()
1385 .contains("comm")
1386 );
1387 }
1388 reply_tos.push((reply_to1, reply_to2));
1389 }
1390 _ => {
1391 panic!("unexpected message: {:?}", msg);
1392 }
1393 }
1394 }
1395
1396 let p2p_threshold =
1399 hyperactor_config::global::get(crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD);
1400 if p2p_threshold == 0 {
1401 let rank_lookup = proc_mesh
1404 .ranks()
1405 .iter()
1406 .enumerate()
1407 .map(|(i, r)| (r.proc_addr().clone(), i))
1408 .collect::<HashMap<ProcAddr, usize>>();
1409
1410 let selection = sel!(*);
1412 verify_split_port_paths(
1413 &selection,
1414 &proc_mesh.extent(),
1415 &reply_port_ref1,
1416 &reply_port_ref2,
1417 &rank_lookup,
1418 );
1419 }
1420
1421 MeshSetupV1 {
1422 instance,
1423 actor_mesh_ref,
1424 reply1_rx,
1425 reply2_rx,
1426 reply_tos,
1427 host_mesh,
1428 }
1429 }
1430
1431 async fn execute_cast_and_reply_v1() {
1432 let mut setup = setup_mesh_v1::<NoneAccumulator>(None).await;
1433
1434 let ranks = setup.actor_mesh_ref.values().collect::<Vec<_>>();
1435 execute_cast_and_reply(
1436 ranks,
1437 setup.instance,
1438 setup.reply1_rx,
1439 setup.reply2_rx,
1440 setup.reply_tos,
1441 )
1442 .await;
1443
1444 let _ = setup.host_mesh.shutdown(setup.instance).await;
1445 }
1446
1447 #[async_timed_test(timeout_secs = 60)]
1448 async fn test_cast_and_reply_v1_retrofit() {
1449 let config = hyperactor_config::global::lock();
1450 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false);
1451 let _guard2 = config.override_key(
1452 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1453 false,
1454 );
1455 execute_cast_and_reply_v1().await
1456 }
1457
1458 #[async_timed_test(timeout_secs = 60)]
1459 async fn test_cast_and_reply_v1_native() {
1460 let config = hyperactor_config::global::lock();
1461 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1462 let _guard2 = config.override_key(
1463 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1464 true,
1465 );
1466 execute_cast_and_reply_v1().await
1467 }
1468
1469 #[async_timed_test(timeout_secs = 60)]
1470 async fn test_cast_and_reply_v1_native_p2p() {
1471 let config = hyperactor_config::global::lock();
1472 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1473 let _guard2 = config.override_key(
1474 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1475 true,
1476 );
1477 let _guard3 = config.override_key(crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD, 1024);
1478 execute_cast_and_reply_v1().await
1479 }
1480
1481 async fn execute_cast_and_accum_v1(config: &hyperactor_config::global::ConfigLock) {
1482 let _guard1 = config.override_key(hyperactor::config::SPLIT_MAX_BUFFER_SIZE, 1);
1484
1485 let mut setup = setup_mesh_v1(Some(accum::sum::<u64>())).await;
1486
1487 let ranks = setup.actor_mesh_ref.values().collect::<Vec<_>>();
1488 execute_cast_and_accum(ranks, setup.instance, setup.reply1_rx, setup.reply_tos).await;
1489
1490 let _ = setup.host_mesh.shutdown(setup.instance).await;
1491 }
1492
1493 #[async_timed_test(timeout_secs = 60)]
1494 async fn test_cast_and_accum_v1_retrofit() {
1495 let config = hyperactor_config::global::lock();
1496 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false);
1497 let _guard2 = config.override_key(
1498 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1499 false,
1500 );
1501 execute_cast_and_accum_v1(&config).await
1502 }
1503
1504 #[async_timed_test(timeout_secs = 60)]
1505 async fn test_cast_and_accum_v1_native() {
1506 let config = hyperactor_config::global::lock();
1507 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1508 let _guard2 = config.override_key(
1509 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1510 true,
1511 );
1512 execute_cast_and_accum_v1(&config).await
1513 }
1514
1515 #[async_timed_test(timeout_secs = 60)]
1516 async fn test_cast_and_accum_v1_native_p2p() {
1517 let config = hyperactor_config::global::lock();
1518 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1519 let _guard2 = config.override_key(
1520 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1521 true,
1522 );
1523 let _guard3 = config.override_key(crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD, 1024);
1524 execute_cast_and_accum_v1(&config).await
1525 }
1526
1527 struct OncePortMeshSetupV1 {
1528 instance: &'static Instance<testing::TestRootClient>,
1529 reply_rx: hyperactor::mailbox::OncePortReceiver<u64>,
1530 reply_tos: Vec<OncePortRef<u64>>,
1531 _reply_port_ref: OncePortRef<u64>,
1532 host_mesh: HostMesh,
1533 }
1534
1535 async fn setup_once_port_mesh<A>(reducer: Option<A>) -> OncePortMeshSetupV1
1536 where
1537 A: Accumulator<State = u64, Update = u64> + Send + Sync + 'static,
1538 {
1539 let instance = crate::testing::instance();
1540 let host_mesh = local_host_mesh(8).await;
1543 let proc_mesh = host_mesh
1544 .spawn(instance, "test", extent!(gpu = 8), None, None)
1545 .await
1546 .unwrap();
1547
1548 let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1549 let params = TestActorParams {
1550 forward_port: tx.bind(),
1551 };
1552 let actor_name =
1553 crate::mesh_id::ActorMeshId::instance(hyperactor::id::Label::new("test").unwrap());
1554 let actor_mesh: crate::ActorMesh<TestActor> = proc_mesh
1556 .spawn_with_name(&instance, actor_name, ¶ms, None, true)
1557 .await
1558 .unwrap();
1559 let actor_mesh_ref = actor_mesh.deref().clone();
1560
1561 let has_reducer = reducer.is_some();
1562 let (reply_port_handle, reply_rx) = match reducer {
1563 Some(reducer) => instance.mailbox().open_reduce_port(reducer),
1564 None => instance.mailbox().open_once_port::<u64>(),
1565 };
1566 let reply_port_ref = reply_port_handle.bind();
1567
1568 let message = TestMessage::CastAndReplyOnce {
1569 arg: "abc".to_string(),
1570 reply_to: reply_port_ref.clone(),
1571 };
1572
1573 clear_collected_tree();
1574 actor_mesh_ref.cast(instance, message).unwrap();
1575
1576 let mut reply_tos = vec![];
1577 for _ in proc_mesh.extent().points() {
1578 let msg = rx.recv().await.expect("missing");
1579 match msg {
1580 TestMessage::CastAndReplyOnce { arg, reply_to } => {
1581 assert_eq!(arg, "abc");
1582 if has_reducer {
1583 assert_ne!(reply_to, reply_port_ref);
1585 let p2p_threshold = hyperactor_config::global::get(
1586 crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD,
1587 );
1588 if p2p_threshold == 0 {
1589 assert!(
1591 reply_to
1592 .port_addr()
1593 .actor_id()
1594 .label()
1595 .unwrap()
1596 .as_str()
1597 .contains("comm")
1598 );
1599 }
1600 } else {
1601 assert_eq!(reply_to, reply_port_ref);
1603 }
1604 reply_tos.push(reply_to);
1605 }
1606 _ => {
1607 panic!("unexpected message: {:?}", msg);
1608 }
1609 }
1610 }
1611
1612 OncePortMeshSetupV1 {
1613 instance,
1614 reply_rx,
1615 reply_tos,
1616 _reply_port_ref: reply_port_ref,
1617 host_mesh,
1618 }
1619 }
1620
1621 async fn execute_cast_and_reply_once_v1() {
1622 let mut setup = setup_once_port_mesh::<NoneAccumulator>(None).await;
1626
1627 let num_replies = setup.reply_tos.len();
1630 for (i, reply_to) in setup.reply_tos.into_iter().enumerate() {
1631 reply_to.post(setup.instance, i as u64);
1632 }
1633
1634 let result = setup.reply_rx.recv().await.unwrap();
1636 assert!(result < num_replies as u64);
1638
1639 let _ = setup.host_mesh.shutdown(setup.instance).await;
1640 }
1641
1642 #[async_timed_test(timeout_secs = 60)]
1643 async fn test_cast_and_reply_once_v1() {
1644 execute_cast_and_reply_once_v1().await
1645 }
1646
1647 #[async_timed_test(timeout_secs = 60)]
1648 async fn test_cast_and_reply_once_v1_p2p() {
1649 let config = hyperactor_config::global::lock();
1650 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1651 let _guard2 = config.override_key(
1652 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1653 true,
1654 );
1655 let _guard3 = config.override_key(crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD, 1024);
1656 execute_cast_and_reply_once_v1().await
1657 }
1658
1659 async fn execute_cast_and_accum_once_v1() {
1660 let mut setup = setup_once_port_mesh(Some(accum::sum::<u64>())).await;
1664
1665 let mut expected_sum = 0u64;
1667 for (i, reply_to) in setup.reply_tos.into_iter().enumerate() {
1668 reply_to.post(setup.instance, i as u64);
1669 expected_sum += i as u64;
1670 }
1671
1672 let result = setup.reply_rx.recv().await.unwrap();
1674 assert_eq!(result, expected_sum);
1675
1676 let _ = setup.host_mesh.shutdown(setup.instance).await;
1677 }
1678
1679 #[async_timed_test(timeout_secs = 60)]
1680 async fn test_cast_and_accum_once_v1() {
1681 execute_cast_and_accum_once_v1().await
1682 }
1683
1684 #[async_timed_test(timeout_secs = 60)]
1685 async fn test_cast_and_accum_once_v1_p2p() {
1686 let config = hyperactor_config::global::lock();
1687 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1688 let _guard2 = config.override_key(
1689 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1690 true,
1691 );
1692 let _guard3 = config.override_key(crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD, 1024);
1693 execute_cast_and_accum_once_v1().await
1694 }
1695
1696 #[async_timed_test(timeout_secs = 60)]
1697 async fn test_unsplit_port_not_split() {
1698 let instance = crate::testing::instance();
1699 let mut host_mesh = local_host_mesh(8).await;
1700 let proc_mesh = host_mesh
1701 .spawn(instance, "test", extent!(gpu = 8), None, None)
1702 .await
1703 .unwrap();
1704
1705 let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1706 let params = TestActorParams {
1707 forward_port: tx.bind(),
1708 };
1709 let actor_name =
1710 crate::mesh_id::ActorMeshId::instance(hyperactor::id::Label::new("test").unwrap());
1711 let actor_mesh: ActorMesh<TestActor> = proc_mesh
1712 .spawn_with_name(&instance, actor_name, ¶ms, None, true)
1713 .await
1714 .unwrap();
1715 let (reply_port_handle, mut reply_rx) = open_port::<u64>(instance);
1716 let reply_port_ref = reply_port_handle.bind().unsplit();
1717
1718 let message = TestMessage::CastWithUnsplitPort {
1719 reply_to: reply_port_ref.clone(),
1720 };
1721
1722 clear_collected_tree();
1723 actor_mesh.cast(instance, message).unwrap();
1724
1725 let num_points = proc_mesh.extent().points().count();
1727 for _ in 0..num_points {
1728 let msg = rx.recv().await.expect("missing");
1729 match msg {
1730 TestMessage::CastWithUnsplitPort { reply_to } => {
1731 assert_eq!(
1732 reply_to.port_addr(),
1733 reply_port_ref.port_addr(),
1734 "unsplit port should not be replaced by a comm actor split port"
1735 );
1736 }
1737 _ => panic!("unexpected message: {:?}", msg),
1738 }
1739 }
1740
1741 for _ in 0..8 {
1744 let val = reply_rx.recv().await.unwrap();
1745 assert_eq!(val, 42);
1746 }
1747 let _ = host_mesh.shutdown(instance).await;
1748 }
1749}