Skip to main content

hyperactor_mesh/
comm.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use crate::casting::CAST_ACTOR_MESH_ID;
10use crate::comm::multicast::CAST_ORIGINATING_SENDER;
11use crate::comm::multicast::CastEnvelope;
12use crate::comm::multicast::CastMessageV1;
13use crate::comm::multicast::ForwardMessageV1;
14use crate::mesh_id::ActorMeshId;
15use crate::resource;
16pub mod multicast;
17
18use std::cmp::Ordering;
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22use anyhow::Result;
23use async_trait::async_trait;
24use hyperactor::Actor;
25use hyperactor::ActorAddr;
26use hyperactor::ActorRef;
27use hyperactor::Context;
28use hyperactor::Endpoint as _;
29use hyperactor::Handler;
30use hyperactor::Instance;
31use hyperactor::PortRef;
32use hyperactor::RemoteEndpoint as _;
33use hyperactor::RemoteMessage;
34use hyperactor::UnboundPort;
35use hyperactor::UnboundPortKind;
36use hyperactor::accum::ReducerMode;
37use hyperactor::mailbox::DeliveryError;
38use hyperactor::mailbox::MailboxSender;
39use hyperactor::mailbox::Undeliverable;
40use hyperactor::mailbox::UndeliverableMailboxSender;
41use hyperactor::mailbox::UndeliverableMessageError;
42use hyperactor::mailbox::monitored_return_handle;
43use hyperactor::message::ErasedUnbound;
44use hyperactor::ordering::SEQ_INFO;
45use hyperactor::ordering::SeqInfo;
46use hyperactor_config::CONFIG;
47use hyperactor_config::ConfigAttr;
48use hyperactor_config::Flattrs;
49use hyperactor_config::attrs::declare_attrs;
50use hyperactor_mesh_macros::sel;
51use ndslice::Point;
52use ndslice::Selection;
53use ndslice::View;
54use ndslice::selection::routing::RoutingFrame;
55use serde::Deserialize;
56use serde::Serialize;
57use typeuri::Named;
58
59use crate::comm::multicast::CastMessage;
60use crate::comm::multicast::CastMessageEnvelope;
61use crate::comm::multicast::ForwardMessage;
62use crate::comm::multicast::set_cast_info_on_headers;
63
64declare_attrs! {
65    /// Whether to use native v1 casting in v1 ActorMesh.
66    @meta(CONFIG = ConfigAttr::new(
67        Some("HYPERACTOR_MESH_ENABLE_NATIVE_V1_CASTING".to_string()),
68        Some("enable_native_v1_casting".to_string()),
69    ))
70    pub attr ENABLE_NATIVE_V1_CASTING: bool = true;
71}
72
73/// Parameters to initialize the CommActor
74#[derive(Debug, Clone, Serialize, Deserialize, Named, Default)]
75pub struct CommActorParams {}
76wirevalue::register_type!(CommActorParams);
77
78/// A message buffered due to out-of-order delivery.
79#[derive(Debug)]
80struct Buffered {
81    /// Sequence number of this message.
82    seq: usize,
83    /// Whether to deliver this message to this comm-actors actors.
84    deliver_here: bool,
85    /// Peer comm actors to forward message to.
86    next_steps: HashMap<usize, Vec<RoutingFrame>>,
87    /// The message to deliver.
88    message: CastMessageEnvelope,
89}
90
91/// Bookkeeping to handle sequence numbers and in-order delivery for messages
92/// sent to and through this comm actor.
93#[derive(Debug, Default)]
94struct ReceiveState {
95    /// The sequence of the last received message.
96    seq: usize,
97    /// A buffer storing messages we received out-of-order, indexed by the seq
98    /// that should precede it.
99    buffer: HashMap<usize, Buffered>,
100    /// A map of the last sequence number we sent to next steps, indexed by rank.
101    last_seqs: HashMap<usize, usize>,
102}
103
104/// This is the comm actor used for efficient and scalable message multicasting
105/// and result accumulation.
106#[derive(Debug, Default)]
107#[hyperactor::export(
108    CommMeshConfig,
109    CastMessage,
110    ForwardMessage,
111    CastMessageV1,
112    ForwardMessageV1
113)]
114#[hyperactor::spawnable]
115pub struct CommActor {
116    /// Sequence numbers are maintained for each (actor mesh id, sender).
117    send_seq: HashMap<(ActorMeshId, ActorAddr), usize>,
118    /// Each sender is a unique stream.
119    recv_state: HashMap<(ActorMeshId, ActorAddr), ReceiveState>,
120
121    /// The comm actor's mesh configuration, or buffered messages if not yet configured.
122    mesh_config: MeshConfigState,
123}
124
125#[derive(Debug)]
126enum PendingMessage {
127    Cast(CastMessage),
128    Forward(ForwardMessage),
129    ForwardV1(ForwardMessageV1),
130}
131
132#[derive(Debug)]
133enum MeshConfigState {
134    /// Config not yet received; buffer incoming messages until it arrives.
135    NotConfigured(Vec<PendingMessage>),
136    /// Config received; ready to route messages.
137    Configured(CommMeshConfig),
138}
139
140impl Default for MeshConfigState {
141    fn default() -> Self {
142        MeshConfigState::NotConfigured(Vec::new())
143    }
144}
145
146/// Configuration for how a `CommActor` determines its own rank and locates peers.
147#[derive(Debug, Clone, Serialize, Deserialize, Named)]
148pub struct CommMeshConfig {
149    /// The rank of this comm actor on the root mesh.
150    rank: usize,
151    /// Key is the rank of the peer on the root mesh. Value is the peer's comm actor.
152    peers: HashMap<usize, ActorRef<CommActor>>,
153}
154wirevalue::register_type!(CommMeshConfig);
155
156impl CommMeshConfig {
157    /// Create a new mesh configuration with the given rank and peer mapping.
158    pub fn new(rank: usize, peers: HashMap<usize, ActorRef<CommActor>>) -> Self {
159        Self { rank, peers }
160    }
161
162    /// Return the peer comm actor for the given rank.
163    fn peer_for_rank(&self, rank: usize) -> Result<ActorRef<CommActor>> {
164        self.peers
165            .get(&rank)
166            .cloned()
167            .ok_or_else(|| anyhow::anyhow!("no peer for rank {}", rank))
168    }
169
170    /// Return the rank of the comm actor.
171    fn self_rank(&self) -> usize {
172        self.rank
173    }
174}
175
176#[async_trait]
177impl Actor for CommActor {
178    async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
179        this.set_system();
180        Ok(())
181    }
182
183    // This is an override of the default actor behavior.
184    async fn handle_undeliverable_message(
185        &mut self,
186        cx: &Instance<Self>,
187        undelivered: hyperactor::mailbox::Undeliverable<hyperactor::mailbox::MessageEnvelope>,
188    ) -> Result<(), anyhow::Error> {
189        let mut message_envelope = match undelivered {
190            Undeliverable::Message(message_envelope) => message_envelope,
191            Undeliverable::Lost(lost) => {
192                anyhow::bail!(UndeliverableMessageError::Lost { lost });
193            }
194        };
195
196        // 1. Case delivery failure at a "forwarding" step.
197        if let Ok(ForwardMessage { message, .. }) =
198            message_envelope.deserialized::<ForwardMessage>()
199        {
200            let sender = message.sender();
201            let return_port = PortRef::attest_handler_port(sender);
202            message_envelope.set_error(DeliveryError::Multicast(format!(
203                "comm actor {} failed to forward the cast message; returning to origin {}",
204                cx.self_addr(),
205                return_port.port_addr(),
206            )));
207
208            // Needed so that the receiver of the undeliverable message can easily find the
209            // original sender of the cast message.
210            message_envelope.set_header(CAST_ORIGINATING_SENDER, sender.clone());
211
212            return_port.post(cx, Undeliverable::Message(message_envelope.clone()));
213            return Ok(());
214        }
215
216        // 2. Case delivery failure at a "deliver here" step.
217        if let Some(sender) = message_envelope.headers().get(CAST_ORIGINATING_SENDER) {
218            let return_port = PortRef::attest_handler_port(&sender);
219            message_envelope.set_error(DeliveryError::Multicast(format!(
220                "comm actor {} failed to deliver the cast message to the dest \
221                actor; returning to origin {}",
222                cx.self_addr(),
223                return_port.port_addr(),
224            )));
225            return_port.post(cx, Undeliverable::Message(message_envelope.clone()));
226            return Ok(());
227        }
228
229        // 3. A return of an undeliverable message was itself returned.
230        UndeliverableMailboxSender
231            .post(message_envelope, /*unused */ monitored_return_handle());
232        Ok(())
233    }
234}
235
236impl CommActor {
237    /// Forward the message to the comm actor on the given peer rank.
238    fn forward<M: RemoteMessage>(
239        cx: &Context<Self>,
240        config: &CommMeshConfig,
241        rank: usize,
242        message: M,
243    ) -> Result<()>
244    where
245        CommActor: hyperactor::RemoteHandles<M>,
246    {
247        let child = config.peer_for_rank(rank)?;
248        // TEMPORARY: until dropping v0 support
249        if let Some(cast_actor_mesh_id) = cx.headers().get(CAST_ACTOR_MESH_ID) {
250            let mut headers = Flattrs::new();
251            headers.set(CAST_ACTOR_MESH_ID, cast_actor_mesh_id);
252            child.post_with_headers(cx, headers, message);
253        } else {
254            child.post(cx, message);
255        }
256        Ok(())
257    }
258
259    fn handle_message(
260        cx: &Context<Self>,
261        config: &CommMeshConfig,
262        deliver_here: bool,
263        next_steps: HashMap<usize, Vec<RoutingFrame>>,
264        sender: ActorAddr,
265        mut message: CastMessageEnvelope,
266        seq: usize,
267        last_seqs: &mut HashMap<usize, usize>,
268    ) -> Result<()> {
269        split_ports(cx, message.data_mut(), deliver_here, &next_steps)?;
270
271        // Deliver message here, if necessary.
272        if deliver_here {
273            // We should not copy cx.headers() because it contains auto-generated
274            // headers from mailbox. We want fresh headers only containing
275            // user-provided headers.
276            let headers = message.headers().clone();
277            Self::deliver_to_dest(cx, headers, &mut message, config)?;
278        }
279
280        // Forward to peers.
281        next_steps
282            .into_iter()
283            .map(|(peer, dests)| {
284                let last_seq = last_seqs.entry(peer).or_default();
285                Self::forward(
286                    cx,
287                    config,
288                    peer,
289                    ForwardMessage {
290                        dests,
291                        sender: sender.clone(),
292                        message: message.clone(),
293                        seq,
294                        last_seq: *last_seq,
295                    },
296                )?;
297                *last_seq = seq;
298                Ok(())
299            })
300            .collect::<Result<Vec<_>>>()?;
301
302        Ok(())
303    }
304
305    fn deliver_to_dest<M: CastEnvelope>(
306        cx: &Context<Self>,
307        mut headers: Flattrs,
308        message: &mut M,
309        config: &CommMeshConfig,
310    ) -> anyhow::Result<()> {
311        let cast_point = message.cast_point(config)?;
312        // Replace ranks with self ranks.
313        replace_with_self_ranks(&cast_point, message.data_mut())?;
314
315        set_cast_info_on_headers(&mut headers, cast_point, message.sender().clone());
316        cx.post_with_external_seq_info(
317            cx.self_addr()
318                .proc_addr()
319                .actor_addr_uid(message.dest_port().actor_uid().clone())
320                .port_addr(hyperactor::port::Port::from(message.dest_port().port())),
321            headers,
322            wirevalue::Any::serialize(message.data())?,
323        );
324
325        Ok(())
326    }
327}
328
329// Split ports, if any, and update message with new ports. In this
330// way, children actors will reply to this comm actor's ports, instead
331// of to the original ports provided by parent.
332fn split_ports(
333    cx: &Context<CommActor>,
334    data: &mut ErasedUnbound,
335    deliver_here: bool,
336    next_steps: &HashMap<usize, Vec<RoutingFrame>>,
337) -> anyhow::Result<()> {
338    // Split ports, if any, and update message with new ports. In this
339    // way, children actors will reply to this comm actor's ports, instead
340    // of to the original ports provided by parent.
341    data.visit_mut::<UnboundPort>(
342        |UnboundPort(port_id, reducer_spec, return_undeliverable, kind, unsplit)| {
343            if *unsplit {
344                return Ok(());
345            }
346            let reducer_mode = match kind {
347                UnboundPortKind::Streaming(opts) => {
348                    ReducerMode::Streaming(opts.clone().unwrap_or_default())
349                }
350                UnboundPortKind::Once if reducer_spec.is_none() => {
351                    // We can only split OncePorts that have reducers.
352                    // Pass this through -- if it is used multiple times,
353                    // it will cause a delivery error downstream.
354                    // However we should reconsider this behavior
355                    // as it its semantics will now differ between
356                    // unicast and broadcast messages.
357                    return Ok(());
358                }
359                UnboundPortKind::Once => {
360                    // Compute peer count for OncePort splitting. This is the number of
361                    // destinations the message will be delivered to, so that the split
362                    // port can correctly accumulate responses.
363                    let peer_count = next_steps.len() + if deliver_here { 1 } else { 0 };
364                    ReducerMode::Once(peer_count)
365                }
366            };
367
368            let split = port_id.clone().split(
369                cx,
370                reducer_spec.clone(),
371                reducer_mode,
372                *return_undeliverable,
373            )?;
374
375            #[cfg(test)]
376            tests::collect_split_port(&port_id.clone(), &split, deliver_here);
377
378            *port_id = split;
379            Ok(())
380        },
381    )
382}
383
384fn replace_with_self_ranks(cast_point: &Point, data: &mut ErasedUnbound) -> anyhow::Result<()> {
385    data.visit_mut::<resource::Rank>(|resource::Rank(rank)| {
386        *rank = Some(cast_point.rank());
387        Ok(())
388    })
389}
390
391#[async_trait]
392impl Handler<CommMeshConfig> for CommActor {
393    async fn handle(&mut self, cx: &Context<Self>, config: CommMeshConfig) -> Result<()> {
394        let pending =
395            match std::mem::replace(&mut self.mesh_config, MeshConfigState::Configured(config)) {
396                MeshConfigState::NotConfigured(pending) => pending,
397                MeshConfigState::Configured(_) => Vec::new(),
398            };
399        if !pending.is_empty() {
400            tracing::info!(
401                count = pending.len(),
402                "replaying buffered pre-config messages"
403            );
404        }
405        for msg in pending {
406            match msg {
407                PendingMessage::Cast(m) => self.handle(cx, m).await?,
408                PendingMessage::Forward(m) => self.handle(cx, m).await?,
409                PendingMessage::ForwardV1(m) => self.handle(cx, m).await?,
410            }
411        }
412        Ok(())
413    }
414}
415
416// TODO(T218630526): reliable casting for mutable topology
417#[async_trait]
418impl Handler<CastMessage> for CommActor {
419    #[tracing::instrument(level = "debug", skip_all)]
420    async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessage) -> Result<()> {
421        let config = match &mut self.mesh_config {
422            MeshConfigState::NotConfigured(pending) => {
423                pending.push(PendingMessage::Cast(cast_message));
424                return Ok(());
425            }
426            MeshConfigState::Configured(config) => config,
427        };
428        // Always forward the message to the root rank of the slice, casting starts from there.
429        let slice = cast_message.dest.slice.clone();
430        let selection = cast_message.dest.selection.clone();
431        let frame = RoutingFrame::root(selection, slice);
432        let rank = frame.slice.location(&frame.here)?;
433        let seq = self
434            .send_seq
435            .entry(cast_message.message.stream_key())
436            .or_default();
437        let last_seq = *seq;
438        *seq += 1;
439
440        let fwd_message = ForwardMessage {
441            dests: vec![frame],
442            sender: cx.self_addr().clone(),
443            message: cast_message.message,
444            seq: *seq,
445            last_seq,
446        };
447
448        // Optimization: if forwarding to ourselves, handle inline instead of
449        // going through the message queue
450        if config.self_rank() == rank {
451            Handler::<ForwardMessage>::handle(self, cx, fwd_message).await?;
452        } else {
453            Self::forward(cx, config, rank, fwd_message)?;
454        }
455        Ok(())
456    }
457}
458
459#[async_trait]
460impl Handler<ForwardMessage> for CommActor {
461    #[tracing::instrument(level = "debug", skip_all)]
462    async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessage) -> Result<()> {
463        let config = match &mut self.mesh_config {
464            MeshConfigState::NotConfigured(pending) => {
465                pending.push(PendingMessage::Forward(fwd_message));
466                return Ok(());
467            }
468            MeshConfigState::Configured(config) => config,
469        };
470
471        let ForwardMessage {
472            sender,
473            dests,
474            message,
475            seq,
476            last_seq,
477        } = fwd_message;
478
479        // Resolve/dedup routing frames.
480        let rank = config.self_rank();
481        let (deliver_here, next_steps) =
482            ndslice::selection::routing::resolve_routing(rank, dests, &mut |_| {
483                panic!("Choice encountered in CommActor routing")
484            })?;
485
486        let recv_state = self.recv_state.entry(message.stream_key()).or_default();
487        match recv_state.seq.cmp(&last_seq) {
488            // We got the expected next message to deliver to this host.
489            Ordering::Equal => {
490                // We got an in-order operation, so handle it now.
491                Self::handle_message(
492                    cx,
493                    config,
494                    deliver_here,
495                    next_steps,
496                    sender.clone(),
497                    message,
498                    seq,
499                    &mut recv_state.last_seqs,
500                )?;
501                recv_state.seq = seq;
502
503                // Also deliver any pending operations from the recv buffer that
504                // were received out-of-order that are now unblocked.
505                while let Some(Buffered {
506                    seq,
507                    deliver_here,
508                    next_steps,
509                    message,
510                }) = recv_state.buffer.remove(&recv_state.seq)
511                {
512                    Self::handle_message(
513                        cx,
514                        config,
515                        deliver_here,
516                        next_steps,
517                        sender.clone(),
518                        message,
519                        seq,
520                        &mut recv_state.last_seqs,
521                    )?;
522                    recv_state.seq = seq;
523                }
524            }
525            // We got an out-of-order operation, so buffer it for now, until we
526            // recieved the onces sequenced before it.
527            Ordering::Less => {
528                tracing::warn!(
529                    "buffering out-of-order message with seq {} (last {}), expected {}: {:?}",
530                    seq,
531                    last_seq,
532                    recv_state.seq,
533                    message
534                );
535                recv_state.buffer.insert(
536                    last_seq,
537                    Buffered {
538                        seq,
539                        deliver_here,
540                        next_steps,
541                        message,
542                    },
543                );
544            }
545            // We already got this message -- just drop it.
546            Ordering::Greater => {
547                tracing::warn!("received duplicate message with seq {}: {:?}", seq, message);
548            }
549        }
550
551        Ok(())
552    }
553}
554
555#[async_trait]
556impl Handler<CastMessageV1> for CommActor {
557    async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessageV1) -> Result<()> {
558        let slice = cast_message.dest_region.slice().clone();
559        let frame = RoutingFrame::root(sel!(*), slice);
560        let forward_message = ForwardMessageV1 {
561            dests: vec![frame],
562            message: cast_message,
563        };
564        self.handle(cx, forward_message).await
565    }
566}
567
568#[async_trait]
569impl Handler<ForwardMessageV1> for CommActor {
570    async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessageV1) -> Result<()> {
571        let config = match &mut self.mesh_config {
572            MeshConfigState::NotConfigured(pending) => {
573                pending.push(PendingMessage::ForwardV1(fwd_message));
574                return Ok(());
575            }
576            MeshConfigState::Configured(config) => config,
577        };
578
579        let ForwardMessageV1 { dests, mut message } = fwd_message;
580        // Resolve/dedup routing frames.
581        let rank_on_root_mesh = config.self_rank();
582        let (deliver_here, next_steps) =
583            ndslice::selection::routing::resolve_routing(rank_on_root_mesh, dests, &mut |_| {
584                panic!("choice encountered in CommActor routing")
585            })?;
586
587        split_ports(cx, &mut message.data, deliver_here, &next_steps)?;
588
589        // Deliver message here, if necessary.
590        if deliver_here {
591            let mut headers = message.headers().clone();
592            let seq = message
593                .seqs
594                .get(message.cast_point(config)?.rank())
595                .expect("mismatched seqs and dest_region");
596            headers.set(
597                SEQ_INFO,
598                SeqInfo::Session {
599                    session_id: message.session_id,
600                    seq,
601                },
602            );
603            Self::deliver_to_dest(cx, headers, &mut message, config)?;
604        }
605
606        // Forward to peers.
607        for (peer_rank_on_root_mesh, dests) in next_steps {
608            let forward_message = ForwardMessageV1 {
609                dests,
610                message: message.clone(),
611            };
612            Self::forward(cx, config, peer_rank_on_root_mesh, forward_message)?;
613        }
614
615        Ok(())
616    }
617}
618
619pub mod test_utils {
620    use anyhow::Result;
621    use async_trait::async_trait;
622    use hyperactor::Actor;
623    use hyperactor::ActorAddr;
624    use hyperactor::Bind;
625    use hyperactor::Context;
626    use hyperactor::Handler;
627    use hyperactor::PortRef;
628    use hyperactor::Unbind;
629    use serde::Deserialize;
630    use serde::Serialize;
631    use typeuri::Named;
632
633    use super::*;
634
635    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
636    pub struct MyReply {
637        pub sender: ActorAddr,
638        pub value: u64,
639    }
640
641    #[derive(Debug, Named, Serialize, Deserialize, PartialEq, Clone, Bind, Unbind)]
642    #[expect(
643        clippy::large_enum_variant,
644        reason = "test fixture; CastAndReply carries #[binding(include)] PortRefs whose Bind/Unbind derive interaction with Box<T> needs verification — separate diff"
645    )]
646    pub enum TestMessage {
647        Forward(String),
648        CastAndReply {
649            arg: String,
650            // Intentionally not including 0. As a result, this port will not be
651            // split.
652            // #[binding(include)]
653            reply_to0: PortRef<String>,
654            #[binding(include)]
655            reply_to1: PortRef<u64>,
656            #[binding(include)]
657            reply_to2: PortRef<MyReply>,
658        },
659        CastAndReplyOnce {
660            arg: String,
661            #[binding(include)]
662            reply_to: hyperactor::OncePortRef<u64>,
663        },
664        CastWithUnsplitPort {
665            #[binding(include)]
666            reply_to: PortRef<u64>,
667        },
668    }
669
670    #[derive(Debug)]
671    #[hyperactor::export(TestMessage { cast = true })]
672    #[hyperactor::spawnable]
673    pub struct TestActor {
674        // Forward the received message to this port, so it can be inspected by
675        // the unit test.
676        forward_port: PortRef<TestMessage>,
677    }
678
679    #[derive(Debug, Clone, Named, Serialize, Deserialize)]
680    pub struct TestActorParams {
681        pub forward_port: PortRef<TestMessage>,
682    }
683
684    #[async_trait]
685    impl Actor for TestActor {}
686
687    #[async_trait]
688    impl hyperactor::RemoteSpawn for TestActor {
689        type Params = TestActorParams;
690
691        async fn new(params: Self::Params, _environment: Flattrs) -> Result<Self> {
692            let Self::Params { forward_port } = params;
693            Ok(Self { forward_port })
694        }
695    }
696
697    #[async_trait]
698    impl Handler<TestMessage> for TestActor {
699        async fn handle(&mut self, cx: &Context<Self>, msg: TestMessage) -> anyhow::Result<()> {
700            // For CastWithUnsplitPort, send a reply so the test can
701            // verify that the unsplit port is still directly reachable.
702            if let TestMessage::CastWithUnsplitPort { ref reply_to } = msg {
703                reply_to.post(cx, 42);
704            }
705            self.forward_port.post(cx, msg);
706            Ok(())
707        }
708    }
709}
710
711#[cfg(test)]
712mod tests {
713    use std::collections::BTreeMap;
714    use std::collections::HashSet;
715    use std::fmt::Display;
716    use std::hash::Hash;
717    use std::ops::Deref;
718    use std::ops::DerefMut;
719    use std::sync::Mutex;
720    use std::sync::OnceLock;
721
722    use hyperactor::accum;
723
724    /// Common setup for pre-config buffering tests: a single proc with a
725    /// TestActor (for observing delivery) and an unconfigured CommActor.
726    /// Returns (client, rx, comm_handle, actor_mesh_id) plus handles
727    /// that must be kept alive.
728    async fn buffering_fixture(
729        proc_name: &str,
730    ) -> (
731        Instance<()>,
732        hyperactor::mailbox::PortReceiver<TestMessage>,
733        hyperactor::ActorHandle<CommActor>,
734        crate::mesh_id::ActorMeshId,
735        // Drop guards: client handle, test actor handle, test actor ref.
736        (
737            hyperactor::ActorHandle<()>,
738            hyperactor::ActorHandle<TestActor>,
739            ActorRef<TestActor>,
740        ),
741    ) {
742        use hyperactor::Proc;
743        use hyperactor::RemoteSpawn;
744        use hyperactor::channel::ChannelTransport;
745        use hyperactor::id::Label;
746
747        let proc = Proc::direct(ChannelTransport::Unix.any(), proc_name.to_string()).unwrap();
748        let (client, client_handle) = proc.client("client").unwrap();
749
750        let actor_mesh_id = crate::mesh_id::ActorMeshId::instance(Label::new("test").unwrap());
751
752        let (tx, rx) = open_port(&client);
753        let forward_port = tx.bind();
754        let test_actor = TestActor::new(TestActorParams { forward_port }, Default::default())
755            .await
756            .unwrap();
757        let test_handle = proc
758            .spawn_with_uid(actor_mesh_id.uid().clone(), test_actor)
759            .unwrap();
760        let test_ref: ActorRef<TestActor> = test_handle.bind::<TestActor>();
761
762        let comm_handle = proc.spawn("comm", CommActor::default()).unwrap();
763
764        (
765            client,
766            rx,
767            comm_handle,
768            actor_mesh_id,
769            (client_handle, test_handle, test_ref),
770        )
771    }
772
773    /// Send CommMeshConfig (single-rank mesh pointing at self).
774    fn send_config(client: &Instance<()>, comm_handle: &hyperactor::ActorHandle<CommActor>) {
775        let comm_ref = comm_handle.bind::<CommActor>();
776        let mut peers = HashMap::new();
777        peers.insert(0, comm_ref);
778        comm_handle.post(client, CommMeshConfig::new(0, peers));
779    }
780
781    /// Send a message before config, send config, send another after config,
782    /// and verify both are delivered in order.
783    async fn assert_buffered_and_replayed<M: hyperactor::Message>(
784        proc_name: &str,
785        mut make_msg: impl FnMut(&Instance<()>, &crate::mesh_id::ActorMeshId, &str) -> M,
786    ) where
787        CommActor: hyperactor::Handler<M>,
788    {
789        let (client, mut rx, comm_handle, actor_mesh_id, _guards) =
790            buffering_fixture(proc_name).await;
791
792        comm_handle.post(&client, make_msg(&client, &actor_mesh_id, "buffered"));
793        send_config(&client, &comm_handle);
794        comm_handle.post(&client, make_msg(&client, &actor_mesh_id, "direct"));
795
796        assert_eq!(
797            rx.recv().await.unwrap(),
798            TestMessage::Forward("buffered".to_string()),
799        );
800        assert_eq!(
801            rx.recv().await.unwrap(),
802            TestMessage::Forward("direct".to_string()),
803        );
804        comm_handle.drain_and_stop("test done").ok();
805    }
806
807    #[async_timed_test(timeout_secs = 1)]
808    async fn cast_before_config_is_buffered_and_replayed() {
809        use ndslice::Slice;
810
811        assert_buffered_and_replayed("test_cast", |client, actor_mesh_id, payload| {
812            let actor_mesh_id = actor_mesh_id.clone();
813            let slice = Slice::new_row_major(vec![1]);
814            let shape = ndslice::Shape::new(vec!["rank".to_string()], slice.clone()).unwrap();
815            let envelope = multicast::CastMessageEnvelope::new::<TestActor, TestMessage>(
816                actor_mesh_id,
817                client.self_addr().clone(),
818                shape,
819                hyperactor_config::Flattrs::new(),
820                TestMessage::Forward(payload.to_string()),
821            )
822            .unwrap();
823            multicast::CastMessage {
824                dest: multicast::Uslice {
825                    slice,
826                    selection: sel!(*),
827                },
828                message: envelope,
829            }
830        })
831        .await;
832    }
833
834    #[async_timed_test(timeout_secs = 1)]
835    async fn forward_before_config_is_buffered_and_replayed() {
836        use ndslice::Slice;
837        use ndslice::selection::routing::RoutingFrame;
838
839        let mut next_seq: usize = 0;
840        assert_buffered_and_replayed("test_fwd", move |client, actor_mesh_id, payload| {
841            let actor_mesh_id = actor_mesh_id.clone();
842            let slice = Slice::new_row_major(vec![1]);
843            let shape = ndslice::Shape::new(vec!["rank".to_string()], slice.clone()).unwrap();
844            let envelope = multicast::CastMessageEnvelope::new::<TestActor, TestMessage>(
845                actor_mesh_id,
846                client.self_addr().clone(),
847                shape,
848                hyperactor_config::Flattrs::new(),
849                TestMessage::Forward(payload.to_string()),
850            )
851            .unwrap();
852            let frame = RoutingFrame::root(sel!(*), slice);
853            let last_seq = next_seq;
854            next_seq += 1;
855            multicast::ForwardMessage {
856                sender: client.self_addr().clone(),
857                dests: vec![frame],
858                seq: next_seq,
859                last_seq,
860                message: envelope,
861            }
862        })
863        .await;
864    }
865
866    #[async_timed_test(timeout_secs = 1)]
867    async fn forward_v1_before_config_is_buffered_and_replayed() {
868        use ndslice::Region;
869        use ndslice::Slice;
870        use ndslice::selection::routing::RoutingFrame;
871
872        assert_buffered_and_replayed("test_fwd_v1", |client, actor_mesh_id, payload| {
873            let slice = Slice::new_row_major(vec![1]);
874            let region = Region::new(vec!["rank".to_string()], slice.clone());
875            let cast_msg = multicast::CastMessageV1::new::<TestActor, TestMessage>(
876                client.self_addr().clone(),
877                actor_mesh_id,
878                region.clone(),
879                hyperactor_config::Flattrs::new(),
880                TestMessage::Forward(payload.to_string()),
881                uuid::Uuid::new_v4(),
882                crate::ValueMesh::from_single(region, 1u64),
883            )
884            .unwrap();
885            let frame = RoutingFrame::root(sel!(*), slice);
886            multicast::ForwardMessageV1 {
887                dests: vec![frame],
888                message: cast_msg,
889            }
890        })
891        .await;
892    }
893
894    use hyperactor::ActorAddr;
895    use hyperactor::ActorRef;
896    use hyperactor::Index;
897    use hyperactor::OncePortRef;
898    use hyperactor::PortAddr;
899    use hyperactor::PortRef;
900    use hyperactor::ProcAddr;
901    use hyperactor::accum::Accumulator;
902    use hyperactor::accum::ReducerSpec;
903    use hyperactor::context;
904    use hyperactor::context::Mailbox;
905    use hyperactor::mailbox::PortReceiver;
906    use hyperactor::mailbox::open_port;
907    use hyperactor_config;
908    use hyperactor_mesh_macros::sel;
909    use maplit::btreemap;
910    use maplit::hashmap;
911    use ndslice::Extent;
912    use ndslice::Selection;
913    use ndslice::ViewExt as _;
914    use ndslice::extent;
915    use ndslice::selection::test_utils::collect_commactor_routing_tree;
916    use test_utils::*;
917    use timed_test::async_timed_test;
918    use tokio::time::Duration;
919
920    use super::*;
921    use crate::ActorMesh;
922    use crate::host_mesh::HostMesh;
923    use crate::test_utils::local_host_mesh;
924    use crate::testing;
925
926    // Helper to look up the rank for a given actor ID using the rank_lookup table.
927    fn lookup_rank(actor_id: &ActorAddr, rank_lookup: &HashMap<ProcAddr, usize>) -> usize {
928        let proc_id = actor_id.proc_addr();
929        *rank_lookup
930            .get(&proc_id)
931            .unwrap_or_else(|| panic!("proc rank not found for {}", proc_id))
932    }
933
934    struct Edge<T> {
935        from: T,
936        to: T,
937        is_leaf: bool,
938    }
939
940    impl<T> From<(T, T, bool)> for Edge<T> {
941        fn from((from, to, is_leaf): (T, T, bool)) -> Self {
942            Self { from, to, is_leaf }
943        }
944    }
945
946    // The relationship between original ports and split ports. The elements in
947    // the tuple are (original port, split port, deliver_here).
948    static SPLIT_PORT_TREE: OnceLock<Mutex<Vec<Edge<PortAddr>>>> = OnceLock::new();
949
950    // Collect the relationships between original ports and split ports into
951    // SPLIT_PORT_TREE. This is used by tests to verify that ports are split as expected.
952    pub(crate) fn collect_split_port(original: &PortAddr, split: &PortAddr, deliver_here: bool) {
953        let mutex = SPLIT_PORT_TREE.get_or_init(|| Mutex::new(vec![]));
954        let mut tree = mutex.lock().unwrap();
955
956        tree.deref_mut().push(Edge {
957            from: original.clone(),
958            to: split.clone(),
959            is_leaf: deliver_here,
960        });
961    }
962
963    // There could be other cast calls before the one we want to check, e.g. from
964    // allocating the proc mesh, or spawning the actor mesh. Clear the collected
965    // tree so it will only contain the cast we want to check.
966    fn clear_collected_tree() {
967        if let Some(tree) = SPLIT_PORT_TREE.get() {
968            let mut tree: std::sync::MutexGuard<'_, Vec<Edge<PortAddr>>> = tree.lock().unwrap();
969            tree.clear();
970        }
971    }
972
973    // A representation of a tree.
974    //   * Map's keys are the tree's leafs;
975    //   * Map's values are the path from the root to that leaf.
976    #[derive(PartialEq)]
977    struct PathToLeaves<T>(BTreeMap<T, Vec<T>>);
978
979    // Add a custom Debug trait impl so the result from assert_eq! is readable.
980    impl<T: Display> Debug for PathToLeaves<T> {
981        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
982            fn vec_to_string<T: Display>(v: &[T]) -> String {
983                v.iter()
984                    .map(ToString::to_string)
985                    .collect::<Vec<String>>()
986                    .join(", ")
987            }
988
989            for (src, path) in &self.0 {
990                writeln!(f, "{} -> {}", src, vec_to_string(path))?;
991            }
992            Ok(())
993        }
994    }
995
996    fn build_paths<T: Clone + Eq + Hash + Ord>(edges: &[Edge<T>]) -> PathToLeaves<T> {
997        let mut child_parent_map = HashMap::new();
998        let mut all_nodes = HashSet::new();
999        let mut parents = HashSet::new();
1000        let mut children = HashSet::new();
1001        let mut dests = HashSet::new();
1002
1003        // Build parent map and track all nodes and children
1004        for Edge { from, to, is_leaf } in edges {
1005            child_parent_map.insert(to.clone(), from.clone());
1006            all_nodes.insert(from.clone());
1007            all_nodes.insert(to.clone());
1008            parents.insert(from.clone());
1009            children.insert(to.clone());
1010            if *is_leaf {
1011                dests.insert(to.clone());
1012            }
1013        }
1014
1015        // For each leaf, reconstruct path back to root
1016        let mut result = BTreeMap::new();
1017        for dest in dests {
1018            let mut path = vec![dest.clone()];
1019            let mut current = dest.clone();
1020            while let Some(parent) = child_parent_map.get(&current) {
1021                path.push(parent.clone());
1022                current = parent.clone();
1023            }
1024            path.reverse();
1025            result.insert(dest, path);
1026        }
1027
1028        PathToLeaves(result)
1029    }
1030
1031    #[test]
1032    fn test_build_paths() {
1033        // Given the tree:
1034        //     0
1035        //    / \
1036        //   1   4
1037        //  / \   \
1038        // 2   3   5
1039        let edges: Vec<_> = [
1040            (0, 1, false),
1041            (1, 2, true),
1042            (1, 3, true),
1043            (0, 4, true),
1044            (4, 5, true),
1045        ]
1046        .into_iter()
1047        .map(|(from, to, is_leaf)| Edge { from, to, is_leaf })
1048        .collect();
1049
1050        let paths = build_paths(&edges);
1051
1052        let expected = btreemap! {
1053            2 => vec![0, 1, 2],
1054            3 => vec![0, 1, 3],
1055            4 => vec![0, 4],
1056            5 => vec![0, 4, 5],
1057        };
1058
1059        assert_eq!(paths.0, expected);
1060    }
1061
1062    //  Given a port tree,
1063    //     * remove the client port, i.e. the 1st element of the path;
1064    //     * verify all remaining ports are comm handler ports;
1065    //     * remove the actor information and return a rank-based tree representation.
1066    //
1067    //  The rank-based tree representation is what [collect_commactor_routing_tree] returns.
1068    //  This conversion enables us to compare the path against [collect_commactor_routing_tree]'s result.
1069    //
1070    //      For example, for a 2x2 slice, the port tree could look like:
1071    //      dest[0].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028]]
1072    //      dest[1].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[1].comm[0][1028]]
1073    //      dest[2].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[2].comm[0][1028]]
1074    //      dest[3].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[2].comm[0][1028], dest[3].comm[0][1028]]
1075    //
1076    //     The result should be:
1077    //     0 -> 0
1078    //     1 -> 0, 1
1079    //     2 -> 0, 2
1080    //     3 -> 0, 2, 3
1081    fn get_ranks(
1082        paths: PathToLeaves<PortAddr>,
1083        client_reply: &PortAddr,
1084        rank_lookup: &HashMap<ProcAddr, usize>,
1085    ) -> PathToLeaves<Index> {
1086        let ranks = paths
1087            .0
1088            .into_iter()
1089            .map(|(dst, mut path): (PortAddr, Vec<PortAddr>)| {
1090                let first = path.remove(0);
1091                // The first PortId is the client's reply port.
1092                assert_eq!(&first, client_reply);
1093                // Other ports's actor ID must be dest[?].comm[0], where ? is
1094                // the rank we want to extract here.
1095                assert!(dst.actor_addr().label().unwrap().as_str().contains("comm"));
1096                let actor_path = path
1097                    .into_iter()
1098                    .map(|p: PortAddr| {
1099                        assert!(p.actor_addr().label().unwrap().as_str().contains("comm"));
1100                        let actor_ref = p.actor_addr();
1101                        lookup_rank(&actor_ref, rank_lookup)
1102                    })
1103                    .collect();
1104                let dst_actor_ref = dst.actor_addr();
1105                (lookup_rank(&dst_actor_ref, rank_lookup), actor_path)
1106            })
1107            .collect();
1108        PathToLeaves(ranks)
1109    }
1110
1111    struct NoneAccumulator;
1112
1113    impl Accumulator for NoneAccumulator {
1114        type State = u64;
1115        type Update = u64;
1116
1117        fn accumulate(
1118            &self,
1119            _state: &mut Self::State,
1120            _update: Self::Update,
1121        ) -> anyhow::Result<()> {
1122            unimplemented!()
1123        }
1124
1125        fn reducer_spec(&self) -> Option<ReducerSpec> {
1126            unimplemented!()
1127        }
1128    }
1129
1130    // Verify the split port paths are the same as the casting paths.
1131    fn verify_split_port_paths(
1132        selection: &Selection,
1133        extent: &Extent,
1134        reply_port_ref1: &PortRef<u64>,
1135        reply_port_ref2: &PortRef<MyReply>,
1136        rank_lookup: &HashMap<ProcAddr, usize>,
1137    ) {
1138        // Get the paths used in casting
1139        let sel_paths = PathToLeaves(
1140            collect_commactor_routing_tree(selection, &extent.to_slice())
1141                .delivered
1142                .into_iter()
1143                .collect(),
1144        );
1145
1146        // Get the split port paths collected in SPLIT_PORT_TREE during casting
1147        let (reply1_paths, reply2_paths) = {
1148            let tree = SPLIT_PORT_TREE
1149                .get()
1150                .expect("not initialized; are Hosts in the same process as SPLIT_PORT_TREE?");
1151            let edges = tree.lock().unwrap();
1152            let (reply1, reply2): (BTreeMap<_, _>, BTreeMap<_, _>) = build_paths(&edges)
1153                .0
1154                .into_iter()
1155                .partition(|(_dst, path)| path[0] == *reply_port_ref1.port_addr());
1156            (
1157                get_ranks(
1158                    PathToLeaves(reply1),
1159                    reply_port_ref1.port_addr(),
1160                    rank_lookup,
1161                ),
1162                get_ranks(
1163                    PathToLeaves(reply2),
1164                    reply_port_ref2.port_addr(),
1165                    rank_lookup,
1166                ),
1167            )
1168        };
1169
1170        // split port paths should be the same as casting paths
1171        assert_eq!(sel_paths, reply1_paths);
1172        assert_eq!(sel_paths, reply2_paths);
1173    }
1174
1175    async fn execute_cast_and_reply(
1176        ranks: Vec<ActorRef<TestActor>>,
1177        instance: &impl context::Actor,
1178        mut reply1_rx: PortReceiver<u64>,
1179        mut reply2_rx: PortReceiver<MyReply>,
1180        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1181    ) {
1182        // Reply from each dest actor. The replies should be received by client.
1183        {
1184            for (rank, (dest_actor, (reply_to1, reply_to2))) in
1185                ranks.iter().zip(reply_tos.iter()).enumerate()
1186            {
1187                let rank_u64 = rank as u64;
1188                reply_to1.post(instance, rank_u64);
1189                let my_reply = MyReply {
1190                    sender: dest_actor.actor_addr().clone(),
1191                    value: rank_u64,
1192                };
1193                reply_to2.post(instance, my_reply.clone());
1194
1195                assert_eq!(reply1_rx.recv().await.unwrap(), rank_u64);
1196                assert_eq!(reply2_rx.recv().await.unwrap(), my_reply);
1197            }
1198        }
1199
1200        tracing::info!("the 1st updates from all dest actors were receivered by client");
1201
1202        // Now send multiple replies from the dest actors. They should all be
1203        // received by client. Replies sent from the same dest actor should
1204        // be received in the same order as they were sent out.
1205        {
1206            let n = 100;
1207            let mut expected2: HashMap<ActorAddr, Vec<MyReply>> = hashmap! {};
1208            for (i, (dest_actor, (_reply_to1, reply_to2))) in
1209                ranks.iter().zip(reply_tos.iter()).enumerate()
1210            {
1211                let mut sent2 = vec![];
1212                for j in 0..n {
1213                    let value = (i * 100 + j) as u64;
1214                    let my_reply = MyReply {
1215                        sender: dest_actor.actor_addr().clone(),
1216                        value,
1217                    };
1218                    reply_to2.post(instance, my_reply.clone());
1219                    sent2.push(my_reply);
1220                }
1221                assert!(
1222                    expected2
1223                        .insert(dest_actor.actor_addr().clone(), sent2)
1224                        .is_none(),
1225                    "duplicate actor_id {} in map",
1226                    dest_actor.actor_addr()
1227                );
1228            }
1229
1230            let mut received2: HashMap<ActorAddr, Vec<MyReply>> = hashmap! {};
1231
1232            for _ in 0..(n * ranks.len()) {
1233                let my_reply = reply2_rx.recv().await.unwrap();
1234                received2
1235                    .entry(my_reply.sender.clone())
1236                    .or_default()
1237                    .push(my_reply);
1238            }
1239            assert_eq!(received2, expected2);
1240        }
1241    }
1242
1243    async fn wait_for_with_timeout(
1244        receiver: &mut PortReceiver<u64>,
1245        expected: u64,
1246        dur: Duration,
1247    ) -> anyhow::Result<()> {
1248        // timeout wraps the entire async block
1249        tokio::time::timeout(dur, async {
1250            loop {
1251                let msg = receiver.recv().await.unwrap();
1252                if msg == expected {
1253                    break;
1254                }
1255            }
1256        })
1257        .await?;
1258        Ok(())
1259    }
1260
1261    async fn execute_cast_and_accum(
1262        ranks: Vec<ActorRef<TestActor>>,
1263        instance: &impl context::Actor,
1264        mut reply1_rx: PortReceiver<u64>,
1265        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1266    ) {
1267        // Now send multiple replies from the dest actors. They should all be
1268        // received by client. Replies sent from the same dest actor should
1269        // be received in the same order as they were sent out.
1270        let mut sum = 0;
1271        let n = 100;
1272        for (i, (_dest_actor, (reply_to1, _reply_to2))) in
1273            ranks.iter().zip(reply_tos.iter()).enumerate()
1274        {
1275            for j in 0..n {
1276                let value = (i + j) as u64;
1277                reply_to1.post(instance, value);
1278                sum += value;
1279            }
1280        }
1281        wait_for_with_timeout(&mut reply1_rx, sum, Duration::from_secs(8))
1282            .await
1283            .unwrap();
1284        // no more messages
1285        tokio::time::sleep(Duration::from_secs(2)).await;
1286        let msg = reply1_rx.try_recv().unwrap();
1287        assert_eq!(msg, None);
1288    }
1289
1290    struct MeshSetupV1 {
1291        instance: &'static Instance<testing::TestRootClient>,
1292        actor_mesh_ref: crate::ActorMeshRef<TestActor>,
1293        reply1_rx: PortReceiver<u64>,
1294        reply2_rx: PortReceiver<MyReply>,
1295        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1296        // Keep the host mesh alive so comm actors aren't shut down.
1297        host_mesh: HostMesh,
1298    }
1299
1300    async fn setup_mesh_v1<A>(accum: Option<A>) -> MeshSetupV1
1301    where
1302        A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
1303    {
1304        let instance = crate::testing::instance();
1305        // We have to use a in process host mesh, because SPLIT_PORT_TREE only
1306        // can collect paths from the same process.
1307        let host_mesh = local_host_mesh(8).await;
1308        let proc_mesh = host_mesh
1309            .spawn(instance, "test", extent!(gpu = 8), None, None)
1310            .await
1311            .unwrap();
1312
1313        let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1314        let params = TestActorParams {
1315            forward_port: tx.bind(),
1316        };
1317        let actor_name =
1318            crate::mesh_id::ActorMeshId::instance(hyperactor::id::Label::new("test").unwrap());
1319        // Make this actor a "system" actor to avoid spawning a controller actor.
1320        // This test is verifying the whole comm tree, so we want fewer actors
1321        // involved.
1322        let actor_mesh = proc_mesh
1323            .spawn_with_name(&instance, actor_name, &params, None, true)
1324            .await
1325            .unwrap();
1326        let actor_mesh_ref: crate::ActorMeshRef<TestActor> = actor_mesh.deref().clone();
1327
1328        let (reply_port_handle0, _) = open_port::<String>(instance);
1329        let reply_port_ref0 = reply_port_handle0.bind();
1330        let (reply_port_handle1, reply1_rx) = match accum {
1331            Some(a) => instance.mailbox().open_accum_port(a),
1332            None => open_port(instance),
1333        };
1334        let reply_port_ref1 = reply_port_handle1.bind();
1335        let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(instance);
1336        let reply_port_ref2 = reply_port_handle2.bind();
1337        let message = TestMessage::CastAndReply {
1338            arg: "abc".to_string(),
1339            reply_to0: reply_port_ref0.clone(),
1340            reply_to1: reply_port_ref1.clone(),
1341            reply_to2: reply_port_ref2.clone(),
1342        };
1343
1344        clear_collected_tree();
1345        actor_mesh_ref.cast(instance, message).unwrap();
1346
1347        let mut reply_tos = vec![];
1348        for _ in proc_mesh.extent().points() {
1349            let msg = rx.recv().await.expect("missing");
1350            match msg {
1351                TestMessage::CastAndReply {
1352                    arg,
1353                    reply_to0,
1354                    reply_to1,
1355                    reply_to2,
1356                } => {
1357                    assert_eq!(arg, "abc");
1358                    // port 0 is still the same as the original one because it
1359                    // is not included in MutVisitor.
1360                    assert_eq!(reply_to0, reply_port_ref0);
1361                    // ports have been replaced by split ports.
1362                    assert_ne!(reply_to1, reply_port_ref1);
1363                    assert_ne!(reply_to2, reply_port_ref2);
1364                    let p2p_threshold = hyperactor_config::global::get(
1365                        crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD,
1366                    );
1367                    if p2p_threshold == 0 {
1368                        // Tree path: split ports belong to comm actors.
1369                        assert!(
1370                            reply_to1
1371                                .port_addr()
1372                                .actor_id()
1373                                .label()
1374                                .unwrap()
1375                                .as_str()
1376                                .contains("comm")
1377                        );
1378                        assert!(
1379                            reply_to2
1380                                .port_addr()
1381                                .actor_id()
1382                                .label()
1383                                .unwrap()
1384                                .as_str()
1385                                .contains("comm")
1386                        );
1387                    }
1388                    reply_tos.push((reply_to1, reply_to2));
1389                }
1390                _ => {
1391                    panic!("unexpected message: {:?}", msg);
1392                }
1393            }
1394        }
1395
1396        // verify_split_port_paths only applies to the tree path;
1397        // in point-to-point mode no comm actors are involved.
1398        let p2p_threshold =
1399            hyperactor_config::global::get(crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD);
1400        if p2p_threshold == 0 {
1401            // [collect_commactor_routing_tree] only returns ranks, so we need
1402            // to map proc IDs collected in SPLIT_PORT_TREE to ranks.
1403            let rank_lookup = proc_mesh
1404                .ranks()
1405                .iter()
1406                .enumerate()
1407                .map(|(i, r)| (r.proc_addr().clone(), i))
1408                .collect::<HashMap<ProcAddr, usize>>();
1409
1410            // v1 always uses sel!(*) when casting to a mesh.
1411            let selection = sel!(*);
1412            verify_split_port_paths(
1413                &selection,
1414                &proc_mesh.extent(),
1415                &reply_port_ref1,
1416                &reply_port_ref2,
1417                &rank_lookup,
1418            );
1419        }
1420
1421        MeshSetupV1 {
1422            instance,
1423            actor_mesh_ref,
1424            reply1_rx,
1425            reply2_rx,
1426            reply_tos,
1427            host_mesh,
1428        }
1429    }
1430
1431    async fn execute_cast_and_reply_v1() {
1432        let mut setup = setup_mesh_v1::<NoneAccumulator>(None).await;
1433
1434        let ranks = setup.actor_mesh_ref.values().collect::<Vec<_>>();
1435        execute_cast_and_reply(
1436            ranks,
1437            setup.instance,
1438            setup.reply1_rx,
1439            setup.reply2_rx,
1440            setup.reply_tos,
1441        )
1442        .await;
1443
1444        let _ = setup.host_mesh.shutdown(setup.instance).await;
1445    }
1446
1447    #[async_timed_test(timeout_secs = 60)]
1448    async fn test_cast_and_reply_v1_retrofit() {
1449        let config = hyperactor_config::global::lock();
1450        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false);
1451        let _guard2 = config.override_key(
1452            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1453            false,
1454        );
1455        execute_cast_and_reply_v1().await
1456    }
1457
1458    #[async_timed_test(timeout_secs = 60)]
1459    async fn test_cast_and_reply_v1_native() {
1460        let config = hyperactor_config::global::lock();
1461        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1462        let _guard2 = config.override_key(
1463            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1464            true,
1465        );
1466        execute_cast_and_reply_v1().await
1467    }
1468
1469    #[async_timed_test(timeout_secs = 60)]
1470    async fn test_cast_and_reply_v1_native_p2p() {
1471        let config = hyperactor_config::global::lock();
1472        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1473        let _guard2 = config.override_key(
1474            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1475            true,
1476        );
1477        let _guard3 = config.override_key(crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD, 1024);
1478        execute_cast_and_reply_v1().await
1479    }
1480
1481    async fn execute_cast_and_accum_v1(config: &hyperactor_config::global::ConfigLock) {
1482        // Use temporary config for this test
1483        let _guard1 = config.override_key(hyperactor::config::SPLIT_MAX_BUFFER_SIZE, 1);
1484
1485        let mut setup = setup_mesh_v1(Some(accum::sum::<u64>())).await;
1486
1487        let ranks = setup.actor_mesh_ref.values().collect::<Vec<_>>();
1488        execute_cast_and_accum(ranks, setup.instance, setup.reply1_rx, setup.reply_tos).await;
1489
1490        let _ = setup.host_mesh.shutdown(setup.instance).await;
1491    }
1492
1493    #[async_timed_test(timeout_secs = 60)]
1494    async fn test_cast_and_accum_v1_retrofit() {
1495        let config = hyperactor_config::global::lock();
1496        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, false);
1497        let _guard2 = config.override_key(
1498            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1499            false,
1500        );
1501        execute_cast_and_accum_v1(&config).await
1502    }
1503
1504    #[async_timed_test(timeout_secs = 60)]
1505    async fn test_cast_and_accum_v1_native() {
1506        let config = hyperactor_config::global::lock();
1507        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1508        let _guard2 = config.override_key(
1509            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1510            true,
1511        );
1512        execute_cast_and_accum_v1(&config).await
1513    }
1514
1515    #[async_timed_test(timeout_secs = 60)]
1516    async fn test_cast_and_accum_v1_native_p2p() {
1517        let config = hyperactor_config::global::lock();
1518        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1519        let _guard2 = config.override_key(
1520            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1521            true,
1522        );
1523        let _guard3 = config.override_key(crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD, 1024);
1524        execute_cast_and_accum_v1(&config).await
1525    }
1526
1527    struct OncePortMeshSetupV1 {
1528        instance: &'static Instance<testing::TestRootClient>,
1529        reply_rx: hyperactor::mailbox::OncePortReceiver<u64>,
1530        reply_tos: Vec<OncePortRef<u64>>,
1531        _reply_port_ref: OncePortRef<u64>,
1532        host_mesh: HostMesh,
1533    }
1534
1535    async fn setup_once_port_mesh<A>(reducer: Option<A>) -> OncePortMeshSetupV1
1536    where
1537        A: Accumulator<State = u64, Update = u64> + Send + Sync + 'static,
1538    {
1539        let instance = crate::testing::instance();
1540        // We have to use a in process host mesh, because SPLIT_PORT_TREE only
1541        // can collect paths from the same process.
1542        let host_mesh = local_host_mesh(8).await;
1543        let proc_mesh = host_mesh
1544            .spawn(instance, "test", extent!(gpu = 8), None, None)
1545            .await
1546            .unwrap();
1547
1548        let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1549        let params = TestActorParams {
1550            forward_port: tx.bind(),
1551        };
1552        let actor_name =
1553            crate::mesh_id::ActorMeshId::instance(hyperactor::id::Label::new("test").unwrap());
1554        // Make this actor a "system" actor to avoid spawning a controller actor.
1555        let actor_mesh: crate::ActorMesh<TestActor> = proc_mesh
1556            .spawn_with_name(&instance, actor_name, &params, None, true)
1557            .await
1558            .unwrap();
1559        let actor_mesh_ref = actor_mesh.deref().clone();
1560
1561        let has_reducer = reducer.is_some();
1562        let (reply_port_handle, reply_rx) = match reducer {
1563            Some(reducer) => instance.mailbox().open_reduce_port(reducer),
1564            None => instance.mailbox().open_once_port::<u64>(),
1565        };
1566        let reply_port_ref = reply_port_handle.bind();
1567
1568        let message = TestMessage::CastAndReplyOnce {
1569            arg: "abc".to_string(),
1570            reply_to: reply_port_ref.clone(),
1571        };
1572
1573        clear_collected_tree();
1574        actor_mesh_ref.cast(instance, message).unwrap();
1575
1576        let mut reply_tos = vec![];
1577        for _ in proc_mesh.extent().points() {
1578            let msg = rx.recv().await.expect("missing");
1579            match msg {
1580                TestMessage::CastAndReplyOnce { arg, reply_to } => {
1581                    assert_eq!(arg, "abc");
1582                    if has_reducer {
1583                        // With reducer: port is split.
1584                        assert_ne!(reply_to, reply_port_ref);
1585                        let p2p_threshold = hyperactor_config::global::get(
1586                            crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD,
1587                        );
1588                        if p2p_threshold == 0 {
1589                            // Tree path: split port belongs to a comm actor.
1590                            assert!(
1591                                reply_to
1592                                    .port_addr()
1593                                    .actor_id()
1594                                    .label()
1595                                    .unwrap()
1596                                    .as_str()
1597                                    .contains("comm")
1598                            );
1599                        }
1600                    } else {
1601                        // Without reducer: port is passed through unchanged.
1602                        assert_eq!(reply_to, reply_port_ref);
1603                    }
1604                    reply_tos.push(reply_to);
1605                }
1606                _ => {
1607                    panic!("unexpected message: {:?}", msg);
1608                }
1609            }
1610        }
1611
1612        OncePortMeshSetupV1 {
1613            instance,
1614            reply_rx,
1615            reply_tos,
1616            _reply_port_ref: reply_port_ref,
1617            host_mesh,
1618        }
1619    }
1620
1621    async fn execute_cast_and_reply_once_v1() {
1622        // Test OncePort without accumulator - port is NOT split.
1623        // All destinations receive the same original port.
1624        // First reply is delivered, others fail at receiver (port closed).
1625        let mut setup = setup_once_port_mesh::<NoneAccumulator>(None).await;
1626
1627        // All reply_tos point to the same port (not split).
1628        // Only the first message will be delivered successfully.
1629        let num_replies = setup.reply_tos.len();
1630        for (i, reply_to) in setup.reply_tos.into_iter().enumerate() {
1631            reply_to.post(setup.instance, i as u64);
1632        }
1633
1634        // OncePort receives exactly one value (the first to arrive)
1635        let result = setup.reply_rx.recv().await.unwrap();
1636        // The result should be one of the values sent
1637        assert!(result < num_replies as u64);
1638
1639        let _ = setup.host_mesh.shutdown(setup.instance).await;
1640    }
1641
1642    #[async_timed_test(timeout_secs = 60)]
1643    async fn test_cast_and_reply_once_v1() {
1644        execute_cast_and_reply_once_v1().await
1645    }
1646
1647    #[async_timed_test(timeout_secs = 60)]
1648    async fn test_cast_and_reply_once_v1_p2p() {
1649        let config = hyperactor_config::global::lock();
1650        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1651        let _guard2 = config.override_key(
1652            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1653            true,
1654        );
1655        let _guard3 = config.override_key(crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD, 1024);
1656        execute_cast_and_reply_once_v1().await
1657    }
1658
1659    async fn execute_cast_and_accum_once_v1() {
1660        // Test OncePort splitting with sum accumulator.
1661        // Each destination actor replies with its rank.
1662        // The sum of all ranks should be received at the original port.
1663        let mut setup = setup_once_port_mesh(Some(accum::sum::<u64>())).await;
1664
1665        // Each actor replies with its index
1666        let mut expected_sum = 0u64;
1667        for (i, reply_to) in setup.reply_tos.into_iter().enumerate() {
1668            reply_to.post(setup.instance, i as u64);
1669            expected_sum += i as u64;
1670        }
1671
1672        // OncePort should receive the sum of all responses
1673        let result = setup.reply_rx.recv().await.unwrap();
1674        assert_eq!(result, expected_sum);
1675
1676        let _ = setup.host_mesh.shutdown(setup.instance).await;
1677    }
1678
1679    #[async_timed_test(timeout_secs = 60)]
1680    async fn test_cast_and_accum_once_v1() {
1681        execute_cast_and_accum_once_v1().await
1682    }
1683
1684    #[async_timed_test(timeout_secs = 60)]
1685    async fn test_cast_and_accum_once_v1_p2p() {
1686        let config = hyperactor_config::global::lock();
1687        let _guard = config.override_key(ENABLE_NATIVE_V1_CASTING, true);
1688        let _guard2 = config.override_key(
1689            hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER,
1690            true,
1691        );
1692        let _guard3 = config.override_key(crate::config::V1_CAST_POINT_TO_POINT_THRESHOLD, 1024);
1693        execute_cast_and_accum_once_v1().await
1694    }
1695
1696    #[async_timed_test(timeout_secs = 60)]
1697    async fn test_unsplit_port_not_split() {
1698        let instance = crate::testing::instance();
1699        let mut host_mesh = local_host_mesh(8).await;
1700        let proc_mesh = host_mesh
1701            .spawn(instance, "test", extent!(gpu = 8), None, None)
1702            .await
1703            .unwrap();
1704
1705        let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1706        let params = TestActorParams {
1707            forward_port: tx.bind(),
1708        };
1709        let actor_name =
1710            crate::mesh_id::ActorMeshId::instance(hyperactor::id::Label::new("test").unwrap());
1711        let actor_mesh: ActorMesh<TestActor> = proc_mesh
1712            .spawn_with_name(&instance, actor_name, &params, None, true)
1713            .await
1714            .unwrap();
1715        let (reply_port_handle, mut reply_rx) = open_port::<u64>(instance);
1716        let reply_port_ref = reply_port_handle.bind().unsplit();
1717
1718        let message = TestMessage::CastWithUnsplitPort {
1719            reply_to: reply_port_ref.clone(),
1720        };
1721
1722        clear_collected_tree();
1723        actor_mesh.cast(instance, message).unwrap();
1724
1725        // Verify that all destinations received the original port (not split).
1726        let num_points = proc_mesh.extent().points().count();
1727        for _ in 0..num_points {
1728            let msg = rx.recv().await.expect("missing");
1729            match msg {
1730                TestMessage::CastWithUnsplitPort { reply_to } => {
1731                    assert_eq!(
1732                        reply_to.port_addr(),
1733                        reply_port_ref.port_addr(),
1734                        "unsplit port should not be replaced by a comm actor split port"
1735                    );
1736                }
1737                _ => panic!("unexpected message: {:?}", msg),
1738            }
1739        }
1740
1741        // All 8 actors sent replies directly to the same port.
1742        // Verify we receive all 8 replies.
1743        for _ in 0..8 {
1744            let val = reply_rx.recv().await.unwrap();
1745            assert_eq!(val, 42);
1746        }
1747        let _ = host_mesh.shutdown(instance).await;
1748    }
1749}