hyperactor_mesh/
comm.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use crate::comm::multicast::CAST_ORIGINATING_SENDER;
10use crate::reference::ActorMeshId;
11use crate::resource;
12pub mod multicast;
13
14use std::cmp::Ordering;
15use std::collections::HashMap;
16use std::fmt::Debug;
17
18use anyhow::Result;
19use async_trait::async_trait;
20use hyperactor::Actor;
21use hyperactor::ActorId;
22use hyperactor::ActorRef;
23use hyperactor::Context;
24use hyperactor::Handler;
25use hyperactor::Instance;
26use hyperactor::Named;
27use hyperactor::PortRef;
28use hyperactor::WorldId;
29use hyperactor::data::Serialized;
30use hyperactor::mailbox::DeliveryError;
31use hyperactor::mailbox::MailboxSender;
32use hyperactor::mailbox::Undeliverable;
33use hyperactor::mailbox::UndeliverableMailboxSender;
34use hyperactor::mailbox::UndeliverableMessageError;
35use hyperactor::mailbox::monitored_return_handle;
36use hyperactor::reference::UnboundPort;
37use ndslice::selection::routing::RoutingFrame;
38use serde::Deserialize;
39use serde::Serialize;
40
41use crate::comm::multicast::CastMessage;
42use crate::comm::multicast::CastMessageEnvelope;
43use crate::comm::multicast::ForwardMessage;
44use crate::comm::multicast::set_cast_info_on_headers;
45
46/// Parameters to initialize the CommActor
47#[derive(Debug, Clone, Serialize, Deserialize, Named, Default)]
48pub struct CommActorParams {}
49
50/// A message buffered due to out-of-order delivery.
51#[derive(Debug)]
52struct Buffered {
53    /// Sequence number of this message.
54    seq: usize,
55    /// Whether to deliver this message to this comm-actors actors.
56    deliver_here: bool,
57    /// Peer comm actors to forward message to.
58    next_steps: HashMap<usize, Vec<RoutingFrame>>,
59    /// The message to deliver.
60    message: CastMessageEnvelope,
61}
62
63/// Bookkeeping to handle sequence numbers and in-order delivery for messages
64/// sent to and through this comm actor.
65#[derive(Debug, Default)]
66struct ReceiveState {
67    /// The sequence of the last received message.
68    seq: usize,
69    /// A buffer storing messages we received out-of-order, indexed by the seq
70    /// that should precede it.
71    buffer: HashMap<usize, Buffered>,
72    /// A map of the last sequence number we sent to next steps, indexed by rank.
73    last_seqs: HashMap<usize, usize>,
74}
75
76/// This is the comm actor used for efficient and scalable message multicasting
77/// and result accumulation.
78#[derive(Debug)]
79#[hyperactor::export(
80    spawn = true,
81    handlers = [
82        CommActorMode,
83        CastMessage,
84        ForwardMessage,
85    ],
86)]
87pub struct CommActor {
88    /// Sequence numbers are maintained for each (actor mesh id, sender).
89    send_seq: HashMap<(ActorMeshId, ActorId), usize>,
90    /// Each sender is a unique stream.
91    recv_state: HashMap<(ActorMeshId, ActorId), ReceiveState>,
92
93    /// The comm actor's mode.
94    mode: CommActorMode,
95}
96
97/// Configuration for how a `CommActor` determines its own rank and locates peers.
98///
99/// - In `Mesh` mode, the comm actor is assigned an explicit rank and a mapping to each peer by rank.
100/// - In `Implicit` mode, the comm actor infers its rank and peers from its own actor ID.
101#[derive(Debug, Clone, Serialize, Deserialize, Named)]
102pub enum CommActorMode {
103    /// When configured as a mesh, the comm actor is assigned a rank
104    /// and a set of references for each peer rank.
105    Mesh(usize, HashMap<usize, ActorRef<CommActor>>),
106
107    /// In an implicit mode, the comm actor derives its rank and
108    /// peers from its own ID.
109    Implicit,
110
111    /// Like `Implicit`, but override the destination world id.
112    /// This is useful for setups where comm actors may not reside
113    /// in the destination world. It is meant as a temporary bridge
114    /// until we are fully onto ActorMeshes.
115    // TODO: T224926642 Remove this once we are fully onto ActorMeshes.
116    ImplicitWithWorldId(WorldId),
117}
118
119impl Default for CommActorMode {
120    fn default() -> Self {
121        Self::Implicit
122    }
123}
124
125impl CommActorMode {
126    /// Return the peer comm actor for the given rank, given a self id,
127    /// destination port, and rank.
128    fn peer_for_rank(&self, self_id: &ActorId, rank: usize) -> Result<ActorRef<CommActor>> {
129        match self {
130            Self::Mesh(_self_rank, peers) => peers
131                .get(&rank)
132                .cloned()
133                .ok_or_else(|| anyhow::anyhow!("no peer for rank {}", rank)),
134            Self::Implicit => {
135                let world_id = self_id
136                    .proc_id()
137                    .world_id()
138                    .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc"))?;
139                let proc_id = world_id.proc_id(rank);
140                let actor_id = ActorId::root(proc_id, self_id.name().to_string());
141                Ok(ActorRef::<CommActor>::attest(actor_id))
142            }
143            Self::ImplicitWithWorldId(world_id) => {
144                let proc_id = world_id.proc_id(rank);
145                let actor_id = ActorId::root(proc_id, self_id.name().to_string());
146                Ok(ActorRef::<CommActor>::attest(actor_id))
147            }
148        }
149    }
150
151    /// Return the rank of the comm actor, given a self id.
152    fn self_rank(&self, self_id: &ActorId) -> Result<usize> {
153        match self {
154            Self::Mesh(rank, _) => Ok(*rank),
155            Self::Implicit | Self::ImplicitWithWorldId(_) => self_id
156                .proc_id()
157                .rank()
158                .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc")),
159        }
160    }
161}
162
163#[async_trait]
164impl Actor for CommActor {
165    type Params = CommActorParams;
166
167    async fn new(_params: Self::Params) -> Result<Self> {
168        Ok(Self {
169            send_seq: HashMap::new(),
170            recv_state: HashMap::new(),
171            mode: Default::default(),
172        })
173    }
174
175    // This is an override of the default actor behavior.
176    async fn handle_undeliverable_message(
177        &mut self,
178        cx: &Instance<Self>,
179        undelivered: hyperactor::mailbox::Undeliverable<hyperactor::mailbox::MessageEnvelope>,
180    ) -> Result<(), anyhow::Error> {
181        let Undeliverable(mut message_envelope) = undelivered;
182
183        // 1. Case delivery failure at a "forwarding" step.
184        if let Ok(ForwardMessage { message, .. }) =
185            message_envelope.deserialized::<ForwardMessage>()
186        {
187            let sender = message.sender();
188            let return_port = PortRef::attest_message_port(sender);
189            message_envelope.set_error(DeliveryError::Multicast(format!(
190                "comm actor {} failed to forward the cast message; return to \
191                its original sender's port {}",
192                cx.self_id(),
193                return_port.port_id(),
194            )));
195            return_port
196                .send(cx, Undeliverable(message_envelope.clone()))
197                .map_err(|err| {
198                    let error = DeliveryError::BrokenLink(format!(
199                        "error occured when returning ForwardMessage to the original \
200                        sender's port {}; error is: {}",
201                        return_port.port_id(),
202                        err,
203                    ));
204                    message_envelope.set_error(error);
205                    UndeliverableMessageError::ReturnFailure {
206                        envelope: message_envelope,
207                    }
208                })?;
209            return Ok(());
210        }
211
212        // 2. Case delivery failure at a "deliver here" step.
213        if let Some(sender) = message_envelope.headers().get(CAST_ORIGINATING_SENDER) {
214            let return_port = PortRef::attest_message_port(sender);
215            message_envelope.set_error(DeliveryError::Multicast(format!(
216                "comm actor {} failed to deliver the cast message to the dest \
217                actor; return to its original sender's port {}",
218                cx.self_id(),
219                return_port.port_id(),
220            )));
221            return_port
222                .send(cx, Undeliverable(message_envelope.clone()))
223                .map_err(|err| {
224                    let error = DeliveryError::BrokenLink(format!(
225                        "error occured when returning cast message to the original \
226                        sender's port {}; error is: {}",
227                        return_port.port_id(),
228                        err,
229                    ));
230                    message_envelope.set_error(error);
231                    UndeliverableMessageError::ReturnFailure {
232                        envelope: message_envelope,
233                    }
234                })?;
235            return Ok(());
236        }
237
238        // 3. A return of an undeliverable message was itself returned.
239        UndeliverableMailboxSender
240            .post(message_envelope, /*unused */ monitored_return_handle());
241        Ok(())
242    }
243}
244
245impl CommActor {
246    /// Forward the message to the comm actor on the given peer rank.
247    fn forward(
248        cx: &Instance<Self>,
249        mode: &CommActorMode,
250        rank: usize,
251        message: ForwardMessage,
252    ) -> Result<()> {
253        let child = mode.peer_for_rank(cx.self_id(), rank)?;
254        child.send(cx, message)?;
255        Ok(())
256    }
257
258    fn handle_message(
259        cx: &Context<Self>,
260        mode: &CommActorMode,
261        deliver_here: bool,
262        next_steps: HashMap<usize, Vec<RoutingFrame>>,
263        sender: ActorId,
264        mut message: CastMessageEnvelope,
265        seq: usize,
266        last_seqs: &mut HashMap<usize, usize>,
267    ) -> Result<()> {
268        // Split ports, if any, and update message with new ports. In this
269        // way, children actors will reply to this comm actor's ports, instead
270        // of to the original ports provided by parent.
271        message.data_mut().visit_mut::<UnboundPort>(
272            |UnboundPort(port_id, reducer_spec, reducer_opts, return_undeliverable)| {
273                let split = port_id.split(
274                    cx,
275                    reducer_spec.clone(),
276                    reducer_opts.clone(),
277                    *return_undeliverable,
278                )?;
279
280                #[cfg(test)]
281                tests::collect_split_port(port_id, &split, deliver_here);
282
283                *port_id = split;
284                Ok(())
285            },
286        )?;
287
288        // Deliver message here, if necessary.
289        if deliver_here {
290            let rank_on_root_mesh = mode.self_rank(cx.self_id())?;
291            let cast_rank = message.relative_rank(rank_on_root_mesh)?;
292            // Replace ranks with self ranks.
293            message
294                .data_mut()
295                .visit_mut::<resource::Rank>(|resource::Rank(rank)| {
296                    *rank = Some(cast_rank);
297                    Ok(())
298                })?;
299            let cast_shape = message.shape();
300            let point = cast_shape
301                .extent()
302                .point_of_rank(cast_rank)
303                .expect("rank out of bounds");
304            let mut headers = cx.headers().clone();
305            set_cast_info_on_headers(&mut headers, point, message.sender().clone());
306            cx.post(
307                cx.self_id()
308                    .proc_id()
309                    .actor_id(message.dest_port().actor_name(), 0)
310                    .port_id(message.dest_port().port()),
311                headers,
312                Serialized::serialize(message.data())?,
313            );
314        }
315
316        // Forward to peers.
317        next_steps
318            .into_iter()
319            .map(|(peer, dests)| {
320                let last_seq = last_seqs.entry(peer).or_default();
321                Self::forward(
322                    cx,
323                    mode,
324                    peer,
325                    ForwardMessage {
326                        dests,
327                        sender: sender.clone(),
328                        message: message.clone(),
329                        seq,
330                        last_seq: *last_seq,
331                    },
332                )?;
333                *last_seq = seq;
334                Ok(())
335            })
336            .collect::<Result<Vec<_>>>()?;
337
338        Ok(())
339    }
340}
341
342#[async_trait]
343impl Handler<CommActorMode> for CommActor {
344    async fn handle(&mut self, _cx: &Context<Self>, mode: CommActorMode) -> Result<()> {
345        self.mode = mode;
346        Ok(())
347    }
348}
349
350// TODO(T218630526): reliable casting for mutable topology
351#[async_trait]
352impl Handler<CastMessage> for CommActor {
353    async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessage) -> Result<()> {
354        // Always forward the message to the root rank of the slice, casting starts from there.
355        let slice = cast_message.dest.slice.clone();
356        let selection = cast_message.dest.selection.clone();
357        let frame = RoutingFrame::root(selection, slice);
358        let rank = frame.slice.location(&frame.here)?;
359        let seq = self
360            .send_seq
361            .entry(cast_message.message.stream_key())
362            .or_default();
363        let last_seq = *seq;
364        *seq += 1;
365        Self::forward(
366            cx,
367            &self.mode,
368            rank,
369            ForwardMessage {
370                dests: vec![frame],
371                sender: cx.self_id().clone(),
372                message: cast_message.message,
373                seq: *seq,
374                last_seq,
375            },
376        )?;
377        Ok(())
378    }
379}
380
381#[async_trait]
382impl Handler<ForwardMessage> for CommActor {
383    async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessage) -> Result<()> {
384        let ForwardMessage {
385            sender,
386            dests,
387            message,
388            seq,
389            last_seq,
390        } = fwd_message;
391
392        // Resolve/dedup routing frames.
393        let rank = self.mode.self_rank(cx.self_id())?;
394        let (deliver_here, next_steps) =
395            ndslice::selection::routing::resolve_routing(rank, dests, &mut |_| {
396                panic!("Choice encountered in CommActor routing")
397            })?;
398
399        let recv_state = self.recv_state.entry(message.stream_key()).or_default();
400        match recv_state.seq.cmp(&last_seq) {
401            // We got the expected next message to deliver to this host.
402            Ordering::Equal => {
403                // We got an in-order operation, so handle it now.
404                Self::handle_message(
405                    cx,
406                    &self.mode,
407                    deliver_here,
408                    next_steps,
409                    sender.clone(),
410                    message,
411                    seq,
412                    &mut recv_state.last_seqs,
413                )?;
414                recv_state.seq = seq;
415
416                // Also deliver any pending operations from the recv buffer that
417                // were received out-of-order that are now unblocked.
418                while let Some(Buffered {
419                    seq,
420                    deliver_here,
421                    next_steps,
422                    message,
423                }) = recv_state.buffer.remove(&recv_state.seq)
424                {
425                    Self::handle_message(
426                        cx,
427                        &self.mode,
428                        deliver_here,
429                        next_steps,
430                        sender.clone(),
431                        message,
432                        seq,
433                        &mut recv_state.last_seqs,
434                    )?;
435                    recv_state.seq = seq;
436                }
437            }
438            // We got an out-of-order operation, so buffer it for now, until we
439            // recieved the onces sequenced before it.
440            Ordering::Less => {
441                tracing::warn!(
442                    "buffering out-of-order message with seq {} (last {}), expected {}: {:?}",
443                    seq,
444                    last_seq,
445                    recv_state.seq,
446                    message
447                );
448                recv_state.buffer.insert(
449                    last_seq,
450                    Buffered {
451                        seq,
452                        deliver_here,
453                        next_steps,
454                        message,
455                    },
456                );
457            }
458            // We already got this message -- just drop it.
459            Ordering::Greater => {
460                tracing::warn!("received duplicate message with seq {}: {:?}", seq, message);
461            }
462        }
463
464        Ok(())
465    }
466}
467
468pub mod test_utils {
469    use anyhow::Result;
470    use async_trait::async_trait;
471    use hyperactor::Actor;
472    use hyperactor::ActorId;
473    use hyperactor::Bind;
474    use hyperactor::Context;
475    use hyperactor::Handler;
476    use hyperactor::Named;
477    use hyperactor::PortRef;
478    use hyperactor::Unbind;
479    use serde::Deserialize;
480    use serde::Serialize;
481
482    use super::*;
483
484    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
485    pub struct MyReply {
486        pub sender: ActorId,
487        pub value: u64,
488    }
489
490    #[derive(Debug, Named, Serialize, Deserialize, PartialEq, Clone, Bind, Unbind)]
491    pub enum TestMessage {
492        Forward(String),
493        CastAndReply {
494            arg: String,
495            // Intentionally not including 0. As a result, this port will not be
496            // split.
497            // #[binding(include)]
498            reply_to0: PortRef<String>,
499            #[binding(include)]
500            reply_to1: PortRef<u64>,
501            #[binding(include)]
502            reply_to2: PortRef<MyReply>,
503        },
504    }
505
506    #[derive(Debug)]
507    #[hyperactor::export(
508        spawn = true,
509        handlers = [
510            TestMessage { cast = true },
511        ],
512    )]
513    pub struct TestActor {
514        // Forward the received message to this port, so it can be inspected by
515        // the unit test.
516        forward_port: PortRef<TestMessage>,
517    }
518
519    #[derive(Debug, Clone, Named, Serialize, Deserialize)]
520    pub struct TestActorParams {
521        pub forward_port: PortRef<TestMessage>,
522    }
523
524    #[async_trait]
525    impl Actor for TestActor {
526        type Params = TestActorParams;
527
528        async fn new(params: Self::Params) -> Result<Self> {
529            let Self::Params { forward_port } = params;
530            Ok(Self { forward_port })
531        }
532    }
533
534    #[async_trait]
535    impl Handler<TestMessage> for TestActor {
536        async fn handle(&mut self, cx: &Context<Self>, msg: TestMessage) -> anyhow::Result<()> {
537            self.forward_port.send(cx, msg)?;
538            Ok(())
539        }
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use std::collections::BTreeMap;
546    use std::collections::HashSet;
547    use std::fmt::Display;
548    use std::hash::Hash;
549    use std::ops::Deref;
550    use std::ops::DerefMut;
551    use std::sync::Arc;
552    use std::sync::Mutex;
553    use std::sync::OnceLock;
554
555    use hyperactor::PortId;
556    use hyperactor::PortRef;
557    use hyperactor::accum;
558    use hyperactor::accum::Accumulator;
559    use hyperactor::accum::ReducerSpec;
560    use hyperactor::channel::ChannelTransport;
561    use hyperactor::clock::Clock;
562    use hyperactor::clock::RealClock;
563    use hyperactor::config;
564    use hyperactor::context::Mailbox;
565    use hyperactor::mailbox::PortReceiver;
566    use hyperactor::mailbox::open_port;
567    use hyperactor::reference::Index;
568    use hyperactor_mesh_macros::sel;
569    use maplit::btreemap;
570    use maplit::hashmap;
571    use ndslice::Extent;
572    use ndslice::Selection;
573    use ndslice::ViewExt as _;
574    use ndslice::extent;
575    use ndslice::selection::test_utils::collect_commactor_routing_tree;
576    use test_utils::*;
577    use timed_test::async_timed_test;
578    use tokio::time::Duration;
579
580    use super::*;
581    use crate::ProcMesh;
582    use crate::actor_mesh::ActorMesh;
583    use crate::actor_mesh::RootActorMesh;
584    use crate::alloc::AllocSpec;
585    use crate::alloc::Allocator;
586    use crate::alloc::LocalAllocator;
587    use crate::proc_mesh::SharedSpawnable;
588    use crate::v1;
589
590    struct Edge<T> {
591        from: T,
592        to: T,
593        is_leaf: bool,
594    }
595
596    impl<T> From<(T, T, bool)> for Edge<T> {
597        fn from((from, to, is_leaf): (T, T, bool)) -> Self {
598            Self { from, to, is_leaf }
599        }
600    }
601
602    // The relationship between original ports and split ports. The elements in
603    // the tuple are (original port, split port, deliver_here).
604    static SPLIT_PORT_TREE: OnceLock<Mutex<Vec<Edge<PortId>>>> = OnceLock::new();
605
606    // Collect the relationships between original ports and split ports into
607    // SPLIT_PORT_TREE. This is used by tests to verify that ports are split as expected.
608    pub(crate) fn collect_split_port(original: &PortId, split: &PortId, deliver_here: bool) {
609        let mutex = SPLIT_PORT_TREE.get_or_init(|| Mutex::new(vec![]));
610        let mut tree = mutex.lock().unwrap();
611
612        tree.deref_mut().push(Edge {
613            from: original.clone(),
614            to: split.clone(),
615            is_leaf: deliver_here,
616        });
617    }
618
619    // There could be other cast calls before the one we want to check, e.g. from
620    // allocating the proc mesh, or spawning the actor mesh. Clear the collected
621    // tree so it will only contain the cast we want to check.
622    fn clear_collected_tree() {
623        if let Some(tree) = SPLIT_PORT_TREE.get() {
624            let mut tree = tree.lock().unwrap();
625            tree.clear();
626        }
627    }
628
629    // A representation of a tree.
630    //   * Map's keys are the tree's leafs;
631    //   * Map's values are the path from the root to that leaf.
632    #[derive(PartialEq)]
633    struct PathToLeaves<T>(BTreeMap<T, Vec<T>>);
634
635    // Add a custom Debug trait impl so the result from assert_eq! is readable.
636    impl<T: Display> Debug for PathToLeaves<T> {
637        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
638            fn vec_to_string<T: Display>(v: &[T]) -> String {
639                v.iter()
640                    .map(ToString::to_string)
641                    .collect::<Vec<String>>()
642                    .join(", ")
643            }
644
645            for (src, path) in &self.0 {
646                write!(f, "{} -> {}\n", src, vec_to_string(path))?;
647            }
648            Ok(())
649        }
650    }
651
652    fn build_paths<T: Clone + Eq + Hash + Ord>(edges: &[Edge<T>]) -> PathToLeaves<T> {
653        let mut child_parent_map = HashMap::new();
654        let mut all_nodes = HashSet::new();
655        let mut parents = HashSet::new();
656        let mut children = HashSet::new();
657        let mut dests = HashSet::new();
658
659        // Build parent map and track all nodes and children
660        for Edge { from, to, is_leaf } in edges {
661            child_parent_map.insert(to.clone(), from.clone());
662            all_nodes.insert(from.clone());
663            all_nodes.insert(to.clone());
664            parents.insert(from.clone());
665            children.insert(to.clone());
666            if *is_leaf {
667                dests.insert(to.clone());
668            }
669        }
670
671        // For each leaf, reconstruct path back to root
672        let mut result = BTreeMap::new();
673        for dest in dests {
674            let mut path = vec![dest.clone()];
675            let mut current = dest.clone();
676            while let Some(parent) = child_parent_map.get(&current) {
677                path.push(parent.clone());
678                current = parent.clone();
679            }
680            path.reverse();
681            result.insert(dest, path);
682        }
683
684        PathToLeaves(result)
685    }
686
687    #[test]
688    fn test_build_paths() {
689        // Given the tree:
690        //     0
691        //    / \
692        //   1   4
693        //  / \   \
694        // 2   3   5
695        let edges: Vec<_> = [
696            (0, 1, false),
697            (1, 2, true),
698            (1, 3, true),
699            (0, 4, true),
700            (4, 5, true),
701        ]
702        .into_iter()
703        .map(|(from, to, is_leaf)| Edge { from, to, is_leaf })
704        .collect();
705
706        let paths = build_paths(&edges);
707
708        let expected = btreemap! {
709            2 => vec![0, 1, 2],
710            3 => vec![0, 1, 3],
711            4 => vec![0, 4],
712            5 => vec![0, 4, 5],
713        };
714
715        assert_eq!(paths.0, expected);
716    }
717
718    //  Given a port tree,
719    //     * remove the client port, i.e. the 1st element of the path;
720    //     * verify all remaining ports are comm actor ports;
721    //     * remove the actor information and return a rank-based tree representation.
722    //
723    //  The rank-based tree representation is what [collect_commactor_routing_tree] returns.
724    //  This conversion enables us to compare the path against [collect_commactor_routing_tree]'s result.
725    //
726    //      For example, for a 2x2 slice, the port tree could look like:
727    //      dest[0].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028]]
728    //      dest[1].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[1].comm[0][1028]]
729    //      dest[2].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[2].comm[0][1028]]
730    //      dest[3].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[2].comm[0][1028], dest[3].comm[0][1028]]
731    //
732    //     The result should be:
733    //     0 -> 0
734    //     1 -> 0, 1
735    //     2 -> 0, 2
736    //     3 -> 0, 2, 3
737    fn get_ranks(paths: PathToLeaves<PortId>, client_reply: &PortId) -> PathToLeaves<Index> {
738        let ranks = paths
739            .0
740            .into_iter()
741            .map(|(dst, mut path)| {
742                let first = path.remove(0);
743                // The first PortId is the client's reply port.
744                assert_eq!(&first, client_reply);
745                // Other ports's actor ID must be dest[?].comm[0], where ? is
746                // the rank we want to extract here.
747                assert!(dst.actor_id().name().contains("comm"));
748                let actor_path = path
749                    .into_iter()
750                    .map(|p| {
751                        assert!(p.actor_id().name().contains("comm"));
752                        p.actor_id().rank()
753                    })
754                    .collect();
755                (dst.into_actor_id().rank(), actor_path)
756            })
757            .collect();
758        PathToLeaves(ranks)
759    }
760
761    struct MeshSetup {
762        actor_mesh: RootActorMesh<'static, TestActor>,
763        reply1_rx: PortReceiver<u64>,
764        reply2_rx: PortReceiver<MyReply>,
765        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
766    }
767
768    struct NoneAccumulator;
769
770    impl Accumulator for NoneAccumulator {
771        type State = u64;
772        type Update = u64;
773
774        fn accumulate(
775            &self,
776            _state: &mut Self::State,
777            _update: Self::Update,
778        ) -> anyhow::Result<()> {
779            unimplemented!()
780        }
781
782        fn reducer_spec(&self) -> Option<ReducerSpec> {
783            unimplemented!()
784        }
785    }
786
787    // Verify the split port paths are the same as the casting paths.
788    fn verify_split_port_paths(
789        selection: &Selection,
790        extent: &Extent,
791        reply_port_ref1: &PortRef<u64>,
792        reply_port_ref2: &PortRef<MyReply>,
793    ) {
794        // Get the paths used in casting
795        let sel_paths = PathToLeaves(
796            collect_commactor_routing_tree(selection, &extent.to_slice())
797                .delivered
798                .into_iter()
799                .collect(),
800        );
801
802        // Get the split port paths collected in SPLIT_PORT_TREE during casting
803        let (reply1_paths, reply2_paths) = {
804            let tree = SPLIT_PORT_TREE.get().unwrap();
805            let edges = tree.lock().unwrap();
806            let (reply1, reply2): (BTreeMap<_, _>, BTreeMap<_, _>) = build_paths(&edges)
807                .0
808                .into_iter()
809                .partition(|(_dst, path)| &path[0] == reply_port_ref1.port_id());
810            (
811                get_ranks(PathToLeaves(reply1), reply_port_ref1.port_id()),
812                get_ranks(PathToLeaves(reply2), reply_port_ref2.port_id()),
813            )
814        };
815
816        // split port paths should be the same as casting paths
817        assert_eq!(sel_paths, reply1_paths);
818        assert_eq!(sel_paths, reply2_paths);
819    }
820
821    async fn setup_mesh<A>(accum: Option<A>) -> MeshSetup
822    where
823        A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
824    {
825        let extent = extent!(replica = 4, host = 4, gpu = 4);
826        let alloc = LocalAllocator
827            .allocate(AllocSpec {
828                extent: extent.clone(),
829                constraints: Default::default(),
830                proc_name: None,
831                transport: ChannelTransport::Local,
832                proc_allocation_mode: Default::default(),
833            })
834            .await
835            .unwrap();
836
837        let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap());
838        let dest_actor_name = "dest_actor";
839        let (tx, mut rx) = hyperactor::mailbox::open_port(proc_mesh.client());
840        let params = TestActorParams {
841            forward_port: tx.bind(),
842        };
843        let instance = crate::v1::testing::instance().await;
844        let actor_mesh = Arc::clone(&proc_mesh)
845            .spawn::<TestActor>(&instance, dest_actor_name, &params)
846            .await
847            .unwrap();
848
849        let (reply_port_handle0, _) = open_port::<String>(proc_mesh.client());
850        let reply_port_ref0 = reply_port_handle0.bind();
851        let (reply_port_handle1, reply1_rx) = match accum {
852            Some(a) => proc_mesh.client().mailbox().open_accum_port(a),
853            None => open_port(proc_mesh.client()),
854        };
855        let reply_port_ref1 = reply_port_handle1.bind();
856        let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(proc_mesh.client());
857        let reply_port_ref2 = reply_port_handle2.bind();
858        let message = TestMessage::CastAndReply {
859            arg: "abc".to_string(),
860            reply_to0: reply_port_ref0.clone(),
861            reply_to1: reply_port_ref1.clone(),
862            reply_to2: reply_port_ref2.clone(),
863        };
864
865        let selection = sel!(*);
866        clear_collected_tree();
867        actor_mesh
868            .cast(proc_mesh.client(), selection.clone(), message)
869            .unwrap();
870
871        let mut reply_tos = vec![];
872        for _ in extent.points() {
873            let msg = rx.recv().await.expect("missing");
874            match msg {
875                TestMessage::CastAndReply {
876                    arg,
877                    reply_to0,
878                    reply_to1,
879                    reply_to2,
880                } => {
881                    assert_eq!(arg, "abc");
882                    // port 0 is still the same as the original one because it
883                    // is not included in MutVisitor.
884                    assert_eq!(reply_to0, reply_port_ref0);
885                    // ports have been replaced by comm actor's split ports.
886                    assert_ne!(reply_to1, reply_port_ref1);
887                    assert_eq!(reply_to1.port_id().actor_id().name(), "comm");
888                    assert_ne!(reply_to2, reply_port_ref2);
889                    assert_eq!(reply_to2.port_id().actor_id().name(), "comm");
890                    reply_tos.push((reply_to1, reply_to2));
891                }
892                _ => {
893                    panic!("unexpected message: {:?}", msg);
894                }
895            }
896        }
897
898        verify_split_port_paths(&selection, &extent, &reply_port_ref1, &reply_port_ref2);
899
900        MeshSetup {
901            actor_mesh,
902            reply1_rx,
903            reply2_rx,
904            reply_tos,
905        }
906    }
907
908    async fn execute_cast_and_reply(
909        ranks: Vec<ActorRef<TestActor>>,
910        instance: &Instance<()>,
911        mut reply1_rx: PortReceiver<u64>,
912        mut reply2_rx: PortReceiver<MyReply>,
913        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
914    ) {
915        // Reply from each dest actor. The replies should be received by client.
916        {
917            for (dest_actor, (reply_to1, reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
918                let rank = dest_actor.actor_id().rank() as u64;
919                reply_to1.send(instance, rank).unwrap();
920                let my_reply = MyReply {
921                    sender: dest_actor.actor_id().clone(),
922                    value: rank,
923                };
924                reply_to2.send(instance, my_reply.clone()).unwrap();
925
926                assert_eq!(reply1_rx.recv().await.unwrap(), rank);
927                assert_eq!(reply2_rx.recv().await.unwrap(), my_reply);
928            }
929        }
930
931        tracing::info!("the 1st updates from all dest actors were receivered by client");
932
933        // Now send multiple replies from the dest actors. They should all be
934        // received by client. Replies sent from the same dest actor should
935        // be received in the same order as they were sent out.
936        {
937            let n = 100;
938            let mut expected2: HashMap<usize, Vec<MyReply>> = hashmap! {};
939            for (dest_actor, (_reply_to1, reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
940                let rank = dest_actor.actor_id().rank();
941                let mut sent2 = vec![];
942                for i in 0..n {
943                    let value = (rank * 100 + i) as u64;
944                    let my_reply = MyReply {
945                        sender: dest_actor.actor_id().clone(),
946                        value,
947                    };
948                    reply_to2.send(instance, my_reply.clone()).unwrap();
949                    sent2.push(my_reply);
950                }
951                assert!(
952                    expected2.insert(rank, sent2).is_none(),
953                    "duplicate rank {rank} in map"
954                );
955            }
956
957            let mut received2: HashMap<usize, Vec<MyReply>> = hashmap! {};
958
959            for _ in 0..(n * ranks.len()) {
960                let my_reply = reply2_rx.recv().await.unwrap();
961                received2
962                    .entry(my_reply.sender.rank())
963                    .or_default()
964                    .push(my_reply);
965            }
966            assert_eq!(received2, expected2);
967        }
968    }
969
970    #[async_timed_test(timeout_secs = 30)]
971    async fn test_cast_and_reply() {
972        let MeshSetup {
973            actor_mesh,
974            reply1_rx,
975            reply2_rx,
976            reply_tos,
977            ..
978        } = setup_mesh::<NoneAccumulator>(None).await;
979        let proc_mesh_client = actor_mesh.proc_mesh().client();
980
981        let ranks = actor_mesh.ranks().clone();
982        execute_cast_and_reply(ranks, proc_mesh_client, reply1_rx, reply2_rx, reply_tos).await;
983    }
984
985    async fn wait_for_with_timeout(
986        receiver: &mut PortReceiver<u64>,
987        expected: u64,
988        dur: Duration,
989    ) -> anyhow::Result<()> {
990        // timeout wraps the entire async block
991        RealClock
992            .timeout(dur, async {
993                loop {
994                    let msg = receiver.recv().await.unwrap();
995                    if msg == expected {
996                        break;
997                    }
998                }
999            })
1000            .await?;
1001        Ok(())
1002    }
1003
1004    async fn execute_cast_and_accum(
1005        ranks: Vec<ActorRef<TestActor>>,
1006        instance: &Instance<()>,
1007        mut reply1_rx: PortReceiver<u64>,
1008        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1009    ) {
1010        // Now send multiple replies from the dest actors. They should all be
1011        // received by client. Replies sent from the same dest actor should
1012        // be received in the same order as they were sent out.
1013        let mut sum = 0;
1014        let n = 100;
1015        for (dest_actor, (reply_to1, _reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
1016            let rank = dest_actor.actor_id().rank();
1017            for i in 0..n {
1018                let value = (rank + i) as u64;
1019                reply_to1.send(instance, value).unwrap();
1020                sum += value;
1021            }
1022        }
1023        wait_for_with_timeout(&mut reply1_rx, sum, Duration::from_secs(2))
1024            .await
1025            .unwrap();
1026        // no more messages
1027        RealClock.sleep(Duration::from_secs(2)).await;
1028        let msg = reply1_rx.try_recv().unwrap();
1029        assert_eq!(msg, None);
1030    }
1031
1032    #[async_timed_test(timeout_secs = 30)]
1033    async fn test_cast_and_accum() {
1034        let config = config::global::lock();
1035        // Use temporary config for this test
1036        let _guard1 = config.override_key(config::SPLIT_MAX_BUFFER_SIZE, 1);
1037
1038        let MeshSetup {
1039            actor_mesh,
1040            reply1_rx,
1041            reply_tos,
1042            ..
1043        } = setup_mesh(Some(accum::sum::<u64>())).await;
1044        let proc_mesh_client = actor_mesh.proc_mesh().client();
1045        let ranks = actor_mesh.ranks().clone();
1046        execute_cast_and_accum(ranks, proc_mesh_client, reply1_rx, reply_tos).await;
1047    }
1048
1049    struct MeshSetupV1 {
1050        instance: &'static Instance<()>,
1051        actor_mesh_ref: v1::ActorMeshRef<TestActor>,
1052        reply1_rx: PortReceiver<u64>,
1053        reply2_rx: PortReceiver<MyReply>,
1054        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1055    }
1056
1057    async fn setup_mesh_v1<A>(accum: Option<A>) -> MeshSetupV1
1058    where
1059        A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
1060    {
1061        let instance = v1::testing::instance().await;
1062
1063        let extent = extent!(replica = 4, host = 4, gpu = 4);
1064        let alloc = LocalAllocator
1065            .allocate(AllocSpec {
1066                extent: extent.clone(),
1067                constraints: Default::default(),
1068                proc_name: None,
1069                transport: ChannelTransport::Local,
1070                proc_allocation_mode: Default::default(),
1071            })
1072            .await
1073            .unwrap();
1074
1075        let proc_mesh = v1::ProcMesh::allocate(instance, Box::new(alloc), "test.local")
1076            .await
1077            .unwrap();
1078
1079        let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1080        let params = TestActorParams {
1081            forward_port: tx.bind(),
1082        };
1083        let actor_mesh = proc_mesh.spawn(&instance, "test", &params).await.unwrap();
1084        let actor_mesh_ref = actor_mesh.deref().clone();
1085
1086        let (reply_port_handle0, _) = open_port::<String>(instance);
1087        let reply_port_ref0 = reply_port_handle0.bind();
1088        let (reply_port_handle1, reply1_rx) = match accum {
1089            Some(a) => instance.mailbox().open_accum_port(a),
1090            None => open_port(instance),
1091        };
1092        let reply_port_ref1 = reply_port_handle1.bind();
1093        let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(instance);
1094        let reply_port_ref2 = reply_port_handle2.bind();
1095        let message = TestMessage::CastAndReply {
1096            arg: "abc".to_string(),
1097            reply_to0: reply_port_ref0.clone(),
1098            reply_to1: reply_port_ref1.clone(),
1099            reply_to2: reply_port_ref2.clone(),
1100        };
1101
1102        clear_collected_tree();
1103        actor_mesh_ref.cast(instance, message).unwrap();
1104
1105        let mut reply_tos = vec![];
1106        for _ in extent.points() {
1107            let msg = rx.recv().await.expect("missing");
1108            match msg {
1109                TestMessage::CastAndReply {
1110                    arg,
1111                    reply_to0,
1112                    reply_to1,
1113                    reply_to2,
1114                } => {
1115                    assert_eq!(arg, "abc");
1116                    // port 0 is still the same as the original one because it
1117                    // is not included in MutVisitor.
1118                    assert_eq!(reply_to0, reply_port_ref0);
1119                    // ports have been replaced by comm actor's split ports.
1120                    assert_ne!(reply_to1, reply_port_ref1);
1121                    assert!(reply_to1.port_id().actor_id().name().contains("comm"));
1122                    assert_ne!(reply_to2, reply_port_ref2);
1123                    assert!(reply_to2.port_id().actor_id().name().contains("comm"));
1124                    reply_tos.push((reply_to1, reply_to2));
1125                }
1126                _ => {
1127                    panic!("unexpected message: {:?}", msg);
1128                }
1129            }
1130        }
1131
1132        // v1 always uses sel!(*) when casting to a mesh.
1133        let selection = sel!(*);
1134        verify_split_port_paths(&selection, &extent, &reply_port_ref1, &reply_port_ref2);
1135
1136        MeshSetupV1 {
1137            instance,
1138            actor_mesh_ref,
1139            reply1_rx,
1140            reply2_rx,
1141            reply_tos,
1142        }
1143    }
1144
1145    #[async_timed_test(timeout_secs = 30)]
1146    async fn test_cast_and_reply_v1() {
1147        let MeshSetupV1 {
1148            instance,
1149            actor_mesh_ref,
1150            reply1_rx,
1151            reply2_rx,
1152            reply_tos,
1153            ..
1154        } = setup_mesh_v1::<NoneAccumulator>(None).await;
1155
1156        let ranks = actor_mesh_ref.values().collect::<Vec<_>>();
1157        execute_cast_and_reply(ranks, instance, reply1_rx, reply2_rx, reply_tos).await;
1158    }
1159
1160    #[async_timed_test(timeout_secs = 30)]
1161    async fn test_cast_and_accum_v1() {
1162        let config = config::global::lock();
1163        // Use temporary config for this test
1164        let _guard1 = config.override_key(config::SPLIT_MAX_BUFFER_SIZE, 1);
1165
1166        let MeshSetupV1 {
1167            instance,
1168            actor_mesh_ref,
1169            reply1_rx,
1170            reply_tos,
1171            ..
1172        } = setup_mesh_v1(Some(accum::sum::<u64>())).await;
1173
1174        let ranks = actor_mesh_ref.values().collect::<Vec<_>>();
1175        execute_cast_and_accum(ranks, instance, reply1_rx, reply_tos).await;
1176    }
1177}