Skip to main content

hyperactor_mesh/
casting.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Casting utilities for actor meshes.
10
11use std::collections::BTreeSet;
12
13use hyperactor::ActorRef;
14use hyperactor::RemoteEndpoint as _;
15use hyperactor::RemoteHandles;
16use hyperactor::RemoteMessage;
17use hyperactor::actor::Referable;
18use hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER;
19use hyperactor::context;
20use hyperactor::mailbox;
21use hyperactor::mailbox::MailboxSenderError;
22use hyperactor::mailbox::MessageEnvelope;
23use hyperactor::mailbox::Undeliverable;
24use hyperactor::message::Castable;
25use hyperactor::message::IndexedErasedUnbound;
26use hyperactor_config::Flattrs;
27use hyperactor_config::attrs::declare_attrs;
28use ndslice::Selection;
29use ndslice::Shape;
30use ndslice::ShapeError;
31use ndslice::SliceError;
32use ndslice::reshape::Limit;
33use ndslice::reshape::ReshapeError;
34use ndslice::reshape::ReshapeSliceExt;
35use ndslice::reshape::reshape_selection;
36use ndslice::selection;
37use ndslice::selection::EvalOpts;
38use ndslice::selection::ReifySlice;
39use ndslice::selection::normal;
40
41use crate::CommActor;
42use crate::comm::ENABLE_NATIVE_V1_CASTING;
43use crate::comm::multicast::CAST_ORIGINATING_SENDER;
44use crate::comm::multicast::CastMessage;
45use crate::comm::multicast::CastMessageEnvelope;
46use crate::comm::multicast::Uslice;
47use crate::config::MAX_CAST_DIMENSION_SIZE;
48use crate::mesh_id::ActorMeshId;
49use crate::metrics;
50
51/// Returns true if native V1 casting is enabled. Panics if V1 casting
52/// is on but the required dest actor reordering buffer is not.
53pub(crate) fn v1_casting_enabled() -> bool {
54    let enabled = hyperactor_config::global::get(ENABLE_NATIVE_V1_CASTING);
55    if enabled {
56        assert!(
57            hyperactor_config::global::get(ENABLE_DEST_ACTOR_REORDERING_BUFFER),
58            "native V1 casting requires ENABLE_DEST_ACTOR_REORDERING_BUFFER to be enabled",
59        );
60    }
61    enabled
62}
63
64declare_attrs! {
65    /// Which mesh this message was cast to. Used for undeliverable message
66    /// handling, where the CastMessageEnvelope is serialized, and its content
67    /// cannot be inspected.
68    pub attr CAST_ACTOR_MESH_ID: ActorMeshId;
69}
70
71/// An undeliverable might have its sender address set as the comm actor instead
72/// of the original sender. Update it based on the headers present in the message
73/// so it matches the sender.
74pub fn update_undeliverable_envelope_for_casting(
75    mut envelope: Undeliverable<MessageEnvelope>,
76) -> Undeliverable<MessageEnvelope> {
77    let Some(message) = envelope.as_message_mut() else {
78        return envelope;
79    };
80    let old_actor = message.sender().clone();
81    if let Some(actor_id) = message.headers().get(CAST_ORIGINATING_SENDER) {
82        tracing::debug!(
83            actor_id = %old_actor,
84            "remapped comm-actor id to id from CAST_ORIGINATING_SENDER {}", actor_id
85        );
86        message.update_sender(actor_id);
87    }
88    // Else do nothing, it wasn't from a comm actor.
89    envelope
90}
91
92/// Common implementation for `ActorMesh`s and `ActorMeshRef`s to cast
93/// an `M`-typed message.
94///
95/// `caller_headers` are caller-supplied envelope headers (e.g.
96/// operation-context keys) merged into the inner envelope headers so
97/// receivers see them on `cx.headers()`.
98#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
99#[tracing::instrument(level = "debug", skip_all)]
100pub(crate) fn actor_mesh_cast<A, M>(
101    cx: &impl context::Actor,
102    actor_mesh_id: ActorMeshId,
103    comm_actor_ref: &ActorRef<CommActor>,
104    selection_of_root: Selection,
105    root_mesh_shape: &Shape,
106    cast_mesh_shape: &Shape,
107    message: M,
108    caller_headers: &Flattrs,
109) -> Result<(), CastError>
110where
111    A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
112    M: Castable + RemoteMessage,
113{
114    let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
115        "message_type" => M::typename(),
116        "message_variant" => message.arm().unwrap_or_default(),
117    ));
118
119    // Caller-known headers ride first; cast-info (timestamp,
120    // message type, mesh id) is stamped afterward and wins on
121    // collision because those keys are owned by this layer.
122    let mut headers = caller_headers.clone();
123    mailbox::headers::set_send_timestamp(&mut headers);
124    mailbox::headers::set_rust_message_type::<M>(&mut headers);
125    headers.set(CAST_ACTOR_MESH_ID, actor_mesh_id.clone());
126    let message = CastMessageEnvelope::new::<A, M>(
127        actor_mesh_id.clone(),
128        cx.mailbox().actor_addr().clone(),
129        cast_mesh_shape.clone(),
130        headers,
131        message,
132    )?;
133
134    // Mesh's shape might have large extents on some dimensions. Those
135    // dimensions would cause large fanout in our comm actor
136    // implementation. To avoid that, we reshape it by increasing
137    // dimensionality and limiting the extent of each dimension. Note
138    // the reshape is only visible to the internal algorithm. The
139    // shape that user sees maintains intact.
140    //
141    // For example, a typical shape is [hosts=1024, gpus=8]. By using
142    // limit 8, it becomes [8, 8, 8, 2, 8] during casting. In other
143    // words, it adds 3 extra layers to the comm actor tree, while
144    // keeping the fanout in each layer per dimension at 8 or smaller.
145    //
146    // An important note here is that max dimension size != max fanout.
147    // Rank 0 must send a message to all ranks at index 0 for every dimension.
148    // If our reshaped shape is [8, 8, 8, 2, 8], rank 0 must send
149    // 7 + 7 + 7 + 1 + 7 = 21 messages.
150
151    let slice_of_root = root_mesh_shape.slice();
152
153    let max_cast_dimension_size = hyperactor_config::global::get(MAX_CAST_DIMENSION_SIZE);
154
155    let slice_of_cast = slice_of_root.reshape_with_limit(Limit::from(max_cast_dimension_size));
156
157    let selection_of_cast =
158        reshape_selection(selection_of_root, root_mesh_shape.slice(), &slice_of_cast)?;
159
160    let cast_message = CastMessage {
161        dest: Uslice {
162            slice: slice_of_cast,
163            selection: selection_of_cast,
164        },
165        message,
166    };
167
168    // TEMPORARY: remove with v0 support. Same ownership rule as
169    // the inner envelope: caller-known headers ride first, cast-info
170    // wins on collision.
171    let mut headers = caller_headers.clone();
172    headers.set(CAST_ACTOR_MESH_ID, actor_mesh_id);
173
174    comm_actor_ref
175        .port()
176        .post_with_headers(cx, headers, cast_message);
177
178    Ok(())
179}
180
181#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
182pub(crate) fn cast_to_sliced_mesh<A, M>(
183    cx: &impl context::Actor,
184    actor_mesh_id: ActorMeshId,
185    comm_actor_ref: &ActorRef<CommActor>,
186    sel_of_sliced: &Selection,
187    message: M,
188    sliced_shape: &Shape,
189    root_mesh_shape: &Shape,
190    caller_headers: &Flattrs,
191) -> Result<(), CastError>
192where
193    A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
194    M: Castable + RemoteMessage,
195{
196    let root_slice = root_mesh_shape.slice();
197
198    // Casting to `*`?
199    let sel_of_root = if selection::normalize(sel_of_sliced) == normal::NormalizedSelection::True {
200        // Reify this view into base.
201        root_slice.reify_slice(sliced_shape.slice())?
202    } else {
203        // No, fall back on `of_ranks`.
204        let ranks = sel_of_sliced
205            .eval(&EvalOpts::strict(), sliced_shape.slice())?
206            .collect::<BTreeSet<_>>();
207        Selection::of_ranks(root_slice, &ranks)?
208    };
209
210    // Cast.
211    actor_mesh_cast::<A, M>(
212        cx,
213        actor_mesh_id,
214        comm_actor_ref,
215        sel_of_root,
216        root_mesh_shape,
217        sliced_shape,
218        message,
219        caller_headers,
220    )
221}
222
223/// The type of error of casting operations.
224#[derive(Debug, thiserror::Error)]
225pub enum CastError {
226    #[error("invalid selection {0}: {1}")]
227    InvalidSelection(Selection, ShapeError),
228
229    #[error("send on rank {0}: {1}")]
230    MailboxSenderError(usize, MailboxSenderError),
231
232    #[error("unsupported selection: {0}")]
233    SelectionNotSupported(String),
234
235    #[error(transparent)]
236    RootMailboxSenderError(#[from] MailboxSenderError),
237
238    #[error(transparent)]
239    ShapeError(#[from] ShapeError),
240
241    #[error(transparent)]
242    SliceError(#[from] SliceError),
243
244    #[error(transparent)]
245    SerializationEncodeError(#[from] bincode::error::EncodeError),
246
247    #[error(transparent)]
248    SerializationDecodeError(#[from] bincode::error::DecodeError),
249
250    #[error(transparent)]
251    Other(#[from] anyhow::Error),
252
253    #[error(transparent)]
254    ReshapeError(#[from] ReshapeError),
255}