hyperactor_mesh/
comm.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use crate::casting::CAST_ACTOR_MESH_ID;
10use crate::comm::multicast::CAST_ORIGINATING_SENDER;
11use crate::comm::multicast::CastEnvelope;
12use crate::comm::multicast::CastMessageV1;
13use crate::comm::multicast::ForwardMessageV1;
14use crate::reference::ActorMeshId;
15use crate::resource;
16pub mod multicast;
17
18use std::cmp::Ordering;
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22use anyhow::Result;
23use async_trait::async_trait;
24use hyperactor::Actor;
25use hyperactor::Context;
26use hyperactor::Handler;
27use hyperactor::Instance;
28use hyperactor::RemoteMessage;
29use hyperactor::accum::ReducerMode;
30use hyperactor::mailbox::DeliveryError;
31use hyperactor::mailbox::MailboxSender;
32use hyperactor::mailbox::Undeliverable;
33use hyperactor::mailbox::UndeliverableMailboxSender;
34use hyperactor::mailbox::UndeliverableMessageError;
35use hyperactor::mailbox::monitored_return_handle;
36use hyperactor::message::ErasedUnbound;
37use hyperactor::ordering::SEQ_INFO;
38use hyperactor::ordering::SeqInfo;
39use hyperactor::reference as hyperactor_reference;
40use hyperactor_config::CONFIG;
41use hyperactor_config::ConfigAttr;
42use hyperactor_config::Flattrs;
43use hyperactor_config::attrs::declare_attrs;
44use hyperactor_mesh_macros::sel;
45use ndslice::Point;
46use ndslice::Selection;
47use ndslice::View;
48use ndslice::selection::routing::RoutingFrame;
49use serde::Deserialize;
50use serde::Serialize;
51use typeuri::Named;
52
53use crate::comm::multicast::CastMessage;
54use crate::comm::multicast::CastMessageEnvelope;
55use crate::comm::multicast::ForwardMessage;
56use crate::comm::multicast::set_cast_info_on_headers;
57
58declare_attrs! {
59    /// Whether to use native v1 casting in v1 ActorMesh.
60    @meta(CONFIG = ConfigAttr::new(
61        Some("HYPERACTOR_MESH_ENABLE_NATIVE_V1_CASTING".to_string()),
62        Some("enable_native_v1_casting".to_string()),
63    ))
64    pub attr ENABLE_NATIVE_V1_CASTING: bool = false;
65}
66
67/// Parameters to initialize the CommActor
68#[derive(Debug, Clone, Serialize, Deserialize, Named, Default)]
69pub struct CommActorParams {}
70wirevalue::register_type!(CommActorParams);
71
72/// A message buffered due to out-of-order delivery.
73#[derive(Debug)]
74struct Buffered {
75    /// Sequence number of this message.
76    seq: usize,
77    /// Whether to deliver this message to this comm-actors actors.
78    deliver_here: bool,
79    /// Peer comm actors to forward message to.
80    next_steps: HashMap<usize, Vec<RoutingFrame>>,
81    /// The message to deliver.
82    message: CastMessageEnvelope,
83}
84
85/// Bookkeeping to handle sequence numbers and in-order delivery for messages
86/// sent to and through this comm actor.
87#[derive(Debug, Default)]
88struct ReceiveState {
89    /// The sequence of the last received message.
90    seq: usize,
91    /// A buffer storing messages we received out-of-order, indexed by the seq
92    /// that should precede it.
93    buffer: HashMap<usize, Buffered>,
94    /// A map of the last sequence number we sent to next steps, indexed by rank.
95    last_seqs: HashMap<usize, usize>,
96}
97
98/// This is the comm actor used for efficient and scalable message multicasting
99/// and result accumulation.
100#[derive(Debug, Default)]
101#[hyperactor::export(
102    spawn = true,
103    handlers = [
104        CommMeshConfig,
105        CastMessage,
106        ForwardMessage,
107        CastMessageV1,
108        ForwardMessageV1,
109    ],
110)]
111pub struct CommActor {
112    /// Sequence numbers are maintained for each (actor mesh id, sender).
113    send_seq: HashMap<(ActorMeshId, hyperactor_reference::ActorId), usize>,
114    /// Each sender is a unique stream.
115    recv_state: HashMap<(ActorMeshId, hyperactor_reference::ActorId), ReceiveState>,
116
117    /// The comm actor's mesh configuration, or buffered messages if not yet configured.
118    mesh_config: MeshConfigState,
119}
120
121#[derive(Debug)]
122enum PendingMessage {
123    Cast(CastMessage),
124    Forward(ForwardMessage),
125    ForwardV1(ForwardMessageV1),
126}
127
128#[derive(Debug)]
129enum MeshConfigState {
130    /// Config not yet received; buffer incoming messages until it arrives.
131    NotConfigured(Vec<PendingMessage>),
132    /// Config received; ready to route messages.
133    Configured(CommMeshConfig),
134}
135
136impl Default for MeshConfigState {
137    fn default() -> Self {
138        MeshConfigState::NotConfigured(Vec::new())
139    }
140}
141
142/// Configuration for how a `CommActor` determines its own rank and locates peers.
143#[derive(Debug, Clone, Serialize, Deserialize, Named)]
144pub struct CommMeshConfig {
145    /// The rank of this comm actor on the root mesh.
146    rank: usize,
147    /// Key is the rank of the peer on the root mesh. Value is the peer's comm actor.
148    peers: HashMap<usize, hyperactor_reference::ActorRef<CommActor>>,
149}
150wirevalue::register_type!(CommMeshConfig);
151
152impl CommMeshConfig {
153    /// Create a new mesh configuration with the given rank and peer mapping.
154    pub fn new(
155        rank: usize,
156        peers: HashMap<usize, hyperactor_reference::ActorRef<CommActor>>,
157    ) -> Self {
158        Self { rank, peers }
159    }
160
161    /// Return the peer comm actor for the given rank.
162    fn peer_for_rank(&self, rank: usize) -> Result<hyperactor_reference::ActorRef<CommActor>> {
163        self.peers
164            .get(&rank)
165            .cloned()
166            .ok_or_else(|| anyhow::anyhow!("no peer for rank {}", rank))
167    }
168
169    /// Return the rank of the comm actor.
170    fn self_rank(&self) -> usize {
171        self.rank
172    }
173}
174
175#[async_trait]
176impl Actor for CommActor {
177    async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
178        this.set_system();
179        Ok(())
180    }
181
182    // This is an override of the default actor behavior.
183    async fn handle_undeliverable_message(
184        &mut self,
185        cx: &Instance<Self>,
186        undelivered: hyperactor::mailbox::Undeliverable<hyperactor::mailbox::MessageEnvelope>,
187    ) -> Result<(), anyhow::Error> {
188        let Undeliverable(mut message_envelope) = undelivered;
189
190        // 1. Case delivery failure at a "forwarding" step.
191        if let Ok(ForwardMessage { message, .. }) =
192            message_envelope.deserialized::<ForwardMessage>()
193        {
194            let sender = message.sender();
195            let return_port = hyperactor_reference::PortRef::attest_message_port(sender);
196            message_envelope.set_error(DeliveryError::Multicast(format!(
197                "comm actor {} failed to forward the cast message; returning to origin {}",
198                cx.self_id(),
199                return_port.port_id(),
200            )));
201
202            // Needed so that the receiver of the undeliverable message can easily find the
203            // original sender of the cast message.
204            message_envelope.set_header(CAST_ORIGINATING_SENDER, sender.clone());
205
206            return_port
207                .send(cx, Undeliverable(message_envelope.clone()))
208                .map_err(|err| {
209                    let error = DeliveryError::BrokenLink(format!(
210                        "error occured when returning ForwardMessage to the original \
211                        sender's port {}; error is: {}",
212                        return_port.port_id(),
213                        err,
214                    ));
215                    message_envelope.set_error(error);
216                    UndeliverableMessageError::ReturnFailure {
217                        envelope: message_envelope,
218                    }
219                })?;
220            return Ok(());
221        }
222
223        // 2. Case delivery failure at a "deliver here" step.
224        if let Some(sender) = message_envelope.headers().get(CAST_ORIGINATING_SENDER) {
225            let return_port = hyperactor_reference::PortRef::attest_message_port(&sender);
226            message_envelope.set_error(DeliveryError::Multicast(format!(
227                "comm actor {} failed to deliver the cast message to the dest \
228                actor; returning to origin {}",
229                cx.self_id(),
230                return_port.port_id(),
231            )));
232            return_port
233                .send(cx, Undeliverable(message_envelope.clone()))
234                .map_err(|err| {
235                    let error = DeliveryError::BrokenLink(format!(
236                        "error occured when returning cast message to the origin \
237                        sender {}; error is: {}",
238                        return_port.port_id(),
239                        err,
240                    ));
241                    message_envelope.set_error(error);
242                    UndeliverableMessageError::ReturnFailure {
243                        envelope: message_envelope,
244                    }
245                })?;
246            return Ok(());
247        }
248
249        // 3. A return of an undeliverable message was itself returned.
250        UndeliverableMailboxSender
251            .post(message_envelope, /*unused */ monitored_return_handle());
252        Ok(())
253    }
254}
255
256impl CommActor {
257    /// Forward the message to the comm actor on the given peer rank.
258    fn forward<M: RemoteMessage>(
259        cx: &Context<Self>,
260        config: &CommMeshConfig,
261        rank: usize,
262        message: M,
263    ) -> Result<()>
264    where
265        CommActor: hyperactor::RemoteHandles<M>,
266    {
267        let child = config.peer_for_rank(rank)?;
268        // TEMPORARY: until dropping v0 support
269        if let Some(cast_actor_mesh_id) = cx.headers().get(CAST_ACTOR_MESH_ID) {
270            let mut headers = Flattrs::new();
271            headers.set(CAST_ACTOR_MESH_ID, cast_actor_mesh_id);
272            child.send_with_headers(cx, headers, message)?;
273        } else {
274            child.send(cx, message)?;
275        }
276        Ok(())
277    }
278
279    fn handle_message(
280        cx: &Context<Self>,
281        config: &CommMeshConfig,
282        deliver_here: bool,
283        next_steps: HashMap<usize, Vec<RoutingFrame>>,
284        sender: hyperactor_reference::ActorId,
285        mut message: CastMessageEnvelope,
286        seq: usize,
287        last_seqs: &mut HashMap<usize, usize>,
288    ) -> Result<()> {
289        split_ports(cx, message.data_mut(), deliver_here, &next_steps)?;
290
291        // Deliver message here, if necessary.
292        if deliver_here {
293            // We should not copy cx.headers() because it contains auto-generated
294            // headers from mailbox. We want fresh headers only containing
295            // user-provided headers.
296            let headers = message.headers().clone();
297            Self::deliver_to_dest(cx, headers, &mut message, config)?;
298        }
299
300        // Forward to peers.
301        next_steps
302            .into_iter()
303            .map(|(peer, dests)| {
304                let last_seq = last_seqs.entry(peer).or_default();
305                Self::forward(
306                    cx,
307                    config,
308                    peer,
309                    ForwardMessage {
310                        dests,
311                        sender: sender.clone(),
312                        message: message.clone(),
313                        seq,
314                        last_seq: *last_seq,
315                    },
316                )?;
317                *last_seq = seq;
318                Ok(())
319            })
320            .collect::<Result<Vec<_>>>()?;
321
322        Ok(())
323    }
324
325    fn deliver_to_dest<M: CastEnvelope>(
326        cx: &Context<Self>,
327        mut headers: Flattrs,
328        message: &mut M,
329        config: &CommMeshConfig,
330    ) -> anyhow::Result<()> {
331        let cast_point = message.cast_point(config)?;
332        // Replace ranks with self ranks.
333        replace_with_self_ranks(&cast_point, message.data_mut())?;
334
335        set_cast_info_on_headers(&mut headers, cast_point, message.sender().clone());
336        cx.post_with_external_seq_info(
337            cx.self_id()
338                .proc_id()
339                .actor_id(message.dest_port().actor_name(), 0)
340                .port_id(message.dest_port().port()),
341            headers,
342            wirevalue::Any::serialize(message.data())?,
343        );
344
345        Ok(())
346    }
347}
348
349// Split ports, if any, and update message with new ports. In this
350// way, children actors will reply to this comm actor's ports, instead
351// of to the original ports provided by parent.
352fn split_ports(
353    cx: &Context<CommActor>,
354    data: &mut ErasedUnbound,
355    deliver_here: bool,
356    next_steps: &HashMap<usize, Vec<RoutingFrame>>,
357) -> anyhow::Result<()> {
358    // Split ports, if any, and update message with new ports. In this
359    // way, children actors will reply to this comm actor's ports, instead
360    // of to the original ports provided by parent.
361    data.visit_mut::<hyperactor_reference::UnboundPort>(
362        |hyperactor_reference::UnboundPort(
363            port_id,
364            reducer_spec,
365            return_undeliverable,
366            kind,
367            unsplit,
368        )| {
369            if *unsplit {
370                return Ok(());
371            }
372            let reducer_mode = match kind {
373                hyperactor_reference::UnboundPortKind::Streaming(opts) => {
374                    ReducerMode::Streaming(opts.clone().unwrap_or_default())
375                }
376                hyperactor_reference::UnboundPortKind::Once if reducer_spec.is_none() => {
377                    // We can only split OncePorts that have reducers.
378                    // Pass this through -- if it is used multiple times,
379                    // it will cause a delivery error downstream.
380                    // However we should reconsider this behavior
381                    // as it its semantics will now differ between
382                    // unicast and broadcast messages.
383                    return Ok(());
384                }
385                hyperactor_reference::UnboundPortKind::Once => {
386                    // Compute peer count for OncePort splitting. This is the number of
387                    // destinations the message will be delivered to, so that the split
388                    // port can correctly accumulate responses.
389                    let peer_count = next_steps.len() + if deliver_here { 1 } else { 0 };
390                    ReducerMode::Once(peer_count)
391                }
392            };
393
394            let split = port_id.split(
395                cx,
396                reducer_spec.clone(),
397                reducer_mode,
398                *return_undeliverable,
399            )?;
400
401            #[cfg(test)]
402            tests::collect_split_port(port_id, &split, deliver_here);
403
404            *port_id = split;
405            Ok(())
406        },
407    )
408}
409
410fn replace_with_self_ranks(cast_point: &Point, data: &mut ErasedUnbound) -> anyhow::Result<()> {
411    data.visit_mut::<resource::Rank>(|resource::Rank(rank)| {
412        *rank = Some(cast_point.rank());
413        Ok(())
414    })
415}
416
417#[async_trait]
418impl Handler<CommMeshConfig> for CommActor {
419    async fn handle(&mut self, cx: &Context<Self>, config: CommMeshConfig) -> Result<()> {
420        let pending =
421            match std::mem::replace(&mut self.mesh_config, MeshConfigState::Configured(config)) {
422                MeshConfigState::NotConfigured(pending) => pending,
423                MeshConfigState::Configured(_) => Vec::new(),
424            };
425        if !pending.is_empty() {
426            tracing::info!(
427                count = pending.len(),
428                "replaying buffered pre-config messages"
429            );
430        }
431        for msg in pending {
432            match msg {
433                PendingMessage::Cast(m) => self.handle(cx, m).await?,
434                PendingMessage::Forward(m) => self.handle(cx, m).await?,
435                PendingMessage::ForwardV1(m) => self.handle(cx, m).await?,
436            }
437        }
438        Ok(())
439    }
440}
441
442// TODO(T218630526): reliable casting for mutable topology
443#[async_trait]
444impl Handler<CastMessage> for CommActor {
445    #[tracing::instrument(level = "debug", skip_all)]
446    async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessage) -> Result<()> {
447        let config = match &mut self.mesh_config {
448            MeshConfigState::NotConfigured(pending) => {
449                pending.push(PendingMessage::Cast(cast_message));
450                return Ok(());
451            }
452            MeshConfigState::Configured(config) => config,
453        };
454        // Always forward the message to the root rank of the slice, casting starts from there.
455        let slice = cast_message.dest.slice.clone();
456        let selection = cast_message.dest.selection.clone();
457        let frame = RoutingFrame::root(selection, slice);
458        let rank = frame.slice.location(&frame.here)?;
459        let seq = self
460            .send_seq
461            .entry(cast_message.message.stream_key())
462            .or_default();
463        let last_seq = *seq;
464        *seq += 1;
465
466        let fwd_message = ForwardMessage {
467            dests: vec![frame],
468            sender: cx.self_id().clone(),
469            message: cast_message.message,
470            seq: *seq,
471            last_seq,
472        };
473
474        // Optimization: if forwarding to ourselves, handle inline instead of
475        // going through the message queue
476        if config.self_rank() == rank {
477            Handler::<ForwardMessage>::handle(self, cx, fwd_message).await?;
478        } else {
479            Self::forward(cx, config, rank, fwd_message)?;
480        }
481        Ok(())
482    }
483}
484
485#[async_trait]
486impl Handler<ForwardMessage> for CommActor {
487    #[tracing::instrument(level = "debug", skip_all)]
488    async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessage) -> Result<()> {
489        let config = match &mut self.mesh_config {
490            MeshConfigState::NotConfigured(pending) => {
491                pending.push(PendingMessage::Forward(fwd_message));
492                return Ok(());
493            }
494            MeshConfigState::Configured(config) => config,
495        };
496
497        let ForwardMessage {
498            sender,
499            dests,
500            message,
501            seq,
502            last_seq,
503        } = fwd_message;
504
505        // Resolve/dedup routing frames.
506        let rank = config.self_rank();
507        let (deliver_here, next_steps) =
508            ndslice::selection::routing::resolve_routing(rank, dests, &mut |_| {
509                panic!("Choice encountered in CommActor routing")
510            })?;
511
512        let recv_state = self.recv_state.entry(message.stream_key()).or_default();
513        match recv_state.seq.cmp(&last_seq) {
514            // We got the expected next message to deliver to this host.
515            Ordering::Equal => {
516                // We got an in-order operation, so handle it now.
517                Self::handle_message(
518                    cx,
519                    config,
520                    deliver_here,
521                    next_steps,
522                    sender.clone(),
523                    message,
524                    seq,
525                    &mut recv_state.last_seqs,
526                )?;
527                recv_state.seq = seq;
528
529                // Also deliver any pending operations from the recv buffer that
530                // were received out-of-order that are now unblocked.
531                while let Some(Buffered {
532                    seq,
533                    deliver_here,
534                    next_steps,
535                    message,
536                }) = recv_state.buffer.remove(&recv_state.seq)
537                {
538                    Self::handle_message(
539                        cx,
540                        config,
541                        deliver_here,
542                        next_steps,
543                        sender.clone(),
544                        message,
545                        seq,
546                        &mut recv_state.last_seqs,
547                    )?;
548                    recv_state.seq = seq;
549                }
550            }
551            // We got an out-of-order operation, so buffer it for now, until we
552            // recieved the onces sequenced before it.
553            Ordering::Less => {
554                tracing::warn!(
555                    "buffering out-of-order message with seq {} (last {}), expected {}: {:?}",
556                    seq,
557                    last_seq,
558                    recv_state.seq,
559                    message
560                );
561                recv_state.buffer.insert(
562                    last_seq,
563                    Buffered {
564                        seq,
565                        deliver_here,
566                        next_steps,
567                        message,
568                    },
569                );
570            }
571            // We already got this message -- just drop it.
572            Ordering::Greater => {
573                tracing::warn!("received duplicate message with seq {}: {:?}", seq, message);
574            }
575        }
576
577        Ok(())
578    }
579}
580
581#[async_trait]
582impl Handler<CastMessageV1> for CommActor {
583    async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessageV1) -> Result<()> {
584        let slice = cast_message.dest_region.slice().clone();
585        let frame = RoutingFrame::root(sel!(*), slice);
586        let forward_message = ForwardMessageV1 {
587            dests: vec![frame],
588            message: cast_message,
589        };
590        self.handle(cx, forward_message).await
591    }
592}
593
594#[async_trait]
595impl Handler<ForwardMessageV1> for CommActor {
596    async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessageV1) -> Result<()> {
597        let config = match &mut self.mesh_config {
598            MeshConfigState::NotConfigured(pending) => {
599                pending.push(PendingMessage::ForwardV1(fwd_message));
600                return Ok(());
601            }
602            MeshConfigState::Configured(config) => config,
603        };
604
605        let ForwardMessageV1 { dests, mut message } = fwd_message;
606        // Resolve/dedup routing frames.
607        let rank_on_root_mesh = config.self_rank();
608        let (deliver_here, next_steps) =
609            ndslice::selection::routing::resolve_routing(rank_on_root_mesh, dests, &mut |_| {
610                panic!("choice encountered in CommActor routing")
611            })?;
612
613        split_ports(cx, &mut message.data, deliver_here, &next_steps)?;
614
615        // Deliver message here, if necessary.
616        if deliver_here {
617            let mut headers = message.headers().clone();
618            let seq = message
619                .seqs
620                .get(message.cast_point(config)?.rank())
621                .expect("mismatched seqs and dest_region");
622            headers.set(
623                SEQ_INFO,
624                SeqInfo::Session {
625                    session_id: message.session_id,
626                    seq,
627                },
628            );
629            Self::deliver_to_dest(cx, headers, &mut message, config)?;
630        }
631
632        // Forward to peers.
633        for (peer_rank_on_root_mesh, dests) in next_steps {
634            let forward_message = ForwardMessageV1 {
635                dests,
636                message: message.clone(),
637            };
638            Self::forward(cx, config, peer_rank_on_root_mesh, forward_message)?;
639        }
640
641        Ok(())
642    }
643}
644
645pub mod test_utils {
646    use anyhow::Result;
647    use async_trait::async_trait;
648    use hyperactor::Actor;
649    use hyperactor::Bind;
650    use hyperactor::Context;
651    use hyperactor::Handler;
652    use hyperactor::Unbind;
653    use hyperactor::reference as hyperactor_reference;
654    use serde::Deserialize;
655    use serde::Serialize;
656    use typeuri::Named;
657
658    use super::*;
659
660    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
661    pub struct MyReply {
662        pub sender: hyperactor_reference::ActorId,
663        pub value: u64,
664    }
665
666    #[derive(Debug, Named, Serialize, Deserialize, PartialEq, Clone, Bind, Unbind)]
667    pub enum TestMessage {
668        Forward(String),
669        CastAndReply {
670            arg: String,
671            // Intentionally not including 0. As a result, this port will not be
672            // split.
673            // #[binding(include)]
674            reply_to0: hyperactor_reference::PortRef<String>,
675            #[binding(include)]
676            reply_to1: hyperactor_reference::PortRef<u64>,
677            #[binding(include)]
678            reply_to2: hyperactor_reference::PortRef<MyReply>,
679        },
680        CastAndReplyOnce {
681            arg: String,
682            #[binding(include)]
683            reply_to: hyperactor::reference::OncePortRef<u64>,
684        },
685        CastWithUnsplitPort {
686            #[binding(include)]
687            reply_to: hyperactor_reference::PortRef<u64>,
688        },
689    }
690
691    #[derive(Debug)]
692    #[hyperactor::export(
693        spawn = true,
694        handlers = [
695            TestMessage { cast = true },
696        ],
697    )]
698    pub struct TestActor {
699        // Forward the received message to this port, so it can be inspected by
700        // the unit test.
701        forward_port: hyperactor_reference::PortRef<TestMessage>,
702    }
703
704    #[derive(Debug, Clone, Named, Serialize, Deserialize)]
705    pub struct TestActorParams {
706        pub forward_port: hyperactor_reference::PortRef<TestMessage>,
707    }
708
709    #[async_trait]
710    impl Actor for TestActor {}
711
712    #[async_trait]
713    impl hyperactor::RemoteSpawn for TestActor {
714        type Params = TestActorParams;
715
716        async fn new(params: Self::Params, _environment: Flattrs) -> Result<Self> {
717            let Self::Params { forward_port } = params;
718            Ok(Self { forward_port })
719        }
720    }
721
722    #[async_trait]
723    impl Handler<TestMessage> for TestActor {
724        async fn handle(&mut self, cx: &Context<Self>, msg: TestMessage) -> anyhow::Result<()> {
725            // For CastWithUnsplitPort, send a reply so the test can
726            // verify that the unsplit port is still directly reachable.
727            if let TestMessage::CastWithUnsplitPort { ref reply_to } = msg {
728                reply_to.send(cx, 42)?;
729            }
730            self.forward_port.send(cx, msg)?;
731            Ok(())
732        }
733    }
734}
735
736#[cfg(test)]
737mod tests {
738    use std::collections::BTreeMap;
739    use std::collections::HashSet;
740    use std::fmt::Display;
741    use std::hash::Hash;
742    use std::ops::Deref;
743    use std::ops::DerefMut;
744    use std::sync::Mutex;
745    use std::sync::OnceLock;
746
747    use hyperactor::accum;
748
749    /// Common setup for pre-config buffering tests: a single proc with a
750    /// TestActor (for observing delivery) and an unconfigured CommActor.
751    /// Returns (client, rx, comm_handle, actor_mesh_name) plus handles
752    /// that must be kept alive.
753    async fn buffering_fixture(
754        proc_name: &str,
755    ) -> (
756        Instance<()>,
757        hyperactor::mailbox::PortReceiver<TestMessage>,
758        hyperactor::ActorHandle<CommActor>,
759        crate::Name,
760        // Drop guards: client handle, test actor handle, test actor ref.
761        (
762            hyperactor::ActorHandle<()>,
763            hyperactor::ActorHandle<TestActor>,
764            hyperactor_reference::ActorRef<TestActor>,
765        ),
766    ) {
767        use hyperactor::Proc;
768        use hyperactor::RemoteSpawn;
769        use hyperactor::channel::ChannelTransport;
770
771        let proc = Proc::direct(ChannelTransport::Unix.any(), proc_name.to_string()).unwrap();
772        let (client, client_handle) = proc.instance("client").unwrap();
773
774        let actor_mesh_name = crate::Name::new("test").unwrap();
775        let actor_name = actor_mesh_name.to_string();
776
777        let (tx, rx) = open_port(&client);
778        let forward_port = tx.bind();
779        let test_actor = TestActor::new(TestActorParams { forward_port }, Default::default())
780            .await
781            .unwrap();
782        let test_handle = proc.spawn(&actor_name, test_actor).unwrap();
783        let test_ref: hyperactor_reference::ActorRef<TestActor> = test_handle.bind::<TestActor>();
784
785        let comm_handle = proc.spawn("comm", CommActor::default()).unwrap();
786
787        (
788            client,
789            rx,
790            comm_handle,
791            actor_mesh_name,
792            (client_handle, test_handle, test_ref),
793        )
794    }
795
796    /// Send CommMeshConfig (single-rank mesh pointing at self).
797    fn send_config(client: &Instance<()>, comm_handle: &hyperactor::ActorHandle<CommActor>) {
798        let comm_ref = comm_handle.bind::<CommActor>();
799        let mut peers = HashMap::new();
800        peers.insert(0, comm_ref);
801        comm_handle
802            .send(client, CommMeshConfig::new(0, peers))
803            .unwrap();
804    }
805
806    /// Send a message before config, send config, send another after config,
807    /// and verify both are delivered in order.
808    async fn assert_buffered_and_replayed<M: hyperactor::Message>(
809        proc_name: &str,
810        mut make_msg: impl FnMut(&Instance<()>, &crate::Name, &str) -> M,
811    ) where
812        CommActor: hyperactor::Handler<M>,
813    {
814        let (client, mut rx, comm_handle, actor_mesh_name, _guards) =
815            buffering_fixture(proc_name).await;
816
817        comm_handle
818            .send(&client, make_msg(&client, &actor_mesh_name, "buffered"))
819            .unwrap();
820        send_config(&client, &comm_handle);
821        comm_handle
822            .send(&client, make_msg(&client, &actor_mesh_name, "direct"))
823            .unwrap();
824
825        assert_eq!(
826            rx.recv().await.unwrap(),
827            TestMessage::Forward("buffered".to_string()),
828        );
829        assert_eq!(
830            rx.recv().await.unwrap(),
831            TestMessage::Forward("direct".to_string()),
832        );
833        comm_handle.drain_and_stop("test done").ok();
834    }
835
836    #[async_timed_test(timeout_secs = 1)]
837    async fn cast_before_config_is_buffered_and_replayed() {
838        use ndslice::Slice;
839
840        assert_buffered_and_replayed("test_cast", |client, name, payload| {
841            let actor_mesh_id = crate::reference::ActorMeshId(name.clone());
842            let slice = Slice::new_row_major(vec![1]);
843            let shape = ndslice::Shape::new(vec!["rank".to_string()], slice.clone()).unwrap();
844            let envelope = multicast::CastMessageEnvelope::new::<TestActor, TestMessage>(
845                actor_mesh_id,
846                client.self_id().clone(),
847                shape,
848                hyperactor_config::Flattrs::new(),
849                TestMessage::Forward(payload.to_string()),
850            )
851            .unwrap();
852            multicast::CastMessage {
853                dest: multicast::Uslice {
854                    slice,
855                    selection: sel!(*),
856                },
857                message: envelope,
858            }
859        })
860        .await;
861    }
862
863    #[async_timed_test(timeout_secs = 1)]
864    async fn forward_before_config_is_buffered_and_replayed() {
865        use ndslice::Slice;
866        use ndslice::selection::routing::RoutingFrame;
867
868        let mut next_seq: usize = 0;
869        assert_buffered_and_replayed("test_fwd", move |client, name, payload| {
870            let actor_mesh_id = crate::reference::ActorMeshId(name.clone());
871            let slice = Slice::new_row_major(vec![1]);
872            let shape = ndslice::Shape::new(vec!["rank".to_string()], slice.clone()).unwrap();
873            let envelope = multicast::CastMessageEnvelope::new::<TestActor, TestMessage>(
874                actor_mesh_id,
875                client.self_id().clone(),
876                shape,
877                hyperactor_config::Flattrs::new(),
878                TestMessage::Forward(payload.to_string()),
879            )
880            .unwrap();
881            let frame = RoutingFrame::root(sel!(*), slice);
882            let last_seq = next_seq;
883            next_seq += 1;
884            multicast::ForwardMessage {
885                sender: client.self_id().clone(),
886                dests: vec![frame],
887                seq: next_seq,
888                last_seq,
889                message: envelope,
890            }
891        })
892        .await;
893    }
894
895    #[async_timed_test(timeout_secs = 1)]
896    async fn forward_v1_before_config_is_buffered_and_replayed() {
897        use ndslice::Region;
898        use ndslice::Slice;
899        use ndslice::selection::routing::RoutingFrame;
900
901        assert_buffered_and_replayed("test_fwd_v1", |client, name, payload| {
902            let slice = Slice::new_row_major(vec![1]);
903            let region = Region::new(vec!["rank".to_string()], slice.clone());
904            let cast_msg = multicast::CastMessageV1::new::<TestActor, TestMessage>(
905                client.self_id().clone(),
906                name,
907                region.clone(),
908                hyperactor_config::Flattrs::new(),
909                TestMessage::Forward(payload.to_string()),
910                uuid::Uuid::new_v4(),
911                crate::ValueMesh::from_single(region, 0u64),
912            )
913            .unwrap();
914            let frame = RoutingFrame::root(sel!(*), slice);
915            multicast::ForwardMessageV1 {
916                dests: vec![frame],
917                message: cast_msg,
918            }
919        })
920        .await;
921    }
922
923    use hyperactor::accum::Accumulator;
924    use hyperactor::accum::ReducerSpec;
925    use hyperactor::context;
926    use hyperactor::context::Mailbox;
927    use hyperactor::mailbox::PortReceiver;
928    use hyperactor::mailbox::open_port;
929    use hyperactor::reference as hyperactor_reference;
930    use hyperactor_config;
931    use hyperactor_mesh_macros::sel;
932    use maplit::btreemap;
933    use maplit::hashmap;
934    use ndslice::Extent;
935    use ndslice::Selection;
936    use ndslice::ViewExt as _;
937    use ndslice::extent;
938    use ndslice::selection::test_utils::collect_commactor_routing_tree;
939    use test_utils::*;
940    use timed_test::async_timed_test;
941    use tokio::time::Duration;
942
943    use super::*;
944    use crate::ActorMesh;
945    use crate::Name;
946    use crate::host_mesh::HostMesh;
947    use crate::test_utils::local_host_mesh;
948    use crate::testing;
949
950    // Helper to look up the rank for a given actor ID using the rank_lookup table.
951    fn lookup_rank(
952        actor_id: &hyperactor::reference::ActorId,
953        rank_lookup: &HashMap<hyperactor_reference::ProcId, usize>,
954    ) -> usize {
955        let proc_id = actor_id.proc_id();
956        *rank_lookup
957            .get(proc_id)
958            .unwrap_or_else(|| panic!("proc rank not found for {}", proc_id))
959    }
960
961    struct Edge<T> {
962        from: T,
963        to: T,
964        is_leaf: bool,
965    }
966
967    impl<T> From<(T, T, bool)> for Edge<T> {
968        fn from((from, to, is_leaf): (T, T, bool)) -> Self {
969            Self { from, to, is_leaf }
970        }
971    }
972
973    // The relationship between original ports and split ports. The elements in
974    // the tuple are (original port, split port, deliver_here).
975    static SPLIT_PORT_TREE: OnceLock<Mutex<Vec<Edge<hyperactor_reference::PortId>>>> =
976        OnceLock::new();
977
978    // Collect the relationships between original ports and split ports into
979    // SPLIT_PORT_TREE. This is used by tests to verify that ports are split as expected.
980    pub(crate) fn collect_split_port(
981        original: &hyperactor_reference::PortId,
982        split: &hyperactor_reference::PortId,
983        deliver_here: bool,
984    ) {
985        let mutex = SPLIT_PORT_TREE.get_or_init(|| Mutex::new(vec![]));
986        let mut tree = mutex.lock().unwrap();
987
988        tree.deref_mut().push(Edge {
989            from: original.clone(),
990            to: split.clone(),
991            is_leaf: deliver_here,
992        });
993    }
994
995    // There could be other cast calls before the one we want to check, e.g. from
996    // allocating the proc mesh, or spawning the actor mesh. Clear the collected
997    // tree so it will only contain the cast we want to check.
998    fn clear_collected_tree() {
999        if let Some(tree) = SPLIT_PORT_TREE.get() {
1000            let mut tree = tree.lock().unwrap();
1001            tree.clear();
1002        }
1003    }
1004
1005    // A representation of a tree.
1006    //   * Map's keys are the tree's leafs;
1007    //   * Map's values are the path from the root to that leaf.
1008    #[derive(PartialEq)]
1009    struct PathToLeaves<T>(BTreeMap<T, Vec<T>>);
1010
1011    // Add a custom Debug trait impl so the result from assert_eq! is readable.
1012    impl<T: Display> Debug for PathToLeaves<T> {
1013        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1014            fn vec_to_string<T: Display>(v: &[T]) -> String {
1015                v.iter()
1016                    .map(ToString::to_string)
1017                    .collect::<Vec<String>>()
1018                    .join(", ")
1019            }
1020
1021            for (src, path) in &self.0 {
1022                write!(f, "{} -> {}\n", src, vec_to_string(path))?;
1023            }
1024            Ok(())
1025        }
1026    }
1027
1028    fn build_paths<T: Clone + Eq + Hash + Ord>(edges: &[Edge<T>]) -> PathToLeaves<T> {
1029        let mut child_parent_map = HashMap::new();
1030        let mut all_nodes = HashSet::new();
1031        let mut parents = HashSet::new();
1032        let mut children = HashSet::new();
1033        let mut dests = HashSet::new();
1034
1035        // Build parent map and track all nodes and children
1036        for Edge { from, to, is_leaf } in edges {
1037            child_parent_map.insert(to.clone(), from.clone());
1038            all_nodes.insert(from.clone());
1039            all_nodes.insert(to.clone());
1040            parents.insert(from.clone());
1041            children.insert(to.clone());
1042            if *is_leaf {
1043                dests.insert(to.clone());
1044            }
1045        }
1046
1047        // For each leaf, reconstruct path back to root
1048        let mut result = BTreeMap::new();
1049        for dest in dests {
1050            let mut path = vec![dest.clone()];
1051            let mut current = dest.clone();
1052            while let Some(parent) = child_parent_map.get(&current) {
1053                path.push(parent.clone());
1054                current = parent.clone();
1055            }
1056            path.reverse();
1057            result.insert(dest, path);
1058        }
1059
1060        PathToLeaves(result)
1061    }
1062
1063    #[test]
1064    fn test_build_paths() {
1065        // Given the tree:
1066        //     0
1067        //    / \
1068        //   1   4
1069        //  / \   \
1070        // 2   3   5
1071        let edges: Vec<_> = [
1072            (0, 1, false),
1073            (1, 2, true),
1074            (1, 3, true),
1075            (0, 4, true),
1076            (4, 5, true),
1077        ]
1078        .into_iter()
1079        .map(|(from, to, is_leaf)| Edge { from, to, is_leaf })
1080        .collect();
1081
1082        let paths = build_paths(&edges);
1083
1084        let expected = btreemap! {
1085            2 => vec![0, 1, 2],
1086            3 => vec![0, 1, 3],
1087            4 => vec![0, 4],
1088            5 => vec![0, 4, 5],
1089        };
1090
1091        assert_eq!(paths.0, expected);
1092    }
1093
1094    //  Given a port tree,
1095    //     * remove the client port, i.e. the 1st element of the path;
1096    //     * verify all remaining ports are comm actor ports;
1097    //     * remove the actor information and return a rank-based tree representation.
1098    //
1099    //  The rank-based tree representation is what [collect_commactor_routing_tree] returns.
1100    //  This conversion enables us to compare the path against [collect_commactor_routing_tree]'s result.
1101    //
1102    //      For example, for a 2x2 slice, the port tree could look like:
1103    //      dest[0].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028]]
1104    //      dest[1].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[1].comm[0][1028]]
1105    //      dest[2].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[2].comm[0][1028]]
1106    //      dest[3].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[2].comm[0][1028], dest[3].comm[0][1028]]
1107    //
1108    //     The result should be:
1109    //     0 -> 0
1110    //     1 -> 0, 1
1111    //     2 -> 0, 2
1112    //     3 -> 0, 2, 3
1113    fn get_ranks(
1114        paths: PathToLeaves<hyperactor_reference::PortId>,
1115        client_reply: &hyperactor_reference::PortId,
1116        rank_lookup: &HashMap<hyperactor_reference::ProcId, usize>,
1117    ) -> PathToLeaves<hyperactor_reference::Index> {
1118        let ranks = paths
1119            .0
1120            .into_iter()
1121            .map(|(dst, mut path)| {
1122                let first = path.remove(0);
1123                // The first PortId is the client's reply port.
1124                assert_eq!(&first, client_reply);
1125                // Other ports's actor ID must be dest[?].comm[0], where ? is
1126                // the rank we want to extract here.
1127                assert!(dst.actor_id().name().contains("comm"));
1128                let actor_path = path
1129                    .into_iter()
1130                    .map(|p| {
1131                        assert!(p.actor_id().name().contains("comm"));
1132                        lookup_rank(p.actor_id(), rank_lookup)
1133                    })
1134                    .collect();
1135                (lookup_rank(dst.actor_id(), rank_lookup), actor_path)
1136            })
1137            .collect();
1138        PathToLeaves(ranks)
1139    }
1140
1141    struct NoneAccumulator;
1142
1143    impl Accumulator for NoneAccumulator {
1144        type State = u64;
1145        type Update = u64;
1146
1147        fn accumulate(
1148            &self,
1149            _state: &mut Self::State,
1150            _update: Self::Update,
1151        ) -> anyhow::Result<()> {
1152            unimplemented!()
1153        }
1154
1155        fn reducer_spec(&self) -> Option<ReducerSpec> {
1156            unimplemented!()
1157        }
1158    }
1159
1160    // Verify the split port paths are the same as the casting paths.
1161    fn verify_split_port_paths(
1162        selection: &Selection,
1163        extent: &Extent,
1164        reply_port_ref1: &hyperactor_reference::PortRef<u64>,
1165        reply_port_ref2: &hyperactor_reference::PortRef<MyReply>,
1166        rank_lookup: &HashMap<hyperactor_reference::ProcId, usize>,
1167    ) {
1168        // Get the paths used in casting
1169        let sel_paths = PathToLeaves(
1170            collect_commactor_routing_tree(selection, &extent.to_slice())
1171                .delivered
1172                .into_iter()
1173                .collect(),
1174        );
1175
1176        // Get the split port paths collected in SPLIT_PORT_TREE during casting
1177        let (reply1_paths, reply2_paths) = {
1178            let tree = SPLIT_PORT_TREE
1179                .get()
1180                .expect("not initialized; are Hosts in the same process as SPLIT_PORT_TREE?");
1181            let edges = tree.lock().unwrap();
1182            let (reply1, reply2): (BTreeMap<_, _>, BTreeMap<_, _>) = build_paths(&edges)
1183                .0
1184                .into_iter()
1185                .partition(|(_dst, path)| &path[0] == reply_port_ref1.port_id());
1186            (
1187                get_ranks(PathToLeaves(reply1), reply_port_ref1.port_id(), rank_lookup),
1188                get_ranks(PathToLeaves(reply2), reply_port_ref2.port_id(), rank_lookup),
1189            )
1190        };
1191
1192        // split port paths should be the same as casting paths
1193        assert_eq!(sel_paths, reply1_paths);
1194        assert_eq!(sel_paths, reply2_paths);
1195    }
1196
1197    async fn execute_cast_and_reply(
1198        ranks: Vec<hyperactor_reference::ActorRef<TestActor>>,
1199        instance: &impl context::Actor,
1200        mut reply1_rx: PortReceiver<u64>,
1201        mut reply2_rx: PortReceiver<MyReply>,
1202        reply_tos: Vec<(
1203            hyperactor_reference::PortRef<u64>,
1204            hyperactor_reference::PortRef<MyReply>,
1205        )>,
1206    ) {
1207        // Reply from each dest actor. The replies should be received by client.
1208        {
1209            for (rank, (dest_actor, (reply_to1, reply_to2))) in
1210                ranks.iter().zip(reply_tos.iter()).enumerate()
1211            {
1212                let rank_u64 = rank as u64;
1213                reply_to1.send(instance, rank_u64).unwrap();
1214                let my_reply = MyReply {
1215                    sender: dest_actor.actor_id().clone(),
1216                    value: rank_u64,
1217                };
1218                reply_to2.send(instance, my_reply.clone()).unwrap();
1219
1220                assert_eq!(reply1_rx.recv().await.unwrap(), rank_u64);
1221                assert_eq!(reply2_rx.recv().await.unwrap(), my_reply);
1222            }
1223        }
1224
1225        tracing::info!("the 1st updates from all dest actors were receivered by client");
1226
1227        // Now send multiple replies from the dest actors. They should all be
1228        // received by client. Replies sent from the same dest actor should
1229        // be received in the same order as they were sent out.
1230        {
1231            let n = 100;
1232            let mut expected2: HashMap<hyperactor_reference::ActorId, Vec<MyReply>> = hashmap! {};
1233            for (i, (dest_actor, (_reply_to1, reply_to2))) in
1234                ranks.iter().zip(reply_tos.iter()).enumerate()
1235            {
1236                let mut sent2 = vec![];
1237                for j in 0..n {
1238                    let value = (i * 100 + j) as u64;
1239                    let my_reply = MyReply {
1240                        sender: dest_actor.actor_id().clone(),
1241                        value,
1242                    };
1243                    reply_to2.send(instance, my_reply.clone()).unwrap();
1244                    sent2.push(my_reply);
1245                }
1246                assert!(
1247                    expected2
1248                        .insert(dest_actor.actor_id().clone(), sent2)
1249                        .is_none(),
1250                    "duplicate actor_id {} in map",
1251                    dest_actor.actor_id()
1252                );
1253            }
1254
1255            let mut received2: HashMap<hyperactor_reference::ActorId, Vec<MyReply>> = hashmap! {};
1256
1257            for _ in 0..(n * ranks.len()) {
1258                let my_reply = reply2_rx.recv().await.unwrap();
1259                received2
1260                    .entry(my_reply.sender.clone())
1261                    .or_default()
1262                    .push(my_reply);
1263            }
1264            assert_eq!(received2, expected2);
1265        }
1266    }
1267
1268    async fn wait_for_with_timeout(
1269        receiver: &mut PortReceiver<u64>,
1270        expected: u64,
1271        dur: Duration,
1272    ) -> anyhow::Result<()> {
1273        // timeout wraps the entire async block
1274        tokio::time::timeout(dur, async {
1275            loop {
1276                let msg = receiver.recv().await.unwrap();
1277                if msg == expected {
1278                    break;
1279                }
1280            }
1281        })
1282        .await?;
1283        Ok(())
1284    }
1285
1286    async fn execute_cast_and_accum(
1287        ranks: Vec<hyperactor_reference::ActorRef<TestActor>>,
1288        instance: &impl context::Actor,
1289        mut reply1_rx: PortReceiver<u64>,
1290        reply_tos: Vec<(
1291            hyperactor_reference::PortRef<u64>,
1292            hyperactor_reference::PortRef<MyReply>,
1293        )>,
1294    ) {
1295        // Now send multiple replies from the dest actors. They should all be
1296        // received by client. Replies sent from the same dest actor should
1297        // be received in the same order as they were sent out.
1298        let mut sum = 0;
1299        let n = 100;
1300        for (i, (_dest_actor, (reply_to1, _reply_to2))) in
1301            ranks.iter().zip(reply_tos.iter()).enumerate()
1302        {
1303            for j in 0..n {
1304                let value = (i + j) as u64;
1305                reply_to1.send(instance, value).unwrap();
1306                sum += value;
1307            }
1308        }
1309        wait_for_with_timeout(&mut reply1_rx, sum, Duration::from_secs(8))
1310            .await
1311            .unwrap();
1312        // no more messages
1313        tokio::time::sleep(Duration::from_secs(2)).await;
1314        let msg = reply1_rx.try_recv().unwrap();
1315        assert_eq!(msg, None);
1316    }
1317
1318    struct MeshSetupV1 {
1319        instance: &'static Instance<testing::TestRootClient>,
1320        actor_mesh_ref: crate::ActorMeshRef<TestActor>,
1321        reply1_rx: PortReceiver<u64>,
1322        reply2_rx: PortReceiver<MyReply>,
1323        reply_tos: Vec<(
1324            hyperactor_reference::PortRef<u64>,
1325            hyperactor_reference::PortRef<MyReply>,
1326        )>,
1327        // Keep the host mesh alive so comm actors aren't shut down.
1328        host_mesh: HostMesh,
1329    }
1330
1331    async fn setup_mesh_v1<A>(accum: Option<A>) -> MeshSetupV1
1332    where
1333        A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
1334    {
1335        let instance = crate::testing::instance();
1336        // We have to use a in process host mesh, because SPLIT_PORT_TREE only
1337        // can collect paths from the same process.
1338        let host_mesh = local_host_mesh(8).await;
1339        let proc_mesh = host_mesh
1340            .spawn(instance, "test", extent!(gpu = 8), None)
1341            .await
1342            .unwrap();
1343
1344        let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1345        let params = TestActorParams {
1346            forward_port: tx.bind(),
1347        };
1348        let actor_name = crate::Name::new("test").expect("valid test name");
1349        // Make this actor a "system" actor to avoid spawning a controller actor.
1350        // This test is verifying the whole comm tree, so we want fewer actors
1351        // involved.
1352        let actor_mesh = proc_mesh
1353            .spawn_with_name(&instance, actor_name, &params, None, true)
1354            .await
1355            .unwrap();
1356        let actor_mesh_ref: crate::ActorMeshRef<TestActor> = actor_mesh.deref().clone();
1357
1358        let (reply_port_handle0, _) = open_port::<String>(instance);
1359        let reply_port_ref0 = reply_port_handle0.bind();
1360        let (reply_port_handle1, reply1_rx) = match accum {
1361            Some(a) => instance.mailbox().open_accum_port(a),
1362            None => open_port(instance),
1363        };
1364        let reply_port_ref1 = reply_port_handle1.bind();
1365        let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(instance);
1366        let reply_port_ref2 = reply_port_handle2.bind();
1367        let message = TestMessage::CastAndReply {
1368            arg: "abc".to_string(),
1369            reply_to0: reply_port_ref0.clone(),
1370            reply_to1: reply_port_ref1.clone(),
1371            reply_to2: reply_port_ref2.clone(),
1372        };
1373
1374        clear_collected_tree();
1375        actor_mesh_ref.cast(instance, message).unwrap();
1376
1377        let mut reply_tos = vec![];
1378        for _ in proc_mesh.extent().points() {
1379            let msg = rx.recv().await.expect("missing");
1380            match msg {
1381                TestMessage::CastAndReply {
1382                    arg,
1383                    reply_to0,
1384                    reply_to1,
1385                    reply_to2,
1386                } => {
1387                    assert_eq!(arg, "abc");
1388                    // port 0 is still the same as the original one because it
1389                    // is not included in MutVisitor.
1390                    assert_eq!(reply_to0, reply_port_ref0);
1391                    // ports have been replaced by comm actor's split ports.
1392                    assert_ne!(reply_to1, reply_port_ref1);
1393                    assert!(reply_to1.port_id().actor_id().name().contains("comm"));
1394                    assert_ne!(reply_to2, reply_port_ref2);
1395                    assert!(reply_to2.port_id().actor_id().name().contains("comm"));
1396                    reply_tos.push((reply_to1, reply_to2));
1397                }
1398                _ => {
1399                    panic!("unexpected message: {:?}", msg);
1400                }
1401            }
1402        }
1403
1404        // [collect_commactor_routing_tree] only returns ranks,So we need to
1405        // map proc Ids collected in SPLIT_PORT_TREE to ranks.
1406        let rank_lookup = proc_mesh
1407            .ranks()
1408            .iter()
1409            .enumerate()
1410            .map(|(i, r)| (r.proc_id().clone(), i))
1411            .collect::<HashMap<hyperactor_reference::ProcId, usize>>();
1412
1413        // v1 always uses sel!(*) when casting to a mesh.
1414        let selection = sel!(*);
1415        verify_split_port_paths(
1416            &selection,
1417            &proc_mesh.extent(),
1418            &reply_port_ref1,
1419            &reply_port_ref2,
1420            &rank_lookup,
1421        );
1422
1423        MeshSetupV1 {
1424            instance,
1425            actor_mesh_ref,
1426            reply1_rx,
1427            reply2_rx,
1428            reply_tos,
1429            host_mesh,
1430        }
1431    }
1432
1433    async fn execute_cast_and_reply_v1() {
1434        let mut setup = setup_mesh_v1::<NoneAccumulator>(None).await;
1435
1436        let ranks = setup.actor_mesh_ref.values().collect::<Vec<_>>();
1437        execute_cast_and_reply(
1438            ranks,
1439            setup.instance,
1440            setup.reply1_rx,
1441            setup.reply2_rx,
1442            setup.reply_tos,
1443        )
1444        .await;
1445
1446        let _ = setup.host_mesh.shutdown(setup.instance).await;
1447    }
1448
1449    #[async_timed_test(timeout_secs = 60)]
1450    async fn test_cast_and_reply_v1_retrofit() {
1451        let config = hyperactor_config::global::lock();
1452        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false);
1453        let _guard2 = config.override_key(
1454            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1455            false,
1456        );
1457        execute_cast_and_reply_v1().await
1458    }
1459
1460    #[async_timed_test(timeout_secs = 60)]
1461    async fn test_cast_and_reply_v1_native() {
1462        let config = hyperactor_config::global::lock();
1463        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1464        let _guard2 = config.override_key(
1465            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1466            true,
1467        );
1468        execute_cast_and_reply_v1().await
1469    }
1470
1471    async fn execute_cast_and_accum_v1(config: &hyperactor_config::global::ConfigLock) {
1472        // Use temporary config for this test
1473        let _guard1 = config.override_key(hyperactor::config::SPLIT_MAX_BUFFER_SIZE, 1);
1474
1475        let mut setup = setup_mesh_v1(Some(accum::sum::<u64>())).await;
1476
1477        let ranks = setup.actor_mesh_ref.values().collect::<Vec<_>>();
1478        execute_cast_and_accum(ranks, setup.instance, setup.reply1_rx, setup.reply_tos).await;
1479
1480        let _ = setup.host_mesh.shutdown(setup.instance).await;
1481    }
1482
1483    #[async_timed_test(timeout_secs = 60)]
1484    async fn test_cast_and_accum_v1_retrofit() {
1485        let config = hyperactor_config::global::lock();
1486        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false);
1487        let _guard2 = config.override_key(
1488            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1489            false,
1490        );
1491        execute_cast_and_accum_v1(&config).await
1492    }
1493
1494    #[async_timed_test(timeout_secs = 60)]
1495    async fn test_cast_and_accum_v1_native() {
1496        let config = hyperactor_config::global::lock();
1497        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1498        let _guard2 = config.override_key(
1499            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1500            true,
1501        );
1502        execute_cast_and_accum_v1(&config).await
1503    }
1504
1505    struct OncePortMeshSetupV1 {
1506        instance: &'static Instance<testing::TestRootClient>,
1507        reply_rx: hyperactor::mailbox::OncePortReceiver<u64>,
1508        reply_tos: Vec<hyperactor::reference::OncePortRef<u64>>,
1509        _reply_port_ref: hyperactor::reference::OncePortRef<u64>,
1510        host_mesh: HostMesh,
1511    }
1512
1513    async fn setup_once_port_mesh<A>(reducer: Option<A>) -> OncePortMeshSetupV1
1514    where
1515        A: Accumulator<State = u64, Update = u64> + Send + Sync + 'static,
1516    {
1517        let instance = crate::testing::instance();
1518        // We have to use a in process host mesh, because SPLIT_PORT_TREE only
1519        // can collect paths from the same process.
1520        let host_mesh = local_host_mesh(8).await;
1521        let proc_mesh = host_mesh
1522            .spawn(instance, "test", extent!(gpu = 8), None)
1523            .await
1524            .unwrap();
1525
1526        let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1527        let params = TestActorParams {
1528            forward_port: tx.bind(),
1529        };
1530        let actor_name = crate::Name::new("test").expect("valid test name");
1531        // Make this actor a "system" actor to avoid spawning a controller actor.
1532        let actor_mesh: crate::ActorMesh<TestActor> = proc_mesh
1533            .spawn_with_name(&instance, actor_name, &params, None, true)
1534            .await
1535            .unwrap();
1536        let actor_mesh_ref = actor_mesh.deref().clone();
1537
1538        let has_reducer = reducer.is_some();
1539        let (reply_port_handle, reply_rx) = match reducer {
1540            Some(reducer) => instance.mailbox().open_reduce_port(reducer),
1541            None => instance.mailbox().open_once_port::<u64>(),
1542        };
1543        let reply_port_ref = reply_port_handle.bind();
1544
1545        let message = TestMessage::CastAndReplyOnce {
1546            arg: "abc".to_string(),
1547            reply_to: reply_port_ref.clone(),
1548        };
1549
1550        clear_collected_tree();
1551        actor_mesh_ref.cast(instance, message).unwrap();
1552
1553        let mut reply_tos = vec![];
1554        for _ in proc_mesh.extent().points() {
1555            let msg = rx.recv().await.expect("missing");
1556            match msg {
1557                TestMessage::CastAndReplyOnce { arg, reply_to } => {
1558                    assert_eq!(arg, "abc");
1559                    if has_reducer {
1560                        // With reducer: port is split by comm actor.
1561                        assert_ne!(reply_to, reply_port_ref);
1562                        assert!(reply_to.port_id().actor_id().name().contains("comm"));
1563                    } else {
1564                        // Without reducer: port is passed through unchanged.
1565                        assert_eq!(reply_to, reply_port_ref);
1566                    }
1567                    reply_tos.push(reply_to);
1568                }
1569                _ => {
1570                    panic!("unexpected message: {:?}", msg);
1571                }
1572            }
1573        }
1574
1575        OncePortMeshSetupV1 {
1576            instance,
1577            reply_rx,
1578            reply_tos,
1579            _reply_port_ref: reply_port_ref,
1580            host_mesh,
1581        }
1582    }
1583
1584    #[async_timed_test(timeout_secs = 60)]
1585    async fn test_cast_and_reply_once_v1() {
1586        // Test OncePort without accumulator - port is NOT split.
1587        // All destinations receive the same original port.
1588        // First reply is delivered, others fail at receiver (port closed).
1589        let mut setup = setup_once_port_mesh::<NoneAccumulator>(None).await;
1590
1591        // All reply_tos point to the same port (not split).
1592        // Only the first message will be delivered successfully.
1593        let num_replies = setup.reply_tos.len();
1594        for (i, reply_to) in setup.reply_tos.into_iter().enumerate() {
1595            reply_to.send(setup.instance, i as u64).unwrap();
1596        }
1597
1598        // OncePort receives exactly one value (the first to arrive)
1599        let result = setup.reply_rx.recv().await.unwrap();
1600        // The result should be one of the values sent
1601        assert!(result < num_replies as u64);
1602
1603        let _ = setup.host_mesh.shutdown(setup.instance).await;
1604    }
1605
1606    #[async_timed_test(timeout_secs = 60)]
1607    async fn test_cast_and_accum_once_v1() {
1608        // Test OncePort splitting with sum accumulator.
1609        // Each destination actor replies with its rank.
1610        // The sum of all ranks should be received at the original port.
1611        let mut setup = setup_once_port_mesh(Some(accum::sum::<u64>())).await;
1612
1613        // Each actor replies with its index
1614        let mut expected_sum = 0u64;
1615        for (i, reply_to) in setup.reply_tos.into_iter().enumerate() {
1616            reply_to.send(setup.instance, i as u64).unwrap();
1617            expected_sum += i as u64;
1618        }
1619
1620        // OncePort should receive the sum of all responses
1621        let result = setup.reply_rx.recv().await.unwrap();
1622        assert_eq!(result, expected_sum);
1623
1624        let _ = setup.host_mesh.shutdown(setup.instance).await;
1625    }
1626
1627    #[async_timed_test(timeout_secs = 60)]
1628    async fn test_unsplit_port_not_split() {
1629        let instance = crate::testing::instance();
1630        let mut host_mesh = local_host_mesh(8).await;
1631        let proc_mesh = host_mesh
1632            .spawn(instance, "test", extent!(gpu = 8), None)
1633            .await
1634            .unwrap();
1635
1636        let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1637        let params = TestActorParams {
1638            forward_port: tx.bind(),
1639        };
1640        let actor_name = Name::new("test").expect("valid test name");
1641        let actor_mesh: ActorMesh<TestActor> = proc_mesh
1642            .spawn_with_name(&instance, actor_name, &params, None, true)
1643            .await
1644            .unwrap();
1645        let (reply_port_handle, mut reply_rx) = open_port::<u64>(instance);
1646        let reply_port_ref = reply_port_handle.bind().unsplit();
1647
1648        let message = TestMessage::CastWithUnsplitPort {
1649            reply_to: reply_port_ref.clone(),
1650        };
1651
1652        clear_collected_tree();
1653        actor_mesh.cast(instance, message).unwrap();
1654
1655        // Verify that all destinations received the original port (not split).
1656        let num_points = proc_mesh.extent().points().count();
1657        for _ in 0..num_points {
1658            let msg = rx.recv().await.expect("missing");
1659            match msg {
1660                TestMessage::CastWithUnsplitPort { reply_to } => {
1661                    assert_eq!(
1662                        reply_to.port_id(),
1663                        reply_port_ref.port_id(),
1664                        "unsplit port should not be replaced by a comm actor split port"
1665                    );
1666                }
1667                _ => panic!("unexpected message: {:?}", msg),
1668            }
1669        }
1670
1671        // All 8 actors sent replies directly to the same port.
1672        // Verify we receive all 8 replies.
1673        for _ in 0..8 {
1674            let val = reply_rx.recv().await.unwrap();
1675            assert_eq!(val, 42);
1676        }
1677        let _ = host_mesh.shutdown(instance).await;
1678    }
1679}