hyperactor_mesh/
comm.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use crate::comm::multicast::CAST_ORIGINATING_SENDER;
10use crate::reference::ActorMeshId;
11pub mod multicast;
12
13use std::cmp::Ordering;
14use std::collections::HashMap;
15use std::fmt::Debug;
16
17use anyhow::Result;
18use async_trait::async_trait;
19use hyperactor::Actor;
20use hyperactor::ActorId;
21use hyperactor::ActorRef;
22use hyperactor::Context;
23use hyperactor::Handler;
24use hyperactor::Instance;
25use hyperactor::Named;
26use hyperactor::PortRef;
27use hyperactor::WorldId;
28use hyperactor::data::Serialized;
29use hyperactor::mailbox::DeliveryError;
30use hyperactor::mailbox::MailboxSender;
31use hyperactor::mailbox::Undeliverable;
32use hyperactor::mailbox::UndeliverableMailboxSender;
33use hyperactor::mailbox::UndeliverableMessageError;
34use hyperactor::mailbox::monitored_return_handle;
35use hyperactor::reference::UnboundPort;
36use ndslice::selection::routing::RoutingFrame;
37use serde::Deserialize;
38use serde::Serialize;
39
40use crate::comm::multicast::CastMessage;
41use crate::comm::multicast::CastMessageEnvelope;
42use crate::comm::multicast::ForwardMessage;
43use crate::comm::multicast::set_cast_info_on_headers;
44
45/// Parameters to initialize the CommActor
46#[derive(Debug, Clone, Serialize, Deserialize, Named, Default)]
47pub struct CommActorParams {}
48
49/// A message buffered due to out-of-order delivery.
50#[derive(Debug)]
51struct Buffered {
52    /// Sequence number of this message.
53    seq: usize,
54    /// Whether to deliver this message to this comm-actors actors.
55    deliver_here: bool,
56    /// Peer comm actors to forward message to.
57    next_steps: HashMap<usize, Vec<RoutingFrame>>,
58    /// The message to deliver.
59    message: CastMessageEnvelope,
60}
61
62/// Bookkeeping to handle sequence numbers and in-order delivery for messages
63/// sent to and through this comm actor.
64#[derive(Debug, Default)]
65struct ReceiveState {
66    /// The sequence of the last received message.
67    seq: usize,
68    /// A buffer storing messages we received out-of-order, indexed by the seq
69    /// that should precede it.
70    buffer: HashMap<usize, Buffered>,
71    /// A map of the last sequence number we sent to next steps, indexed by rank.
72    last_seqs: HashMap<usize, usize>,
73}
74
75/// This is the comm actor used for efficient and scalable message multicasting
76/// and result accumulation.
77#[derive(Debug)]
78#[hyperactor::export(
79    spawn = true,
80    handlers = [
81        CommActorMode,
82        CastMessage,
83        ForwardMessage,
84    ],
85)]
86pub struct CommActor {
87    /// Sequence numbers are maintained for each (actor mesh id, sender).
88    send_seq: HashMap<(ActorMeshId, ActorId), usize>,
89    /// Each sender is a unique stream.
90    recv_state: HashMap<(ActorMeshId, ActorId), ReceiveState>,
91
92    /// The comm actor's mode.
93    mode: CommActorMode,
94}
95
96/// Configuration for how a `CommActor` determines its own rank and locates peers.
97///
98/// - In `Mesh` mode, the comm actor is assigned an explicit rank and a mapping to each peer by rank.
99/// - In `Implicit` mode, the comm actor infers its rank and peers from its own actor ID.
100#[derive(Debug, Clone, Serialize, Deserialize, Named)]
101pub enum CommActorMode {
102    /// When configured as a mesh, the comm actor is assigned a rank
103    /// and a set of references for each peer rank.
104    Mesh(usize, HashMap<usize, ActorRef<CommActor>>),
105
106    /// In an implicit mode, the comm actor derives its rank and
107    /// peers from its own ID.
108    Implicit,
109
110    /// Like `Implicit`, but override the destination world id.
111    /// This is useful for setups where comm actors may not reside
112    /// in the destination world. It is meant as a temporary bridge
113    /// until we are fully onto ActorMeshes.
114    // TODO: T224926642 Remove this once we are fully onto ActorMeshes.
115    ImplicitWithWorldId(WorldId),
116}
117
118impl Default for CommActorMode {
119    fn default() -> Self {
120        Self::Implicit
121    }
122}
123
124impl CommActorMode {
125    /// Return the peer comm actor for the given rank, given a self id,
126    /// destination port, and rank.
127    fn peer_for_rank(&self, self_id: &ActorId, rank: usize) -> Result<ActorRef<CommActor>> {
128        match self {
129            Self::Mesh(_self_rank, peers) => peers
130                .get(&rank)
131                .cloned()
132                .ok_or_else(|| anyhow::anyhow!("no peer for rank {}", rank)),
133            Self::Implicit => {
134                let world_id = self_id
135                    .proc_id()
136                    .world_id()
137                    .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc"))?;
138                let proc_id = world_id.proc_id(rank);
139                let actor_id = ActorId::root(proc_id, self_id.name().to_string());
140                Ok(ActorRef::<CommActor>::attest(actor_id))
141            }
142            Self::ImplicitWithWorldId(world_id) => {
143                let proc_id = world_id.proc_id(rank);
144                let actor_id = ActorId::root(proc_id, self_id.name().to_string());
145                Ok(ActorRef::<CommActor>::attest(actor_id))
146            }
147        }
148    }
149
150    /// Return the rank of the comm actor, given a self id.
151    fn self_rank(&self, self_id: &ActorId) -> Result<usize> {
152        match self {
153            Self::Mesh(rank, _) => Ok(*rank),
154            Self::Implicit | Self::ImplicitWithWorldId(_) => self_id
155                .proc_id()
156                .rank()
157                .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc")),
158        }
159    }
160}
161
162#[async_trait]
163impl Actor for CommActor {
164    type Params = CommActorParams;
165
166    async fn new(_params: Self::Params) -> Result<Self> {
167        Ok(Self {
168            send_seq: HashMap::new(),
169            recv_state: HashMap::new(),
170            mode: Default::default(),
171        })
172    }
173
174    // This is an override of the default actor behavior.
175    async fn handle_undeliverable_message(
176        &mut self,
177        cx: &Instance<Self>,
178        undelivered: hyperactor::mailbox::Undeliverable<hyperactor::mailbox::MessageEnvelope>,
179    ) -> Result<(), anyhow::Error> {
180        let Undeliverable(mut message_envelope) = undelivered;
181
182        // 1. Case delivery failure at a "forwarding" step.
183        if let Ok(ForwardMessage { message, .. }) =
184            message_envelope.deserialized::<ForwardMessage>()
185        {
186            let sender = message.sender();
187            let return_port = PortRef::attest_message_port(sender);
188            return_port
189                .send(cx, Undeliverable(message_envelope.clone()))
190                .map_err(|err| {
191                    message_envelope
192                        .try_set_error(DeliveryError::BrokenLink(format!("send failure: {err}")));
193                    UndeliverableMessageError::return_failure(&message_envelope)
194                })?;
195            return Ok(());
196        }
197
198        // 2. Case delivery failure at a "deliver here" step.
199        if let Some(sender) = message_envelope.headers().get(CAST_ORIGINATING_SENDER) {
200            let return_port = PortRef::attest_message_port(sender);
201            return_port
202                .send(cx, Undeliverable(message_envelope.clone()))
203                .map_err(|err| {
204                    message_envelope
205                        .try_set_error(DeliveryError::BrokenLink(format!("send failure: {err}")));
206                    UndeliverableMessageError::return_failure(&message_envelope)
207                })?;
208            return Ok(());
209        }
210
211        // 3. A return of an undeliverable message was itself returned.
212        UndeliverableMailboxSender
213            .post(message_envelope, /*unused */ monitored_return_handle());
214        Ok(())
215    }
216}
217
218impl CommActor {
219    /// Forward the message to the comm actor on the given peer rank.
220    fn forward(
221        cx: &Instance<Self>,
222        mode: &CommActorMode,
223        rank: usize,
224        message: ForwardMessage,
225    ) -> Result<()> {
226        let child = mode.peer_for_rank(cx.self_id(), rank)?;
227        child.send(cx, message)?;
228        Ok(())
229    }
230
231    fn handle_message(
232        cx: &Context<Self>,
233        mode: &CommActorMode,
234        deliver_here: bool,
235        next_steps: HashMap<usize, Vec<RoutingFrame>>,
236        sender: ActorId,
237        mut message: CastMessageEnvelope,
238        seq: usize,
239        last_seqs: &mut HashMap<usize, usize>,
240    ) -> Result<()> {
241        // Split ports, if any, and update message with new ports. In this
242        // way, children actors will reply to this comm actor's ports, instead
243        // of to the original ports provided by parent.
244        message
245            .data_mut()
246            .visit_mut::<UnboundPort>(|UnboundPort(port_id, reducer_spec)| {
247                let split = port_id.split(cx, reducer_spec.clone())?;
248
249                #[cfg(test)]
250                tests::collect_split_port(port_id, &split, deliver_here);
251
252                *port_id = split;
253                Ok(())
254            })?;
255
256        // Deliver message here, if necessary.
257        if deliver_here {
258            let rank_on_root_mesh = mode.self_rank(cx.self_id())?;
259            let cast_rank = message.relative_rank(rank_on_root_mesh)?;
260            let cast_shape = message.shape();
261            let mut headers = cx.headers().clone();
262            set_cast_info_on_headers(
263                &mut headers,
264                cast_rank,
265                cast_shape.clone(),
266                message.sender().clone(),
267            );
268            cx.post(
269                cx.self_id()
270                    .proc_id()
271                    .actor_id(message.dest_port().actor_name(), 0)
272                    .port_id(message.dest_port().port()),
273                headers,
274                Serialized::serialize(message.data())?,
275            );
276        }
277
278        // Forward to peers.
279        next_steps
280            .into_iter()
281            .map(|(peer, dests)| {
282                let last_seq = last_seqs.entry(peer).or_default();
283                Self::forward(
284                    cx,
285                    mode,
286                    peer,
287                    ForwardMessage {
288                        dests,
289                        sender: sender.clone(),
290                        message: message.clone(),
291                        seq,
292                        last_seq: *last_seq,
293                    },
294                )?;
295                *last_seq = seq;
296                Ok(())
297            })
298            .collect::<Result<Vec<_>>>()?;
299
300        Ok(())
301    }
302}
303
304#[async_trait]
305impl Handler<CommActorMode> for CommActor {
306    async fn handle(&mut self, _cx: &Context<Self>, mode: CommActorMode) -> Result<()> {
307        self.mode = mode;
308        Ok(())
309    }
310}
311
312// TODO(T218630526): reliable casting for mutable topology
313#[async_trait]
314impl Handler<CastMessage> for CommActor {
315    async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessage) -> Result<()> {
316        // Always forward the message to the root rank of the slice, casting starts from there.
317        let slice = cast_message.dest.slice.clone();
318        let selection = cast_message.dest.selection.clone();
319        let frame = RoutingFrame::root(selection, slice);
320        let rank = frame.slice.location(&frame.here)?;
321        let seq = self
322            .send_seq
323            .entry(cast_message.message.stream_key())
324            .or_default();
325        let last_seq = *seq;
326        *seq += 1;
327        Self::forward(
328            cx,
329            &self.mode,
330            rank,
331            ForwardMessage {
332                dests: vec![frame],
333                sender: cx.self_id().clone(),
334                message: cast_message.message,
335                seq: *seq,
336                last_seq,
337            },
338        )?;
339        Ok(())
340    }
341}
342
343#[async_trait]
344impl Handler<ForwardMessage> for CommActor {
345    async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessage) -> Result<()> {
346        let ForwardMessage {
347            sender,
348            dests,
349            message,
350            seq,
351            last_seq,
352        } = fwd_message;
353
354        // Resolve/dedup routing frames.
355        let rank = self.mode.self_rank(cx.self_id())?;
356        let (deliver_here, next_steps) =
357            ndslice::selection::routing::resolve_routing(rank, dests, &mut |_| {
358                panic!("Choice encountered in CommActor routing")
359            })?;
360
361        let recv_state = self.recv_state.entry(message.stream_key()).or_default();
362        match recv_state.seq.cmp(&last_seq) {
363            // We got the expected next message to deliver to this host.
364            Ordering::Equal => {
365                // We got an in-order operation, so handle it now.
366                Self::handle_message(
367                    cx,
368                    &self.mode,
369                    deliver_here,
370                    next_steps,
371                    sender.clone(),
372                    message,
373                    seq,
374                    &mut recv_state.last_seqs,
375                )?;
376                recv_state.seq = seq;
377
378                // Also deliver any pending operations from the recv buffer that
379                // were received out-of-order that are now unblocked.
380                while let Some(Buffered {
381                    seq,
382                    deliver_here,
383                    next_steps,
384                    message,
385                }) = recv_state.buffer.remove(&recv_state.seq)
386                {
387                    Self::handle_message(
388                        cx,
389                        &self.mode,
390                        deliver_here,
391                        next_steps,
392                        sender.clone(),
393                        message,
394                        seq,
395                        &mut recv_state.last_seqs,
396                    )?;
397                    recv_state.seq = seq;
398                }
399            }
400            // We got an out-of-order operation, so buffer it for now, until we
401            // recieved the onces sequenced before it.
402            Ordering::Less => {
403                tracing::warn!(
404                    "buffering out-of-order message with seq {} (last {}), expected {}: {:?}",
405                    seq,
406                    last_seq,
407                    recv_state.seq,
408                    message
409                );
410                recv_state.buffer.insert(
411                    last_seq,
412                    Buffered {
413                        seq,
414                        deliver_here,
415                        next_steps,
416                        message,
417                    },
418                );
419            }
420            // We already got this message -- just drop it.
421            Ordering::Greater => {
422                tracing::warn!("received duplicate message with seq {}: {:?}", seq, message);
423            }
424        }
425
426        Ok(())
427    }
428}
429
430pub mod test_utils {
431    use anyhow::Result;
432    use async_trait::async_trait;
433    use hyperactor::Actor;
434    use hyperactor::ActorId;
435    use hyperactor::Bind;
436    use hyperactor::Context;
437    use hyperactor::Handler;
438    use hyperactor::Named;
439    use hyperactor::PortRef;
440    use hyperactor::Unbind;
441    use serde::Deserialize;
442    use serde::Serialize;
443
444    use super::*;
445
446    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
447    pub struct MyReply {
448        pub sender: ActorId,
449        pub value: u64,
450    }
451
452    #[derive(Debug, Named, Serialize, Deserialize, PartialEq, Clone, Bind, Unbind)]
453    pub enum TestMessage {
454        Forward(String),
455        CastAndReply {
456            arg: String,
457            // Intentionally not including 0. As a result, this port will not be
458            // split.
459            // #[binding(include)]
460            reply_to0: PortRef<String>,
461            #[binding(include)]
462            reply_to1: PortRef<u64>,
463            #[binding(include)]
464            reply_to2: PortRef<MyReply>,
465        },
466    }
467
468    #[derive(Debug)]
469    #[hyperactor::export(
470        spawn = true,
471        handlers = [
472            TestMessage { cast = true },
473        ],
474    )]
475    pub struct TestActor {
476        // Forward the received message to this port, so it can be inspected by
477        // the unit test.
478        forward_port: PortRef<TestMessage>,
479    }
480
481    #[derive(Debug, Clone, Named, Serialize, Deserialize)]
482    pub struct TestActorParams {
483        pub forward_port: PortRef<TestMessage>,
484    }
485
486    #[async_trait]
487    impl Actor for TestActor {
488        type Params = TestActorParams;
489
490        async fn new(params: Self::Params) -> Result<Self> {
491            let Self::Params { forward_port } = params;
492            Ok(Self { forward_port })
493        }
494    }
495
496    #[async_trait]
497    impl Handler<TestMessage> for TestActor {
498        async fn handle(&mut self, cx: &Context<Self>, msg: TestMessage) -> anyhow::Result<()> {
499            self.forward_port.send(cx, msg)?;
500            Ok(())
501        }
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use std::collections::BTreeMap;
508    use std::collections::HashSet;
509    use std::fmt::Display;
510    use std::hash::Hash;
511    use std::ops::DerefMut;
512    use std::sync::Arc;
513    use std::sync::Mutex;
514    use std::sync::OnceLock;
515
516    use hyperactor::PortId;
517    use hyperactor::PortRef;
518    use hyperactor::accum;
519    use hyperactor::accum::Accumulator;
520    use hyperactor::accum::ReducerSpec;
521    use hyperactor::clock::Clock;
522    use hyperactor::clock::RealClock;
523    use hyperactor::config;
524    use hyperactor::mailbox::PortReceiver;
525    use hyperactor::mailbox::open_port;
526    use hyperactor::reference::Index;
527    use hyperactor_mesh_macros::sel;
528    use maplit::btreemap;
529    use maplit::hashmap;
530    use ndslice::Selection;
531    use ndslice::extent;
532    use ndslice::selection::test_utils::collect_commactor_routing_tree;
533    use test_utils::*;
534    use timed_test::async_timed_test;
535    use tokio::time::Duration;
536
537    use super::*;
538    use crate::ProcMesh;
539    use crate::actor_mesh::ActorMesh;
540    use crate::actor_mesh::RootActorMesh;
541    use crate::alloc::AllocSpec;
542    use crate::alloc::Allocator;
543    use crate::alloc::LocalAllocator;
544    use crate::proc_mesh::SharedSpawnable;
545
546    struct Edge<T> {
547        from: T,
548        to: T,
549        is_leaf: bool,
550    }
551
552    impl<T> From<(T, T, bool)> for Edge<T> {
553        fn from((from, to, is_leaf): (T, T, bool)) -> Self {
554            Self { from, to, is_leaf }
555        }
556    }
557
558    // The relationship between original ports and split ports. The elements in
559    // the tuple are (original port, split port, deliver_here).
560    static SPLIT_PORT_TREE: OnceLock<Mutex<Vec<Edge<PortId>>>> = OnceLock::new();
561
562    // Collect the relationships between original ports and split ports into
563    // SPLIT_PORT_TREE. This is used by tests to verify that ports are split as expected.
564    pub(crate) fn collect_split_port(original: &PortId, split: &PortId, deliver_here: bool) {
565        let mutex = SPLIT_PORT_TREE.get_or_init(|| Mutex::new(vec![]));
566        let mut tree = mutex.lock().unwrap();
567
568        tree.deref_mut().push(Edge {
569            from: original.clone(),
570            to: split.clone(),
571            is_leaf: deliver_here,
572        });
573    }
574
575    // A representation of a tree.
576    //   * Map's keys are the tree's leafs;
577    //   * Map's values are the path from the root to that leaf.
578    #[derive(PartialEq)]
579    struct PathToLeaves<T>(BTreeMap<T, Vec<T>>);
580
581    // Add a custom Debug trait impl so the result from assert_eq! is readable.
582    impl<T: Display> Debug for PathToLeaves<T> {
583        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
584            fn vec_to_string<T: Display>(v: &[T]) -> String {
585                v.iter()
586                    .map(ToString::to_string)
587                    .collect::<Vec<String>>()
588                    .join(", ")
589            }
590
591            for (src, path) in &self.0 {
592                write!(f, "{} -> {}\n", src, vec_to_string(path))?;
593            }
594            Ok(())
595        }
596    }
597
598    fn build_paths<T: Clone + Eq + Hash + Ord>(edges: &[Edge<T>]) -> PathToLeaves<T> {
599        let mut child_parent_map = HashMap::new();
600        let mut all_nodes = HashSet::new();
601        let mut parents = HashSet::new();
602        let mut children = HashSet::new();
603        let mut dests = HashSet::new();
604
605        // Build parent map and track all nodes and children
606        for Edge { from, to, is_leaf } in edges {
607            child_parent_map.insert(to.clone(), from.clone());
608            all_nodes.insert(from.clone());
609            all_nodes.insert(to.clone());
610            parents.insert(from.clone());
611            children.insert(to.clone());
612            if *is_leaf {
613                dests.insert(to.clone());
614            }
615        }
616
617        // For each leaf, reconstruct path back to root
618        let mut result = BTreeMap::new();
619        for dest in dests {
620            let mut path = vec![dest.clone()];
621            let mut current = dest.clone();
622            while let Some(parent) = child_parent_map.get(&current) {
623                path.push(parent.clone());
624                current = parent.clone();
625            }
626            path.reverse();
627            result.insert(dest, path);
628        }
629
630        PathToLeaves(result)
631    }
632
633    #[test]
634    fn test_build_paths() {
635        // Given the tree:
636        //     0
637        //    / \
638        //   1   4
639        //  / \   \
640        // 2   3   5
641        let edges: Vec<_> = [
642            (0, 1, false),
643            (1, 2, true),
644            (1, 3, true),
645            (0, 4, true),
646            (4, 5, true),
647        ]
648        .into_iter()
649        .map(|(from, to, is_leaf)| Edge { from, to, is_leaf })
650        .collect();
651
652        let paths = build_paths(&edges);
653
654        let expected = btreemap! {
655            2 => vec![0, 1, 2],
656            3 => vec![0, 1, 3],
657            4 => vec![0, 4],
658            5 => vec![0, 4, 5],
659        };
660
661        assert_eq!(paths.0, expected);
662    }
663
664    //  Given a port tree,
665    //     * remove the client port, i.e. the 1st element of the path;
666    //     * verify all remaining ports are comm actor ports;
667    //     * remove the actor information and return a rank-based tree representation.
668    //
669    //  The rank-based tree representation is what [collect_commactor_routing_tree] returns.
670    //  This conversion enables us to compare the path against [collect_commactor_routing_tree]'s result.
671    //
672    //      For example, for a 2x2 slice, the port tree could look like:
673    //      dest[0].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028]]
674    //      dest[1].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[1].comm[0][1028]]
675    //      dest[2].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[2].comm[0][1028]]
676    //      dest[3].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[2].comm[0][1028], dest[3].comm[0][1028]]
677    //
678    //     The result should be:
679    //     0 -> 0
680    //     1 -> 0, 1
681    //     2 -> 0, 2
682    //     3 -> 0, 2, 3
683    fn get_ranks(paths: PathToLeaves<PortId>, client_reply: &PortId) -> PathToLeaves<Index> {
684        let ranks = paths
685            .0
686            .into_iter()
687            .map(|(dst, mut path)| {
688                let first = path.remove(0);
689                // The first PortId is the client's reply port.
690                assert_eq!(&first, client_reply);
691                // Other ports's actor ID must be dest[?].comm[0], where ? is
692                // the rank we want to extract here.
693                assert_eq!(dst.actor_id().name(), "comm");
694                let actor_path = path
695                    .into_iter()
696                    .map(|p| {
697                        assert_eq!(p.actor_id().name(), "comm");
698                        p.actor_id().rank()
699                    })
700                    .collect();
701                (dst.into_actor_id().rank(), actor_path)
702            })
703            .collect();
704        PathToLeaves(ranks)
705    }
706
707    struct MeshSetup {
708        actor_mesh: RootActorMesh<'static, TestActor>,
709        reply1_rx: PortReceiver<u64>,
710        reply2_rx: PortReceiver<MyReply>,
711        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
712    }
713
714    struct NoneAccumulator;
715
716    impl Accumulator for NoneAccumulator {
717        type State = u64;
718        type Update = u64;
719
720        fn accumulate(
721            &self,
722            _state: &mut Self::State,
723            _update: Self::Update,
724        ) -> anyhow::Result<()> {
725            unimplemented!()
726        }
727
728        fn reducer_spec(&self) -> Option<ReducerSpec> {
729            unimplemented!()
730        }
731    }
732
733    async fn setup_mesh<A>(accum: Option<A>) -> MeshSetup
734    where
735        A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
736    {
737        let extent = extent!(replica = 4, host = 4, gpu = 4);
738        let alloc = LocalAllocator
739            .allocate(AllocSpec {
740                extent: extent.clone(),
741                constraints: Default::default(),
742            })
743            .await
744            .unwrap();
745
746        let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap());
747        let dest_actor_name = "dest_actor";
748        let (tx, mut rx) = hyperactor::mailbox::open_port(proc_mesh.client());
749        let params = TestActorParams {
750            forward_port: tx.bind(),
751        };
752        let actor_mesh = proc_mesh
753            .clone()
754            .spawn::<TestActor>(dest_actor_name, &params)
755            .await
756            .unwrap();
757
758        let (reply_port_handle0, _) = open_port::<String>(proc_mesh.client());
759        let reply_port_ref0 = reply_port_handle0.bind();
760        let (reply_port_handle1, reply1_rx) = match accum {
761            Some(a) => proc_mesh.client().open_accum_port(a),
762            None => open_port(proc_mesh.client()),
763        };
764        let reply_port_ref1 = reply_port_handle1.bind();
765        let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(proc_mesh.client());
766        let reply_port_ref2 = reply_port_handle2.bind();
767        let message = TestMessage::CastAndReply {
768            arg: "abc".to_string(),
769            reply_to0: reply_port_ref0.clone(),
770            reply_to1: reply_port_ref1.clone(),
771            reply_to2: reply_port_ref2.clone(),
772        };
773
774        let selection = sel!(*);
775        actor_mesh
776            .cast(proc_mesh.client(), selection.clone(), message)
777            .unwrap();
778
779        let mut reply_tos = vec![];
780        for _ in extent.points() {
781            let msg = rx.recv().await.expect("missing");
782            match msg {
783                TestMessage::CastAndReply {
784                    arg,
785                    reply_to0,
786                    reply_to1,
787                    reply_to2,
788                } => {
789                    assert_eq!(arg, "abc");
790                    // port 0 is still the same as the original one because it
791                    // is not included in MutVisitor.
792                    assert_eq!(reply_to0, reply_port_ref0);
793                    // ports have been replaced by comm actor's split ports.
794                    assert_ne!(reply_to1, reply_port_ref1);
795                    assert_eq!(reply_to1.port_id().actor_id().name(), "comm");
796                    assert_ne!(reply_to2, reply_port_ref2);
797                    assert_eq!(reply_to2.port_id().actor_id().name(), "comm");
798                    reply_tos.push((reply_to1, reply_to2));
799                }
800                _ => {
801                    panic!("unexpected message: {:?}", msg);
802                }
803            }
804        }
805
806        // Verify the split port paths are the same as the casting paths.
807        {
808            // Get the paths used in casting
809            let sel_paths = PathToLeaves(
810                collect_commactor_routing_tree(&selection, &extent.to_slice())
811                    .delivered
812                    .into_iter()
813                    .collect(),
814            );
815
816            // Get the split port paths collected in SPLIT_PORT_TREE during casting
817            let (reply1_paths, reply2_paths) = {
818                let tree = SPLIT_PORT_TREE.get().unwrap();
819                let edges = tree.lock().unwrap();
820                let (reply1, reply2): (BTreeMap<_, _>, BTreeMap<_, _>) = build_paths(&edges)
821                    .0
822                    .into_iter()
823                    .partition(|(_dst, path)| &path[0] == reply_port_ref1.port_id());
824                (
825                    get_ranks(PathToLeaves(reply1), reply_port_ref1.port_id()),
826                    get_ranks(PathToLeaves(reply2), reply_port_ref2.port_id()),
827                )
828            };
829
830            // split port paths should be the same as casting paths
831            assert_eq!(sel_paths, reply1_paths);
832            assert_eq!(sel_paths, reply2_paths);
833        }
834
835        MeshSetup {
836            actor_mesh,
837            reply1_rx,
838            reply2_rx,
839            reply_tos,
840        }
841    }
842
843    #[async_timed_test(timeout_secs = 30)]
844    async fn test_cast_and_reply() {
845        let MeshSetup {
846            actor_mesh,
847            mut reply1_rx,
848            mut reply2_rx,
849            reply_tos,
850            ..
851        } = setup_mesh::<NoneAccumulator>(None).await;
852        let proc_mesh_client = actor_mesh.proc_mesh().client();
853
854        // Reply from each dest actor. The replies should be received by client.
855        {
856            for (dest_actor, (reply_to1, reply_to2)) in
857                actor_mesh.ranks.iter().zip(reply_tos.iter())
858            {
859                let rank = dest_actor.actor_id().rank() as u64;
860                reply_to1.send(proc_mesh_client, rank).unwrap();
861                let my_reply = MyReply {
862                    sender: dest_actor.actor_id().clone(),
863                    value: rank,
864                };
865                reply_to2.send(proc_mesh_client, my_reply.clone()).unwrap();
866
867                assert_eq!(reply1_rx.recv().await.unwrap(), rank);
868                assert_eq!(reply2_rx.recv().await.unwrap(), my_reply);
869            }
870        }
871
872        tracing::info!("the 1st updates from all dest actors were receivered by client");
873
874        // Now send multiple replies from the dest actors. They should all be
875        // received by client. Replies sent from the same dest actor should
876        // be received in the same order as they were sent out.
877        {
878            let n = 100;
879            let mut expected2: HashMap<usize, Vec<MyReply>> = hashmap! {};
880            for (dest_actor, (_reply_to1, reply_to2)) in
881                actor_mesh.ranks.iter().zip(reply_tos.iter())
882            {
883                let rank = dest_actor.actor_id().rank();
884                let mut sent2 = vec![];
885                for i in 0..n {
886                    let value = (rank * 100 + i) as u64;
887                    let my_reply = MyReply {
888                        sender: dest_actor.actor_id().clone(),
889                        value,
890                    };
891                    reply_to2.send(proc_mesh_client, my_reply.clone()).unwrap();
892                    sent2.push(my_reply);
893                }
894                assert!(
895                    expected2.insert(rank, sent2).is_none(),
896                    "duplicate rank {rank} in map"
897                );
898            }
899
900            let mut received2: HashMap<usize, Vec<MyReply>> = hashmap! {};
901
902            for _ in 0..(n * actor_mesh.ranks.len()) {
903                let my_reply = reply2_rx.recv().await.unwrap();
904                received2
905                    .entry(my_reply.sender.rank())
906                    .or_default()
907                    .push(my_reply);
908            }
909            assert_eq!(received2, expected2);
910        }
911    }
912
913    async fn wait_for_with_timeout(
914        receiver: &mut PortReceiver<u64>,
915        expected: u64,
916        dur: Duration,
917    ) -> anyhow::Result<()> {
918        // timeout wraps the entire async block
919        RealClock
920            .timeout(dur, async {
921                loop {
922                    let msg = receiver.recv().await.unwrap();
923                    if msg == expected {
924                        break;
925                    }
926                }
927            })
928            .await?;
929        Ok(())
930    }
931
932    #[async_timed_test(timeout_secs = 30)]
933    async fn test_cast_and_accum() -> Result<()> {
934        let config = config::global::lock();
935        // Use temporary config for this test
936        let _guard1 = config.override_key(config::SPLIT_MAX_BUFFER_SIZE, 1);
937
938        let MeshSetup {
939            actor_mesh,
940            mut reply1_rx,
941            reply_tos,
942            ..
943        } = setup_mesh(Some(accum::sum::<u64>())).await;
944        let proc_mesh_client = actor_mesh.proc_mesh().client();
945
946        // Now send multiple replies from the dest actors. They should all be
947        // received by client. Replies sent from the same dest actor should
948        // be received in the same order as they were sent out.
949        {
950            let mut sum = 0;
951            let n = 100;
952            for (dest_actor, (reply_to1, _reply_to2)) in
953                actor_mesh.ranks.iter().zip(reply_tos.iter())
954            {
955                let rank = dest_actor.actor_id().rank();
956                for i in 0..n {
957                    let value = (rank + i) as u64;
958                    reply_to1.send(proc_mesh_client, value).unwrap();
959                    sum += value;
960                }
961            }
962            wait_for_with_timeout(&mut reply1_rx, sum, Duration::from_secs(2))
963                .await
964                .unwrap();
965            // no more messages
966            RealClock.sleep(Duration::from_secs(2)).await;
967            let msg = reply1_rx.try_recv().unwrap();
968            assert_eq!(msg, None);
969        }
970        Ok(())
971    }
972}