hyperactor_mesh/comm/
multicast.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! The comm actor that provides message casting and result accumulation.
10
11use hyperactor::Actor;
12use hyperactor::Context;
13use hyperactor::Named;
14use hyperactor::RemoteHandles;
15use hyperactor::RemoteMessage;
16use hyperactor::actor::Referable;
17use hyperactor::attrs::Attrs;
18use hyperactor::data::Serialized;
19use hyperactor::declare_attrs;
20use hyperactor::message::Castable;
21use hyperactor::message::ErasedUnbound;
22use hyperactor::message::IndexedErasedUnbound;
23use hyperactor::reference::ActorId;
24use ndslice::Extent;
25use ndslice::Point;
26use ndslice::Shape;
27use ndslice::Slice;
28use ndslice::selection::Selection;
29use ndslice::selection::routing::RoutingFrame;
30use serde::Deserialize;
31use serde::Serialize;
32
33use crate::reference::ActorMeshId;
34
35/// A union of slices that can be used to represent arbitrary subset of
36/// ranks in a gang. It is represented by a Slice together with a Selection.
37/// This is used to define the destination of a cast message or the source of
38/// accumulation request.
39#[derive(Serialize, Deserialize, Debug, Clone)]
40pub struct Uslice {
41    /// A slice representing a whole gang.
42    pub slice: Slice,
43    /// A selection used to represent any subset of the gang.
44    pub selection: Selection,
45}
46
47/// An envelope that carries a message destined to a group of actors.
48#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)]
49pub struct CastMessageEnvelope {
50    /// The destination actor mesh id.
51    actor_mesh_id: ActorMeshId,
52    /// The sender of this message.
53    sender: ActorId,
54    /// The destination port of the message. It could match multiple actors with
55    /// rank wildcard.
56    dest_port: DestinationPort,
57    /// The serialized message.
58    data: ErasedUnbound,
59    /// The shape of the cast.
60    shape: Shape,
61}
62
63impl CastMessageEnvelope {
64    /// Create a new CastMessageEnvelope.
65    pub fn new<A, M>(
66        actor_mesh_id: ActorMeshId,
67        sender: ActorId,
68        shape: Shape,
69        message: M,
70    ) -> Result<Self, anyhow::Error>
71    where
72        A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
73        M: Castable + RemoteMessage,
74    {
75        let data = ErasedUnbound::try_from_message(message)?;
76        let actor_name = match &actor_mesh_id {
77            ActorMeshId::V0(_, actor_name) => actor_name.clone(),
78            ActorMeshId::V1(name) => name.to_string(),
79        };
80        Ok(Self {
81            actor_mesh_id,
82            sender,
83            dest_port: DestinationPort::new::<A, M>(actor_name),
84            data,
85            shape,
86        })
87    }
88
89    /// Create a new CastMessageEnvelope from serialized data. Only use this
90    /// when the message do not contain reply ports. Or it does but you are okay
91    /// with the destination actors reply to the client actor directly.
92    pub fn from_serialized(
93        actor_mesh_id: ActorMeshId,
94        sender: ActorId,
95        dest_port: DestinationPort,
96        shape: Shape,
97        data: Serialized,
98    ) -> Self {
99        Self {
100            actor_mesh_id,
101            sender,
102            dest_port,
103            data: ErasedUnbound::new(data),
104            shape,
105        }
106    }
107
108    pub(crate) fn sender(&self) -> &ActorId {
109        &self.sender
110    }
111
112    pub(crate) fn dest_port(&self) -> &DestinationPort {
113        &self.dest_port
114    }
115
116    pub(crate) fn data(&self) -> &ErasedUnbound {
117        &self.data
118    }
119
120    pub(crate) fn data_mut(&mut self) -> &mut ErasedUnbound {
121        &mut self.data
122    }
123
124    pub(crate) fn shape(&self) -> &Shape {
125        &self.shape
126    }
127
128    /// Given a rank in the root shape, return the corresponding point in the
129    /// provided shape, which is a view of the root shape.
130    pub(crate) fn relative_rank(&self, rank_on_root_mesh: usize) -> anyhow::Result<usize> {
131        let shape = self.shape();
132        let coords = shape.slice().coordinates(rank_on_root_mesh).map_err(|e| {
133            anyhow::anyhow!(
134                "fail to calculate coords for root rank {} due to error: {}; shape is {:?}",
135                rank_on_root_mesh,
136                e,
137                shape,
138            )
139        })?;
140        let extent =
141            Extent::new(shape.labels().to_vec(), shape.slice().sizes().to_vec()).map_err(|e| {
142                anyhow::anyhow!(
143                    "fail to calculate extent for root rank {} due to error: {}; shape is {}",
144                    rank_on_root_mesh,
145                    e,
146                    shape,
147                )
148            })?;
149        let point = extent.point(coords).map_err(|e| {
150            anyhow::anyhow!(
151                "fail to calculate point for root rank {} due to error: {}; extent is {}, shape is {}",
152                rank_on_root_mesh,
153                e,
154                extent,
155                shape,
156            )
157        })?;
158        Ok(point.rank())
159    }
160
161    /// The unique key used to indicate the stream to which to deliver this message.
162    /// Concretely, the comm actors along the path should use this key to manage
163    /// sequence numbers and reorder buffers.
164    pub(crate) fn stream_key(&self) -> (ActorMeshId, ActorId) {
165        (self.actor_mesh_id.clone(), self.sender.clone())
166    }
167}
168
169/// Destination port id of a message. It is a `PortId` with the rank masked out,
170/// and the messege is always sent to the root actor because only root actor
171/// can be accessed externally. The rank is resolved by the destination Selection
172/// of the message. We can use `DestinationPort::port_id(rank)` to get the actual
173/// `PortId` of the message.
174#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)]
175pub struct DestinationPort {
176    /// The actor name to which the message should be delivered.
177    actor_name: String,
178    /// The port index of the destination actors, it is derived from the
179    /// message type and cached here.
180    port: u64,
181}
182
183impl DestinationPort {
184    /// Create a new DestinationPort for a global actor name and message type.
185    pub fn new<A, M>(actor_name: String) -> Self
186    where
187        A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
188        M: Castable + RemoteMessage,
189    {
190        Self {
191            actor_name,
192            port: IndexedErasedUnbound::<M>::port(),
193        }
194    }
195
196    /// The port id of the destination.
197    pub fn port(&self) -> u64 {
198        self.port
199    }
200
201    /// Get the actor name of the destination.
202    pub fn actor_name(&self) -> &str {
203        &self.actor_name
204    }
205}
206
207/// The is used to start casting a message to a group of actors.
208#[derive(Serialize, Deserialize, Debug, Clone, Named)]
209pub struct CastMessage {
210    /// The cast destination.
211    pub dest: Uslice,
212    /// The message to cast.
213    pub message: CastMessageEnvelope,
214}
215
216/// Forward a message to procs of next hops. This is used by comm actor to
217/// forward a message to other comm actors following the selection topology.
218/// This message is not visible to the clients.
219#[derive(Serialize, Deserialize, Debug, Clone, Named)]
220pub(crate) struct ForwardMessage {
221    /// The comm actor who originally casted the message.
222    pub(crate) sender: ActorId,
223    /// The destination of the message.
224    pub(crate) dests: Vec<RoutingFrame>,
225    /// The sequence number of this message.
226    pub(crate) seq: usize,
227    /// The sequence number of the previous message receieved.
228    pub(crate) last_seq: usize,
229    /// The message to distribute.
230    pub(crate) message: CastMessageEnvelope,
231}
232
233declare_attrs! {
234    /// Used inside headers to store the originating sender of a cast.
235    pub attr CAST_ORIGINATING_SENDER: ActorId;
236
237    /// The point in the casted region that this message was sent to.
238    pub attr CAST_POINT: Point;
239}
240
241pub fn set_cast_info_on_headers(headers: &mut Attrs, cast_point: Point, sender: ActorId) {
242    headers.set(CAST_POINT, cast_point);
243    headers.set(CAST_ORIGINATING_SENDER, sender);
244}
245
246pub trait CastInfo {
247    /// Get the cast rank and cast shape.
248    /// If something wasn't explicitly sent via a cast, then
249    /// we represent it as the only member of a 0-dimensonal cast shape,
250    /// which is the same as a singleton.
251    fn cast_point(&self) -> Point;
252    fn sender(&self) -> &ActorId;
253}
254
255impl<A: Actor> CastInfo for Context<'_, A> {
256    fn cast_point(&self) -> Point {
257        match self.headers().get(CAST_POINT) {
258            Some(point) => point.clone(),
259            None => Extent::unity().point_of_rank(0).unwrap(),
260        }
261    }
262
263    fn sender(&self) -> &ActorId {
264        self.headers()
265            .get(CAST_ORIGINATING_SENDER)
266            .expect("has sender header")
267    }
268}