hyperactor_mesh/
comm.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use crate::comm::multicast::CAST_ORIGINATING_SENDER;
10use crate::reference::ActorMeshId;
11use crate::resource;
12pub mod multicast;
13
14use std::cmp::Ordering;
15use std::collections::HashMap;
16use std::fmt::Debug;
17
18use anyhow::Result;
19use async_trait::async_trait;
20use hyperactor::Actor;
21use hyperactor::ActorId;
22use hyperactor::ActorRef;
23use hyperactor::Context;
24use hyperactor::Handler;
25use hyperactor::Instance;
26use hyperactor::PortRef;
27use hyperactor::WorldId;
28use hyperactor::mailbox::DeliveryError;
29use hyperactor::mailbox::MailboxSender;
30use hyperactor::mailbox::Undeliverable;
31use hyperactor::mailbox::UndeliverableMailboxSender;
32use hyperactor::mailbox::UndeliverableMessageError;
33use hyperactor::mailbox::monitored_return_handle;
34use hyperactor::reference::UnboundPort;
35use ndslice::selection::routing::RoutingFrame;
36use serde::Deserialize;
37use serde::Serialize;
38use typeuri::Named;
39
40use crate::comm::multicast::CastMessage;
41use crate::comm::multicast::CastMessageEnvelope;
42use crate::comm::multicast::ForwardMessage;
43use crate::comm::multicast::set_cast_info_on_headers;
44
45/// Parameters to initialize the CommActor
46#[derive(Debug, Clone, Serialize, Deserialize, Named, Default)]
47pub struct CommActorParams {}
48wirevalue::register_type!(CommActorParams);
49
50/// A message buffered due to out-of-order delivery.
51#[derive(Debug)]
52struct Buffered {
53    /// Sequence number of this message.
54    seq: usize,
55    /// Whether to deliver this message to this comm-actors actors.
56    deliver_here: bool,
57    /// Peer comm actors to forward message to.
58    next_steps: HashMap<usize, Vec<RoutingFrame>>,
59    /// The message to deliver.
60    message: CastMessageEnvelope,
61}
62
63/// Bookkeeping to handle sequence numbers and in-order delivery for messages
64/// sent to and through this comm actor.
65#[derive(Debug, Default)]
66struct ReceiveState {
67    /// The sequence of the last received message.
68    seq: usize,
69    /// A buffer storing messages we received out-of-order, indexed by the seq
70    /// that should precede it.
71    buffer: HashMap<usize, Buffered>,
72    /// A map of the last sequence number we sent to next steps, indexed by rank.
73    last_seqs: HashMap<usize, usize>,
74}
75
76/// This is the comm actor used for efficient and scalable message multicasting
77/// and result accumulation.
78#[derive(Debug, Default)]
79#[hyperactor::export(
80    spawn = true,
81    handlers = [
82        CommActorMode,
83        CastMessage,
84        ForwardMessage,
85    ],
86)]
87pub struct CommActor {
88    /// Sequence numbers are maintained for each (actor mesh id, sender).
89    send_seq: HashMap<(ActorMeshId, ActorId), usize>,
90    /// Each sender is a unique stream.
91    recv_state: HashMap<(ActorMeshId, ActorId), ReceiveState>,
92
93    /// The comm actor's mode.
94    mode: CommActorMode,
95}
96
97/// Configuration for how a `CommActor` determines its own rank and locates peers.
98///
99/// - In `Mesh` mode, the comm actor is assigned an explicit rank and a mapping to each peer by rank.
100/// - In `Implicit` mode, the comm actor infers its rank and peers from its own actor ID.
101#[derive(Debug, Clone, Serialize, Deserialize, Named)]
102#[derive(Default)]
103pub enum CommActorMode {
104    /// When configured as a mesh, the comm actor is assigned a rank
105    /// and a set of references for each peer rank.
106    Mesh(usize, HashMap<usize, ActorRef<CommActor>>),
107
108    /// In an implicit mode, the comm actor derives its rank and
109    /// peers from its own ID.
110    #[default]
111    Implicit,
112
113    /// Like `Implicit`, but override the destination world id.
114    /// This is useful for setups where comm actors may not reside
115    /// in the destination world. It is meant as a temporary bridge
116    /// until we are fully onto ActorMeshes.
117    // TODO: T224926642 Remove this once we are fully onto ActorMeshes.
118    ImplicitWithWorldId(WorldId),
119}
120wirevalue::register_type!(CommActorMode);
121
122impl CommActorMode {
123    /// Return the peer comm actor for the given rank, given a self id,
124    /// destination port, and rank.
125    fn peer_for_rank(&self, self_id: &ActorId, rank: usize) -> Result<ActorRef<CommActor>> {
126        match self {
127            Self::Mesh(_self_rank, peers) => peers
128                .get(&rank)
129                .cloned()
130                .ok_or_else(|| anyhow::anyhow!("no peer for rank {}", rank)),
131            Self::Implicit => {
132                let world_id = self_id
133                    .proc_id()
134                    .world_id()
135                    .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc"))?;
136                let proc_id = world_id.proc_id(rank);
137                let actor_id = ActorId::root(proc_id, self_id.name().to_string());
138                Ok(ActorRef::<CommActor>::attest(actor_id))
139            }
140            Self::ImplicitWithWorldId(world_id) => {
141                let proc_id = world_id.proc_id(rank);
142                let actor_id = ActorId::root(proc_id, self_id.name().to_string());
143                Ok(ActorRef::<CommActor>::attest(actor_id))
144            }
145        }
146    }
147
148    /// Return the rank of the comm actor, given a self id.
149    fn self_rank(&self, self_id: &ActorId) -> Result<usize> {
150        match self {
151            Self::Mesh(rank, _) => Ok(*rank),
152            Self::Implicit | Self::ImplicitWithWorldId(_) => self_id
153                .proc_id()
154                .rank()
155                .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc")),
156        }
157    }
158}
159
160#[async_trait]
161impl Actor for CommActor {
162    // This is an override of the default actor behavior.
163    async fn handle_undeliverable_message(
164        &mut self,
165        cx: &Instance<Self>,
166        undelivered: hyperactor::mailbox::Undeliverable<hyperactor::mailbox::MessageEnvelope>,
167    ) -> Result<(), anyhow::Error> {
168        let Undeliverable(mut message_envelope) = undelivered;
169
170        // 1. Case delivery failure at a "forwarding" step.
171        if let Ok(ForwardMessage { message, .. }) =
172            message_envelope.deserialized::<ForwardMessage>()
173        {
174            let sender = message.sender();
175            let return_port = PortRef::attest_message_port(sender);
176            message_envelope.set_error(DeliveryError::Multicast(format!(
177                "comm actor {} failed to forward the cast message; return to \
178                its original sender's port {}",
179                cx.self_id(),
180                return_port.port_id(),
181            )));
182            return_port
183                .send(cx, Undeliverable(message_envelope.clone()))
184                .map_err(|err| {
185                    let error = DeliveryError::BrokenLink(format!(
186                        "error occured when returning ForwardMessage to the original \
187                        sender's port {}; error is: {}",
188                        return_port.port_id(),
189                        err,
190                    ));
191                    message_envelope.set_error(error);
192                    UndeliverableMessageError::ReturnFailure {
193                        envelope: message_envelope,
194                    }
195                })?;
196            return Ok(());
197        }
198
199        // 2. Case delivery failure at a "deliver here" step.
200        if let Some(sender) = message_envelope.headers().get(CAST_ORIGINATING_SENDER) {
201            let return_port = PortRef::attest_message_port(sender);
202            message_envelope.set_error(DeliveryError::Multicast(format!(
203                "comm actor {} failed to deliver the cast message to the dest \
204                actor; return to its original sender's port {}",
205                cx.self_id(),
206                return_port.port_id(),
207            )));
208            return_port
209                .send(cx, Undeliverable(message_envelope.clone()))
210                .map_err(|err| {
211                    let error = DeliveryError::BrokenLink(format!(
212                        "error occured when returning cast message to the original \
213                        sender's port {}; error is: {}",
214                        return_port.port_id(),
215                        err,
216                    ));
217                    message_envelope.set_error(error);
218                    UndeliverableMessageError::ReturnFailure {
219                        envelope: message_envelope,
220                    }
221                })?;
222            return Ok(());
223        }
224
225        // 3. A return of an undeliverable message was itself returned.
226        UndeliverableMailboxSender
227            .post(message_envelope, /*unused */ monitored_return_handle());
228        Ok(())
229    }
230}
231
232impl CommActor {
233    /// Forward the message to the comm actor on the given peer rank.
234    fn forward(
235        cx: &Instance<Self>,
236        mode: &CommActorMode,
237        rank: usize,
238        message: ForwardMessage,
239    ) -> Result<()> {
240        let child = mode.peer_for_rank(cx.self_id(), rank)?;
241        child.send(cx, message)?;
242        Ok(())
243    }
244
245    fn handle_message(
246        cx: &Context<Self>,
247        mode: &CommActorMode,
248        deliver_here: bool,
249        next_steps: HashMap<usize, Vec<RoutingFrame>>,
250        sender: ActorId,
251        mut message: CastMessageEnvelope,
252        seq: usize,
253        last_seqs: &mut HashMap<usize, usize>,
254    ) -> Result<()> {
255        // Split ports, if any, and update message with new ports. In this
256        // way, children actors will reply to this comm actor's ports, instead
257        // of to the original ports provided by parent.
258        message.data_mut().visit_mut::<UnboundPort>(
259            |UnboundPort(port_id, reducer_spec, reducer_opts, return_undeliverable)| {
260                let split = port_id.split(
261                    cx,
262                    reducer_spec.clone(),
263                    reducer_opts.clone(),
264                    *return_undeliverable,
265                )?;
266
267                #[cfg(test)]
268                tests::collect_split_port(port_id, &split, deliver_here);
269
270                *port_id = split;
271                Ok(())
272            },
273        )?;
274
275        // Deliver message here, if necessary.
276        if deliver_here {
277            let rank_on_root_mesh = mode.self_rank(cx.self_id())?;
278            let cast_rank = message.relative_rank(rank_on_root_mesh)?;
279            // Replace ranks with self ranks.
280            message
281                .data_mut()
282                .visit_mut::<resource::Rank>(|resource::Rank(rank)| {
283                    *rank = Some(cast_rank);
284                    Ok(())
285                })?;
286            let cast_shape = message.shape();
287            let point = cast_shape
288                .extent()
289                .point_of_rank(cast_rank)
290                .expect("rank out of bounds");
291            let mut headers = cx.headers().clone();
292            set_cast_info_on_headers(&mut headers, point, message.sender().clone());
293            cx.post(
294                cx.self_id()
295                    .proc_id()
296                    .actor_id(message.dest_port().actor_name(), 0)
297                    .port_id(message.dest_port().port()),
298                headers,
299                wirevalue::Any::serialize(message.data())?,
300            );
301        }
302
303        // Forward to peers.
304        next_steps
305            .into_iter()
306            .map(|(peer, dests)| {
307                let last_seq = last_seqs.entry(peer).or_default();
308                Self::forward(
309                    cx,
310                    mode,
311                    peer,
312                    ForwardMessage {
313                        dests,
314                        sender: sender.clone(),
315                        message: message.clone(),
316                        seq,
317                        last_seq: *last_seq,
318                    },
319                )?;
320                *last_seq = seq;
321                Ok(())
322            })
323            .collect::<Result<Vec<_>>>()?;
324
325        Ok(())
326    }
327}
328
329#[async_trait]
330impl Handler<CommActorMode> for CommActor {
331    async fn handle(&mut self, _cx: &Context<Self>, mode: CommActorMode) -> Result<()> {
332        self.mode = mode;
333        Ok(())
334    }
335}
336
337// TODO(T218630526): reliable casting for mutable topology
338#[async_trait]
339impl Handler<CastMessage> for CommActor {
340    #[hyperactor::instrument]
341    async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessage) -> Result<()> {
342        // Always forward the message to the root rank of the slice, casting starts from there.
343        let slice = cast_message.dest.slice.clone();
344        let selection = cast_message.dest.selection.clone();
345        let frame = RoutingFrame::root(selection, slice);
346        let rank = frame.slice.location(&frame.here)?;
347        let seq = self
348            .send_seq
349            .entry(cast_message.message.stream_key())
350            .or_default();
351        let last_seq = *seq;
352        *seq += 1;
353
354        let fwd_message = ForwardMessage {
355            dests: vec![frame],
356            sender: cx.self_id().clone(),
357            message: cast_message.message,
358            seq: *seq,
359            last_seq,
360        };
361
362        // Optimization: if forwarding to ourselves, handle inline instead of
363        // going through the message queue
364        if self
365            .mode
366            .self_rank(cx.self_id())
367            .is_ok_and(|self_rank| self_rank == rank)
368        {
369            Handler::<ForwardMessage>::handle(self, cx, fwd_message).await?;
370        } else {
371            Self::forward(cx, &self.mode, rank, fwd_message)?;
372        }
373        Ok(())
374    }
375}
376
377#[async_trait]
378impl Handler<ForwardMessage> for CommActor {
379    #[hyperactor::instrument]
380    async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessage) -> Result<()> {
381        let ForwardMessage {
382            sender,
383            dests,
384            message,
385            seq,
386            last_seq,
387        } = fwd_message;
388
389        // Resolve/dedup routing frames.
390        let rank = self.mode.self_rank(cx.self_id())?;
391        let (deliver_here, next_steps) =
392            ndslice::selection::routing::resolve_routing(rank, dests, &mut |_| {
393                panic!("Choice encountered in CommActor routing")
394            })?;
395
396        let recv_state = self.recv_state.entry(message.stream_key()).or_default();
397        match recv_state.seq.cmp(&last_seq) {
398            // We got the expected next message to deliver to this host.
399            Ordering::Equal => {
400                // We got an in-order operation, so handle it now.
401                Self::handle_message(
402                    cx,
403                    &self.mode,
404                    deliver_here,
405                    next_steps,
406                    sender.clone(),
407                    message,
408                    seq,
409                    &mut recv_state.last_seqs,
410                )?;
411                recv_state.seq = seq;
412
413                // Also deliver any pending operations from the recv buffer that
414                // were received out-of-order that are now unblocked.
415                while let Some(Buffered {
416                    seq,
417                    deliver_here,
418                    next_steps,
419                    message,
420                }) = recv_state.buffer.remove(&recv_state.seq)
421                {
422                    Self::handle_message(
423                        cx,
424                        &self.mode,
425                        deliver_here,
426                        next_steps,
427                        sender.clone(),
428                        message,
429                        seq,
430                        &mut recv_state.last_seqs,
431                    )?;
432                    recv_state.seq = seq;
433                }
434            }
435            // We got an out-of-order operation, so buffer it for now, until we
436            // recieved the onces sequenced before it.
437            Ordering::Less => {
438                tracing::warn!(
439                    "buffering out-of-order message with seq {} (last {}), expected {}: {:?}",
440                    seq,
441                    last_seq,
442                    recv_state.seq,
443                    message
444                );
445                recv_state.buffer.insert(
446                    last_seq,
447                    Buffered {
448                        seq,
449                        deliver_here,
450                        next_steps,
451                        message,
452                    },
453                );
454            }
455            // We already got this message -- just drop it.
456            Ordering::Greater => {
457                tracing::warn!("received duplicate message with seq {}: {:?}", seq, message);
458            }
459        }
460
461        Ok(())
462    }
463}
464
465pub mod test_utils {
466    use anyhow::Result;
467    use async_trait::async_trait;
468    use hyperactor::Actor;
469    use hyperactor::ActorId;
470    use hyperactor::Bind;
471    use hyperactor::Context;
472    use hyperactor::Handler;
473    use hyperactor::PortRef;
474    use hyperactor::Unbind;
475    use serde::Deserialize;
476    use serde::Serialize;
477    use typeuri::Named;
478
479    use super::*;
480
481    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
482    pub struct MyReply {
483        pub sender: ActorId,
484        pub value: u64,
485    }
486
487    #[derive(Debug, Named, Serialize, Deserialize, PartialEq, Clone, Bind, Unbind)]
488    pub enum TestMessage {
489        Forward(String),
490        CastAndReply {
491            arg: String,
492            // Intentionally not including 0. As a result, this port will not be
493            // split.
494            // #[binding(include)]
495            reply_to0: PortRef<String>,
496            #[binding(include)]
497            reply_to1: PortRef<u64>,
498            #[binding(include)]
499            reply_to2: PortRef<MyReply>,
500        },
501    }
502
503    #[derive(Debug)]
504    #[hyperactor::export(
505        spawn = true,
506        handlers = [
507            TestMessage { cast = true },
508        ],
509    )]
510    pub struct TestActor {
511        // Forward the received message to this port, so it can be inspected by
512        // the unit test.
513        forward_port: PortRef<TestMessage>,
514    }
515
516    #[derive(Debug, Clone, Named, Serialize, Deserialize)]
517    pub struct TestActorParams {
518        pub forward_port: PortRef<TestMessage>,
519    }
520
521    #[async_trait]
522    impl Actor for TestActor {}
523
524    #[async_trait]
525    impl hyperactor::RemoteSpawn for TestActor {
526        type Params = TestActorParams;
527
528        async fn new(params: Self::Params) -> Result<Self> {
529            let Self::Params { forward_port } = params;
530            Ok(Self { forward_port })
531        }
532    }
533
534    #[async_trait]
535    impl Handler<TestMessage> for TestActor {
536        async fn handle(&mut self, cx: &Context<Self>, msg: TestMessage) -> anyhow::Result<()> {
537            self.forward_port.send(cx, msg)?;
538            Ok(())
539        }
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use std::collections::BTreeMap;
546    use std::collections::HashSet;
547    use std::fmt::Display;
548    use std::hash::Hash;
549    use std::ops::Deref;
550    use std::ops::DerefMut;
551    use std::sync::Arc;
552    use std::sync::Mutex;
553    use std::sync::OnceLock;
554
555    use hyperactor::PortId;
556    use hyperactor::PortRef;
557    use hyperactor::accum;
558    use hyperactor::accum::Accumulator;
559    use hyperactor::accum::ReducerSpec;
560    use hyperactor::channel::ChannelTransport;
561    use hyperactor::clock::Clock;
562    use hyperactor::clock::RealClock;
563    use hyperactor::context;
564    use hyperactor::context::Mailbox;
565    use hyperactor::mailbox::PortReceiver;
566    use hyperactor::mailbox::open_port;
567    use hyperactor::reference::Index;
568    use hyperactor_config;
569    use hyperactor_mesh_macros::sel;
570    use maplit::btreemap;
571    use maplit::hashmap;
572    use ndslice::Extent;
573    use ndslice::Selection;
574    use ndslice::ViewExt as _;
575    use ndslice::extent;
576    use ndslice::selection::test_utils::collect_commactor_routing_tree;
577    use test_utils::*;
578    use timed_test::async_timed_test;
579    use tokio::time::Duration;
580
581    use super::*;
582    use crate::ProcMesh;
583    use crate::actor_mesh::ActorMesh;
584    use crate::actor_mesh::RootActorMesh;
585    use crate::alloc::AllocSpec;
586    use crate::alloc::Allocator;
587    use crate::alloc::LocalAllocator;
588    use crate::proc_mesh::SharedSpawnable;
589    use crate::v1;
590    use crate::v1::testing;
591
592    struct Edge<T> {
593        from: T,
594        to: T,
595        is_leaf: bool,
596    }
597
598    impl<T> From<(T, T, bool)> for Edge<T> {
599        fn from((from, to, is_leaf): (T, T, bool)) -> Self {
600            Self { from, to, is_leaf }
601        }
602    }
603
604    // The relationship between original ports and split ports. The elements in
605    // the tuple are (original port, split port, deliver_here).
606    static SPLIT_PORT_TREE: OnceLock<Mutex<Vec<Edge<PortId>>>> = OnceLock::new();
607
608    // Collect the relationships between original ports and split ports into
609    // SPLIT_PORT_TREE. This is used by tests to verify that ports are split as expected.
610    pub(crate) fn collect_split_port(original: &PortId, split: &PortId, deliver_here: bool) {
611        let mutex = SPLIT_PORT_TREE.get_or_init(|| Mutex::new(vec![]));
612        let mut tree = mutex.lock().unwrap();
613
614        tree.deref_mut().push(Edge {
615            from: original.clone(),
616            to: split.clone(),
617            is_leaf: deliver_here,
618        });
619    }
620
621    // There could be other cast calls before the one we want to check, e.g. from
622    // allocating the proc mesh, or spawning the actor mesh. Clear the collected
623    // tree so it will only contain the cast we want to check.
624    fn clear_collected_tree() {
625        if let Some(tree) = SPLIT_PORT_TREE.get() {
626            let mut tree = tree.lock().unwrap();
627            tree.clear();
628        }
629    }
630
631    // A representation of a tree.
632    //   * Map's keys are the tree's leafs;
633    //   * Map's values are the path from the root to that leaf.
634    #[derive(PartialEq)]
635    struct PathToLeaves<T>(BTreeMap<T, Vec<T>>);
636
637    // Add a custom Debug trait impl so the result from assert_eq! is readable.
638    impl<T: Display> Debug for PathToLeaves<T> {
639        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
640            fn vec_to_string<T: Display>(v: &[T]) -> String {
641                v.iter()
642                    .map(ToString::to_string)
643                    .collect::<Vec<String>>()
644                    .join(", ")
645            }
646
647            for (src, path) in &self.0 {
648                write!(f, "{} -> {}\n", src, vec_to_string(path))?;
649            }
650            Ok(())
651        }
652    }
653
654    fn build_paths<T: Clone + Eq + Hash + Ord>(edges: &[Edge<T>]) -> PathToLeaves<T> {
655        let mut child_parent_map = HashMap::new();
656        let mut all_nodes = HashSet::new();
657        let mut parents = HashSet::new();
658        let mut children = HashSet::new();
659        let mut dests = HashSet::new();
660
661        // Build parent map and track all nodes and children
662        for Edge { from, to, is_leaf } in edges {
663            child_parent_map.insert(to.clone(), from.clone());
664            all_nodes.insert(from.clone());
665            all_nodes.insert(to.clone());
666            parents.insert(from.clone());
667            children.insert(to.clone());
668            if *is_leaf {
669                dests.insert(to.clone());
670            }
671        }
672
673        // For each leaf, reconstruct path back to root
674        let mut result = BTreeMap::new();
675        for dest in dests {
676            let mut path = vec![dest.clone()];
677            let mut current = dest.clone();
678            while let Some(parent) = child_parent_map.get(&current) {
679                path.push(parent.clone());
680                current = parent.clone();
681            }
682            path.reverse();
683            result.insert(dest, path);
684        }
685
686        PathToLeaves(result)
687    }
688
689    #[test]
690    fn test_build_paths() {
691        // Given the tree:
692        //     0
693        //    / \
694        //   1   4
695        //  / \   \
696        // 2   3   5
697        let edges: Vec<_> = [
698            (0, 1, false),
699            (1, 2, true),
700            (1, 3, true),
701            (0, 4, true),
702            (4, 5, true),
703        ]
704        .into_iter()
705        .map(|(from, to, is_leaf)| Edge { from, to, is_leaf })
706        .collect();
707
708        let paths = build_paths(&edges);
709
710        let expected = btreemap! {
711            2 => vec![0, 1, 2],
712            3 => vec![0, 1, 3],
713            4 => vec![0, 4],
714            5 => vec![0, 4, 5],
715        };
716
717        assert_eq!(paths.0, expected);
718    }
719
720    //  Given a port tree,
721    //     * remove the client port, i.e. the 1st element of the path;
722    //     * verify all remaining ports are comm actor ports;
723    //     * remove the actor information and return a rank-based tree representation.
724    //
725    //  The rank-based tree representation is what [collect_commactor_routing_tree] returns.
726    //  This conversion enables us to compare the path against [collect_commactor_routing_tree]'s result.
727    //
728    //      For example, for a 2x2 slice, the port tree could look like:
729    //      dest[0].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028]]
730    //      dest[1].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[1].comm[0][1028]]
731    //      dest[2].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[2].comm[0][1028]]
732    //      dest[3].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[2].comm[0][1028], dest[3].comm[0][1028]]
733    //
734    //     The result should be:
735    //     0 -> 0
736    //     1 -> 0, 1
737    //     2 -> 0, 2
738    //     3 -> 0, 2, 3
739    fn get_ranks(paths: PathToLeaves<PortId>, client_reply: &PortId) -> PathToLeaves<Index> {
740        let ranks = paths
741            .0
742            .into_iter()
743            .map(|(dst, mut path)| {
744                let first = path.remove(0);
745                // The first PortId is the client's reply port.
746                assert_eq!(&first, client_reply);
747                // Other ports's actor ID must be dest[?].comm[0], where ? is
748                // the rank we want to extract here.
749                assert!(dst.actor_id().name().contains("comm"));
750                let actor_path = path
751                    .into_iter()
752                    .map(|p| {
753                        assert!(p.actor_id().name().contains("comm"));
754                        p.actor_id().rank()
755                    })
756                    .collect();
757                (dst.into_actor_id().rank(), actor_path)
758            })
759            .collect();
760        PathToLeaves(ranks)
761    }
762
763    struct MeshSetup {
764        actor_mesh: RootActorMesh<'static, TestActor>,
765        reply1_rx: PortReceiver<u64>,
766        reply2_rx: PortReceiver<MyReply>,
767        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
768    }
769
770    struct NoneAccumulator;
771
772    impl Accumulator for NoneAccumulator {
773        type State = u64;
774        type Update = u64;
775
776        fn accumulate(
777            &self,
778            _state: &mut Self::State,
779            _update: Self::Update,
780        ) -> anyhow::Result<()> {
781            unimplemented!()
782        }
783
784        fn reducer_spec(&self) -> Option<ReducerSpec> {
785            unimplemented!()
786        }
787    }
788
789    // Verify the split port paths are the same as the casting paths.
790    fn verify_split_port_paths(
791        selection: &Selection,
792        extent: &Extent,
793        reply_port_ref1: &PortRef<u64>,
794        reply_port_ref2: &PortRef<MyReply>,
795    ) {
796        // Get the paths used in casting
797        let sel_paths = PathToLeaves(
798            collect_commactor_routing_tree(selection, &extent.to_slice())
799                .delivered
800                .into_iter()
801                .collect(),
802        );
803
804        // Get the split port paths collected in SPLIT_PORT_TREE during casting
805        let (reply1_paths, reply2_paths) = {
806            let tree = SPLIT_PORT_TREE.get().unwrap();
807            let edges = tree.lock().unwrap();
808            let (reply1, reply2): (BTreeMap<_, _>, BTreeMap<_, _>) = build_paths(&edges)
809                .0
810                .into_iter()
811                .partition(|(_dst, path)| &path[0] == reply_port_ref1.port_id());
812            (
813                get_ranks(PathToLeaves(reply1), reply_port_ref1.port_id()),
814                get_ranks(PathToLeaves(reply2), reply_port_ref2.port_id()),
815            )
816        };
817
818        // split port paths should be the same as casting paths
819        assert_eq!(sel_paths, reply1_paths);
820        assert_eq!(sel_paths, reply2_paths);
821    }
822
823    async fn setup_mesh<A>(accum: Option<A>) -> MeshSetup
824    where
825        A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
826    {
827        let extent = extent!(replica = 4, host = 4, gpu = 4);
828        let alloc = LocalAllocator
829            .allocate(AllocSpec {
830                extent: extent.clone(),
831                constraints: Default::default(),
832                proc_name: None,
833                transport: ChannelTransport::Local,
834                proc_allocation_mode: Default::default(),
835            })
836            .await
837            .unwrap();
838
839        let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap());
840        let dest_actor_name = "dest_actor";
841        let (tx, mut rx) = hyperactor::mailbox::open_port(proc_mesh.client());
842        let params = TestActorParams {
843            forward_port: tx.bind(),
844        };
845        let instance = crate::v1::testing::instance();
846        let actor_mesh: RootActorMesh<TestActor> = Arc::clone(&proc_mesh)
847            .spawn(&instance, dest_actor_name, &params)
848            .await
849            .unwrap();
850
851        let (reply_port_handle0, _) = open_port::<String>(proc_mesh.client());
852        let reply_port_ref0 = reply_port_handle0.bind();
853        let (reply_port_handle1, reply1_rx) = match accum {
854            Some(a) => proc_mesh.client().mailbox().open_accum_port(a),
855            None => open_port(proc_mesh.client()),
856        };
857        let reply_port_ref1 = reply_port_handle1.bind();
858        let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(proc_mesh.client());
859        let reply_port_ref2 = reply_port_handle2.bind();
860        let message = TestMessage::CastAndReply {
861            arg: "abc".to_string(),
862            reply_to0: reply_port_ref0.clone(),
863            reply_to1: reply_port_ref1.clone(),
864            reply_to2: reply_port_ref2.clone(),
865        };
866
867        let selection = sel!(*);
868        clear_collected_tree();
869        actor_mesh
870            .cast(proc_mesh.client(), selection.clone(), message)
871            .unwrap();
872
873        let mut reply_tos = vec![];
874        for _ in extent.points() {
875            let msg = rx.recv().await.expect("missing");
876            match msg {
877                TestMessage::CastAndReply {
878                    arg,
879                    reply_to0,
880                    reply_to1,
881                    reply_to2,
882                } => {
883                    assert_eq!(arg, "abc");
884                    // port 0 is still the same as the original one because it
885                    // is not included in MutVisitor.
886                    assert_eq!(reply_to0, reply_port_ref0);
887                    // ports have been replaced by comm actor's split ports.
888                    assert_ne!(reply_to1, reply_port_ref1);
889                    assert_eq!(reply_to1.port_id().actor_id().name(), "comm");
890                    assert_ne!(reply_to2, reply_port_ref2);
891                    assert_eq!(reply_to2.port_id().actor_id().name(), "comm");
892                    reply_tos.push((reply_to1, reply_to2));
893                }
894                _ => {
895                    panic!("unexpected message: {:?}", msg);
896                }
897            }
898        }
899
900        verify_split_port_paths(&selection, &extent, &reply_port_ref1, &reply_port_ref2);
901
902        MeshSetup {
903            actor_mesh,
904            reply1_rx,
905            reply2_rx,
906            reply_tos,
907        }
908    }
909
910    async fn execute_cast_and_reply(
911        ranks: Vec<ActorRef<TestActor>>,
912        instance: &impl context::Actor,
913        mut reply1_rx: PortReceiver<u64>,
914        mut reply2_rx: PortReceiver<MyReply>,
915        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
916    ) {
917        // Reply from each dest actor. The replies should be received by client.
918        {
919            for (dest_actor, (reply_to1, reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
920                let rank = dest_actor.actor_id().rank() as u64;
921                reply_to1.send(instance, rank).unwrap();
922                let my_reply = MyReply {
923                    sender: dest_actor.actor_id().clone(),
924                    value: rank,
925                };
926                reply_to2.send(instance, my_reply.clone()).unwrap();
927
928                assert_eq!(reply1_rx.recv().await.unwrap(), rank);
929                assert_eq!(reply2_rx.recv().await.unwrap(), my_reply);
930            }
931        }
932
933        tracing::info!("the 1st updates from all dest actors were receivered by client");
934
935        // Now send multiple replies from the dest actors. They should all be
936        // received by client. Replies sent from the same dest actor should
937        // be received in the same order as they were sent out.
938        {
939            let n = 100;
940            let mut expected2: HashMap<usize, Vec<MyReply>> = hashmap! {};
941            for (dest_actor, (_reply_to1, reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
942                let rank = dest_actor.actor_id().rank();
943                let mut sent2 = vec![];
944                for i in 0..n {
945                    let value = (rank * 100 + i) as u64;
946                    let my_reply = MyReply {
947                        sender: dest_actor.actor_id().clone(),
948                        value,
949                    };
950                    reply_to2.send(instance, my_reply.clone()).unwrap();
951                    sent2.push(my_reply);
952                }
953                assert!(
954                    expected2.insert(rank, sent2).is_none(),
955                    "duplicate rank {rank} in map"
956                );
957            }
958
959            let mut received2: HashMap<usize, Vec<MyReply>> = hashmap! {};
960
961            for _ in 0..(n * ranks.len()) {
962                let my_reply = reply2_rx.recv().await.unwrap();
963                received2
964                    .entry(my_reply.sender.rank())
965                    .or_default()
966                    .push(my_reply);
967            }
968            assert_eq!(received2, expected2);
969        }
970    }
971
972    #[async_timed_test(timeout_secs = 30)]
973    async fn test_cast_and_reply() {
974        let MeshSetup {
975            actor_mesh,
976            reply1_rx,
977            reply2_rx,
978            reply_tos,
979            ..
980        } = setup_mesh::<NoneAccumulator>(None).await;
981        let proc_mesh_client = actor_mesh.proc_mesh().client();
982
983        let ranks = actor_mesh.ranks().clone();
984        execute_cast_and_reply(ranks, proc_mesh_client, reply1_rx, reply2_rx, reply_tos).await;
985    }
986
987    async fn wait_for_with_timeout(
988        receiver: &mut PortReceiver<u64>,
989        expected: u64,
990        dur: Duration,
991    ) -> anyhow::Result<()> {
992        // timeout wraps the entire async block
993        RealClock
994            .timeout(dur, async {
995                loop {
996                    let msg = receiver.recv().await.unwrap();
997                    if msg == expected {
998                        break;
999                    }
1000                }
1001            })
1002            .await?;
1003        Ok(())
1004    }
1005
1006    async fn execute_cast_and_accum(
1007        ranks: Vec<ActorRef<TestActor>>,
1008        instance: &impl context::Actor,
1009        mut reply1_rx: PortReceiver<u64>,
1010        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1011    ) {
1012        // Now send multiple replies from the dest actors. They should all be
1013        // received by client. Replies sent from the same dest actor should
1014        // be received in the same order as they were sent out.
1015        let mut sum = 0;
1016        let n = 100;
1017        for (dest_actor, (reply_to1, _reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
1018            let rank = dest_actor.actor_id().rank();
1019            for i in 0..n {
1020                let value = (rank + i) as u64;
1021                reply_to1.send(instance, value).unwrap();
1022                sum += value;
1023            }
1024        }
1025        wait_for_with_timeout(&mut reply1_rx, sum, Duration::from_secs(2))
1026            .await
1027            .unwrap();
1028        // no more messages
1029        RealClock.sleep(Duration::from_secs(2)).await;
1030        let msg = reply1_rx.try_recv().unwrap();
1031        assert_eq!(msg, None);
1032    }
1033
1034    #[async_timed_test(timeout_secs = 30)]
1035    async fn test_cast_and_accum() {
1036        let config = hyperactor_config::global::lock();
1037        // Use temporary config for this test
1038        let _guard1 = config.override_key(hyperactor::config::SPLIT_MAX_BUFFER_SIZE, 1);
1039
1040        let MeshSetup {
1041            actor_mesh,
1042            reply1_rx,
1043            reply_tos,
1044            ..
1045        } = setup_mesh(Some(accum::sum::<u64>())).await;
1046        let proc_mesh_client = actor_mesh.proc_mesh().client();
1047        let ranks = actor_mesh.ranks().clone();
1048        execute_cast_and_accum(ranks, proc_mesh_client, reply1_rx, reply_tos).await;
1049    }
1050
1051    struct MeshSetupV1 {
1052        instance: &'static Instance<testing::TestRootClient>,
1053        actor_mesh_ref: v1::ActorMeshRef<TestActor>,
1054        reply1_rx: PortReceiver<u64>,
1055        reply2_rx: PortReceiver<MyReply>,
1056        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1057    }
1058
1059    async fn setup_mesh_v1<A>(accum: Option<A>) -> MeshSetupV1
1060    where
1061        A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
1062    {
1063        let instance = v1::testing::instance();
1064
1065        let extent = extent!(replica = 4, host = 4, gpu = 4);
1066        let alloc = LocalAllocator
1067            .allocate(AllocSpec {
1068                extent: extent.clone(),
1069                constraints: Default::default(),
1070                proc_name: None,
1071                transport: ChannelTransport::Local,
1072                proc_allocation_mode: Default::default(),
1073            })
1074            .await
1075            .unwrap();
1076
1077        let proc_mesh = v1::ProcMesh::allocate(instance, Box::new(alloc), "test_local")
1078            .await
1079            .unwrap();
1080
1081        let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1082        let params = TestActorParams {
1083            forward_port: tx.bind(),
1084        };
1085        let actor_name = v1::Name::new("test").expect("valid test name");
1086        // Make this actor a "system" actor to avoid spawning a controller actor.
1087        // This test is verifying the whole comm tree, so we want fewer actors
1088        // involved.
1089        let actor_mesh = proc_mesh
1090            .spawn_with_name(&instance, actor_name, &params, None, true)
1091            .await
1092            .unwrap();
1093        let actor_mesh_ref = actor_mesh.deref().clone();
1094
1095        let (reply_port_handle0, _) = open_port::<String>(instance);
1096        let reply_port_ref0 = reply_port_handle0.bind();
1097        let (reply_port_handle1, reply1_rx) = match accum {
1098            Some(a) => instance.mailbox().open_accum_port(a),
1099            None => open_port(instance),
1100        };
1101        let reply_port_ref1 = reply_port_handle1.bind();
1102        let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(instance);
1103        let reply_port_ref2 = reply_port_handle2.bind();
1104        let message = TestMessage::CastAndReply {
1105            arg: "abc".to_string(),
1106            reply_to0: reply_port_ref0.clone(),
1107            reply_to1: reply_port_ref1.clone(),
1108            reply_to2: reply_port_ref2.clone(),
1109        };
1110
1111        clear_collected_tree();
1112        actor_mesh_ref.cast(instance, message).unwrap();
1113
1114        let mut reply_tos = vec![];
1115        for _ in extent.points() {
1116            let msg = rx.recv().await.expect("missing");
1117            match msg {
1118                TestMessage::CastAndReply {
1119                    arg,
1120                    reply_to0,
1121                    reply_to1,
1122                    reply_to2,
1123                } => {
1124                    assert_eq!(arg, "abc");
1125                    // port 0 is still the same as the original one because it
1126                    // is not included in MutVisitor.
1127                    assert_eq!(reply_to0, reply_port_ref0);
1128                    // ports have been replaced by comm actor's split ports.
1129                    assert_ne!(reply_to1, reply_port_ref1);
1130                    assert!(reply_to1.port_id().actor_id().name().contains("comm"));
1131                    assert_ne!(reply_to2, reply_port_ref2);
1132                    assert!(reply_to2.port_id().actor_id().name().contains("comm"));
1133                    reply_tos.push((reply_to1, reply_to2));
1134                }
1135                _ => {
1136                    panic!("unexpected message: {:?}", msg);
1137                }
1138            }
1139        }
1140
1141        // v1 always uses sel!(*) when casting to a mesh.
1142        let selection = sel!(*);
1143        verify_split_port_paths(&selection, &extent, &reply_port_ref1, &reply_port_ref2);
1144
1145        MeshSetupV1 {
1146            instance,
1147            actor_mesh_ref,
1148            reply1_rx,
1149            reply2_rx,
1150            reply_tos,
1151        }
1152    }
1153
1154    #[async_timed_test(timeout_secs = 30)]
1155    async fn test_cast_and_reply_v1() {
1156        let MeshSetupV1 {
1157            instance,
1158            actor_mesh_ref,
1159            reply1_rx,
1160            reply2_rx,
1161            reply_tos,
1162            ..
1163        } = setup_mesh_v1::<NoneAccumulator>(None).await;
1164
1165        let ranks = actor_mesh_ref.values().collect::<Vec<_>>();
1166        execute_cast_and_reply(ranks, instance, reply1_rx, reply2_rx, reply_tos).await;
1167    }
1168
1169    #[async_timed_test(timeout_secs = 30)]
1170    async fn test_cast_and_accum_v1() {
1171        let config = hyperactor_config::global::lock();
1172        // Use temporary config for this test
1173        let _guard1 = config.override_key(hyperactor::config::SPLIT_MAX_BUFFER_SIZE, 1);
1174
1175        let MeshSetupV1 {
1176            instance,
1177            actor_mesh_ref,
1178            reply1_rx,
1179            reply_tos,
1180            ..
1181        } = setup_mesh_v1(Some(accum::sum::<u64>())).await;
1182
1183        let ranks = actor_mesh_ref.values().collect::<Vec<_>>();
1184        execute_cast_and_accum(ranks, instance, reply1_rx, reply_tos).await;
1185    }
1186}