hyperactor_mesh/
comm.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use crate::casting::CAST_ACTOR_MESH_ID;
10use crate::comm::multicast::CAST_ORIGINATING_SENDER;
11use crate::comm::multicast::CastEnvelope;
12use crate::comm::multicast::CastMessageV1;
13use crate::comm::multicast::ForwardMessageV1;
14use crate::reference::ActorMeshId;
15use crate::resource;
16pub mod multicast;
17
18use std::cmp::Ordering;
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22use anyhow::Result;
23use async_trait::async_trait;
24use hyperactor::Actor;
25use hyperactor::Context;
26use hyperactor::Handler;
27use hyperactor::Instance;
28use hyperactor::RemoteMessage;
29use hyperactor::accum::ReducerMode;
30use hyperactor::mailbox::DeliveryError;
31use hyperactor::mailbox::MailboxSender;
32use hyperactor::mailbox::Undeliverable;
33use hyperactor::mailbox::UndeliverableMailboxSender;
34use hyperactor::mailbox::UndeliverableMessageError;
35use hyperactor::mailbox::monitored_return_handle;
36use hyperactor::message::ErasedUnbound;
37use hyperactor::ordering::SEQ_INFO;
38use hyperactor::ordering::SeqInfo;
39use hyperactor::reference as hyperactor_reference;
40use hyperactor_config::CONFIG;
41use hyperactor_config::ConfigAttr;
42use hyperactor_config::Flattrs;
43use hyperactor_config::attrs::declare_attrs;
44use hyperactor_mesh_macros::sel;
45use ndslice::Point;
46use ndslice::Selection;
47use ndslice::View;
48use ndslice::selection::routing::RoutingFrame;
49use serde::Deserialize;
50use serde::Serialize;
51use typeuri::Named;
52
53use crate::comm::multicast::CastMessage;
54use crate::comm::multicast::CastMessageEnvelope;
55use crate::comm::multicast::ForwardMessage;
56use crate::comm::multicast::set_cast_info_on_headers;
57
58declare_attrs! {
59    /// Whether to use native v1 casting in v1 ActorMesh.
60    @meta(CONFIG = ConfigAttr::new(
61        Some("HYPERACTOR_MESH_ENABLE_NATIVE_V1_CASTING".to_string()),
62        Some("enable_native_v1_casting".to_string()),
63    ))
64    pub attr ENABLE_NATIVE_V1_CASTING: bool = false;
65}
66
67/// Parameters to initialize the CommActor
68#[derive(Debug, Clone, Serialize, Deserialize, Named, Default)]
69pub struct CommActorParams {}
70wirevalue::register_type!(CommActorParams);
71
72/// A message buffered due to out-of-order delivery.
73#[derive(Debug)]
74struct Buffered {
75    /// Sequence number of this message.
76    seq: usize,
77    /// Whether to deliver this message to this comm-actors actors.
78    deliver_here: bool,
79    /// Peer comm actors to forward message to.
80    next_steps: HashMap<usize, Vec<RoutingFrame>>,
81    /// The message to deliver.
82    message: CastMessageEnvelope,
83}
84
85/// Bookkeeping to handle sequence numbers and in-order delivery for messages
86/// sent to and through this comm actor.
87#[derive(Debug, Default)]
88struct ReceiveState {
89    /// The sequence of the last received message.
90    seq: usize,
91    /// A buffer storing messages we received out-of-order, indexed by the seq
92    /// that should precede it.
93    buffer: HashMap<usize, Buffered>,
94    /// A map of the last sequence number we sent to next steps, indexed by rank.
95    last_seqs: HashMap<usize, usize>,
96}
97
98/// This is the comm actor used for efficient and scalable message multicasting
99/// and result accumulation.
100#[derive(Debug, Default)]
101#[hyperactor::export(
102    spawn = true,
103    handlers = [
104        CommMeshConfig,
105        CastMessage,
106        ForwardMessage,
107        CastMessageV1,
108        ForwardMessageV1,
109    ],
110)]
111pub struct CommActor {
112    /// Sequence numbers are maintained for each (actor mesh id, sender).
113    send_seq: HashMap<(ActorMeshId, hyperactor_reference::ActorId), usize>,
114    /// Each sender is a unique stream.
115    recv_state: HashMap<(ActorMeshId, hyperactor_reference::ActorId), ReceiveState>,
116
117    /// The comm actor's mesh configuration.
118    mesh_config: Option<CommMeshConfig>,
119}
120
121/// Configuration for how a `CommActor` determines its own rank and locates peers.
122#[derive(Debug, Clone, Serialize, Deserialize, Named)]
123pub struct CommMeshConfig {
124    /// The rank of this comm actor on the root mesh.
125    rank: usize,
126    /// Key is the rank of the peer on the root mesh. Value is the peer's comm actor.
127    peers: HashMap<usize, hyperactor_reference::ActorRef<CommActor>>,
128}
129wirevalue::register_type!(CommMeshConfig);
130
131impl CommMeshConfig {
132    /// Create a new mesh configuration with the given rank and peer mapping.
133    pub fn new(
134        rank: usize,
135        peers: HashMap<usize, hyperactor_reference::ActorRef<CommActor>>,
136    ) -> Self {
137        Self { rank, peers }
138    }
139
140    /// Return the peer comm actor for the given rank.
141    fn peer_for_rank(&self, rank: usize) -> Result<hyperactor_reference::ActorRef<CommActor>> {
142        self.peers
143            .get(&rank)
144            .cloned()
145            .ok_or_else(|| anyhow::anyhow!("no peer for rank {}", rank))
146    }
147
148    /// Return the rank of the comm actor.
149    fn self_rank(&self) -> usize {
150        self.rank
151    }
152}
153
154#[async_trait]
155impl Actor for CommActor {
156    async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
157        this.set_system();
158        Ok(())
159    }
160
161    // This is an override of the default actor behavior.
162    async fn handle_undeliverable_message(
163        &mut self,
164        cx: &Instance<Self>,
165        undelivered: hyperactor::mailbox::Undeliverable<hyperactor::mailbox::MessageEnvelope>,
166    ) -> Result<(), anyhow::Error> {
167        let Undeliverable(mut message_envelope) = undelivered;
168
169        // 1. Case delivery failure at a "forwarding" step.
170        if let Ok(ForwardMessage { message, .. }) =
171            message_envelope.deserialized::<ForwardMessage>()
172        {
173            let sender = message.sender();
174            let return_port = hyperactor_reference::PortRef::attest_message_port(sender);
175            message_envelope.set_error(DeliveryError::Multicast(format!(
176                "comm actor {} failed to forward the cast message; returning to origin {}",
177                cx.self_id(),
178                return_port.port_id(),
179            )));
180
181            // Needed so that the receiver of the undeliverable message can easily find the
182            // original sender of the cast message.
183            message_envelope.set_header(CAST_ORIGINATING_SENDER, sender.clone());
184
185            return_port
186                .send(cx, Undeliverable(message_envelope.clone()))
187                .map_err(|err| {
188                    let error = DeliveryError::BrokenLink(format!(
189                        "error occured when returning ForwardMessage to the original \
190                        sender's port {}; error is: {}",
191                        return_port.port_id(),
192                        err,
193                    ));
194                    message_envelope.set_error(error);
195                    UndeliverableMessageError::ReturnFailure {
196                        envelope: message_envelope,
197                    }
198                })?;
199            return Ok(());
200        }
201
202        // 2. Case delivery failure at a "deliver here" step.
203        if let Some(sender) = message_envelope.headers().get(CAST_ORIGINATING_SENDER) {
204            let return_port = hyperactor_reference::PortRef::attest_message_port(&sender);
205            message_envelope.set_error(DeliveryError::Multicast(format!(
206                "comm actor {} failed to deliver the cast message to the dest \
207                actor; returning to origin {}",
208                cx.self_id(),
209                return_port.port_id(),
210            )));
211            return_port
212                .send(cx, Undeliverable(message_envelope.clone()))
213                .map_err(|err| {
214                    let error = DeliveryError::BrokenLink(format!(
215                        "error occured when returning cast message to the origin \
216                        sender {}; error is: {}",
217                        return_port.port_id(),
218                        err,
219                    ));
220                    message_envelope.set_error(error);
221                    UndeliverableMessageError::ReturnFailure {
222                        envelope: message_envelope,
223                    }
224                })?;
225            return Ok(());
226        }
227
228        // 3. A return of an undeliverable message was itself returned.
229        UndeliverableMailboxSender
230            .post(message_envelope, /*unused */ monitored_return_handle());
231        Ok(())
232    }
233}
234
235impl CommActor {
236    /// Forward the message to the comm actor on the given peer rank.
237    fn forward<M: RemoteMessage>(
238        cx: &Context<Self>,
239        config: &CommMeshConfig,
240        rank: usize,
241        message: M,
242    ) -> Result<()>
243    where
244        CommActor: hyperactor::RemoteHandles<M>,
245    {
246        let child = config.peer_for_rank(rank)?;
247        // TEMPORARY: until dropping v0 support
248        if let Some(cast_actor_mesh_id) = cx.headers().get(CAST_ACTOR_MESH_ID) {
249            let mut headers = Flattrs::new();
250            headers.set(CAST_ACTOR_MESH_ID, cast_actor_mesh_id);
251            child.send_with_headers(cx, headers, message)?;
252        } else {
253            child.send(cx, message)?;
254        }
255        Ok(())
256    }
257
258    fn handle_message(
259        cx: &Context<Self>,
260        config: &CommMeshConfig,
261        deliver_here: bool,
262        next_steps: HashMap<usize, Vec<RoutingFrame>>,
263        sender: hyperactor_reference::ActorId,
264        mut message: CastMessageEnvelope,
265        seq: usize,
266        last_seqs: &mut HashMap<usize, usize>,
267    ) -> Result<()> {
268        split_ports(cx, message.data_mut(), deliver_here, &next_steps)?;
269
270        // Deliver message here, if necessary.
271        if deliver_here {
272            // We should not copy cx.headers() because it contains auto-generated
273            // headers from mailbox. We want fresh headers only containing
274            // user-provided headers.
275            let headers = message.headers().clone();
276            Self::deliver_to_dest(cx, headers, &mut message, config)?;
277        }
278
279        // Forward to peers.
280        next_steps
281            .into_iter()
282            .map(|(peer, dests)| {
283                let last_seq = last_seqs.entry(peer).or_default();
284                Self::forward(
285                    cx,
286                    config,
287                    peer,
288                    ForwardMessage {
289                        dests,
290                        sender: sender.clone(),
291                        message: message.clone(),
292                        seq,
293                        last_seq: *last_seq,
294                    },
295                )?;
296                *last_seq = seq;
297                Ok(())
298            })
299            .collect::<Result<Vec<_>>>()?;
300
301        Ok(())
302    }
303
304    fn deliver_to_dest<M: CastEnvelope>(
305        cx: &Context<Self>,
306        mut headers: Flattrs,
307        message: &mut M,
308        config: &CommMeshConfig,
309    ) -> anyhow::Result<()> {
310        let cast_point = message.cast_point(config)?;
311        // Replace ranks with self ranks.
312        replace_with_self_ranks(&cast_point, message.data_mut())?;
313
314        set_cast_info_on_headers(&mut headers, cast_point, message.sender().clone());
315        cx.post_with_external_seq_info(
316            cx.self_id()
317                .proc_id()
318                .actor_id(message.dest_port().actor_name(), 0)
319                .port_id(message.dest_port().port()),
320            headers,
321            wirevalue::Any::serialize(message.data())?,
322        );
323
324        Ok(())
325    }
326}
327
328// Split ports, if any, and update message with new ports. In this
329// way, children actors will reply to this comm actor's ports, instead
330// of to the original ports provided by parent.
331fn split_ports(
332    cx: &Context<CommActor>,
333    data: &mut ErasedUnbound,
334    deliver_here: bool,
335    next_steps: &HashMap<usize, Vec<RoutingFrame>>,
336) -> anyhow::Result<()> {
337    // Split ports, if any, and update message with new ports. In this
338    // way, children actors will reply to this comm actor's ports, instead
339    // of to the original ports provided by parent.
340    data.visit_mut::<hyperactor_reference::UnboundPort>(
341        |hyperactor_reference::UnboundPort(port_id, reducer_spec, return_undeliverable, kind)| {
342            let reducer_mode = match kind {
343                hyperactor_reference::UnboundPortKind::Streaming(opts) => {
344                    ReducerMode::Streaming(opts.clone().unwrap_or_default())
345                }
346                hyperactor_reference::UnboundPortKind::Once if reducer_spec.is_none() => {
347                    // We can only split OncePorts that have reducers.
348                    // Pass this through -- if it is used multiple times,
349                    // it will cause a delivery error downstream.
350                    // However we should reconsider this behavior
351                    // as it its semantics will now differ between
352                    // unicast and broadcast messages.
353                    return Ok(());
354                }
355                hyperactor_reference::UnboundPortKind::Once => {
356                    // Compute peer count for OncePort splitting. This is the number of
357                    // destinations the message will be delivered to, so that the split
358                    // port can correctly accumulate responses.
359                    let peer_count = next_steps.len() + if deliver_here { 1 } else { 0 };
360                    ReducerMode::Once(peer_count)
361                }
362            };
363
364            let split = port_id.split(
365                cx,
366                reducer_spec.clone(),
367                reducer_mode,
368                *return_undeliverable,
369            )?;
370
371            #[cfg(test)]
372            tests::collect_split_port(port_id, &split, deliver_here);
373
374            *port_id = split;
375            Ok(())
376        },
377    )
378}
379
380fn replace_with_self_ranks(cast_point: &Point, data: &mut ErasedUnbound) -> anyhow::Result<()> {
381    data.visit_mut::<resource::Rank>(|resource::Rank(rank)| {
382        *rank = Some(cast_point.rank());
383        Ok(())
384    })
385}
386
387#[async_trait]
388impl Handler<CommMeshConfig> for CommActor {
389    async fn handle(&mut self, _cx: &Context<Self>, config: CommMeshConfig) -> Result<()> {
390        self.mesh_config = Some(config);
391        Ok(())
392    }
393}
394
395// TODO(T218630526): reliable casting for mutable topology
396#[async_trait]
397impl Handler<CastMessage> for CommActor {
398    #[tracing::instrument(level = "debug", skip_all)]
399    async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessage) -> Result<()> {
400        // Always forward the message to the root rank of the slice, casting starts from there.
401        let slice = cast_message.dest.slice.clone();
402        let selection = cast_message.dest.selection.clone();
403        let frame = RoutingFrame::root(selection, slice);
404        let rank = frame.slice.location(&frame.here)?;
405        let seq = self
406            .send_seq
407            .entry(cast_message.message.stream_key())
408            .or_default();
409        let last_seq = *seq;
410        *seq += 1;
411
412        let fwd_message = ForwardMessage {
413            dests: vec![frame],
414            sender: cx.self_id().clone(),
415            message: cast_message.message,
416            seq: *seq,
417            last_seq,
418        };
419
420        let config = self
421            .mesh_config
422            .as_ref()
423            .ok_or_else(|| anyhow::anyhow!("CommMeshConfig has not been set yet"))?;
424
425        // Optimization: if forwarding to ourselves, handle inline instead of
426        // going through the message queue
427        if config.self_rank() == rank {
428            Handler::<ForwardMessage>::handle(self, cx, fwd_message).await?;
429        } else {
430            Self::forward(cx, config, rank, fwd_message)?;
431        }
432        Ok(())
433    }
434}
435
436#[async_trait]
437impl Handler<ForwardMessage> for CommActor {
438    #[tracing::instrument(level = "debug", skip_all)]
439    async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessage) -> Result<()> {
440        let ForwardMessage {
441            sender,
442            dests,
443            message,
444            seq,
445            last_seq,
446        } = fwd_message;
447
448        let config = self
449            .mesh_config
450            .as_ref()
451            .ok_or_else(|| anyhow::anyhow!("CommMeshConfig has not been set yet"))?;
452
453        // Resolve/dedup routing frames.
454        let rank = config.self_rank();
455        let (deliver_here, next_steps) =
456            ndslice::selection::routing::resolve_routing(rank, dests, &mut |_| {
457                panic!("Choice encountered in CommActor routing")
458            })?;
459
460        let recv_state = self.recv_state.entry(message.stream_key()).or_default();
461        match recv_state.seq.cmp(&last_seq) {
462            // We got the expected next message to deliver to this host.
463            Ordering::Equal => {
464                // We got an in-order operation, so handle it now.
465                Self::handle_message(
466                    cx,
467                    config,
468                    deliver_here,
469                    next_steps,
470                    sender.clone(),
471                    message,
472                    seq,
473                    &mut recv_state.last_seqs,
474                )?;
475                recv_state.seq = seq;
476
477                // Also deliver any pending operations from the recv buffer that
478                // were received out-of-order that are now unblocked.
479                while let Some(Buffered {
480                    seq,
481                    deliver_here,
482                    next_steps,
483                    message,
484                }) = recv_state.buffer.remove(&recv_state.seq)
485                {
486                    Self::handle_message(
487                        cx,
488                        config,
489                        deliver_here,
490                        next_steps,
491                        sender.clone(),
492                        message,
493                        seq,
494                        &mut recv_state.last_seqs,
495                    )?;
496                    recv_state.seq = seq;
497                }
498            }
499            // We got an out-of-order operation, so buffer it for now, until we
500            // recieved the onces sequenced before it.
501            Ordering::Less => {
502                tracing::warn!(
503                    "buffering out-of-order message with seq {} (last {}), expected {}: {:?}",
504                    seq,
505                    last_seq,
506                    recv_state.seq,
507                    message
508                );
509                recv_state.buffer.insert(
510                    last_seq,
511                    Buffered {
512                        seq,
513                        deliver_here,
514                        next_steps,
515                        message,
516                    },
517                );
518            }
519            // We already got this message -- just drop it.
520            Ordering::Greater => {
521                tracing::warn!("received duplicate message with seq {}: {:?}", seq, message);
522            }
523        }
524
525        Ok(())
526    }
527}
528
529#[async_trait]
530impl Handler<CastMessageV1> for CommActor {
531    async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessageV1) -> Result<()> {
532        let slice = cast_message.dest_region.slice().clone();
533        let frame = RoutingFrame::root(sel!(*), slice);
534        let forward_message = ForwardMessageV1 {
535            dests: vec![frame],
536            message: cast_message,
537        };
538        self.handle(cx, forward_message).await
539    }
540}
541
542#[async_trait]
543impl Handler<ForwardMessageV1> for CommActor {
544    async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessageV1) -> Result<()> {
545        let ForwardMessageV1 { dests, mut message } = fwd_message;
546        let config = self
547            .mesh_config
548            .as_ref()
549            .ok_or_else(|| anyhow::anyhow!("CommMeshConfig has not been set yet"))?;
550        // Resolve/dedup routing frames.
551        let rank_on_root_mesh = config.self_rank();
552        let (deliver_here, next_steps) =
553            ndslice::selection::routing::resolve_routing(rank_on_root_mesh, dests, &mut |_| {
554                panic!("choice encountered in CommActor routing")
555            })?;
556
557        split_ports(cx, &mut message.data, deliver_here, &next_steps)?;
558
559        // Deliver message here, if necessary.
560        if deliver_here {
561            let mut headers = message.headers().clone();
562            let seq = message
563                .seqs
564                .get(message.cast_point(config)?.rank())
565                .expect("mismatched seqs and dest_region");
566            headers.set(
567                SEQ_INFO,
568                SeqInfo::Session {
569                    session_id: message.session_id,
570                    seq,
571                },
572            );
573            Self::deliver_to_dest(cx, headers, &mut message, config)?;
574        }
575
576        // Forward to peers.
577        for (peer_rank_on_root_mesh, dests) in next_steps {
578            let forward_message = ForwardMessageV1 {
579                dests,
580                message: message.clone(),
581            };
582            Self::forward(cx, config, peer_rank_on_root_mesh, forward_message)?;
583        }
584
585        Ok(())
586    }
587}
588
589pub mod test_utils {
590    use anyhow::Result;
591    use async_trait::async_trait;
592    use hyperactor::Actor;
593    use hyperactor::Bind;
594    use hyperactor::Context;
595    use hyperactor::Handler;
596    use hyperactor::Unbind;
597    use hyperactor::reference as hyperactor_reference;
598    use serde::Deserialize;
599    use serde::Serialize;
600    use typeuri::Named;
601
602    use super::*;
603
604    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
605    pub struct MyReply {
606        pub sender: hyperactor_reference::ActorId,
607        pub value: u64,
608    }
609
610    #[derive(Debug, Named, Serialize, Deserialize, PartialEq, Clone, Bind, Unbind)]
611    pub enum TestMessage {
612        Forward(String),
613        CastAndReply {
614            arg: String,
615            // Intentionally not including 0. As a result, this port will not be
616            // split.
617            // #[binding(include)]
618            reply_to0: hyperactor_reference::PortRef<String>,
619            #[binding(include)]
620            reply_to1: hyperactor_reference::PortRef<u64>,
621            #[binding(include)]
622            reply_to2: hyperactor_reference::PortRef<MyReply>,
623        },
624        CastAndReplyOnce {
625            arg: String,
626            #[binding(include)]
627            reply_to: hyperactor::reference::OncePortRef<u64>,
628        },
629    }
630
631    #[derive(Debug)]
632    #[hyperactor::export(
633        spawn = true,
634        handlers = [
635            TestMessage { cast = true },
636        ],
637    )]
638    pub struct TestActor {
639        // Forward the received message to this port, so it can be inspected by
640        // the unit test.
641        forward_port: hyperactor_reference::PortRef<TestMessage>,
642    }
643
644    #[derive(Debug, Clone, Named, Serialize, Deserialize)]
645    pub struct TestActorParams {
646        pub forward_port: hyperactor_reference::PortRef<TestMessage>,
647    }
648
649    #[async_trait]
650    impl Actor for TestActor {}
651
652    #[async_trait]
653    impl hyperactor::RemoteSpawn for TestActor {
654        type Params = TestActorParams;
655
656        async fn new(params: Self::Params, _environment: Flattrs) -> Result<Self> {
657            let Self::Params { forward_port } = params;
658            Ok(Self { forward_port })
659        }
660    }
661
662    #[async_trait]
663    impl Handler<TestMessage> for TestActor {
664        async fn handle(&mut self, cx: &Context<Self>, msg: TestMessage) -> anyhow::Result<()> {
665            self.forward_port.send(cx, msg)?;
666            Ok(())
667        }
668    }
669}
670
671#[cfg(test)]
672mod tests {
673    use std::collections::BTreeMap;
674    use std::collections::HashSet;
675    use std::fmt::Display;
676    use std::hash::Hash;
677    use std::ops::Deref;
678    use std::ops::DerefMut;
679    use std::sync::Mutex;
680    use std::sync::OnceLock;
681
682    use hyperactor::accum;
683    use hyperactor::accum::Accumulator;
684    use hyperactor::accum::ReducerSpec;
685    use hyperactor::context;
686    use hyperactor::context::Mailbox;
687    use hyperactor::mailbox::PortReceiver;
688    use hyperactor::mailbox::open_port;
689    use hyperactor::reference as hyperactor_reference;
690    use hyperactor_config;
691    use hyperactor_mesh_macros::sel;
692    use maplit::btreemap;
693    use maplit::hashmap;
694    use ndslice::Extent;
695    use ndslice::Selection;
696    use ndslice::ViewExt as _;
697    use ndslice::extent;
698    use ndslice::selection::test_utils::collect_commactor_routing_tree;
699    use test_utils::*;
700    use timed_test::async_timed_test;
701    use tokio::time::Duration;
702
703    use super::*;
704    use crate::host_mesh::HostMesh;
705    use crate::test_utils::local_host_mesh;
706    use crate::testing;
707
708    // Helper to look up the rank for a given actor ID using the rank_lookup table.
709    fn lookup_rank(
710        actor_id: &hyperactor::reference::ActorId,
711        rank_lookup: &HashMap<hyperactor_reference::ProcId, usize>,
712    ) -> usize {
713        let proc_id = actor_id.proc_id();
714        *rank_lookup
715            .get(proc_id)
716            .unwrap_or_else(|| panic!("proc rank not found for {}", proc_id))
717    }
718
719    struct Edge<T> {
720        from: T,
721        to: T,
722        is_leaf: bool,
723    }
724
725    impl<T> From<(T, T, bool)> for Edge<T> {
726        fn from((from, to, is_leaf): (T, T, bool)) -> Self {
727            Self { from, to, is_leaf }
728        }
729    }
730
731    // The relationship between original ports and split ports. The elements in
732    // the tuple are (original port, split port, deliver_here).
733    static SPLIT_PORT_TREE: OnceLock<Mutex<Vec<Edge<hyperactor_reference::PortId>>>> =
734        OnceLock::new();
735
736    // Collect the relationships between original ports and split ports into
737    // SPLIT_PORT_TREE. This is used by tests to verify that ports are split as expected.
738    pub(crate) fn collect_split_port(
739        original: &hyperactor_reference::PortId,
740        split: &hyperactor_reference::PortId,
741        deliver_here: bool,
742    ) {
743        let mutex = SPLIT_PORT_TREE.get_or_init(|| Mutex::new(vec![]));
744        let mut tree = mutex.lock().unwrap();
745
746        tree.deref_mut().push(Edge {
747            from: original.clone(),
748            to: split.clone(),
749            is_leaf: deliver_here,
750        });
751    }
752
753    // There could be other cast calls before the one we want to check, e.g. from
754    // allocating the proc mesh, or spawning the actor mesh. Clear the collected
755    // tree so it will only contain the cast we want to check.
756    fn clear_collected_tree() {
757        if let Some(tree) = SPLIT_PORT_TREE.get() {
758            let mut tree = tree.lock().unwrap();
759            tree.clear();
760        }
761    }
762
763    // A representation of a tree.
764    //   * Map's keys are the tree's leafs;
765    //   * Map's values are the path from the root to that leaf.
766    #[derive(PartialEq)]
767    struct PathToLeaves<T>(BTreeMap<T, Vec<T>>);
768
769    // Add a custom Debug trait impl so the result from assert_eq! is readable.
770    impl<T: Display> Debug for PathToLeaves<T> {
771        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
772            fn vec_to_string<T: Display>(v: &[T]) -> String {
773                v.iter()
774                    .map(ToString::to_string)
775                    .collect::<Vec<String>>()
776                    .join(", ")
777            }
778
779            for (src, path) in &self.0 {
780                write!(f, "{} -> {}\n", src, vec_to_string(path))?;
781            }
782            Ok(())
783        }
784    }
785
786    fn build_paths<T: Clone + Eq + Hash + Ord>(edges: &[Edge<T>]) -> PathToLeaves<T> {
787        let mut child_parent_map = HashMap::new();
788        let mut all_nodes = HashSet::new();
789        let mut parents = HashSet::new();
790        let mut children = HashSet::new();
791        let mut dests = HashSet::new();
792
793        // Build parent map and track all nodes and children
794        for Edge { from, to, is_leaf } in edges {
795            child_parent_map.insert(to.clone(), from.clone());
796            all_nodes.insert(from.clone());
797            all_nodes.insert(to.clone());
798            parents.insert(from.clone());
799            children.insert(to.clone());
800            if *is_leaf {
801                dests.insert(to.clone());
802            }
803        }
804
805        // For each leaf, reconstruct path back to root
806        let mut result = BTreeMap::new();
807        for dest in dests {
808            let mut path = vec![dest.clone()];
809            let mut current = dest.clone();
810            while let Some(parent) = child_parent_map.get(&current) {
811                path.push(parent.clone());
812                current = parent.clone();
813            }
814            path.reverse();
815            result.insert(dest, path);
816        }
817
818        PathToLeaves(result)
819    }
820
821    #[test]
822    fn test_build_paths() {
823        // Given the tree:
824        //     0
825        //    / \
826        //   1   4
827        //  / \   \
828        // 2   3   5
829        let edges: Vec<_> = [
830            (0, 1, false),
831            (1, 2, true),
832            (1, 3, true),
833            (0, 4, true),
834            (4, 5, true),
835        ]
836        .into_iter()
837        .map(|(from, to, is_leaf)| Edge { from, to, is_leaf })
838        .collect();
839
840        let paths = build_paths(&edges);
841
842        let expected = btreemap! {
843            2 => vec![0, 1, 2],
844            3 => vec![0, 1, 3],
845            4 => vec![0, 4],
846            5 => vec![0, 4, 5],
847        };
848
849        assert_eq!(paths.0, expected);
850    }
851
852    //  Given a port tree,
853    //     * remove the client port, i.e. the 1st element of the path;
854    //     * verify all remaining ports are comm actor ports;
855    //     * remove the actor information and return a rank-based tree representation.
856    //
857    //  The rank-based tree representation is what [collect_commactor_routing_tree] returns.
858    //  This conversion enables us to compare the path against [collect_commactor_routing_tree]'s result.
859    //
860    //      For example, for a 2x2 slice, the port tree could look like:
861    //      dest[0].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028]]
862    //      dest[1].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[1].comm[0][1028]]
863    //      dest[2].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[2].comm[0][1028]]
864    //      dest[3].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[2].comm[0][1028], dest[3].comm[0][1028]]
865    //
866    //     The result should be:
867    //     0 -> 0
868    //     1 -> 0, 1
869    //     2 -> 0, 2
870    //     3 -> 0, 2, 3
871    fn get_ranks(
872        paths: PathToLeaves<hyperactor_reference::PortId>,
873        client_reply: &hyperactor_reference::PortId,
874        rank_lookup: &HashMap<hyperactor_reference::ProcId, usize>,
875    ) -> PathToLeaves<hyperactor_reference::Index> {
876        let ranks = paths
877            .0
878            .into_iter()
879            .map(|(dst, mut path)| {
880                let first = path.remove(0);
881                // The first PortId is the client's reply port.
882                assert_eq!(&first, client_reply);
883                // Other ports's actor ID must be dest[?].comm[0], where ? is
884                // the rank we want to extract here.
885                assert!(dst.actor_id().name().contains("comm"));
886                let actor_path = path
887                    .into_iter()
888                    .map(|p| {
889                        assert!(p.actor_id().name().contains("comm"));
890                        lookup_rank(p.actor_id(), rank_lookup)
891                    })
892                    .collect();
893                (lookup_rank(dst.actor_id(), rank_lookup), actor_path)
894            })
895            .collect();
896        PathToLeaves(ranks)
897    }
898
899    struct NoneAccumulator;
900
901    impl Accumulator for NoneAccumulator {
902        type State = u64;
903        type Update = u64;
904
905        fn accumulate(
906            &self,
907            _state: &mut Self::State,
908            _update: Self::Update,
909        ) -> anyhow::Result<()> {
910            unimplemented!()
911        }
912
913        fn reducer_spec(&self) -> Option<ReducerSpec> {
914            unimplemented!()
915        }
916    }
917
918    // Verify the split port paths are the same as the casting paths.
919    fn verify_split_port_paths(
920        selection: &Selection,
921        extent: &Extent,
922        reply_port_ref1: &hyperactor_reference::PortRef<u64>,
923        reply_port_ref2: &hyperactor_reference::PortRef<MyReply>,
924        rank_lookup: &HashMap<hyperactor_reference::ProcId, usize>,
925    ) {
926        // Get the paths used in casting
927        let sel_paths = PathToLeaves(
928            collect_commactor_routing_tree(selection, &extent.to_slice())
929                .delivered
930                .into_iter()
931                .collect(),
932        );
933
934        // Get the split port paths collected in SPLIT_PORT_TREE during casting
935        let (reply1_paths, reply2_paths) = {
936            let tree = SPLIT_PORT_TREE
937                .get()
938                .expect("not initialized; are Hosts in the same process as SPLIT_PORT_TREE?");
939            let edges = tree.lock().unwrap();
940            let (reply1, reply2): (BTreeMap<_, _>, BTreeMap<_, _>) = build_paths(&edges)
941                .0
942                .into_iter()
943                .partition(|(_dst, path)| &path[0] == reply_port_ref1.port_id());
944            (
945                get_ranks(PathToLeaves(reply1), reply_port_ref1.port_id(), rank_lookup),
946                get_ranks(PathToLeaves(reply2), reply_port_ref2.port_id(), rank_lookup),
947            )
948        };
949
950        // split port paths should be the same as casting paths
951        assert_eq!(sel_paths, reply1_paths);
952        assert_eq!(sel_paths, reply2_paths);
953    }
954
955    async fn execute_cast_and_reply(
956        ranks: Vec<hyperactor_reference::ActorRef<TestActor>>,
957        instance: &impl context::Actor,
958        mut reply1_rx: PortReceiver<u64>,
959        mut reply2_rx: PortReceiver<MyReply>,
960        reply_tos: Vec<(
961            hyperactor_reference::PortRef<u64>,
962            hyperactor_reference::PortRef<MyReply>,
963        )>,
964    ) {
965        // Reply from each dest actor. The replies should be received by client.
966        {
967            for (rank, (dest_actor, (reply_to1, reply_to2))) in
968                ranks.iter().zip(reply_tos.iter()).enumerate()
969            {
970                let rank_u64 = rank as u64;
971                reply_to1.send(instance, rank_u64).unwrap();
972                let my_reply = MyReply {
973                    sender: dest_actor.actor_id().clone(),
974                    value: rank_u64,
975                };
976                reply_to2.send(instance, my_reply.clone()).unwrap();
977
978                assert_eq!(reply1_rx.recv().await.unwrap(), rank_u64);
979                assert_eq!(reply2_rx.recv().await.unwrap(), my_reply);
980            }
981        }
982
983        tracing::info!("the 1st updates from all dest actors were receivered by client");
984
985        // Now send multiple replies from the dest actors. They should all be
986        // received by client. Replies sent from the same dest actor should
987        // be received in the same order as they were sent out.
988        {
989            let n = 100;
990            let mut expected2: HashMap<hyperactor_reference::ActorId, Vec<MyReply>> = hashmap! {};
991            for (i, (dest_actor, (_reply_to1, reply_to2))) in
992                ranks.iter().zip(reply_tos.iter()).enumerate()
993            {
994                let mut sent2 = vec![];
995                for j in 0..n {
996                    let value = (i * 100 + j) as u64;
997                    let my_reply = MyReply {
998                        sender: dest_actor.actor_id().clone(),
999                        value,
1000                    };
1001                    reply_to2.send(instance, my_reply.clone()).unwrap();
1002                    sent2.push(my_reply);
1003                }
1004                assert!(
1005                    expected2
1006                        .insert(dest_actor.actor_id().clone(), sent2)
1007                        .is_none(),
1008                    "duplicate actor_id {} in map",
1009                    dest_actor.actor_id()
1010                );
1011            }
1012
1013            let mut received2: HashMap<hyperactor_reference::ActorId, Vec<MyReply>> = hashmap! {};
1014
1015            for _ in 0..(n * ranks.len()) {
1016                let my_reply = reply2_rx.recv().await.unwrap();
1017                received2
1018                    .entry(my_reply.sender.clone())
1019                    .or_default()
1020                    .push(my_reply);
1021            }
1022            assert_eq!(received2, expected2);
1023        }
1024    }
1025
1026    async fn wait_for_with_timeout(
1027        receiver: &mut PortReceiver<u64>,
1028        expected: u64,
1029        dur: Duration,
1030    ) -> anyhow::Result<()> {
1031        // timeout wraps the entire async block
1032        tokio::time::timeout(dur, async {
1033            loop {
1034                let msg = receiver.recv().await.unwrap();
1035                if msg == expected {
1036                    break;
1037                }
1038            }
1039        })
1040        .await?;
1041        Ok(())
1042    }
1043
1044    async fn execute_cast_and_accum(
1045        ranks: Vec<hyperactor_reference::ActorRef<TestActor>>,
1046        instance: &impl context::Actor,
1047        mut reply1_rx: PortReceiver<u64>,
1048        reply_tos: Vec<(
1049            hyperactor_reference::PortRef<u64>,
1050            hyperactor_reference::PortRef<MyReply>,
1051        )>,
1052    ) {
1053        // Now send multiple replies from the dest actors. They should all be
1054        // received by client. Replies sent from the same dest actor should
1055        // be received in the same order as they were sent out.
1056        let mut sum = 0;
1057        let n = 100;
1058        for (i, (_dest_actor, (reply_to1, _reply_to2))) in
1059            ranks.iter().zip(reply_tos.iter()).enumerate()
1060        {
1061            for j in 0..n {
1062                let value = (i + j) as u64;
1063                reply_to1.send(instance, value).unwrap();
1064                sum += value;
1065            }
1066        }
1067        wait_for_with_timeout(&mut reply1_rx, sum, Duration::from_secs(8))
1068            .await
1069            .unwrap();
1070        // no more messages
1071        tokio::time::sleep(Duration::from_secs(2)).await;
1072        let msg = reply1_rx.try_recv().unwrap();
1073        assert_eq!(msg, None);
1074    }
1075
1076    struct MeshSetupV1 {
1077        instance: &'static Instance<testing::TestRootClient>,
1078        actor_mesh_ref: crate::ActorMeshRef<TestActor>,
1079        reply1_rx: PortReceiver<u64>,
1080        reply2_rx: PortReceiver<MyReply>,
1081        reply_tos: Vec<(
1082            hyperactor_reference::PortRef<u64>,
1083            hyperactor_reference::PortRef<MyReply>,
1084        )>,
1085        // Keep the host mesh alive so comm actors aren't shut down.
1086        host_mesh: HostMesh,
1087    }
1088
1089    async fn setup_mesh_v1<A>(accum: Option<A>) -> MeshSetupV1
1090    where
1091        A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
1092    {
1093        let instance = crate::testing::instance();
1094        // We have to use a in process host mesh, because SPLIT_PORT_TREE only
1095        // can collect paths from the same process.
1096        let host_mesh = local_host_mesh(8).await;
1097        let proc_mesh = host_mesh
1098            .spawn(instance, "test", extent!(gpu = 8))
1099            .await
1100            .unwrap();
1101
1102        let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1103        let params = TestActorParams {
1104            forward_port: tx.bind(),
1105        };
1106        let actor_name = crate::Name::new("test").expect("valid test name");
1107        // Make this actor a "system" actor to avoid spawning a controller actor.
1108        // This test is verifying the whole comm tree, so we want fewer actors
1109        // involved.
1110        let actor_mesh = proc_mesh
1111            .spawn_with_name(&instance, actor_name, &params, None, true)
1112            .await
1113            .unwrap();
1114        let actor_mesh_ref: crate::ActorMeshRef<TestActor> = actor_mesh.deref().clone();
1115
1116        let (reply_port_handle0, _) = open_port::<String>(instance);
1117        let reply_port_ref0 = reply_port_handle0.bind();
1118        let (reply_port_handle1, reply1_rx) = match accum {
1119            Some(a) => instance.mailbox().open_accum_port(a),
1120            None => open_port(instance),
1121        };
1122        let reply_port_ref1 = reply_port_handle1.bind();
1123        let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(instance);
1124        let reply_port_ref2 = reply_port_handle2.bind();
1125        let message = TestMessage::CastAndReply {
1126            arg: "abc".to_string(),
1127            reply_to0: reply_port_ref0.clone(),
1128            reply_to1: reply_port_ref1.clone(),
1129            reply_to2: reply_port_ref2.clone(),
1130        };
1131
1132        clear_collected_tree();
1133        actor_mesh_ref.cast(instance, message).unwrap();
1134
1135        let mut reply_tos = vec![];
1136        for _ in proc_mesh.extent().points() {
1137            let msg = rx.recv().await.expect("missing");
1138            match msg {
1139                TestMessage::CastAndReply {
1140                    arg,
1141                    reply_to0,
1142                    reply_to1,
1143                    reply_to2,
1144                } => {
1145                    assert_eq!(arg, "abc");
1146                    // port 0 is still the same as the original one because it
1147                    // is not included in MutVisitor.
1148                    assert_eq!(reply_to0, reply_port_ref0);
1149                    // ports have been replaced by comm actor's split ports.
1150                    assert_ne!(reply_to1, reply_port_ref1);
1151                    assert!(reply_to1.port_id().actor_id().name().contains("comm"));
1152                    assert_ne!(reply_to2, reply_port_ref2);
1153                    assert!(reply_to2.port_id().actor_id().name().contains("comm"));
1154                    reply_tos.push((reply_to1, reply_to2));
1155                }
1156                _ => {
1157                    panic!("unexpected message: {:?}", msg);
1158                }
1159            }
1160        }
1161
1162        // [collect_commactor_routing_tree] only returns ranks,So we need to
1163        // map proc Ids collected in SPLIT_PORT_TREE to ranks.
1164        let rank_lookup = proc_mesh
1165            .ranks()
1166            .iter()
1167            .enumerate()
1168            .map(|(i, r)| (r.proc_id().clone(), i))
1169            .collect::<HashMap<hyperactor_reference::ProcId, usize>>();
1170
1171        // v1 always uses sel!(*) when casting to a mesh.
1172        let selection = sel!(*);
1173        verify_split_port_paths(
1174            &selection,
1175            &proc_mesh.extent(),
1176            &reply_port_ref1,
1177            &reply_port_ref2,
1178            &rank_lookup,
1179        );
1180
1181        MeshSetupV1 {
1182            instance,
1183            actor_mesh_ref,
1184            reply1_rx,
1185            reply2_rx,
1186            reply_tos,
1187            host_mesh,
1188        }
1189    }
1190
1191    async fn execute_cast_and_reply_v1() {
1192        let mut setup = setup_mesh_v1::<NoneAccumulator>(None).await;
1193
1194        let ranks = setup.actor_mesh_ref.values().collect::<Vec<_>>();
1195        execute_cast_and_reply(
1196            ranks,
1197            setup.instance,
1198            setup.reply1_rx,
1199            setup.reply2_rx,
1200            setup.reply_tos,
1201        )
1202        .await;
1203
1204        let _ = setup.host_mesh.shutdown(setup.instance).await;
1205    }
1206
1207    #[async_timed_test(timeout_secs = 60)]
1208    async fn test_cast_and_reply_v1_retrofit() {
1209        let config = hyperactor_config::global::lock();
1210        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false);
1211        let _guard2 = config.override_key(
1212            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1213            false,
1214        );
1215        execute_cast_and_reply_v1().await
1216    }
1217
1218    #[async_timed_test(timeout_secs = 60)]
1219    async fn test_cast_and_reply_v1_native() {
1220        let config = hyperactor_config::global::lock();
1221        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1222        let _guard2 = config.override_key(
1223            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1224            true,
1225        );
1226        execute_cast_and_reply_v1().await
1227    }
1228
1229    async fn execute_cast_and_accum_v1(config: &hyperactor_config::global::ConfigLock) {
1230        // Use temporary config for this test
1231        let _guard1 = config.override_key(hyperactor::config::SPLIT_MAX_BUFFER_SIZE, 1);
1232
1233        let mut setup = setup_mesh_v1(Some(accum::sum::<u64>())).await;
1234
1235        let ranks = setup.actor_mesh_ref.values().collect::<Vec<_>>();
1236        execute_cast_and_accum(ranks, setup.instance, setup.reply1_rx, setup.reply_tos).await;
1237
1238        let _ = setup.host_mesh.shutdown(setup.instance).await;
1239    }
1240
1241    #[async_timed_test(timeout_secs = 60)]
1242    async fn test_cast_and_accum_v1_retrofit() {
1243        let config = hyperactor_config::global::lock();
1244        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false);
1245        let _guard2 = config.override_key(
1246            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1247            false,
1248        );
1249        execute_cast_and_accum_v1(&config).await
1250    }
1251
1252    #[async_timed_test(timeout_secs = 60)]
1253    async fn test_cast_and_accum_v1_native() {
1254        let config = hyperactor_config::global::lock();
1255        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1256        let _guard2 = config.override_key(
1257            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1258            true,
1259        );
1260        execute_cast_and_accum_v1(&config).await
1261    }
1262
1263    struct OncePortMeshSetupV1 {
1264        instance: &'static Instance<testing::TestRootClient>,
1265        reply_rx: hyperactor::mailbox::OncePortReceiver<u64>,
1266        reply_tos: Vec<hyperactor::reference::OncePortRef<u64>>,
1267        _reply_port_ref: hyperactor::reference::OncePortRef<u64>,
1268        host_mesh: HostMesh,
1269    }
1270
1271    async fn setup_once_port_mesh<A>(reducer: Option<A>) -> OncePortMeshSetupV1
1272    where
1273        A: Accumulator<State = u64, Update = u64> + Send + Sync + 'static,
1274    {
1275        let instance = crate::testing::instance();
1276        // We have to use a in process host mesh, because SPLIT_PORT_TREE only
1277        // can collect paths from the same process.
1278        let host_mesh = local_host_mesh(8).await;
1279        let proc_mesh = host_mesh
1280            .spawn(instance, "test", extent!(gpu = 8))
1281            .await
1282            .unwrap();
1283
1284        let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1285        let params = TestActorParams {
1286            forward_port: tx.bind(),
1287        };
1288        let actor_name = crate::Name::new("test").expect("valid test name");
1289        // Make this actor a "system" actor to avoid spawning a controller actor.
1290        let actor_mesh: crate::ActorMesh<TestActor> = proc_mesh
1291            .spawn_with_name(&instance, actor_name, &params, None, true)
1292            .await
1293            .unwrap();
1294        let actor_mesh_ref = actor_mesh.deref().clone();
1295
1296        let has_reducer = reducer.is_some();
1297        let (reply_port_handle, reply_rx) = match reducer {
1298            Some(reducer) => instance.mailbox().open_reduce_port(reducer),
1299            None => instance.mailbox().open_once_port::<u64>(),
1300        };
1301        let reply_port_ref = reply_port_handle.bind();
1302
1303        let message = TestMessage::CastAndReplyOnce {
1304            arg: "abc".to_string(),
1305            reply_to: reply_port_ref.clone(),
1306        };
1307
1308        clear_collected_tree();
1309        actor_mesh_ref.cast(instance, message).unwrap();
1310
1311        let mut reply_tos = vec![];
1312        for _ in proc_mesh.extent().points() {
1313            let msg = rx.recv().await.expect("missing");
1314            match msg {
1315                TestMessage::CastAndReplyOnce { arg, reply_to } => {
1316                    assert_eq!(arg, "abc");
1317                    if has_reducer {
1318                        // With reducer: port is split by comm actor.
1319                        assert_ne!(reply_to, reply_port_ref);
1320                        assert!(reply_to.port_id().actor_id().name().contains("comm"));
1321                    } else {
1322                        // Without reducer: port is passed through unchanged.
1323                        assert_eq!(reply_to, reply_port_ref);
1324                    }
1325                    reply_tos.push(reply_to);
1326                }
1327                _ => {
1328                    panic!("unexpected message: {:?}", msg);
1329                }
1330            }
1331        }
1332
1333        OncePortMeshSetupV1 {
1334            instance,
1335            reply_rx,
1336            reply_tos,
1337            _reply_port_ref: reply_port_ref,
1338            host_mesh,
1339        }
1340    }
1341
1342    #[async_timed_test(timeout_secs = 60)]
1343    async fn test_cast_and_reply_once_v1() {
1344        // Test OncePort without accumulator - port is NOT split.
1345        // All destinations receive the same original port.
1346        // First reply is delivered, others fail at receiver (port closed).
1347        let mut setup = setup_once_port_mesh::<NoneAccumulator>(None).await;
1348
1349        // All reply_tos point to the same port (not split).
1350        // Only the first message will be delivered successfully.
1351        let num_replies = setup.reply_tos.len();
1352        for (i, reply_to) in setup.reply_tos.into_iter().enumerate() {
1353            reply_to.send(setup.instance, i as u64).unwrap();
1354        }
1355
1356        // OncePort receives exactly one value (the first to arrive)
1357        let result = setup.reply_rx.recv().await.unwrap();
1358        // The result should be one of the values sent
1359        assert!(result < num_replies as u64);
1360
1361        let _ = setup.host_mesh.shutdown(setup.instance).await;
1362    }
1363
1364    #[async_timed_test(timeout_secs = 60)]
1365    async fn test_cast_and_accum_once_v1() {
1366        // Test OncePort splitting with sum accumulator.
1367        // Each destination actor replies with its rank.
1368        // The sum of all ranks should be received at the original port.
1369        let mut setup = setup_once_port_mesh(Some(accum::sum::<u64>())).await;
1370
1371        // Each actor replies with its index
1372        let mut expected_sum = 0u64;
1373        for (i, reply_to) in setup.reply_tos.into_iter().enumerate() {
1374            reply_to.send(setup.instance, i as u64).unwrap();
1375            expected_sum += i as u64;
1376        }
1377
1378        // OncePort should receive the sum of all responses
1379        let result = setup.reply_rx.recv().await.unwrap();
1380        assert_eq!(result, expected_sum);
1381
1382        let _ = setup.host_mesh.shutdown(setup.instance).await;
1383    }
1384}