1use crate::casting::CAST_ACTOR_MESH_ID;
10use crate::comm::multicast::CAST_ORIGINATING_SENDER;
11use crate::comm::multicast::CastEnvelope;
12use crate::comm::multicast::CastMessageV1;
13use crate::comm::multicast::ForwardMessageV1;
14use crate::reference::ActorMeshId;
15use crate::resource;
16pub mod multicast;
17
18use std::cmp::Ordering;
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22use anyhow::Result;
23use async_trait::async_trait;
24use hyperactor::Actor;
25use hyperactor::Context;
26use hyperactor::Handler;
27use hyperactor::Instance;
28use hyperactor::RemoteMessage;
29use hyperactor::accum::ReducerMode;
30use hyperactor::mailbox::DeliveryError;
31use hyperactor::mailbox::MailboxSender;
32use hyperactor::mailbox::Undeliverable;
33use hyperactor::mailbox::UndeliverableMailboxSender;
34use hyperactor::mailbox::UndeliverableMessageError;
35use hyperactor::mailbox::monitored_return_handle;
36use hyperactor::message::ErasedUnbound;
37use hyperactor::ordering::SEQ_INFO;
38use hyperactor::ordering::SeqInfo;
39use hyperactor::reference as hyperactor_reference;
40use hyperactor_config::CONFIG;
41use hyperactor_config::ConfigAttr;
42use hyperactor_config::Flattrs;
43use hyperactor_config::attrs::declare_attrs;
44use hyperactor_mesh_macros::sel;
45use ndslice::Point;
46use ndslice::Selection;
47use ndslice::View;
48use ndslice::selection::routing::RoutingFrame;
49use serde::Deserialize;
50use serde::Serialize;
51use typeuri::Named;
52
53use crate::comm::multicast::CastMessage;
54use crate::comm::multicast::CastMessageEnvelope;
55use crate::comm::multicast::ForwardMessage;
56use crate::comm::multicast::set_cast_info_on_headers;
57
58declare_attrs! {
59 @meta(CONFIG = ConfigAttr::new(
61 Some("HYPERACTOR_MESH_ENABLE_NATIVE_V1_CASTING".to_string()),
62 Some("enable_native_v1_casting".to_string()),
63 ))
64 pub attr ENABLE_NATIVE_V1_CASTING: bool = false;
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize, Named, Default)]
69pub struct CommActorParams {}
70wirevalue::register_type!(CommActorParams);
71
72#[derive(Debug)]
74struct Buffered {
75 seq: usize,
77 deliver_here: bool,
79 next_steps: HashMap<usize, Vec<RoutingFrame>>,
81 message: CastMessageEnvelope,
83}
84
85#[derive(Debug, Default)]
88struct ReceiveState {
89 seq: usize,
91 buffer: HashMap<usize, Buffered>,
94 last_seqs: HashMap<usize, usize>,
96}
97
98#[derive(Debug, Default)]
101#[hyperactor::export(
102 spawn = true,
103 handlers = [
104 CommMeshConfig,
105 CastMessage,
106 ForwardMessage,
107 CastMessageV1,
108 ForwardMessageV1,
109 ],
110)]
111pub struct CommActor {
112 send_seq: HashMap<(ActorMeshId, hyperactor_reference::ActorId), usize>,
114 recv_state: HashMap<(ActorMeshId, hyperactor_reference::ActorId), ReceiveState>,
116
117 mesh_config: MeshConfigState,
119}
120
121#[derive(Debug)]
122enum PendingMessage {
123 Cast(CastMessage),
124 Forward(ForwardMessage),
125 ForwardV1(ForwardMessageV1),
126}
127
128#[derive(Debug)]
129enum MeshConfigState {
130 NotConfigured(Vec<PendingMessage>),
132 Configured(CommMeshConfig),
134}
135
136impl Default for MeshConfigState {
137 fn default() -> Self {
138 MeshConfigState::NotConfigured(Vec::new())
139 }
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize, Named)]
144pub struct CommMeshConfig {
145 rank: usize,
147 peers: HashMap<usize, hyperactor_reference::ActorRef<CommActor>>,
149}
150wirevalue::register_type!(CommMeshConfig);
151
152impl CommMeshConfig {
153 pub fn new(
155 rank: usize,
156 peers: HashMap<usize, hyperactor_reference::ActorRef<CommActor>>,
157 ) -> Self {
158 Self { rank, peers }
159 }
160
161 fn peer_for_rank(&self, rank: usize) -> Result<hyperactor_reference::ActorRef<CommActor>> {
163 self.peers
164 .get(&rank)
165 .cloned()
166 .ok_or_else(|| anyhow::anyhow!("no peer for rank {}", rank))
167 }
168
169 fn self_rank(&self) -> usize {
171 self.rank
172 }
173}
174
175#[async_trait]
176impl Actor for CommActor {
177 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
178 this.set_system();
179 Ok(())
180 }
181
182 async fn handle_undeliverable_message(
184 &mut self,
185 cx: &Instance<Self>,
186 undelivered: hyperactor::mailbox::Undeliverable<hyperactor::mailbox::MessageEnvelope>,
187 ) -> Result<(), anyhow::Error> {
188 let Undeliverable(mut message_envelope) = undelivered;
189
190 if let Ok(ForwardMessage { message, .. }) =
192 message_envelope.deserialized::<ForwardMessage>()
193 {
194 let sender = message.sender();
195 let return_port = hyperactor_reference::PortRef::attest_message_port(sender);
196 message_envelope.set_error(DeliveryError::Multicast(format!(
197 "comm actor {} failed to forward the cast message; returning to origin {}",
198 cx.self_id(),
199 return_port.port_id(),
200 )));
201
202 message_envelope.set_header(CAST_ORIGINATING_SENDER, sender.clone());
205
206 return_port
207 .send(cx, Undeliverable(message_envelope.clone()))
208 .map_err(|err| {
209 let error = DeliveryError::BrokenLink(format!(
210 "error occured when returning ForwardMessage to the original \
211 sender's port {}; error is: {}",
212 return_port.port_id(),
213 err,
214 ));
215 message_envelope.set_error(error);
216 UndeliverableMessageError::ReturnFailure {
217 envelope: message_envelope,
218 }
219 })?;
220 return Ok(());
221 }
222
223 if let Some(sender) = message_envelope.headers().get(CAST_ORIGINATING_SENDER) {
225 let return_port = hyperactor_reference::PortRef::attest_message_port(&sender);
226 message_envelope.set_error(DeliveryError::Multicast(format!(
227 "comm actor {} failed to deliver the cast message to the dest \
228 actor; returning to origin {}",
229 cx.self_id(),
230 return_port.port_id(),
231 )));
232 return_port
233 .send(cx, Undeliverable(message_envelope.clone()))
234 .map_err(|err| {
235 let error = DeliveryError::BrokenLink(format!(
236 "error occured when returning cast message to the origin \
237 sender {}; error is: {}",
238 return_port.port_id(),
239 err,
240 ));
241 message_envelope.set_error(error);
242 UndeliverableMessageError::ReturnFailure {
243 envelope: message_envelope,
244 }
245 })?;
246 return Ok(());
247 }
248
249 UndeliverableMailboxSender
251 .post(message_envelope, monitored_return_handle());
252 Ok(())
253 }
254}
255
256impl CommActor {
257 fn forward<M: RemoteMessage>(
259 cx: &Context<Self>,
260 config: &CommMeshConfig,
261 rank: usize,
262 message: M,
263 ) -> Result<()>
264 where
265 CommActor: hyperactor::RemoteHandles<M>,
266 {
267 let child = config.peer_for_rank(rank)?;
268 if let Some(cast_actor_mesh_id) = cx.headers().get(CAST_ACTOR_MESH_ID) {
270 let mut headers = Flattrs::new();
271 headers.set(CAST_ACTOR_MESH_ID, cast_actor_mesh_id);
272 child.send_with_headers(cx, headers, message)?;
273 } else {
274 child.send(cx, message)?;
275 }
276 Ok(())
277 }
278
279 fn handle_message(
280 cx: &Context<Self>,
281 config: &CommMeshConfig,
282 deliver_here: bool,
283 next_steps: HashMap<usize, Vec<RoutingFrame>>,
284 sender: hyperactor_reference::ActorId,
285 mut message: CastMessageEnvelope,
286 seq: usize,
287 last_seqs: &mut HashMap<usize, usize>,
288 ) -> Result<()> {
289 split_ports(cx, message.data_mut(), deliver_here, &next_steps)?;
290
291 if deliver_here {
293 let headers = message.headers().clone();
297 Self::deliver_to_dest(cx, headers, &mut message, config)?;
298 }
299
300 next_steps
302 .into_iter()
303 .map(|(peer, dests)| {
304 let last_seq = last_seqs.entry(peer).or_default();
305 Self::forward(
306 cx,
307 config,
308 peer,
309 ForwardMessage {
310 dests,
311 sender: sender.clone(),
312 message: message.clone(),
313 seq,
314 last_seq: *last_seq,
315 },
316 )?;
317 *last_seq = seq;
318 Ok(())
319 })
320 .collect::<Result<Vec<_>>>()?;
321
322 Ok(())
323 }
324
325 fn deliver_to_dest<M: CastEnvelope>(
326 cx: &Context<Self>,
327 mut headers: Flattrs,
328 message: &mut M,
329 config: &CommMeshConfig,
330 ) -> anyhow::Result<()> {
331 let cast_point = message.cast_point(config)?;
332 replace_with_self_ranks(&cast_point, message.data_mut())?;
334
335 set_cast_info_on_headers(&mut headers, cast_point, message.sender().clone());
336 cx.post_with_external_seq_info(
337 cx.self_id()
338 .proc_id()
339 .actor_id(message.dest_port().actor_name(), 0)
340 .port_id(message.dest_port().port()),
341 headers,
342 wirevalue::Any::serialize(message.data())?,
343 );
344
345 Ok(())
346 }
347}
348
349fn split_ports(
353 cx: &Context<CommActor>,
354 data: &mut ErasedUnbound,
355 deliver_here: bool,
356 next_steps: &HashMap<usize, Vec<RoutingFrame>>,
357) -> anyhow::Result<()> {
358 data.visit_mut::<hyperactor_reference::UnboundPort>(
362 |hyperactor_reference::UnboundPort(
363 port_id,
364 reducer_spec,
365 return_undeliverable,
366 kind,
367 unsplit,
368 )| {
369 if *unsplit {
370 return Ok(());
371 }
372 let reducer_mode = match kind {
373 hyperactor_reference::UnboundPortKind::Streaming(opts) => {
374 ReducerMode::Streaming(opts.clone().unwrap_or_default())
375 }
376 hyperactor_reference::UnboundPortKind::Once if reducer_spec.is_none() => {
377 return Ok(());
384 }
385 hyperactor_reference::UnboundPortKind::Once => {
386 let peer_count = next_steps.len() + if deliver_here { 1 } else { 0 };
390 ReducerMode::Once(peer_count)
391 }
392 };
393
394 let split = port_id.split(
395 cx,
396 reducer_spec.clone(),
397 reducer_mode,
398 *return_undeliverable,
399 )?;
400
401 #[cfg(test)]
402 tests::collect_split_port(port_id, &split, deliver_here);
403
404 *port_id = split;
405 Ok(())
406 },
407 )
408}
409
410fn replace_with_self_ranks(cast_point: &Point, data: &mut ErasedUnbound) -> anyhow::Result<()> {
411 data.visit_mut::<resource::Rank>(|resource::Rank(rank)| {
412 *rank = Some(cast_point.rank());
413 Ok(())
414 })
415}
416
417#[async_trait]
418impl Handler<CommMeshConfig> for CommActor {
419 async fn handle(&mut self, cx: &Context<Self>, config: CommMeshConfig) -> Result<()> {
420 let pending =
421 match std::mem::replace(&mut self.mesh_config, MeshConfigState::Configured(config)) {
422 MeshConfigState::NotConfigured(pending) => pending,
423 MeshConfigState::Configured(_) => Vec::new(),
424 };
425 if !pending.is_empty() {
426 tracing::info!(
427 count = pending.len(),
428 "replaying buffered pre-config messages"
429 );
430 }
431 for msg in pending {
432 match msg {
433 PendingMessage::Cast(m) => self.handle(cx, m).await?,
434 PendingMessage::Forward(m) => self.handle(cx, m).await?,
435 PendingMessage::ForwardV1(m) => self.handle(cx, m).await?,
436 }
437 }
438 Ok(())
439 }
440}
441
442#[async_trait]
444impl Handler<CastMessage> for CommActor {
445 #[tracing::instrument(level = "debug", skip_all)]
446 async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessage) -> Result<()> {
447 let config = match &mut self.mesh_config {
448 MeshConfigState::NotConfigured(pending) => {
449 pending.push(PendingMessage::Cast(cast_message));
450 return Ok(());
451 }
452 MeshConfigState::Configured(config) => config,
453 };
454 let slice = cast_message.dest.slice.clone();
456 let selection = cast_message.dest.selection.clone();
457 let frame = RoutingFrame::root(selection, slice);
458 let rank = frame.slice.location(&frame.here)?;
459 let seq = self
460 .send_seq
461 .entry(cast_message.message.stream_key())
462 .or_default();
463 let last_seq = *seq;
464 *seq += 1;
465
466 let fwd_message = ForwardMessage {
467 dests: vec![frame],
468 sender: cx.self_id().clone(),
469 message: cast_message.message,
470 seq: *seq,
471 last_seq,
472 };
473
474 if config.self_rank() == rank {
477 Handler::<ForwardMessage>::handle(self, cx, fwd_message).await?;
478 } else {
479 Self::forward(cx, config, rank, fwd_message)?;
480 }
481 Ok(())
482 }
483}
484
485#[async_trait]
486impl Handler<ForwardMessage> for CommActor {
487 #[tracing::instrument(level = "debug", skip_all)]
488 async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessage) -> Result<()> {
489 let config = match &mut self.mesh_config {
490 MeshConfigState::NotConfigured(pending) => {
491 pending.push(PendingMessage::Forward(fwd_message));
492 return Ok(());
493 }
494 MeshConfigState::Configured(config) => config,
495 };
496
497 let ForwardMessage {
498 sender,
499 dests,
500 message,
501 seq,
502 last_seq,
503 } = fwd_message;
504
505 let rank = config.self_rank();
507 let (deliver_here, next_steps) =
508 ndslice::selection::routing::resolve_routing(rank, dests, &mut |_| {
509 panic!("Choice encountered in CommActor routing")
510 })?;
511
512 let recv_state = self.recv_state.entry(message.stream_key()).or_default();
513 match recv_state.seq.cmp(&last_seq) {
514 Ordering::Equal => {
516 Self::handle_message(
518 cx,
519 config,
520 deliver_here,
521 next_steps,
522 sender.clone(),
523 message,
524 seq,
525 &mut recv_state.last_seqs,
526 )?;
527 recv_state.seq = seq;
528
529 while let Some(Buffered {
532 seq,
533 deliver_here,
534 next_steps,
535 message,
536 }) = recv_state.buffer.remove(&recv_state.seq)
537 {
538 Self::handle_message(
539 cx,
540 config,
541 deliver_here,
542 next_steps,
543 sender.clone(),
544 message,
545 seq,
546 &mut recv_state.last_seqs,
547 )?;
548 recv_state.seq = seq;
549 }
550 }
551 Ordering::Less => {
554 tracing::warn!(
555 "buffering out-of-order message with seq {} (last {}), expected {}: {:?}",
556 seq,
557 last_seq,
558 recv_state.seq,
559 message
560 );
561 recv_state.buffer.insert(
562 last_seq,
563 Buffered {
564 seq,
565 deliver_here,
566 next_steps,
567 message,
568 },
569 );
570 }
571 Ordering::Greater => {
573 tracing::warn!("received duplicate message with seq {}: {:?}", seq, message);
574 }
575 }
576
577 Ok(())
578 }
579}
580
581#[async_trait]
582impl Handler<CastMessageV1> for CommActor {
583 async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessageV1) -> Result<()> {
584 let slice = cast_message.dest_region.slice().clone();
585 let frame = RoutingFrame::root(sel!(*), slice);
586 let forward_message = ForwardMessageV1 {
587 dests: vec![frame],
588 message: cast_message,
589 };
590 self.handle(cx, forward_message).await
591 }
592}
593
594#[async_trait]
595impl Handler<ForwardMessageV1> for CommActor {
596 async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessageV1) -> Result<()> {
597 let config = match &mut self.mesh_config {
598 MeshConfigState::NotConfigured(pending) => {
599 pending.push(PendingMessage::ForwardV1(fwd_message));
600 return Ok(());
601 }
602 MeshConfigState::Configured(config) => config,
603 };
604
605 let ForwardMessageV1 { dests, mut message } = fwd_message;
606 let rank_on_root_mesh = config.self_rank();
608 let (deliver_here, next_steps) =
609 ndslice::selection::routing::resolve_routing(rank_on_root_mesh, dests, &mut |_| {
610 panic!("choice encountered in CommActor routing")
611 })?;
612
613 split_ports(cx, &mut message.data, deliver_here, &next_steps)?;
614
615 if deliver_here {
617 let mut headers = message.headers().clone();
618 let seq = message
619 .seqs
620 .get(message.cast_point(config)?.rank())
621 .expect("mismatched seqs and dest_region");
622 headers.set(
623 SEQ_INFO,
624 SeqInfo::Session {
625 session_id: message.session_id,
626 seq,
627 },
628 );
629 Self::deliver_to_dest(cx, headers, &mut message, config)?;
630 }
631
632 for (peer_rank_on_root_mesh, dests) in next_steps {
634 let forward_message = ForwardMessageV1 {
635 dests,
636 message: message.clone(),
637 };
638 Self::forward(cx, config, peer_rank_on_root_mesh, forward_message)?;
639 }
640
641 Ok(())
642 }
643}
644
645pub mod test_utils {
646 use anyhow::Result;
647 use async_trait::async_trait;
648 use hyperactor::Actor;
649 use hyperactor::Bind;
650 use hyperactor::Context;
651 use hyperactor::Handler;
652 use hyperactor::Unbind;
653 use hyperactor::reference as hyperactor_reference;
654 use serde::Deserialize;
655 use serde::Serialize;
656 use typeuri::Named;
657
658 use super::*;
659
660 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
661 pub struct MyReply {
662 pub sender: hyperactor_reference::ActorId,
663 pub value: u64,
664 }
665
666 #[derive(Debug, Named, Serialize, Deserialize, PartialEq, Clone, Bind, Unbind)]
667 pub enum TestMessage {
668 Forward(String),
669 CastAndReply {
670 arg: String,
671 reply_to0: hyperactor_reference::PortRef<String>,
675 #[binding(include)]
676 reply_to1: hyperactor_reference::PortRef<u64>,
677 #[binding(include)]
678 reply_to2: hyperactor_reference::PortRef<MyReply>,
679 },
680 CastAndReplyOnce {
681 arg: String,
682 #[binding(include)]
683 reply_to: hyperactor::reference::OncePortRef<u64>,
684 },
685 CastWithUnsplitPort {
686 #[binding(include)]
687 reply_to: hyperactor_reference::PortRef<u64>,
688 },
689 }
690
691 #[derive(Debug)]
692 #[hyperactor::export(
693 spawn = true,
694 handlers = [
695 TestMessage { cast = true },
696 ],
697 )]
698 pub struct TestActor {
699 forward_port: hyperactor_reference::PortRef<TestMessage>,
702 }
703
704 #[derive(Debug, Clone, Named, Serialize, Deserialize)]
705 pub struct TestActorParams {
706 pub forward_port: hyperactor_reference::PortRef<TestMessage>,
707 }
708
709 #[async_trait]
710 impl Actor for TestActor {}
711
712 #[async_trait]
713 impl hyperactor::RemoteSpawn for TestActor {
714 type Params = TestActorParams;
715
716 async fn new(params: Self::Params, _environment: Flattrs) -> Result<Self> {
717 let Self::Params { forward_port } = params;
718 Ok(Self { forward_port })
719 }
720 }
721
722 #[async_trait]
723 impl Handler<TestMessage> for TestActor {
724 async fn handle(&mut self, cx: &Context<Self>, msg: TestMessage) -> anyhow::Result<()> {
725 if let TestMessage::CastWithUnsplitPort { ref reply_to } = msg {
728 reply_to.send(cx, 42)?;
729 }
730 self.forward_port.send(cx, msg)?;
731 Ok(())
732 }
733 }
734}
735
736#[cfg(test)]
737mod tests {
738 use std::collections::BTreeMap;
739 use std::collections::HashSet;
740 use std::fmt::Display;
741 use std::hash::Hash;
742 use std::ops::Deref;
743 use std::ops::DerefMut;
744 use std::sync::Mutex;
745 use std::sync::OnceLock;
746
747 use hyperactor::accum;
748
749 async fn buffering_fixture(
754 proc_name: &str,
755 ) -> (
756 Instance<()>,
757 hyperactor::mailbox::PortReceiver<TestMessage>,
758 hyperactor::ActorHandle<CommActor>,
759 crate::Name,
760 (
762 hyperactor::ActorHandle<()>,
763 hyperactor::ActorHandle<TestActor>,
764 hyperactor_reference::ActorRef<TestActor>,
765 ),
766 ) {
767 use hyperactor::Proc;
768 use hyperactor::RemoteSpawn;
769 use hyperactor::channel::ChannelTransport;
770
771 let proc = Proc::direct(ChannelTransport::Unix.any(), proc_name.to_string()).unwrap();
772 let (client, client_handle) = proc.instance("client").unwrap();
773
774 let actor_mesh_name = crate::Name::new("test").unwrap();
775 let actor_name = actor_mesh_name.to_string();
776
777 let (tx, rx) = open_port(&client);
778 let forward_port = tx.bind();
779 let test_actor = TestActor::new(TestActorParams { forward_port }, Default::default())
780 .await
781 .unwrap();
782 let test_handle = proc.spawn(&actor_name, test_actor).unwrap();
783 let test_ref: hyperactor_reference::ActorRef<TestActor> = test_handle.bind::<TestActor>();
784
785 let comm_handle = proc.spawn("comm", CommActor::default()).unwrap();
786
787 (
788 client,
789 rx,
790 comm_handle,
791 actor_mesh_name,
792 (client_handle, test_handle, test_ref),
793 )
794 }
795
796 fn send_config(client: &Instance<()>, comm_handle: &hyperactor::ActorHandle<CommActor>) {
798 let comm_ref = comm_handle.bind::<CommActor>();
799 let mut peers = HashMap::new();
800 peers.insert(0, comm_ref);
801 comm_handle
802 .send(client, CommMeshConfig::new(0, peers))
803 .unwrap();
804 }
805
806 async fn assert_buffered_and_replayed<M: hyperactor::Message>(
809 proc_name: &str,
810 mut make_msg: impl FnMut(&Instance<()>, &crate::Name, &str) -> M,
811 ) where
812 CommActor: hyperactor::Handler<M>,
813 {
814 let (client, mut rx, comm_handle, actor_mesh_name, _guards) =
815 buffering_fixture(proc_name).await;
816
817 comm_handle
818 .send(&client, make_msg(&client, &actor_mesh_name, "buffered"))
819 .unwrap();
820 send_config(&client, &comm_handle);
821 comm_handle
822 .send(&client, make_msg(&client, &actor_mesh_name, "direct"))
823 .unwrap();
824
825 assert_eq!(
826 rx.recv().await.unwrap(),
827 TestMessage::Forward("buffered".to_string()),
828 );
829 assert_eq!(
830 rx.recv().await.unwrap(),
831 TestMessage::Forward("direct".to_string()),
832 );
833 comm_handle.drain_and_stop("test done").ok();
834 }
835
836 #[async_timed_test(timeout_secs = 1)]
837 async fn cast_before_config_is_buffered_and_replayed() {
838 use ndslice::Slice;
839
840 assert_buffered_and_replayed("test_cast", |client, name, payload| {
841 let actor_mesh_id = crate::reference::ActorMeshId(name.clone());
842 let slice = Slice::new_row_major(vec![1]);
843 let shape = ndslice::Shape::new(vec!["rank".to_string()], slice.clone()).unwrap();
844 let envelope = multicast::CastMessageEnvelope::new::<TestActor, TestMessage>(
845 actor_mesh_id,
846 client.self_id().clone(),
847 shape,
848 hyperactor_config::Flattrs::new(),
849 TestMessage::Forward(payload.to_string()),
850 )
851 .unwrap();
852 multicast::CastMessage {
853 dest: multicast::Uslice {
854 slice,
855 selection: sel!(*),
856 },
857 message: envelope,
858 }
859 })
860 .await;
861 }
862
863 #[async_timed_test(timeout_secs = 1)]
864 async fn forward_before_config_is_buffered_and_replayed() {
865 use ndslice::Slice;
866 use ndslice::selection::routing::RoutingFrame;
867
868 let mut next_seq: usize = 0;
869 assert_buffered_and_replayed("test_fwd", move |client, name, payload| {
870 let actor_mesh_id = crate::reference::ActorMeshId(name.clone());
871 let slice = Slice::new_row_major(vec![1]);
872 let shape = ndslice::Shape::new(vec!["rank".to_string()], slice.clone()).unwrap();
873 let envelope = multicast::CastMessageEnvelope::new::<TestActor, TestMessage>(
874 actor_mesh_id,
875 client.self_id().clone(),
876 shape,
877 hyperactor_config::Flattrs::new(),
878 TestMessage::Forward(payload.to_string()),
879 )
880 .unwrap();
881 let frame = RoutingFrame::root(sel!(*), slice);
882 let last_seq = next_seq;
883 next_seq += 1;
884 multicast::ForwardMessage {
885 sender: client.self_id().clone(),
886 dests: vec![frame],
887 seq: next_seq,
888 last_seq,
889 message: envelope,
890 }
891 })
892 .await;
893 }
894
895 #[async_timed_test(timeout_secs = 1)]
896 async fn forward_v1_before_config_is_buffered_and_replayed() {
897 use ndslice::Region;
898 use ndslice::Slice;
899 use ndslice::selection::routing::RoutingFrame;
900
901 assert_buffered_and_replayed("test_fwd_v1", |client, name, payload| {
902 let slice = Slice::new_row_major(vec![1]);
903 let region = Region::new(vec!["rank".to_string()], slice.clone());
904 let cast_msg = multicast::CastMessageV1::new::<TestActor, TestMessage>(
905 client.self_id().clone(),
906 name,
907 region.clone(),
908 hyperactor_config::Flattrs::new(),
909 TestMessage::Forward(payload.to_string()),
910 uuid::Uuid::new_v4(),
911 crate::ValueMesh::from_single(region, 0u64),
912 )
913 .unwrap();
914 let frame = RoutingFrame::root(sel!(*), slice);
915 multicast::ForwardMessageV1 {
916 dests: vec![frame],
917 message: cast_msg,
918 }
919 })
920 .await;
921 }
922
923 use hyperactor::accum::Accumulator;
924 use hyperactor::accum::ReducerSpec;
925 use hyperactor::context;
926 use hyperactor::context::Mailbox;
927 use hyperactor::mailbox::PortReceiver;
928 use hyperactor::mailbox::open_port;
929 use hyperactor::reference as hyperactor_reference;
930 use hyperactor_config;
931 use hyperactor_mesh_macros::sel;
932 use maplit::btreemap;
933 use maplit::hashmap;
934 use ndslice::Extent;
935 use ndslice::Selection;
936 use ndslice::ViewExt as _;
937 use ndslice::extent;
938 use ndslice::selection::test_utils::collect_commactor_routing_tree;
939 use test_utils::*;
940 use timed_test::async_timed_test;
941 use tokio::time::Duration;
942
943 use super::*;
944 use crate::ActorMesh;
945 use crate::Name;
946 use crate::host_mesh::HostMesh;
947 use crate::test_utils::local_host_mesh;
948 use crate::testing;
949
950 fn lookup_rank(
952 actor_id: &hyperactor::reference::ActorId,
953 rank_lookup: &HashMap<hyperactor_reference::ProcId, usize>,
954 ) -> usize {
955 let proc_id = actor_id.proc_id();
956 *rank_lookup
957 .get(proc_id)
958 .unwrap_or_else(|| panic!("proc rank not found for {}", proc_id))
959 }
960
961 struct Edge<T> {
962 from: T,
963 to: T,
964 is_leaf: bool,
965 }
966
967 impl<T> From<(T, T, bool)> for Edge<T> {
968 fn from((from, to, is_leaf): (T, T, bool)) -> Self {
969 Self { from, to, is_leaf }
970 }
971 }
972
973 static SPLIT_PORT_TREE: OnceLock<Mutex<Vec<Edge<hyperactor_reference::PortId>>>> =
976 OnceLock::new();
977
978 pub(crate) fn collect_split_port(
981 original: &hyperactor_reference::PortId,
982 split: &hyperactor_reference::PortId,
983 deliver_here: bool,
984 ) {
985 let mutex = SPLIT_PORT_TREE.get_or_init(|| Mutex::new(vec![]));
986 let mut tree = mutex.lock().unwrap();
987
988 tree.deref_mut().push(Edge {
989 from: original.clone(),
990 to: split.clone(),
991 is_leaf: deliver_here,
992 });
993 }
994
995 fn clear_collected_tree() {
999 if let Some(tree) = SPLIT_PORT_TREE.get() {
1000 let mut tree = tree.lock().unwrap();
1001 tree.clear();
1002 }
1003 }
1004
1005 #[derive(PartialEq)]
1009 struct PathToLeaves<T>(BTreeMap<T, Vec<T>>);
1010
1011 impl<T: Display> Debug for PathToLeaves<T> {
1013 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1014 fn vec_to_string<T: Display>(v: &[T]) -> String {
1015 v.iter()
1016 .map(ToString::to_string)
1017 .collect::<Vec<String>>()
1018 .join(", ")
1019 }
1020
1021 for (src, path) in &self.0 {
1022 write!(f, "{} -> {}\n", src, vec_to_string(path))?;
1023 }
1024 Ok(())
1025 }
1026 }
1027
1028 fn build_paths<T: Clone + Eq + Hash + Ord>(edges: &[Edge<T>]) -> PathToLeaves<T> {
1029 let mut child_parent_map = HashMap::new();
1030 let mut all_nodes = HashSet::new();
1031 let mut parents = HashSet::new();
1032 let mut children = HashSet::new();
1033 let mut dests = HashSet::new();
1034
1035 for Edge { from, to, is_leaf } in edges {
1037 child_parent_map.insert(to.clone(), from.clone());
1038 all_nodes.insert(from.clone());
1039 all_nodes.insert(to.clone());
1040 parents.insert(from.clone());
1041 children.insert(to.clone());
1042 if *is_leaf {
1043 dests.insert(to.clone());
1044 }
1045 }
1046
1047 let mut result = BTreeMap::new();
1049 for dest in dests {
1050 let mut path = vec![dest.clone()];
1051 let mut current = dest.clone();
1052 while let Some(parent) = child_parent_map.get(¤t) {
1053 path.push(parent.clone());
1054 current = parent.clone();
1055 }
1056 path.reverse();
1057 result.insert(dest, path);
1058 }
1059
1060 PathToLeaves(result)
1061 }
1062
1063 #[test]
1064 fn test_build_paths() {
1065 let edges: Vec<_> = [
1072 (0, 1, false),
1073 (1, 2, true),
1074 (1, 3, true),
1075 (0, 4, true),
1076 (4, 5, true),
1077 ]
1078 .into_iter()
1079 .map(|(from, to, is_leaf)| Edge { from, to, is_leaf })
1080 .collect();
1081
1082 let paths = build_paths(&edges);
1083
1084 let expected = btreemap! {
1085 2 => vec![0, 1, 2],
1086 3 => vec![0, 1, 3],
1087 4 => vec![0, 4],
1088 5 => vec![0, 4, 5],
1089 };
1090
1091 assert_eq!(paths.0, expected);
1092 }
1093
1094 fn get_ranks(
1114 paths: PathToLeaves<hyperactor_reference::PortId>,
1115 client_reply: &hyperactor_reference::PortId,
1116 rank_lookup: &HashMap<hyperactor_reference::ProcId, usize>,
1117 ) -> PathToLeaves<hyperactor_reference::Index> {
1118 let ranks = paths
1119 .0
1120 .into_iter()
1121 .map(|(dst, mut path)| {
1122 let first = path.remove(0);
1123 assert_eq!(&first, client_reply);
1125 assert!(dst.actor_id().name().contains("comm"));
1128 let actor_path = path
1129 .into_iter()
1130 .map(|p| {
1131 assert!(p.actor_id().name().contains("comm"));
1132 lookup_rank(p.actor_id(), rank_lookup)
1133 })
1134 .collect();
1135 (lookup_rank(dst.actor_id(), rank_lookup), actor_path)
1136 })
1137 .collect();
1138 PathToLeaves(ranks)
1139 }
1140
1141 struct NoneAccumulator;
1142
1143 impl Accumulator for NoneAccumulator {
1144 type State = u64;
1145 type Update = u64;
1146
1147 fn accumulate(
1148 &self,
1149 _state: &mut Self::State,
1150 _update: Self::Update,
1151 ) -> anyhow::Result<()> {
1152 unimplemented!()
1153 }
1154
1155 fn reducer_spec(&self) -> Option<ReducerSpec> {
1156 unimplemented!()
1157 }
1158 }
1159
1160 fn verify_split_port_paths(
1162 selection: &Selection,
1163 extent: &Extent,
1164 reply_port_ref1: &hyperactor_reference::PortRef<u64>,
1165 reply_port_ref2: &hyperactor_reference::PortRef<MyReply>,
1166 rank_lookup: &HashMap<hyperactor_reference::ProcId, usize>,
1167 ) {
1168 let sel_paths = PathToLeaves(
1170 collect_commactor_routing_tree(selection, &extent.to_slice())
1171 .delivered
1172 .into_iter()
1173 .collect(),
1174 );
1175
1176 let (reply1_paths, reply2_paths) = {
1178 let tree = SPLIT_PORT_TREE
1179 .get()
1180 .expect("not initialized; are Hosts in the same process as SPLIT_PORT_TREE?");
1181 let edges = tree.lock().unwrap();
1182 let (reply1, reply2): (BTreeMap<_, _>, BTreeMap<_, _>) = build_paths(&edges)
1183 .0
1184 .into_iter()
1185 .partition(|(_dst, path)| &path[0] == reply_port_ref1.port_id());
1186 (
1187 get_ranks(PathToLeaves(reply1), reply_port_ref1.port_id(), rank_lookup),
1188 get_ranks(PathToLeaves(reply2), reply_port_ref2.port_id(), rank_lookup),
1189 )
1190 };
1191
1192 assert_eq!(sel_paths, reply1_paths);
1194 assert_eq!(sel_paths, reply2_paths);
1195 }
1196
1197 async fn execute_cast_and_reply(
1198 ranks: Vec<hyperactor_reference::ActorRef<TestActor>>,
1199 instance: &impl context::Actor,
1200 mut reply1_rx: PortReceiver<u64>,
1201 mut reply2_rx: PortReceiver<MyReply>,
1202 reply_tos: Vec<(
1203 hyperactor_reference::PortRef<u64>,
1204 hyperactor_reference::PortRef<MyReply>,
1205 )>,
1206 ) {
1207 {
1209 for (rank, (dest_actor, (reply_to1, reply_to2))) in
1210 ranks.iter().zip(reply_tos.iter()).enumerate()
1211 {
1212 let rank_u64 = rank as u64;
1213 reply_to1.send(instance, rank_u64).unwrap();
1214 let my_reply = MyReply {
1215 sender: dest_actor.actor_id().clone(),
1216 value: rank_u64,
1217 };
1218 reply_to2.send(instance, my_reply.clone()).unwrap();
1219
1220 assert_eq!(reply1_rx.recv().await.unwrap(), rank_u64);
1221 assert_eq!(reply2_rx.recv().await.unwrap(), my_reply);
1222 }
1223 }
1224
1225 tracing::info!("the 1st updates from all dest actors were receivered by client");
1226
1227 {
1231 let n = 100;
1232 let mut expected2: HashMap<hyperactor_reference::ActorId, Vec<MyReply>> = hashmap! {};
1233 for (i, (dest_actor, (_reply_to1, reply_to2))) in
1234 ranks.iter().zip(reply_tos.iter()).enumerate()
1235 {
1236 let mut sent2 = vec![];
1237 for j in 0..n {
1238 let value = (i * 100 + j) as u64;
1239 let my_reply = MyReply {
1240 sender: dest_actor.actor_id().clone(),
1241 value,
1242 };
1243 reply_to2.send(instance, my_reply.clone()).unwrap();
1244 sent2.push(my_reply);
1245 }
1246 assert!(
1247 expected2
1248 .insert(dest_actor.actor_id().clone(), sent2)
1249 .is_none(),
1250 "duplicate actor_id {} in map",
1251 dest_actor.actor_id()
1252 );
1253 }
1254
1255 let mut received2: HashMap<hyperactor_reference::ActorId, Vec<MyReply>> = hashmap! {};
1256
1257 for _ in 0..(n * ranks.len()) {
1258 let my_reply = reply2_rx.recv().await.unwrap();
1259 received2
1260 .entry(my_reply.sender.clone())
1261 .or_default()
1262 .push(my_reply);
1263 }
1264 assert_eq!(received2, expected2);
1265 }
1266 }
1267
1268 async fn wait_for_with_timeout(
1269 receiver: &mut PortReceiver<u64>,
1270 expected: u64,
1271 dur: Duration,
1272 ) -> anyhow::Result<()> {
1273 tokio::time::timeout(dur, async {
1275 loop {
1276 let msg = receiver.recv().await.unwrap();
1277 if msg == expected {
1278 break;
1279 }
1280 }
1281 })
1282 .await?;
1283 Ok(())
1284 }
1285
1286 async fn execute_cast_and_accum(
1287 ranks: Vec<hyperactor_reference::ActorRef<TestActor>>,
1288 instance: &impl context::Actor,
1289 mut reply1_rx: PortReceiver<u64>,
1290 reply_tos: Vec<(
1291 hyperactor_reference::PortRef<u64>,
1292 hyperactor_reference::PortRef<MyReply>,
1293 )>,
1294 ) {
1295 let mut sum = 0;
1299 let n = 100;
1300 for (i, (_dest_actor, (reply_to1, _reply_to2))) in
1301 ranks.iter().zip(reply_tos.iter()).enumerate()
1302 {
1303 for j in 0..n {
1304 let value = (i + j) as u64;
1305 reply_to1.send(instance, value).unwrap();
1306 sum += value;
1307 }
1308 }
1309 wait_for_with_timeout(&mut reply1_rx, sum, Duration::from_secs(8))
1310 .await
1311 .unwrap();
1312 tokio::time::sleep(Duration::from_secs(2)).await;
1314 let msg = reply1_rx.try_recv().unwrap();
1315 assert_eq!(msg, None);
1316 }
1317
1318 struct MeshSetupV1 {
1319 instance: &'static Instance<testing::TestRootClient>,
1320 actor_mesh_ref: crate::ActorMeshRef<TestActor>,
1321 reply1_rx: PortReceiver<u64>,
1322 reply2_rx: PortReceiver<MyReply>,
1323 reply_tos: Vec<(
1324 hyperactor_reference::PortRef<u64>,
1325 hyperactor_reference::PortRef<MyReply>,
1326 )>,
1327 host_mesh: HostMesh,
1329 }
1330
1331 async fn setup_mesh_v1<A>(accum: Option<A>) -> MeshSetupV1
1332 where
1333 A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
1334 {
1335 let instance = crate::testing::instance();
1336 let host_mesh = local_host_mesh(8).await;
1339 let proc_mesh = host_mesh
1340 .spawn(instance, "test", extent!(gpu = 8), None)
1341 .await
1342 .unwrap();
1343
1344 let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1345 let params = TestActorParams {
1346 forward_port: tx.bind(),
1347 };
1348 let actor_name = crate::Name::new("test").expect("valid test name");
1349 let actor_mesh = proc_mesh
1353 .spawn_with_name(&instance, actor_name, ¶ms, None, true)
1354 .await
1355 .unwrap();
1356 let actor_mesh_ref: crate::ActorMeshRef<TestActor> = actor_mesh.deref().clone();
1357
1358 let (reply_port_handle0, _) = open_port::<String>(instance);
1359 let reply_port_ref0 = reply_port_handle0.bind();
1360 let (reply_port_handle1, reply1_rx) = match accum {
1361 Some(a) => instance.mailbox().open_accum_port(a),
1362 None => open_port(instance),
1363 };
1364 let reply_port_ref1 = reply_port_handle1.bind();
1365 let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(instance);
1366 let reply_port_ref2 = reply_port_handle2.bind();
1367 let message = TestMessage::CastAndReply {
1368 arg: "abc".to_string(),
1369 reply_to0: reply_port_ref0.clone(),
1370 reply_to1: reply_port_ref1.clone(),
1371 reply_to2: reply_port_ref2.clone(),
1372 };
1373
1374 clear_collected_tree();
1375 actor_mesh_ref.cast(instance, message).unwrap();
1376
1377 let mut reply_tos = vec![];
1378 for _ in proc_mesh.extent().points() {
1379 let msg = rx.recv().await.expect("missing");
1380 match msg {
1381 TestMessage::CastAndReply {
1382 arg,
1383 reply_to0,
1384 reply_to1,
1385 reply_to2,
1386 } => {
1387 assert_eq!(arg, "abc");
1388 assert_eq!(reply_to0, reply_port_ref0);
1391 assert_ne!(reply_to1, reply_port_ref1);
1393 assert!(reply_to1.port_id().actor_id().name().contains("comm"));
1394 assert_ne!(reply_to2, reply_port_ref2);
1395 assert!(reply_to2.port_id().actor_id().name().contains("comm"));
1396 reply_tos.push((reply_to1, reply_to2));
1397 }
1398 _ => {
1399 panic!("unexpected message: {:?}", msg);
1400 }
1401 }
1402 }
1403
1404 let rank_lookup = proc_mesh
1407 .ranks()
1408 .iter()
1409 .enumerate()
1410 .map(|(i, r)| (r.proc_id().clone(), i))
1411 .collect::<HashMap<hyperactor_reference::ProcId, usize>>();
1412
1413 let selection = sel!(*);
1415 verify_split_port_paths(
1416 &selection,
1417 &proc_mesh.extent(),
1418 &reply_port_ref1,
1419 &reply_port_ref2,
1420 &rank_lookup,
1421 );
1422
1423 MeshSetupV1 {
1424 instance,
1425 actor_mesh_ref,
1426 reply1_rx,
1427 reply2_rx,
1428 reply_tos,
1429 host_mesh,
1430 }
1431 }
1432
1433 async fn execute_cast_and_reply_v1() {
1434 let mut setup = setup_mesh_v1::<NoneAccumulator>(None).await;
1435
1436 let ranks = setup.actor_mesh_ref.values().collect::<Vec<_>>();
1437 execute_cast_and_reply(
1438 ranks,
1439 setup.instance,
1440 setup.reply1_rx,
1441 setup.reply2_rx,
1442 setup.reply_tos,
1443 )
1444 .await;
1445
1446 let _ = setup.host_mesh.shutdown(setup.instance).await;
1447 }
1448
1449 #[async_timed_test(timeout_secs = 60)]
1450 async fn test_cast_and_reply_v1_retrofit() {
1451 let config = hyperactor_config::global::lock();
1452 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false);
1453 let _guard2 = config.override_key(
1454 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1455 false,
1456 );
1457 execute_cast_and_reply_v1().await
1458 }
1459
1460 #[async_timed_test(timeout_secs = 60)]
1461 async fn test_cast_and_reply_v1_native() {
1462 let config = hyperactor_config::global::lock();
1463 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1464 let _guard2 = config.override_key(
1465 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1466 true,
1467 );
1468 execute_cast_and_reply_v1().await
1469 }
1470
1471 async fn execute_cast_and_accum_v1(config: &hyperactor_config::global::ConfigLock) {
1472 let _guard1 = config.override_key(hyperactor::config::SPLIT_MAX_BUFFER_SIZE, 1);
1474
1475 let mut setup = setup_mesh_v1(Some(accum::sum::<u64>())).await;
1476
1477 let ranks = setup.actor_mesh_ref.values().collect::<Vec<_>>();
1478 execute_cast_and_accum(ranks, setup.instance, setup.reply1_rx, setup.reply_tos).await;
1479
1480 let _ = setup.host_mesh.shutdown(setup.instance).await;
1481 }
1482
1483 #[async_timed_test(timeout_secs = 60)]
1484 async fn test_cast_and_accum_v1_retrofit() {
1485 let config = hyperactor_config::global::lock();
1486 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false);
1487 let _guard2 = config.override_key(
1488 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1489 false,
1490 );
1491 execute_cast_and_accum_v1(&config).await
1492 }
1493
1494 #[async_timed_test(timeout_secs = 60)]
1495 async fn test_cast_and_accum_v1_native() {
1496 let config = hyperactor_config::global::lock();
1497 let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1498 let _guard2 = config.override_key(
1499 hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1500 true,
1501 );
1502 execute_cast_and_accum_v1(&config).await
1503 }
1504
1505 struct OncePortMeshSetupV1 {
1506 instance: &'static Instance<testing::TestRootClient>,
1507 reply_rx: hyperactor::mailbox::OncePortReceiver<u64>,
1508 reply_tos: Vec<hyperactor::reference::OncePortRef<u64>>,
1509 _reply_port_ref: hyperactor::reference::OncePortRef<u64>,
1510 host_mesh: HostMesh,
1511 }
1512
1513 async fn setup_once_port_mesh<A>(reducer: Option<A>) -> OncePortMeshSetupV1
1514 where
1515 A: Accumulator<State = u64, Update = u64> + Send + Sync + 'static,
1516 {
1517 let instance = crate::testing::instance();
1518 let host_mesh = local_host_mesh(8).await;
1521 let proc_mesh = host_mesh
1522 .spawn(instance, "test", extent!(gpu = 8), None)
1523 .await
1524 .unwrap();
1525
1526 let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1527 let params = TestActorParams {
1528 forward_port: tx.bind(),
1529 };
1530 let actor_name = crate::Name::new("test").expect("valid test name");
1531 let actor_mesh: crate::ActorMesh<TestActor> = proc_mesh
1533 .spawn_with_name(&instance, actor_name, ¶ms, None, true)
1534 .await
1535 .unwrap();
1536 let actor_mesh_ref = actor_mesh.deref().clone();
1537
1538 let has_reducer = reducer.is_some();
1539 let (reply_port_handle, reply_rx) = match reducer {
1540 Some(reducer) => instance.mailbox().open_reduce_port(reducer),
1541 None => instance.mailbox().open_once_port::<u64>(),
1542 };
1543 let reply_port_ref = reply_port_handle.bind();
1544
1545 let message = TestMessage::CastAndReplyOnce {
1546 arg: "abc".to_string(),
1547 reply_to: reply_port_ref.clone(),
1548 };
1549
1550 clear_collected_tree();
1551 actor_mesh_ref.cast(instance, message).unwrap();
1552
1553 let mut reply_tos = vec![];
1554 for _ in proc_mesh.extent().points() {
1555 let msg = rx.recv().await.expect("missing");
1556 match msg {
1557 TestMessage::CastAndReplyOnce { arg, reply_to } => {
1558 assert_eq!(arg, "abc");
1559 if has_reducer {
1560 assert_ne!(reply_to, reply_port_ref);
1562 assert!(reply_to.port_id().actor_id().name().contains("comm"));
1563 } else {
1564 assert_eq!(reply_to, reply_port_ref);
1566 }
1567 reply_tos.push(reply_to);
1568 }
1569 _ => {
1570 panic!("unexpected message: {:?}", msg);
1571 }
1572 }
1573 }
1574
1575 OncePortMeshSetupV1 {
1576 instance,
1577 reply_rx,
1578 reply_tos,
1579 _reply_port_ref: reply_port_ref,
1580 host_mesh,
1581 }
1582 }
1583
1584 #[async_timed_test(timeout_secs = 60)]
1585 async fn test_cast_and_reply_once_v1() {
1586 let mut setup = setup_once_port_mesh::<NoneAccumulator>(None).await;
1590
1591 let num_replies = setup.reply_tos.len();
1594 for (i, reply_to) in setup.reply_tos.into_iter().enumerate() {
1595 reply_to.send(setup.instance, i as u64).unwrap();
1596 }
1597
1598 let result = setup.reply_rx.recv().await.unwrap();
1600 assert!(result < num_replies as u64);
1602
1603 let _ = setup.host_mesh.shutdown(setup.instance).await;
1604 }
1605
1606 #[async_timed_test(timeout_secs = 60)]
1607 async fn test_cast_and_accum_once_v1() {
1608 let mut setup = setup_once_port_mesh(Some(accum::sum::<u64>())).await;
1612
1613 let mut expected_sum = 0u64;
1615 for (i, reply_to) in setup.reply_tos.into_iter().enumerate() {
1616 reply_to.send(setup.instance, i as u64).unwrap();
1617 expected_sum += i as u64;
1618 }
1619
1620 let result = setup.reply_rx.recv().await.unwrap();
1622 assert_eq!(result, expected_sum);
1623
1624 let _ = setup.host_mesh.shutdown(setup.instance).await;
1625 }
1626
1627 #[async_timed_test(timeout_secs = 60)]
1628 async fn test_unsplit_port_not_split() {
1629 let instance = crate::testing::instance();
1630 let mut host_mesh = local_host_mesh(8).await;
1631 let proc_mesh = host_mesh
1632 .spawn(instance, "test", extent!(gpu = 8), None)
1633 .await
1634 .unwrap();
1635
1636 let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1637 let params = TestActorParams {
1638 forward_port: tx.bind(),
1639 };
1640 let actor_name = Name::new("test").expect("valid test name");
1641 let actor_mesh: ActorMesh<TestActor> = proc_mesh
1642 .spawn_with_name(&instance, actor_name, ¶ms, None, true)
1643 .await
1644 .unwrap();
1645 let (reply_port_handle, mut reply_rx) = open_port::<u64>(instance);
1646 let reply_port_ref = reply_port_handle.bind().unsplit();
1647
1648 let message = TestMessage::CastWithUnsplitPort {
1649 reply_to: reply_port_ref.clone(),
1650 };
1651
1652 clear_collected_tree();
1653 actor_mesh.cast(instance, message).unwrap();
1654
1655 let num_points = proc_mesh.extent().points().count();
1657 for _ in 0..num_points {
1658 let msg = rx.recv().await.expect("missing");
1659 match msg {
1660 TestMessage::CastWithUnsplitPort { reply_to } => {
1661 assert_eq!(
1662 reply_to.port_id(),
1663 reply_port_ref.port_id(),
1664 "unsplit port should not be replaced by a comm actor split port"
1665 );
1666 }
1667 _ => panic!("unexpected message: {:?}", msg),
1668 }
1669 }
1670
1671 for _ in 0..8 {
1674 let val = reply_rx.recv().await.unwrap();
1675 assert_eq!(val, 42);
1676 }
1677 let _ = host_mesh.shutdown(instance).await;
1678 }
1679}