hyperactor_mesh/
comm.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use crate::comm::multicast::CAST_ORIGINATING_SENDER;
10use crate::reference::ActorMeshId;
11use crate::resource;
12pub mod multicast;
13
14use std::cmp::Ordering;
15use std::collections::HashMap;
16use std::fmt::Debug;
17
18use anyhow::Result;
19use async_trait::async_trait;
20use hyperactor::Actor;
21use hyperactor::ActorId;
22use hyperactor::ActorRef;
23use hyperactor::Context;
24use hyperactor::Handler;
25use hyperactor::Instance;
26use hyperactor::Named;
27use hyperactor::PortRef;
28use hyperactor::WorldId;
29use hyperactor::data::Serialized;
30use hyperactor::mailbox::DeliveryError;
31use hyperactor::mailbox::MailboxSender;
32use hyperactor::mailbox::Undeliverable;
33use hyperactor::mailbox::UndeliverableMailboxSender;
34use hyperactor::mailbox::UndeliverableMessageError;
35use hyperactor::mailbox::monitored_return_handle;
36use hyperactor::reference::UnboundPort;
37use ndslice::selection::routing::RoutingFrame;
38use serde::Deserialize;
39use serde::Serialize;
40
41use crate::comm::multicast::CastMessage;
42use crate::comm::multicast::CastMessageEnvelope;
43use crate::comm::multicast::ForwardMessage;
44use crate::comm::multicast::set_cast_info_on_headers;
45
46/// Parameters to initialize the CommActor
47#[derive(Debug, Clone, Serialize, Deserialize, Named, Default)]
48pub struct CommActorParams {}
49
50/// A message buffered due to out-of-order delivery.
51#[derive(Debug)]
52struct Buffered {
53    /// Sequence number of this message.
54    seq: usize,
55    /// Whether to deliver this message to this comm-actors actors.
56    deliver_here: bool,
57    /// Peer comm actors to forward message to.
58    next_steps: HashMap<usize, Vec<RoutingFrame>>,
59    /// The message to deliver.
60    message: CastMessageEnvelope,
61}
62
63/// Bookkeeping to handle sequence numbers and in-order delivery for messages
64/// sent to and through this comm actor.
65#[derive(Debug, Default)]
66struct ReceiveState {
67    /// The sequence of the last received message.
68    seq: usize,
69    /// A buffer storing messages we received out-of-order, indexed by the seq
70    /// that should precede it.
71    buffer: HashMap<usize, Buffered>,
72    /// A map of the last sequence number we sent to next steps, indexed by rank.
73    last_seqs: HashMap<usize, usize>,
74}
75
76/// This is the comm actor used for efficient and scalable message multicasting
77/// and result accumulation.
78#[derive(Debug)]
79#[hyperactor::export(
80    spawn = true,
81    handlers = [
82        CommActorMode,
83        CastMessage,
84        ForwardMessage,
85    ],
86)]
87pub struct CommActor {
88    /// Sequence numbers are maintained for each (actor mesh id, sender).
89    send_seq: HashMap<(ActorMeshId, ActorId), usize>,
90    /// Each sender is a unique stream.
91    recv_state: HashMap<(ActorMeshId, ActorId), ReceiveState>,
92
93    /// The comm actor's mode.
94    mode: CommActorMode,
95}
96
97/// Configuration for how a `CommActor` determines its own rank and locates peers.
98///
99/// - In `Mesh` mode, the comm actor is assigned an explicit rank and a mapping to each peer by rank.
100/// - In `Implicit` mode, the comm actor infers its rank and peers from its own actor ID.
101#[derive(Debug, Clone, Serialize, Deserialize, Named)]
102pub enum CommActorMode {
103    /// When configured as a mesh, the comm actor is assigned a rank
104    /// and a set of references for each peer rank.
105    Mesh(usize, HashMap<usize, ActorRef<CommActor>>),
106
107    /// In an implicit mode, the comm actor derives its rank and
108    /// peers from its own ID.
109    Implicit,
110
111    /// Like `Implicit`, but override the destination world id.
112    /// This is useful for setups where comm actors may not reside
113    /// in the destination world. It is meant as a temporary bridge
114    /// until we are fully onto ActorMeshes.
115    // TODO: T224926642 Remove this once we are fully onto ActorMeshes.
116    ImplicitWithWorldId(WorldId),
117}
118
119impl Default for CommActorMode {
120    fn default() -> Self {
121        Self::Implicit
122    }
123}
124
125impl CommActorMode {
126    /// Return the peer comm actor for the given rank, given a self id,
127    /// destination port, and rank.
128    fn peer_for_rank(&self, self_id: &ActorId, rank: usize) -> Result<ActorRef<CommActor>> {
129        match self {
130            Self::Mesh(_self_rank, peers) => peers
131                .get(&rank)
132                .cloned()
133                .ok_or_else(|| anyhow::anyhow!("no peer for rank {}", rank)),
134            Self::Implicit => {
135                let world_id = self_id
136                    .proc_id()
137                    .world_id()
138                    .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc"))?;
139                let proc_id = world_id.proc_id(rank);
140                let actor_id = ActorId::root(proc_id, self_id.name().to_string());
141                Ok(ActorRef::<CommActor>::attest(actor_id))
142            }
143            Self::ImplicitWithWorldId(world_id) => {
144                let proc_id = world_id.proc_id(rank);
145                let actor_id = ActorId::root(proc_id, self_id.name().to_string());
146                Ok(ActorRef::<CommActor>::attest(actor_id))
147            }
148        }
149    }
150
151    /// Return the rank of the comm actor, given a self id.
152    fn self_rank(&self, self_id: &ActorId) -> Result<usize> {
153        match self {
154            Self::Mesh(rank, _) => Ok(*rank),
155            Self::Implicit | Self::ImplicitWithWorldId(_) => self_id
156                .proc_id()
157                .rank()
158                .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc")),
159        }
160    }
161}
162
163#[async_trait]
164impl Actor for CommActor {
165    type Params = CommActorParams;
166
167    async fn new(_params: Self::Params) -> Result<Self> {
168        Ok(Self {
169            send_seq: HashMap::new(),
170            recv_state: HashMap::new(),
171            mode: Default::default(),
172        })
173    }
174
175    // This is an override of the default actor behavior.
176    async fn handle_undeliverable_message(
177        &mut self,
178        cx: &Instance<Self>,
179        undelivered: hyperactor::mailbox::Undeliverable<hyperactor::mailbox::MessageEnvelope>,
180    ) -> Result<(), anyhow::Error> {
181        let Undeliverable(mut message_envelope) = undelivered;
182
183        // 1. Case delivery failure at a "forwarding" step.
184        if let Ok(ForwardMessage { message, .. }) =
185            message_envelope.deserialized::<ForwardMessage>()
186        {
187            let sender = message.sender();
188            let return_port = PortRef::attest_message_port(sender);
189            message_envelope.set_error(DeliveryError::Multicast(format!(
190                "comm actor {} failed to forward the cast message; return to \
191                its original sender's port {}",
192                cx.self_id(),
193                return_port.port_id(),
194            )));
195            return_port
196                .send(cx, Undeliverable(message_envelope.clone()))
197                .map_err(|err| {
198                    let error = DeliveryError::BrokenLink(format!(
199                        "error occured when returning ForwardMessage to the original \
200                        sender's port {}; error is: {}",
201                        return_port.port_id(),
202                        err,
203                    ));
204                    message_envelope.set_error(error);
205                    UndeliverableMessageError::return_failure(&message_envelope)
206                })?;
207            return Ok(());
208        }
209
210        // 2. Case delivery failure at a "deliver here" step.
211        if let Some(sender) = message_envelope.headers().get(CAST_ORIGINATING_SENDER) {
212            let return_port = PortRef::attest_message_port(sender);
213            message_envelope.set_error(DeliveryError::Multicast(format!(
214                "comm actor {} failed to deliver the cast message to the dest \
215                actor; return to its original sender's port {}",
216                cx.self_id(),
217                return_port.port_id(),
218            )));
219            return_port
220                .send(cx, Undeliverable(message_envelope.clone()))
221                .map_err(|err| {
222                    let error = DeliveryError::BrokenLink(format!(
223                        "error occured when returning cast message to the original \
224                        sender's port {}; error is: {}",
225                        return_port.port_id(),
226                        err,
227                    ));
228                    message_envelope.set_error(error);
229                    UndeliverableMessageError::return_failure(&message_envelope)
230                })?;
231            return Ok(());
232        }
233
234        // 3. A return of an undeliverable message was itself returned.
235        UndeliverableMailboxSender
236            .post(message_envelope, /*unused */ monitored_return_handle());
237        Ok(())
238    }
239}
240
241impl CommActor {
242    /// Forward the message to the comm actor on the given peer rank.
243    fn forward(
244        cx: &Instance<Self>,
245        mode: &CommActorMode,
246        rank: usize,
247        message: ForwardMessage,
248    ) -> Result<()> {
249        let child = mode.peer_for_rank(cx.self_id(), rank)?;
250        child.send(cx, message)?;
251        Ok(())
252    }
253
254    fn handle_message(
255        cx: &Context<Self>,
256        mode: &CommActorMode,
257        deliver_here: bool,
258        next_steps: HashMap<usize, Vec<RoutingFrame>>,
259        sender: ActorId,
260        mut message: CastMessageEnvelope,
261        seq: usize,
262        last_seqs: &mut HashMap<usize, usize>,
263    ) -> Result<()> {
264        // Split ports, if any, and update message with new ports. In this
265        // way, children actors will reply to this comm actor's ports, instead
266        // of to the original ports provided by parent.
267        message.data_mut().visit_mut::<UnboundPort>(
268            |UnboundPort(port_id, reducer_spec, reducer_opts)| {
269                let split = port_id.split(cx, reducer_spec.clone(), reducer_opts.clone())?;
270
271                #[cfg(test)]
272                tests::collect_split_port(port_id, &split, deliver_here);
273
274                *port_id = split;
275                Ok(())
276            },
277        )?;
278
279        // Deliver message here, if necessary.
280        if deliver_here {
281            let rank_on_root_mesh = mode.self_rank(cx.self_id())?;
282            let cast_rank = message.relative_rank(rank_on_root_mesh)?;
283            // Replace ranks with self ranks.
284            message
285                .data_mut()
286                .visit_mut::<resource::Rank>(|resource::Rank(rank)| {
287                    *rank = Some(cast_rank);
288                    Ok(())
289                })?;
290            let cast_shape = message.shape();
291            let point = cast_shape
292                .extent()
293                .point_of_rank(cast_rank)
294                .expect("rank out of bounds");
295            let mut headers = cx.headers().clone();
296            set_cast_info_on_headers(&mut headers, point, message.sender().clone());
297            cx.post(
298                cx.self_id()
299                    .proc_id()
300                    .actor_id(message.dest_port().actor_name(), 0)
301                    .port_id(message.dest_port().port()),
302                headers,
303                Serialized::serialize(message.data())?,
304            );
305        }
306
307        // Forward to peers.
308        next_steps
309            .into_iter()
310            .map(|(peer, dests)| {
311                let last_seq = last_seqs.entry(peer).or_default();
312                Self::forward(
313                    cx,
314                    mode,
315                    peer,
316                    ForwardMessage {
317                        dests,
318                        sender: sender.clone(),
319                        message: message.clone(),
320                        seq,
321                        last_seq: *last_seq,
322                    },
323                )?;
324                *last_seq = seq;
325                Ok(())
326            })
327            .collect::<Result<Vec<_>>>()?;
328
329        Ok(())
330    }
331}
332
333#[async_trait]
334impl Handler<CommActorMode> for CommActor {
335    async fn handle(&mut self, _cx: &Context<Self>, mode: CommActorMode) -> Result<()> {
336        self.mode = mode;
337        Ok(())
338    }
339}
340
341// TODO(T218630526): reliable casting for mutable topology
342#[async_trait]
343impl Handler<CastMessage> for CommActor {
344    async fn handle(&mut self, cx: &Context<Self>, cast_message: CastMessage) -> Result<()> {
345        // Always forward the message to the root rank of the slice, casting starts from there.
346        let slice = cast_message.dest.slice.clone();
347        let selection = cast_message.dest.selection.clone();
348        let frame = RoutingFrame::root(selection, slice);
349        let rank = frame.slice.location(&frame.here)?;
350        let seq = self
351            .send_seq
352            .entry(cast_message.message.stream_key())
353            .or_default();
354        let last_seq = *seq;
355        *seq += 1;
356        Self::forward(
357            cx,
358            &self.mode,
359            rank,
360            ForwardMessage {
361                dests: vec![frame],
362                sender: cx.self_id().clone(),
363                message: cast_message.message,
364                seq: *seq,
365                last_seq,
366            },
367        )?;
368        Ok(())
369    }
370}
371
372#[async_trait]
373impl Handler<ForwardMessage> for CommActor {
374    async fn handle(&mut self, cx: &Context<Self>, fwd_message: ForwardMessage) -> Result<()> {
375        let ForwardMessage {
376            sender,
377            dests,
378            message,
379            seq,
380            last_seq,
381        } = fwd_message;
382
383        // Resolve/dedup routing frames.
384        let rank = self.mode.self_rank(cx.self_id())?;
385        let (deliver_here, next_steps) =
386            ndslice::selection::routing::resolve_routing(rank, dests, &mut |_| {
387                panic!("Choice encountered in CommActor routing")
388            })?;
389
390        let recv_state = self.recv_state.entry(message.stream_key()).or_default();
391        match recv_state.seq.cmp(&last_seq) {
392            // We got the expected next message to deliver to this host.
393            Ordering::Equal => {
394                // We got an in-order operation, so handle it now.
395                Self::handle_message(
396                    cx,
397                    &self.mode,
398                    deliver_here,
399                    next_steps,
400                    sender.clone(),
401                    message,
402                    seq,
403                    &mut recv_state.last_seqs,
404                )?;
405                recv_state.seq = seq;
406
407                // Also deliver any pending operations from the recv buffer that
408                // were received out-of-order that are now unblocked.
409                while let Some(Buffered {
410                    seq,
411                    deliver_here,
412                    next_steps,
413                    message,
414                }) = recv_state.buffer.remove(&recv_state.seq)
415                {
416                    Self::handle_message(
417                        cx,
418                        &self.mode,
419                        deliver_here,
420                        next_steps,
421                        sender.clone(),
422                        message,
423                        seq,
424                        &mut recv_state.last_seqs,
425                    )?;
426                    recv_state.seq = seq;
427                }
428            }
429            // We got an out-of-order operation, so buffer it for now, until we
430            // recieved the onces sequenced before it.
431            Ordering::Less => {
432                tracing::warn!(
433                    "buffering out-of-order message with seq {} (last {}), expected {}: {:?}",
434                    seq,
435                    last_seq,
436                    recv_state.seq,
437                    message
438                );
439                recv_state.buffer.insert(
440                    last_seq,
441                    Buffered {
442                        seq,
443                        deliver_here,
444                        next_steps,
445                        message,
446                    },
447                );
448            }
449            // We already got this message -- just drop it.
450            Ordering::Greater => {
451                tracing::warn!("received duplicate message with seq {}: {:?}", seq, message);
452            }
453        }
454
455        Ok(())
456    }
457}
458
459pub mod test_utils {
460    use anyhow::Result;
461    use async_trait::async_trait;
462    use hyperactor::Actor;
463    use hyperactor::ActorId;
464    use hyperactor::Bind;
465    use hyperactor::Context;
466    use hyperactor::Handler;
467    use hyperactor::Named;
468    use hyperactor::PortRef;
469    use hyperactor::Unbind;
470    use serde::Deserialize;
471    use serde::Serialize;
472
473    use super::*;
474
475    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
476    pub struct MyReply {
477        pub sender: ActorId,
478        pub value: u64,
479    }
480
481    #[derive(Debug, Named, Serialize, Deserialize, PartialEq, Clone, Bind, Unbind)]
482    pub enum TestMessage {
483        Forward(String),
484        CastAndReply {
485            arg: String,
486            // Intentionally not including 0. As a result, this port will not be
487            // split.
488            // #[binding(include)]
489            reply_to0: PortRef<String>,
490            #[binding(include)]
491            reply_to1: PortRef<u64>,
492            #[binding(include)]
493            reply_to2: PortRef<MyReply>,
494        },
495    }
496
497    #[derive(Debug)]
498    #[hyperactor::export(
499        spawn = true,
500        handlers = [
501            TestMessage { cast = true },
502        ],
503    )]
504    pub struct TestActor {
505        // Forward the received message to this port, so it can be inspected by
506        // the unit test.
507        forward_port: PortRef<TestMessage>,
508    }
509
510    #[derive(Debug, Clone, Named, Serialize, Deserialize)]
511    pub struct TestActorParams {
512        pub forward_port: PortRef<TestMessage>,
513    }
514
515    #[async_trait]
516    impl Actor for TestActor {
517        type Params = TestActorParams;
518
519        async fn new(params: Self::Params) -> Result<Self> {
520            let Self::Params { forward_port } = params;
521            Ok(Self { forward_port })
522        }
523    }
524
525    #[async_trait]
526    impl Handler<TestMessage> for TestActor {
527        async fn handle(&mut self, cx: &Context<Self>, msg: TestMessage) -> anyhow::Result<()> {
528            self.forward_port.send(cx, msg)?;
529            Ok(())
530        }
531    }
532}
533
534#[cfg(test)]
535mod tests {
536    use std::collections::BTreeMap;
537    use std::collections::HashSet;
538    use std::fmt::Display;
539    use std::hash::Hash;
540    use std::ops::Deref;
541    use std::ops::DerefMut;
542    use std::sync::Arc;
543    use std::sync::Mutex;
544    use std::sync::OnceLock;
545
546    use hyperactor::PortId;
547    use hyperactor::PortRef;
548    use hyperactor::accum;
549    use hyperactor::accum::Accumulator;
550    use hyperactor::accum::ReducerSpec;
551    use hyperactor::channel::ChannelTransport;
552    use hyperactor::clock::Clock;
553    use hyperactor::clock::RealClock;
554    use hyperactor::config;
555    use hyperactor::context::Mailbox;
556    use hyperactor::mailbox::PortReceiver;
557    use hyperactor::mailbox::open_port;
558    use hyperactor::reference::Index;
559    use hyperactor_mesh_macros::sel;
560    use maplit::btreemap;
561    use maplit::hashmap;
562    use ndslice::Extent;
563    use ndslice::Selection;
564    use ndslice::ViewExt as _;
565    use ndslice::extent;
566    use ndslice::selection::test_utils::collect_commactor_routing_tree;
567    use test_utils::*;
568    use timed_test::async_timed_test;
569    use tokio::time::Duration;
570
571    use super::*;
572    use crate::ProcMesh;
573    use crate::actor_mesh::ActorMesh;
574    use crate::actor_mesh::RootActorMesh;
575    use crate::alloc::AllocSpec;
576    use crate::alloc::Allocator;
577    use crate::alloc::LocalAllocator;
578    use crate::proc_mesh::SharedSpawnable;
579    use crate::v1;
580
581    struct Edge<T> {
582        from: T,
583        to: T,
584        is_leaf: bool,
585    }
586
587    impl<T> From<(T, T, bool)> for Edge<T> {
588        fn from((from, to, is_leaf): (T, T, bool)) -> Self {
589            Self { from, to, is_leaf }
590        }
591    }
592
593    // The relationship between original ports and split ports. The elements in
594    // the tuple are (original port, split port, deliver_here).
595    static SPLIT_PORT_TREE: OnceLock<Mutex<Vec<Edge<PortId>>>> = OnceLock::new();
596
597    // Collect the relationships between original ports and split ports into
598    // SPLIT_PORT_TREE. This is used by tests to verify that ports are split as expected.
599    pub(crate) fn collect_split_port(original: &PortId, split: &PortId, deliver_here: bool) {
600        let mutex = SPLIT_PORT_TREE.get_or_init(|| Mutex::new(vec![]));
601        let mut tree = mutex.lock().unwrap();
602
603        tree.deref_mut().push(Edge {
604            from: original.clone(),
605            to: split.clone(),
606            is_leaf: deliver_here,
607        });
608    }
609
610    // There could be other cast calls before the one we want to check, e.g. from
611    // allocating the proc mesh, or spawning the actor mesh. Clear the collected
612    // tree so it will only contain the cast we want to check.
613    fn clear_collected_tree() {
614        if let Some(tree) = SPLIT_PORT_TREE.get() {
615            let mut tree = tree.lock().unwrap();
616            tree.clear();
617        }
618    }
619
620    // A representation of a tree.
621    //   * Map's keys are the tree's leafs;
622    //   * Map's values are the path from the root to that leaf.
623    #[derive(PartialEq)]
624    struct PathToLeaves<T>(BTreeMap<T, Vec<T>>);
625
626    // Add a custom Debug trait impl so the result from assert_eq! is readable.
627    impl<T: Display> Debug for PathToLeaves<T> {
628        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
629            fn vec_to_string<T: Display>(v: &[T]) -> String {
630                v.iter()
631                    .map(ToString::to_string)
632                    .collect::<Vec<String>>()
633                    .join(", ")
634            }
635
636            for (src, path) in &self.0 {
637                write!(f, "{} -> {}\n", src, vec_to_string(path))?;
638            }
639            Ok(())
640        }
641    }
642
643    fn build_paths<T: Clone + Eq + Hash + Ord>(edges: &[Edge<T>]) -> PathToLeaves<T> {
644        let mut child_parent_map = HashMap::new();
645        let mut all_nodes = HashSet::new();
646        let mut parents = HashSet::new();
647        let mut children = HashSet::new();
648        let mut dests = HashSet::new();
649
650        // Build parent map and track all nodes and children
651        for Edge { from, to, is_leaf } in edges {
652            child_parent_map.insert(to.clone(), from.clone());
653            all_nodes.insert(from.clone());
654            all_nodes.insert(to.clone());
655            parents.insert(from.clone());
656            children.insert(to.clone());
657            if *is_leaf {
658                dests.insert(to.clone());
659            }
660        }
661
662        // For each leaf, reconstruct path back to root
663        let mut result = BTreeMap::new();
664        for dest in dests {
665            let mut path = vec![dest.clone()];
666            let mut current = dest.clone();
667            while let Some(parent) = child_parent_map.get(&current) {
668                path.push(parent.clone());
669                current = parent.clone();
670            }
671            path.reverse();
672            result.insert(dest, path);
673        }
674
675        PathToLeaves(result)
676    }
677
678    #[test]
679    fn test_build_paths() {
680        // Given the tree:
681        //     0
682        //    / \
683        //   1   4
684        //  / \   \
685        // 2   3   5
686        let edges: Vec<_> = [
687            (0, 1, false),
688            (1, 2, true),
689            (1, 3, true),
690            (0, 4, true),
691            (4, 5, true),
692        ]
693        .into_iter()
694        .map(|(from, to, is_leaf)| Edge { from, to, is_leaf })
695        .collect();
696
697        let paths = build_paths(&edges);
698
699        let expected = btreemap! {
700            2 => vec![0, 1, 2],
701            3 => vec![0, 1, 3],
702            4 => vec![0, 4],
703            5 => vec![0, 4, 5],
704        };
705
706        assert_eq!(paths.0, expected);
707    }
708
709    //  Given a port tree,
710    //     * remove the client port, i.e. the 1st element of the path;
711    //     * verify all remaining ports are comm actor ports;
712    //     * remove the actor information and return a rank-based tree representation.
713    //
714    //  The rank-based tree representation is what [collect_commactor_routing_tree] returns.
715    //  This conversion enables us to compare the path against [collect_commactor_routing_tree]'s result.
716    //
717    //      For example, for a 2x2 slice, the port tree could look like:
718    //      dest[0].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028]]
719    //      dest[1].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[1].comm[0][1028]]
720    //      dest[2].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[2].comm[0][1028]]
721    //      dest[3].comm[0][1028] -> [client[0].client_user[0][1025], dest[0].comm[0][1028], dest[2].comm[0][1028], dest[3].comm[0][1028]]
722    //
723    //     The result should be:
724    //     0 -> 0
725    //     1 -> 0, 1
726    //     2 -> 0, 2
727    //     3 -> 0, 2, 3
728    fn get_ranks(paths: PathToLeaves<PortId>, client_reply: &PortId) -> PathToLeaves<Index> {
729        let ranks = paths
730            .0
731            .into_iter()
732            .map(|(dst, mut path)| {
733                let first = path.remove(0);
734                // The first PortId is the client's reply port.
735                assert_eq!(&first, client_reply);
736                // Other ports's actor ID must be dest[?].comm[0], where ? is
737                // the rank we want to extract here.
738                assert!(dst.actor_id().name().contains("comm"));
739                let actor_path = path
740                    .into_iter()
741                    .map(|p| {
742                        assert!(p.actor_id().name().contains("comm"));
743                        p.actor_id().rank()
744                    })
745                    .collect();
746                (dst.into_actor_id().rank(), actor_path)
747            })
748            .collect();
749        PathToLeaves(ranks)
750    }
751
752    struct MeshSetup {
753        actor_mesh: RootActorMesh<'static, TestActor>,
754        reply1_rx: PortReceiver<u64>,
755        reply2_rx: PortReceiver<MyReply>,
756        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
757    }
758
759    struct NoneAccumulator;
760
761    impl Accumulator for NoneAccumulator {
762        type State = u64;
763        type Update = u64;
764
765        fn accumulate(
766            &self,
767            _state: &mut Self::State,
768            _update: Self::Update,
769        ) -> anyhow::Result<()> {
770            unimplemented!()
771        }
772
773        fn reducer_spec(&self) -> Option<ReducerSpec> {
774            unimplemented!()
775        }
776    }
777
778    // Verify the split port paths are the same as the casting paths.
779    fn verify_split_port_paths(
780        selection: &Selection,
781        extent: &Extent,
782        reply_port_ref1: &PortRef<u64>,
783        reply_port_ref2: &PortRef<MyReply>,
784    ) {
785        // Get the paths used in casting
786        let sel_paths = PathToLeaves(
787            collect_commactor_routing_tree(selection, &extent.to_slice())
788                .delivered
789                .into_iter()
790                .collect(),
791        );
792
793        // Get the split port paths collected in SPLIT_PORT_TREE during casting
794        let (reply1_paths, reply2_paths) = {
795            let tree = SPLIT_PORT_TREE.get().unwrap();
796            let edges = tree.lock().unwrap();
797            let (reply1, reply2): (BTreeMap<_, _>, BTreeMap<_, _>) = build_paths(&edges)
798                .0
799                .into_iter()
800                .partition(|(_dst, path)| &path[0] == reply_port_ref1.port_id());
801            (
802                get_ranks(PathToLeaves(reply1), reply_port_ref1.port_id()),
803                get_ranks(PathToLeaves(reply2), reply_port_ref2.port_id()),
804            )
805        };
806
807        // split port paths should be the same as casting paths
808        assert_eq!(sel_paths, reply1_paths);
809        assert_eq!(sel_paths, reply2_paths);
810    }
811
812    async fn setup_mesh<A>(accum: Option<A>) -> MeshSetup
813    where
814        A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
815    {
816        let extent = extent!(replica = 4, host = 4, gpu = 4);
817        let alloc = LocalAllocator
818            .allocate(AllocSpec {
819                extent: extent.clone(),
820                constraints: Default::default(),
821                proc_name: None,
822                transport: ChannelTransport::Local,
823            })
824            .await
825            .unwrap();
826
827        let proc_mesh = Arc::new(ProcMesh::allocate(alloc).await.unwrap());
828        let dest_actor_name = "dest_actor";
829        let (tx, mut rx) = hyperactor::mailbox::open_port(proc_mesh.client());
830        let params = TestActorParams {
831            forward_port: tx.bind(),
832        };
833        let actor_mesh = proc_mesh
834            .clone()
835            .spawn::<TestActor>(dest_actor_name, &params)
836            .await
837            .unwrap();
838
839        let (reply_port_handle0, _) = open_port::<String>(proc_mesh.client());
840        let reply_port_ref0 = reply_port_handle0.bind();
841        let (reply_port_handle1, reply1_rx) = match accum {
842            Some(a) => proc_mesh.client().mailbox().open_accum_port(a),
843            None => open_port(proc_mesh.client()),
844        };
845        let reply_port_ref1 = reply_port_handle1.bind();
846        let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(proc_mesh.client());
847        let reply_port_ref2 = reply_port_handle2.bind();
848        let message = TestMessage::CastAndReply {
849            arg: "abc".to_string(),
850            reply_to0: reply_port_ref0.clone(),
851            reply_to1: reply_port_ref1.clone(),
852            reply_to2: reply_port_ref2.clone(),
853        };
854
855        let selection = sel!(*);
856        clear_collected_tree();
857        actor_mesh
858            .cast(proc_mesh.client(), selection.clone(), message)
859            .unwrap();
860
861        let mut reply_tos = vec![];
862        for _ in extent.points() {
863            let msg = rx.recv().await.expect("missing");
864            match msg {
865                TestMessage::CastAndReply {
866                    arg,
867                    reply_to0,
868                    reply_to1,
869                    reply_to2,
870                } => {
871                    assert_eq!(arg, "abc");
872                    // port 0 is still the same as the original one because it
873                    // is not included in MutVisitor.
874                    assert_eq!(reply_to0, reply_port_ref0);
875                    // ports have been replaced by comm actor's split ports.
876                    assert_ne!(reply_to1, reply_port_ref1);
877                    assert_eq!(reply_to1.port_id().actor_id().name(), "comm");
878                    assert_ne!(reply_to2, reply_port_ref2);
879                    assert_eq!(reply_to2.port_id().actor_id().name(), "comm");
880                    reply_tos.push((reply_to1, reply_to2));
881                }
882                _ => {
883                    panic!("unexpected message: {:?}", msg);
884                }
885            }
886        }
887
888        verify_split_port_paths(&selection, &extent, &reply_port_ref1, &reply_port_ref2);
889
890        MeshSetup {
891            actor_mesh,
892            reply1_rx,
893            reply2_rx,
894            reply_tos,
895        }
896    }
897
898    async fn execute_cast_and_reply(
899        ranks: Vec<ActorRef<TestActor>>,
900        instance: &Instance<()>,
901        mut reply1_rx: PortReceiver<u64>,
902        mut reply2_rx: PortReceiver<MyReply>,
903        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
904    ) {
905        // Reply from each dest actor. The replies should be received by client.
906        {
907            for (dest_actor, (reply_to1, reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
908                let rank = dest_actor.actor_id().rank() as u64;
909                reply_to1.send(instance, rank).unwrap();
910                let my_reply = MyReply {
911                    sender: dest_actor.actor_id().clone(),
912                    value: rank,
913                };
914                reply_to2.send(instance, my_reply.clone()).unwrap();
915
916                assert_eq!(reply1_rx.recv().await.unwrap(), rank);
917                assert_eq!(reply2_rx.recv().await.unwrap(), my_reply);
918            }
919        }
920
921        tracing::info!("the 1st updates from all dest actors were receivered by client");
922
923        // Now send multiple replies from the dest actors. They should all be
924        // received by client. Replies sent from the same dest actor should
925        // be received in the same order as they were sent out.
926        {
927            let n = 100;
928            let mut expected2: HashMap<usize, Vec<MyReply>> = hashmap! {};
929            for (dest_actor, (_reply_to1, reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
930                let rank = dest_actor.actor_id().rank();
931                let mut sent2 = vec![];
932                for i in 0..n {
933                    let value = (rank * 100 + i) as u64;
934                    let my_reply = MyReply {
935                        sender: dest_actor.actor_id().clone(),
936                        value,
937                    };
938                    reply_to2.send(instance, my_reply.clone()).unwrap();
939                    sent2.push(my_reply);
940                }
941                assert!(
942                    expected2.insert(rank, sent2).is_none(),
943                    "duplicate rank {rank} in map"
944                );
945            }
946
947            let mut received2: HashMap<usize, Vec<MyReply>> = hashmap! {};
948
949            for _ in 0..(n * ranks.len()) {
950                let my_reply = reply2_rx.recv().await.unwrap();
951                received2
952                    .entry(my_reply.sender.rank())
953                    .or_default()
954                    .push(my_reply);
955            }
956            assert_eq!(received2, expected2);
957        }
958    }
959
960    #[async_timed_test(timeout_secs = 30)]
961    async fn test_cast_and_reply() {
962        let MeshSetup {
963            actor_mesh,
964            reply1_rx,
965            reply2_rx,
966            reply_tos,
967            ..
968        } = setup_mesh::<NoneAccumulator>(None).await;
969        let proc_mesh_client = actor_mesh.proc_mesh().client();
970
971        let ranks = actor_mesh.ranks.clone();
972        execute_cast_and_reply(ranks, proc_mesh_client, reply1_rx, reply2_rx, reply_tos).await;
973    }
974
975    async fn wait_for_with_timeout(
976        receiver: &mut PortReceiver<u64>,
977        expected: u64,
978        dur: Duration,
979    ) -> anyhow::Result<()> {
980        // timeout wraps the entire async block
981        RealClock
982            .timeout(dur, async {
983                loop {
984                    let msg = receiver.recv().await.unwrap();
985                    if msg == expected {
986                        break;
987                    }
988                }
989            })
990            .await?;
991        Ok(())
992    }
993
994    async fn execute_cast_and_accum(
995        ranks: Vec<ActorRef<TestActor>>,
996        instance: &Instance<()>,
997        mut reply1_rx: PortReceiver<u64>,
998        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
999    ) {
1000        // Now send multiple replies from the dest actors. They should all be
1001        // received by client. Replies sent from the same dest actor should
1002        // be received in the same order as they were sent out.
1003        let mut sum = 0;
1004        let n = 100;
1005        for (dest_actor, (reply_to1, _reply_to2)) in ranks.iter().zip(reply_tos.iter()) {
1006            let rank = dest_actor.actor_id().rank();
1007            for i in 0..n {
1008                let value = (rank + i) as u64;
1009                reply_to1.send(instance, value).unwrap();
1010                sum += value;
1011            }
1012        }
1013        wait_for_with_timeout(&mut reply1_rx, sum, Duration::from_secs(2))
1014            .await
1015            .unwrap();
1016        // no more messages
1017        RealClock.sleep(Duration::from_secs(2)).await;
1018        let msg = reply1_rx.try_recv().unwrap();
1019        assert_eq!(msg, None);
1020    }
1021
1022    #[async_timed_test(timeout_secs = 30)]
1023    async fn test_cast_and_accum() {
1024        let config = config::global::lock();
1025        // Use temporary config for this test
1026        let _guard1 = config.override_key(config::SPLIT_MAX_BUFFER_SIZE, 1);
1027
1028        let MeshSetup {
1029            actor_mesh,
1030            reply1_rx,
1031            reply_tos,
1032            ..
1033        } = setup_mesh(Some(accum::sum::<u64>())).await;
1034        let proc_mesh_client = actor_mesh.proc_mesh().client();
1035        let ranks = actor_mesh.ranks.clone();
1036        execute_cast_and_accum(ranks, proc_mesh_client, reply1_rx, reply_tos).await;
1037    }
1038
1039    struct MeshSetupV1 {
1040        instance: &'static Instance<()>,
1041        actor_mesh_ref: v1::ActorMeshRef<TestActor>,
1042        reply1_rx: PortReceiver<u64>,
1043        reply2_rx: PortReceiver<MyReply>,
1044        reply_tos: Vec<(PortRef<u64>, PortRef<MyReply>)>,
1045    }
1046
1047    async fn setup_mesh_v1<A>(accum: Option<A>) -> MeshSetupV1
1048    where
1049        A: Accumulator<Update = u64, State = u64> + Send + Sync + 'static,
1050    {
1051        let instance = v1::testing::instance().await;
1052
1053        let extent = extent!(replica = 4, host = 4, gpu = 4);
1054        let alloc = LocalAllocator
1055            .allocate(AllocSpec {
1056                extent: extent.clone(),
1057                constraints: Default::default(),
1058                proc_name: None,
1059                transport: ChannelTransport::Local,
1060            })
1061            .await
1062            .unwrap();
1063
1064        let proc_mesh = v1::ProcMesh::allocate(instance, Box::new(alloc), "test.local")
1065            .await
1066            .unwrap();
1067
1068        let (tx, mut rx) = hyperactor::mailbox::open_port(instance);
1069        let params = TestActorParams {
1070            forward_port: tx.bind(),
1071        };
1072        let actor_mesh = proc_mesh.spawn(&instance, "test", &params).await.unwrap();
1073        let actor_mesh_ref = actor_mesh.deref().clone();
1074
1075        let (reply_port_handle0, _) = open_port::<String>(instance);
1076        let reply_port_ref0 = reply_port_handle0.bind();
1077        let (reply_port_handle1, reply1_rx) = match accum {
1078            Some(a) => instance.mailbox().open_accum_port(a),
1079            None => open_port(instance),
1080        };
1081        let reply_port_ref1 = reply_port_handle1.bind();
1082        let (reply_port_handle2, reply2_rx) = open_port::<MyReply>(instance);
1083        let reply_port_ref2 = reply_port_handle2.bind();
1084        let message = TestMessage::CastAndReply {
1085            arg: "abc".to_string(),
1086            reply_to0: reply_port_ref0.clone(),
1087            reply_to1: reply_port_ref1.clone(),
1088            reply_to2: reply_port_ref2.clone(),
1089        };
1090
1091        clear_collected_tree();
1092        actor_mesh_ref.cast(instance, message).unwrap();
1093
1094        let mut reply_tos = vec![];
1095        for _ in extent.points() {
1096            let msg = rx.recv().await.expect("missing");
1097            match msg {
1098                TestMessage::CastAndReply {
1099                    arg,
1100                    reply_to0,
1101                    reply_to1,
1102                    reply_to2,
1103                } => {
1104                    assert_eq!(arg, "abc");
1105                    // port 0 is still the same as the original one because it
1106                    // is not included in MutVisitor.
1107                    assert_eq!(reply_to0, reply_port_ref0);
1108                    // ports have been replaced by comm actor's split ports.
1109                    assert_ne!(reply_to1, reply_port_ref1);
1110                    assert!(reply_to1.port_id().actor_id().name().contains("comm"));
1111                    assert_ne!(reply_to2, reply_port_ref2);
1112                    assert!(reply_to2.port_id().actor_id().name().contains("comm"));
1113                    reply_tos.push((reply_to1, reply_to2));
1114                }
1115                _ => {
1116                    panic!("unexpected message: {:?}", msg);
1117                }
1118            }
1119        }
1120
1121        // v1 always uses sel!(*) when casting to a mesh.
1122        let selection = sel!(*);
1123        verify_split_port_paths(&selection, &extent, &reply_port_ref1, &reply_port_ref2);
1124
1125        MeshSetupV1 {
1126            instance,
1127            actor_mesh_ref,
1128            reply1_rx,
1129            reply2_rx,
1130            reply_tos,
1131        }
1132    }
1133
1134    #[async_timed_test(timeout_secs = 30)]
1135    async fn test_cast_and_reply_v1() {
1136        let MeshSetupV1 {
1137            instance,
1138            actor_mesh_ref,
1139            reply1_rx,
1140            reply2_rx,
1141            reply_tos,
1142            ..
1143        } = setup_mesh_v1::<NoneAccumulator>(None).await;
1144
1145        let ranks = actor_mesh_ref.values().collect::<Vec<_>>();
1146        execute_cast_and_reply(ranks, instance, reply1_rx, reply2_rx, reply_tos).await;
1147    }
1148
1149    #[async_timed_test(timeout_secs = 30)]
1150    async fn test_cast_and_accum_v1() {
1151        let config = config::global::lock();
1152        // Use temporary config for this test
1153        let _guard1 = config.override_key(config::SPLIT_MAX_BUFFER_SIZE, 1);
1154
1155        let MeshSetupV1 {
1156            instance,
1157            actor_mesh_ref,
1158            reply1_rx,
1159            reply_tos,
1160            ..
1161        } = setup_mesh_v1(Some(accum::sum::<u64>())).await;
1162
1163        let ranks = actor_mesh_ref.values().collect::<Vec<_>>();
1164        execute_cast_and_accum(ranks, instance, reply1_rx, reply_tos).await;
1165    }
1166}