hyperactor_mesh/comm/
multicast.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! The comm actor that provides message casting and result accumulation.
10
11use hyperactor::Actor;
12use hyperactor::Context;
13use hyperactor::RemoteHandles;
14use hyperactor::RemoteMessage;
15use hyperactor::actor::Referable;
16use hyperactor::message::Castable;
17use hyperactor::message::ErasedUnbound;
18use hyperactor::message::IndexedErasedUnbound;
19use hyperactor::reference::ActorId;
20use hyperactor_config::attrs::Attrs;
21use hyperactor_config::attrs::declare_attrs;
22use ndslice::Extent;
23use ndslice::Point;
24use ndslice::Shape;
25use ndslice::Slice;
26use ndslice::selection::Selection;
27use ndslice::selection::routing::RoutingFrame;
28use serde::Deserialize;
29use serde::Serialize;
30use typeuri::Named;
31
32use crate::reference::ActorMeshId;
33
34/// A union of slices that can be used to represent arbitrary subset of
35/// ranks in a gang. It is represented by a Slice together with a Selection.
36/// This is used to define the destination of a cast message or the source of
37/// accumulation request.
38#[derive(Serialize, Deserialize, Debug, Clone)]
39pub struct Uslice {
40    /// A slice representing a whole gang.
41    pub slice: Slice,
42    /// A selection used to represent any subset of the gang.
43    pub selection: Selection,
44}
45
46/// An envelope that carries a message destined to a group of actors.
47#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)]
48pub struct CastMessageEnvelope {
49    /// The destination actor mesh id.
50    actor_mesh_id: ActorMeshId,
51    /// The sender of this message.
52    sender: ActorId,
53    /// The destination port of the message. It could match multiple actors with
54    /// rank wildcard.
55    dest_port: DestinationPort,
56    /// The serialized message.
57    data: ErasedUnbound,
58    /// The shape of the cast.
59    shape: Shape,
60}
61wirevalue::register_type!(CastMessageEnvelope);
62
63impl CastMessageEnvelope {
64    /// Create a new CastMessageEnvelope.
65    pub fn new<A, M>(
66        actor_mesh_id: ActorMeshId,
67        sender: ActorId,
68        shape: Shape,
69        message: M,
70    ) -> Result<Self, anyhow::Error>
71    where
72        A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
73        M: Castable + RemoteMessage,
74    {
75        let data = ErasedUnbound::try_from_message(message)?;
76        let actor_name = match &actor_mesh_id {
77            ActorMeshId::V0(_, actor_name) => actor_name.clone(),
78            ActorMeshId::V1(name) => name.to_string(),
79        };
80        Ok(Self {
81            actor_mesh_id,
82            sender,
83            dest_port: DestinationPort::new::<A, M>(actor_name),
84            data,
85            shape,
86        })
87    }
88
89    /// Create a new CastMessageEnvelope from serialized data. Only use this
90    /// when the message do not contain reply ports. Or it does but you are okay
91    /// with the destination actors reply to the client actor directly.
92    pub fn from_serialized(
93        actor_mesh_id: ActorMeshId,
94        sender: ActorId,
95        dest_port: DestinationPort,
96        shape: Shape,
97        data: wirevalue::Any,
98    ) -> Self {
99        Self {
100            actor_mesh_id,
101            sender,
102            dest_port,
103            data: ErasedUnbound::new(data),
104            shape,
105        }
106    }
107
108    pub(crate) fn sender(&self) -> &ActorId {
109        &self.sender
110    }
111
112    pub(crate) fn dest_port(&self) -> &DestinationPort {
113        &self.dest_port
114    }
115
116    pub(crate) fn data(&self) -> &ErasedUnbound {
117        &self.data
118    }
119
120    pub(crate) fn data_mut(&mut self) -> &mut ErasedUnbound {
121        &mut self.data
122    }
123
124    pub(crate) fn shape(&self) -> &Shape {
125        &self.shape
126    }
127
128    /// Given a rank in the root shape, return the corresponding point in the
129    /// provided shape, which is a view of the root shape.
130    pub(crate) fn relative_rank(&self, rank_on_root_mesh: usize) -> anyhow::Result<usize> {
131        let shape = self.shape();
132        let coords = shape.slice().coordinates(rank_on_root_mesh).map_err(|e| {
133            anyhow::anyhow!(
134                "fail to calculate coords for root rank {} due to error: {}; shape is {:?}",
135                rank_on_root_mesh,
136                e,
137                shape,
138            )
139        })?;
140        let extent =
141            Extent::new(shape.labels().to_vec(), shape.slice().sizes().to_vec()).map_err(|e| {
142                anyhow::anyhow!(
143                    "fail to calculate extent for root rank {} due to error: {}; shape is {}",
144                    rank_on_root_mesh,
145                    e,
146                    shape,
147                )
148            })?;
149        let point = extent.point(coords).map_err(|e| {
150            anyhow::anyhow!(
151                "fail to calculate point for root rank {} due to error: {}; extent is {}, shape is {}",
152                rank_on_root_mesh,
153                e,
154                extent,
155                shape,
156            )
157        })?;
158        Ok(point.rank())
159    }
160
161    /// The unique key used to indicate the stream to which to deliver this message.
162    /// Concretely, the comm actors along the path should use this key to manage
163    /// sequence numbers and reorder buffers.
164    pub(crate) fn stream_key(&self) -> (ActorMeshId, ActorId) {
165        (self.actor_mesh_id.clone(), self.sender.clone())
166    }
167}
168
169/// Destination port id of a message. It is a `PortId` with the rank masked out,
170/// and the messege is always sent to the root actor because only root actor
171/// can be accessed externally. The rank is resolved by the destination Selection
172/// of the message. We can use `DestinationPort::port_id(rank)` to get the actual
173/// `PortId` of the message.
174#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)]
175pub struct DestinationPort {
176    /// The actor name to which the message should be delivered.
177    actor_name: String,
178    /// The port index of the destination actors, it is derived from the
179    /// message type and cached here.
180    port: u64,
181}
182wirevalue::register_type!(DestinationPort);
183
184impl DestinationPort {
185    /// Create a new DestinationPort for a global actor name and message type.
186    pub fn new<A, M>(actor_name: String) -> Self
187    where
188        A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
189        M: Castable + RemoteMessage,
190    {
191        Self {
192            actor_name,
193            port: IndexedErasedUnbound::<M>::port(),
194        }
195    }
196
197    /// The port id of the destination.
198    pub fn port(&self) -> u64 {
199        self.port
200    }
201
202    /// Get the actor name of the destination.
203    pub fn actor_name(&self) -> &str {
204        &self.actor_name
205    }
206}
207
208/// The is used to start casting a message to a group of actors.
209#[derive(Serialize, Deserialize, Debug, Clone, Named)]
210pub struct CastMessage {
211    /// The cast destination.
212    pub dest: Uslice,
213    /// The message to cast.
214    pub message: CastMessageEnvelope,
215}
216wirevalue::register_type!(CastMessage);
217
218/// Forward a message to procs of next hops. This is used by comm actor to
219/// forward a message to other comm actors following the selection topology.
220/// This message is not visible to the clients.
221#[derive(Serialize, Deserialize, Debug, Clone, Named)]
222pub(crate) struct ForwardMessage {
223    /// The comm actor who originally casted the message.
224    pub(crate) sender: ActorId,
225    /// The destination of the message.
226    pub(crate) dests: Vec<RoutingFrame>,
227    /// The sequence number of this message.
228    pub(crate) seq: usize,
229    /// The sequence number of the previous message receieved.
230    pub(crate) last_seq: usize,
231    /// The message to distribute.
232    pub(crate) message: CastMessageEnvelope,
233}
234wirevalue::register_type!(ForwardMessage);
235
236declare_attrs! {
237    /// Used inside headers to store the originating sender of a cast.
238    pub attr CAST_ORIGINATING_SENDER: ActorId;
239
240    /// The point in the casted region that this message was sent to.
241    pub attr CAST_POINT: Point;
242}
243
244pub fn set_cast_info_on_headers(headers: &mut Attrs, cast_point: Point, sender: ActorId) {
245    headers.set(CAST_POINT, cast_point);
246    headers.set(CAST_ORIGINATING_SENDER, sender);
247}
248
249pub trait CastInfo {
250    /// Get the cast rank and cast shape.
251    /// If something wasn't explicitly sent via a cast, then
252    /// we represent it as the only member of a 0-dimensonal cast shape,
253    /// which is the same as a singleton.
254    fn cast_point(&self) -> Point;
255    fn sender(&self) -> &ActorId;
256}
257
258impl<A: Actor> CastInfo for Context<'_, A> {
259    fn cast_point(&self) -> Point {
260        match self.headers().get(CAST_POINT) {
261            Some(point) => point.clone(),
262            None => Extent::unity().point_of_rank(0).unwrap(),
263        }
264    }
265
266    fn sender(&self) -> &ActorId {
267        self.headers()
268            .get(CAST_ORIGINATING_SENDER)
269            .expect("has sender header")
270    }
271}