hyperactor_mesh/comm/
multicast.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! The comm actor that provides message casting and result accumulation.
10
11use hyperactor::Actor;
12use hyperactor::Context;
13use hyperactor::Named;
14use hyperactor::RemoteHandles;
15use hyperactor::RemoteMessage;
16use hyperactor::actor::RemoteActor;
17use hyperactor::attrs::Attrs;
18use hyperactor::data::Serialized;
19use hyperactor::declare_attrs;
20use hyperactor::message::Castable;
21use hyperactor::message::ErasedUnbound;
22use hyperactor::message::IndexedErasedUnbound;
23use hyperactor::reference::ActorId;
24use ndslice::Extent;
25use ndslice::Shape;
26use ndslice::Slice;
27use ndslice::selection::Selection;
28use ndslice::selection::routing::RoutingFrame;
29use serde::Deserialize;
30use serde::Serialize;
31
32use crate::reference::ActorMeshId;
33
34/// A union of slices that can be used to represent arbitrary subset of
35/// ranks in a gang. It is represented by a Slice together with a Selection.
36/// This is used to define the destination of a cast message or the source of
37/// accumulation request.
38#[derive(Serialize, Deserialize, Debug, Clone)]
39pub struct Uslice {
40    /// A slice representing a whole gang.
41    pub slice: Slice,
42    /// A selection used to represent any subset of the gang.
43    pub selection: Selection,
44}
45
46/// An envelope that carries a message destined to a group of actors.
47#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)]
48pub struct CastMessageEnvelope {
49    /// The destination actor mesh id.
50    actor_mesh_id: ActorMeshId,
51    /// The sender of this message.
52    sender: ActorId,
53    /// The destination port of the message. It could match multiple actors with
54    /// rank wildcard.
55    dest_port: DestinationPort,
56    /// The serialized message.
57    data: ErasedUnbound,
58    /// The shape of the cast.
59    shape: Shape,
60}
61
62impl CastMessageEnvelope {
63    /// Create a new CastMessageEnvelope.
64    pub fn new<A, M>(
65        actor_mesh_id: ActorMeshId,
66        sender: ActorId,
67        shape: Shape,
68        message: M,
69    ) -> Result<Self, anyhow::Error>
70    where
71        A: RemoteActor + RemoteHandles<IndexedErasedUnbound<M>>,
72        M: Castable + RemoteMessage,
73    {
74        let data = ErasedUnbound::try_from_message(message)?;
75        let actor_name = actor_mesh_id.1.to_string();
76        Ok(Self {
77            actor_mesh_id,
78            sender,
79            dest_port: DestinationPort::new::<A, M>(actor_name),
80            data,
81            shape,
82        })
83    }
84
85    /// Create a new CastMessageEnvelope from serialized data. Only use this
86    /// when the message do not contain reply ports. Or it does but you are okay
87    /// with the destination actors reply to the client actor directly.
88    pub fn from_serialized(
89        actor_mesh_id: ActorMeshId,
90        sender: ActorId,
91        dest_port: DestinationPort,
92        shape: Shape,
93        data: Serialized,
94    ) -> Self {
95        Self {
96            actor_mesh_id,
97            sender,
98            dest_port,
99            data: ErasedUnbound::new(data),
100            shape,
101        }
102    }
103
104    pub(crate) fn sender(&self) -> &ActorId {
105        &self.sender
106    }
107
108    pub(crate) fn dest_port(&self) -> &DestinationPort {
109        &self.dest_port
110    }
111
112    pub(crate) fn data(&self) -> &ErasedUnbound {
113        &self.data
114    }
115
116    pub(crate) fn data_mut(&mut self) -> &mut ErasedUnbound {
117        &mut self.data
118    }
119
120    pub(crate) fn shape(&self) -> &Shape {
121        &self.shape
122    }
123
124    /// Given a rank in the root shape, return the corresponding point in the
125    /// provided shape, which is a view of the root shape.
126    pub(crate) fn relative_rank(&self, rank_on_root_mesh: usize) -> anyhow::Result<usize> {
127        let shape = self.shape();
128        let coords = shape.slice().coordinates(rank_on_root_mesh).map_err(|e| {
129            anyhow::anyhow!(
130                "fail to calculate coords for root rank {} due to error: {}; shape is {:?}",
131                rank_on_root_mesh,
132                e,
133                shape,
134            )
135        })?;
136        let extent =
137            Extent::new(shape.labels().to_vec(), shape.slice().sizes().to_vec()).map_err(|e| {
138                anyhow::anyhow!(
139                    "fail to calculate extent for root rank {} due to error: {}; shape is {}",
140                    rank_on_root_mesh,
141                    e,
142                    shape,
143                )
144            })?;
145        let point = extent.point(coords).map_err(|e| {
146            anyhow::anyhow!(
147                "fail to calculate point for root rank {} due to error: {}; extent is {}, shape is {}",
148                rank_on_root_mesh,
149                e,
150                extent,
151                shape,
152            )
153        })?;
154        Ok(point.rank())
155    }
156
157    /// The unique key used to indicate the stream to which to deliver this message.
158    /// Concretely, the comm actors along the path should use this key to manage
159    /// sequence numbers and reorder buffers.
160    pub(crate) fn stream_key(&self) -> (ActorMeshId, ActorId) {
161        (self.actor_mesh_id.clone(), self.sender.clone())
162    }
163}
164
165/// Destination port id of a message. It is a `PortId` with the rank masked out,
166/// and the messege is always sent to the root actor because only root actor
167/// can be accessed externally. The rank is resolved by the destination Selection
168/// of the message. We can use `DestinationPort::port_id(rank)` to get the actual
169/// `PortId` of the message.
170#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)]
171pub struct DestinationPort {
172    /// The actor name to which the message should be delivered.
173    actor_name: String,
174    /// The port index of the destination actors, it is derived from the
175    /// message type and cached here.
176    port: u64,
177}
178
179impl DestinationPort {
180    /// Create a new DestinationPort for a global actor name and message type.
181    pub fn new<A, M>(actor_name: String) -> Self
182    where
183        A: RemoteActor + RemoteHandles<IndexedErasedUnbound<M>>,
184        M: Castable + RemoteMessage,
185    {
186        Self {
187            actor_name,
188            port: IndexedErasedUnbound::<M>::port(),
189        }
190    }
191
192    /// The port id of the destination.
193    pub fn port(&self) -> u64 {
194        self.port
195    }
196
197    /// Get the actor name of the destination.
198    pub fn actor_name(&self) -> &str {
199        &self.actor_name
200    }
201}
202
203/// The is used to start casting a message to a group of actors.
204#[derive(Serialize, Deserialize, Debug, Clone, Named)]
205pub struct CastMessage {
206    /// The cast destination.
207    pub dest: Uslice,
208    /// The message to cast.
209    pub message: CastMessageEnvelope,
210}
211
212/// Forward a message to procs of next hops. This is used by comm actor to
213/// forward a message to other comm actors following the selection topology.
214/// This message is not visible to the clients.
215#[derive(Serialize, Deserialize, Debug, Clone, Named)]
216pub(crate) struct ForwardMessage {
217    /// The comm actor who originally casted the message.
218    pub(crate) sender: ActorId,
219    /// The destination of the message.
220    pub(crate) dests: Vec<RoutingFrame>,
221    /// The sequence number of this message.
222    pub(crate) seq: usize,
223    /// The sequence number of the previous message receieved.
224    pub(crate) last_seq: usize,
225    /// The message to distribute.
226    pub(crate) message: CastMessageEnvelope,
227}
228
229declare_attrs! {
230    /// Used inside headers for cast messages to store
231    /// the rank of the receiver.
232    attr CAST_RANK: usize;
233    /// Used inside headers to store the shape of the
234    /// actor mesh that a message was cast to.
235    attr CAST_SHAPE: Shape;
236    /// Used inside headers to store the originating sender of a cast.
237    pub attr CAST_ORIGINATING_SENDER: ActorId;
238}
239
240pub fn set_cast_info_on_headers(
241    headers: &mut Attrs,
242    cast_rank: usize,
243    cast_shape: Shape,
244    sender: ActorId,
245) {
246    headers.set(CAST_RANK, cast_rank);
247    headers.set(CAST_SHAPE, cast_shape);
248    headers.set(CAST_ORIGINATING_SENDER, sender);
249}
250
251pub trait CastInfo {
252    /// Get the cast rank and cast shape.
253    /// If something wasn't explicitly sent via a cast, then
254    /// we represent it as the only member of a 0-dimensonal cast shape,
255    /// which is the same as a singleton.
256    fn cast_info(&self) -> (usize, Shape);
257    fn sender(&self) -> &ActorId;
258}
259
260impl<A: Actor> CastInfo for Context<'_, A> {
261    fn cast_info(&self) -> (usize, Shape) {
262        let headers = self.headers();
263        match (headers.get(CAST_RANK), headers.get(CAST_SHAPE)) {
264            (Some(rank), Some(shape)) => (*rank, shape.clone()),
265            (None, None) => (0, Shape::unity()),
266            _ => panic!("Expected either both rank and shape or neither"),
267        }
268    }
269    fn sender(&self) -> &ActorId {
270        self.headers()
271            .get(CAST_ORIGINATING_SENDER)
272            .expect("has sender header")
273    }
274}