hyperactor_mesh/comm/
multicast.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! The comm actor that provides message casting and result accumulation.
10
11use hyperactor::Actor;
12use hyperactor::Context;
13use hyperactor::RemoteHandles;
14use hyperactor::RemoteMessage;
15use hyperactor::actor::Referable;
16use hyperactor::message::Castable;
17use hyperactor::message::ErasedUnbound;
18use hyperactor::message::IndexedErasedUnbound;
19use hyperactor::reference as hyperactor_reference;
20use hyperactor_config::Flattrs;
21use hyperactor_config::attrs::declare_attrs;
22use ndslice::Extent;
23use ndslice::Point;
24use ndslice::Region;
25use ndslice::Shape;
26use ndslice::Slice;
27use ndslice::selection::Selection;
28use ndslice::selection::routing::RoutingFrame;
29use serde::Deserialize;
30use serde::Serialize;
31use typeuri::Named;
32use uuid::Uuid;
33
34use crate::Name;
35use crate::ValueMesh;
36use crate::comm::CommMeshConfig;
37use crate::reference::ActorMeshId;
38
39// A temporary trait used to share code in v0/v1 migration. Can be deleted after
40// v0 casting is deleted.
41pub(crate) trait CastEnvelope {
42    fn dest_port(&self) -> &DestinationPort;
43    fn headers(&self) -> &Flattrs;
44    fn sender(&self) -> &hyperactor_reference::ActorId;
45    fn cast_point(&self, config: &CommMeshConfig) -> anyhow::Result<Point>;
46    fn data(&self) -> &ErasedUnbound;
47    fn data_mut(&mut self) -> &mut ErasedUnbound;
48}
49
50/// A union of slices that can be used to represent arbitrary subset of
51/// ranks in a gang. It is represented by a Slice together with a Selection.
52/// This is used to define the destination of a cast message or the source of
53/// accumulation request.
54#[derive(Serialize, Deserialize, Debug, Clone)]
55pub struct Uslice {
56    /// A slice representing a whole gang.
57    pub slice: Slice,
58    /// A selection used to represent any subset of the gang.
59    pub selection: Selection,
60}
61
62/// An envelope that carries a message destined to a group of actors.
63#[derive(Debug, Serialize, Deserialize, Clone, Named)]
64pub struct CastMessageEnvelope {
65    /// The destination actor mesh id.
66    actor_mesh_id: ActorMeshId,
67    /// The end-to-end message headers.
68    headers: Flattrs,
69    /// The sender of this message.
70    sender: hyperactor_reference::ActorId,
71    /// The destination port of the message. It could match multiple actors with
72    /// rank wildcard.
73    dest_port: DestinationPort,
74    /// The serialized message.
75    data: ErasedUnbound,
76    /// The shape of the cast.
77    shape: Shape,
78}
79wirevalue::register_type!(CastMessageEnvelope);
80
81impl CastEnvelope for CastMessageEnvelope {
82    fn sender(&self) -> &hyperactor_reference::ActorId {
83        &self.sender
84    }
85
86    fn headers(&self) -> &Flattrs {
87        &self.headers
88    }
89
90    fn dest_port(&self) -> &DestinationPort {
91        &self.dest_port
92    }
93
94    fn data(&self) -> &ErasedUnbound {
95        &self.data
96    }
97
98    fn data_mut(&mut self) -> &mut ErasedUnbound {
99        &mut self.data
100    }
101
102    fn cast_point(&self, config: &CommMeshConfig) -> anyhow::Result<Point> {
103        let rank_on_root_mesh = config.self_rank();
104        let cast_rank = self.relative_rank(rank_on_root_mesh)?;
105        let cast_shape = self.shape();
106        let cast_point = cast_shape
107            .extent()
108            .point_of_rank(cast_rank)
109            .expect("rank out of bounds");
110        Ok(cast_point)
111    }
112}
113
114impl CastMessageEnvelope {
115    /// Create a new CastMessageEnvelope.
116    pub fn new<A, M>(
117        actor_mesh_id: ActorMeshId,
118        sender: hyperactor_reference::ActorId,
119        shape: Shape,
120        headers: Flattrs,
121        message: M,
122    ) -> Result<Self, anyhow::Error>
123    where
124        A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
125        M: Castable + RemoteMessage,
126    {
127        let data = ErasedUnbound::try_from_message(message)?;
128        let actor_name = actor_mesh_id.0.to_string();
129        Ok(Self {
130            actor_mesh_id,
131            headers,
132            sender,
133            dest_port: DestinationPort::new::<A, M>(actor_name),
134            data,
135            shape,
136        })
137    }
138
139    /// Create a new CastMessageEnvelope from serialized data. Only use this
140    /// when the message do not contain reply ports. Or it does but you are okay
141    /// with the destination actors reply to the client actor directly.
142    pub fn from_serialized(
143        actor_mesh_id: ActorMeshId,
144        sender: hyperactor_reference::ActorId,
145        dest_port: DestinationPort,
146        shape: Shape,
147        headers: Flattrs,
148        data: wirevalue::Any,
149    ) -> Self {
150        Self {
151            actor_mesh_id,
152            sender,
153            headers,
154            dest_port,
155            data: ErasedUnbound::new(data),
156            shape,
157        }
158    }
159
160    pub(crate) fn shape(&self) -> &Shape {
161        &self.shape
162    }
163
164    /// Given a rank in the root shape, return the corresponding point in the
165    /// provided shape, which is a view of the root shape.
166    pub(crate) fn relative_rank(&self, rank_on_root_mesh: usize) -> anyhow::Result<usize> {
167        let shape = self.shape();
168        let coords = shape.slice().coordinates(rank_on_root_mesh).map_err(|e| {
169            anyhow::anyhow!(
170                "fail to calculate coords for root rank {} due to error: {}; shape is {:?}",
171                rank_on_root_mesh,
172                e,
173                shape,
174            )
175        })?;
176        let extent =
177            Extent::new(shape.labels().to_vec(), shape.slice().sizes().to_vec()).map_err(|e| {
178                anyhow::anyhow!(
179                    "fail to calculate extent for root rank {} due to error: {}; shape is {}",
180                    rank_on_root_mesh,
181                    e,
182                    shape,
183                )
184            })?;
185        let point = extent.point(coords).map_err(|e| {
186            anyhow::anyhow!(
187                "fail to calculate point for root rank {} due to error: {}; extent is {}, shape is {}",
188                rank_on_root_mesh,
189                e,
190                extent,
191                shape,
192            )
193        })?;
194        Ok(point.rank())
195    }
196
197    /// The unique key used to indicate the stream to which to deliver this message.
198    /// Concretely, the comm actors along the path should use this key to manage
199    /// sequence numbers and reorder buffers.
200    pub(crate) fn stream_key(&self) -> (ActorMeshId, hyperactor_reference::ActorId) {
201        (self.actor_mesh_id.clone(), self.sender.clone())
202    }
203}
204
205/// Destination port id of a message. It is a `PortId` with the rank masked out,
206/// and the messege is always sent to the root actor because only root actor
207/// can be accessed externally. The rank is resolved by the destination Selection
208/// of the message. We can use `DestinationPort::port_id(rank)` to get the actual
209/// `PortId` of the message.
210#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)]
211pub struct DestinationPort {
212    /// The actor name to which the message should be delivered.
213    actor_name: String,
214    /// The port index of the destination actors, it is derived from the
215    /// message type and cached here.
216    port: u64,
217}
218wirevalue::register_type!(DestinationPort);
219
220impl DestinationPort {
221    /// Create a new DestinationPort for a global actor name and message type.
222    pub fn new<A, M>(actor_name: String) -> Self
223    where
224        A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
225        M: Castable + RemoteMessage,
226    {
227        Self {
228            actor_name,
229            port: IndexedErasedUnbound::<M>::port(),
230        }
231    }
232
233    /// The port id of the destination.
234    pub fn port(&self) -> u64 {
235        self.port
236    }
237
238    /// Get the actor name of the destination.
239    pub fn actor_name(&self) -> &str {
240        &self.actor_name
241    }
242}
243
244/// The is used to start casting a message to a group of actors.
245#[derive(Serialize, Deserialize, Debug, Clone, Named)]
246pub struct CastMessage {
247    /// The cast destination.
248    pub dest: Uslice,
249    /// The message to cast.
250    pub message: CastMessageEnvelope,
251}
252wirevalue::register_type!(CastMessage);
253
254/// Forward a message to procs of next hops. This is used by comm actor to
255/// forward a message to other comm actors following the selection topology.
256/// This message is not visible to the clients.
257#[derive(Serialize, Deserialize, Debug, Clone, Named)]
258pub(crate) struct ForwardMessage {
259    /// The comm actor who originally casted the message.
260    pub(crate) sender: hyperactor_reference::ActorId,
261    /// The destination of the message.
262    pub(crate) dests: Vec<RoutingFrame>,
263    /// The sequence number of this message.
264    pub(crate) seq: usize,
265    /// The sequence number of the previous message receieved.
266    pub(crate) last_seq: usize,
267    /// The message to distribute.
268    pub(crate) message: CastMessageEnvelope,
269}
270wirevalue::register_type!(ForwardMessage);
271
272/// The is used to start casting a message to a group of actors.
273#[derive(Serialize, Deserialize, Debug, Clone, Named)]
274pub(crate) struct CastMessageV1 {
275    /// The additional end-to-end message headers.
276    pub(super) headers: Flattrs,
277    /// The client who sent this message.
278    pub(super) sender: hyperactor_reference::ActorId,
279    /// The client-assigned session id of this message.
280    pub(super) session_id: Uuid,
281    /// The client-assigned sequence numbers of this message.
282    pub(super) seqs: ValueMesh<u64>,
283    /// The destination mesh's region.
284    pub(super) dest_region: Region,
285    /// The destination port of the message. It could match multiple actors with
286    /// rank wildcard.
287    pub(super) dest_port: DestinationPort,
288    /// The serialized message.
289    pub(super) data: ErasedUnbound,
290}
291
292impl CastEnvelope for CastMessageV1 {
293    fn sender(&self) -> &hyperactor_reference::ActorId {
294        &self.sender
295    }
296
297    fn headers(&self) -> &Flattrs {
298        &self.headers
299    }
300
301    fn dest_port(&self) -> &DestinationPort {
302        &self.dest_port
303    }
304
305    fn data(&self) -> &ErasedUnbound {
306        &self.data
307    }
308
309    fn data_mut(&mut self) -> &mut ErasedUnbound {
310        &mut self.data
311    }
312
313    fn cast_point(&self, config: &CommMeshConfig) -> anyhow::Result<Point> {
314        let rank_on_root_mesh = config.self_rank();
315        let cast_point = self.dest_region.point_of_base_rank(rank_on_root_mesh)?;
316        Ok(cast_point)
317    }
318}
319
320impl CastMessageV1 {
321    /// Create a new CastMessageEnvelope.
322    #[allow(unused)]
323    pub(crate) fn new<A, M>(
324        sender: hyperactor_reference::ActorId,
325        dest_mesh: &Name,
326        dest_region: Region,
327        headers: Flattrs,
328        message: M,
329        session_id: Uuid,
330        seqs: ValueMesh<u64>,
331    ) -> Result<Self, anyhow::Error>
332    where
333        A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
334        M: Castable + RemoteMessage,
335    {
336        let data = ErasedUnbound::try_from_message(message)?;
337        Ok(Self {
338            headers,
339            sender,
340            session_id,
341            seqs,
342            dest_region,
343            dest_port: DestinationPort::new::<A, M>(dest_mesh.to_string()),
344            data,
345        })
346    }
347}
348
349/// Forward a message to procs of next hops. This is used by comm actor to
350/// forward a message to other comm actors following the selection topology.
351/// This message is not visible to the clients.
352#[derive(Serialize, Deserialize, Debug, Clone, Named)]
353pub(super) struct ForwardMessageV1 {
354    /// The destination of the message.
355    pub(super) dests: Vec<RoutingFrame>,
356    /// The message to distribute.
357    pub(super) message: CastMessageV1,
358}
359
360declare_attrs! {
361    /// Used inside headers to store the originating sender of a cast.
362    pub attr CAST_ORIGINATING_SENDER: hyperactor_reference::ActorId;
363
364    /// The point in the casted region that this message was sent to.
365    pub attr CAST_POINT: Point;
366}
367
368pub fn set_cast_info_on_headers(
369    headers: &mut Flattrs,
370    cast_point: Point,
371    sender: hyperactor_reference::ActorId,
372) {
373    // Pre-set the telemetry sender hash to the originating actor,
374    // so post_unchecked() does not overwrite it with the comm actor.
375    // TODO: consider merging SENDER_ACTOR_ID_HASH and
376    // CAST_ORIGINATING_SENDER -- they carry overlapping sender identity.
377    headers.set(
378        hyperactor::mailbox::headers::SENDER_ACTOR_ID_HASH,
379        hyperactor_telemetry::hash_to_u64(&sender),
380    );
381    headers.set(CAST_POINT, cast_point);
382    headers.set(CAST_ORIGINATING_SENDER, sender);
383}
384
385pub trait CastInfo {
386    /// Get the cast rank and cast shape.
387    /// If something wasn't explicitly sent via a cast, then
388    /// we represent it as the only member of a 0-dimensonal cast shape,
389    /// which is the same as a singleton.
390    fn cast_point(&self) -> Point;
391    fn sender(&self) -> hyperactor_reference::ActorId;
392}
393
394impl<A: Actor> CastInfo for Context<'_, A> {
395    fn cast_point(&self) -> Point {
396        match self.headers().get(CAST_POINT) {
397            Some(point) => point,
398            None => Extent::unity().point_of_rank(0).unwrap(),
399        }
400    }
401
402    fn sender(&self) -> hyperactor_reference::ActorId {
403        self.headers()
404            .get(CAST_ORIGINATING_SENDER)
405            .expect("has sender header")
406    }
407}