hyperactor_mesh/
casting.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Casting utilities for actor meshes.
10
11use std::collections::BTreeSet;
12
13use hyperactor::RemoteHandles;
14use hyperactor::RemoteMessage;
15use hyperactor::actor::Referable;
16use hyperactor::context;
17use hyperactor::mailbox;
18use hyperactor::mailbox::MailboxSenderError;
19use hyperactor::mailbox::MessageEnvelope;
20use hyperactor::mailbox::Undeliverable;
21use hyperactor::message::Castable;
22use hyperactor::message::IndexedErasedUnbound;
23use hyperactor::reference as hyperactor_reference;
24use hyperactor_config::Flattrs;
25use hyperactor_config::attrs::declare_attrs;
26use ndslice::Selection;
27use ndslice::Shape;
28use ndslice::ShapeError;
29use ndslice::SliceError;
30use ndslice::reshape::Limit;
31use ndslice::reshape::ReshapeError;
32use ndslice::reshape::ReshapeSliceExt;
33use ndslice::reshape::reshape_selection;
34use ndslice::selection;
35use ndslice::selection::EvalOpts;
36use ndslice::selection::ReifySlice;
37use ndslice::selection::normal;
38
39use crate::CommActor;
40use crate::comm::multicast::CAST_ORIGINATING_SENDER;
41use crate::comm::multicast::CastMessage;
42use crate::comm::multicast::CastMessageEnvelope;
43use crate::comm::multicast::Uslice;
44use crate::config::MAX_CAST_DIMENSION_SIZE;
45use crate::metrics;
46use crate::reference::ActorMeshId;
47
48declare_attrs! {
49    /// Which mesh this message was cast to. Used for undeliverable message
50    /// handling, where the CastMessageEnvelope is serialized, and its content
51    /// cannot be inspected.
52    pub attr CAST_ACTOR_MESH_ID: ActorMeshId;
53}
54
55/// An undeliverable might have its sender address set as the comm actor instead
56/// of the original sender. Update it based on the headers present in the message
57/// so it matches the sender.
58pub fn update_undeliverable_envelope_for_casting(
59    mut envelope: Undeliverable<MessageEnvelope>,
60) -> Undeliverable<MessageEnvelope> {
61    let old_actor = envelope.0.sender().clone();
62    if let Some(actor_id) = envelope.0.headers().get(CAST_ORIGINATING_SENDER) {
63        tracing::debug!(
64            actor_id = %old_actor,
65            "remapped comm-actor id to id from CAST_ORIGINATING_SENDER {}", actor_id
66        );
67        envelope.0.update_sender(actor_id);
68    }
69    // Else do nothing, it wasn't from a comm actor.
70    envelope
71}
72
73/// Common implementation for `ActorMesh`s and `ActorMeshRef`s to cast
74/// an `M`-typed message
75#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
76#[tracing::instrument(level = "debug", skip_all)]
77pub(crate) fn actor_mesh_cast<A, M>(
78    cx: &impl context::Actor,
79    actor_mesh_id: ActorMeshId,
80    comm_actor_ref: &hyperactor_reference::ActorRef<CommActor>,
81    selection_of_root: Selection,
82    root_mesh_shape: &Shape,
83    cast_mesh_shape: &Shape,
84    message: M,
85) -> Result<(), CastError>
86where
87    A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
88    M: Castable + RemoteMessage,
89{
90    let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
91        "message_type" => M::typename(),
92        "message_variant" => message.arm().unwrap_or_default(),
93    ));
94
95    let mut headers = Flattrs::new();
96    mailbox::headers::set_send_timestamp(&mut headers);
97    mailbox::headers::set_rust_message_type::<M>(&mut headers);
98    headers.set(CAST_ACTOR_MESH_ID, actor_mesh_id.clone());
99    let message = CastMessageEnvelope::new::<A, M>(
100        actor_mesh_id.clone(),
101        cx.mailbox().actor_id().clone(),
102        cast_mesh_shape.clone(),
103        headers,
104        message,
105    )?;
106
107    // Mesh's shape might have large extents on some dimensions. Those
108    // dimensions would cause large fanout in our comm actor
109    // implementation. To avoid that, we reshape it by increasing
110    // dimensionality and limiting the extent of each dimension. Note
111    // the reshape is only visible to the internal algorithm. The
112    // shape that user sees maintains intact.
113    //
114    // For example, a typical shape is [hosts=1024, gpus=8]. By using
115    // limit 8, it becomes [8, 8, 8, 2, 8] during casting. In other
116    // words, it adds 3 extra layers to the comm actor tree, while
117    // keeping the fanout in each layer per dimension at 8 or smaller.
118    //
119    // An important note here is that max dimension size != max fanout.
120    // Rank 0 must send a message to all ranks at index 0 for every dimension.
121    // If our reshaped shape is [8, 8, 8, 2, 8], rank 0 must send
122    // 7 + 7 + 7 + 1 + 7 = 21 messages.
123
124    let slice_of_root = root_mesh_shape.slice();
125
126    let max_cast_dimension_size = hyperactor_config::global::get(MAX_CAST_DIMENSION_SIZE);
127
128    let slice_of_cast = slice_of_root.reshape_with_limit(Limit::from(max_cast_dimension_size));
129
130    let selection_of_cast =
131        reshape_selection(selection_of_root, root_mesh_shape.slice(), &slice_of_cast)?;
132
133    let cast_message = CastMessage {
134        dest: Uslice {
135            slice: slice_of_cast,
136            selection: selection_of_cast,
137        },
138        message,
139    };
140
141    // TEMPORARY: remove with v0 support
142    let mut headers = Flattrs::new();
143    headers.set(CAST_ACTOR_MESH_ID, actor_mesh_id);
144
145    comm_actor_ref
146        .port()
147        .send_with_headers(cx, headers, cast_message)?;
148
149    Ok(())
150}
151
152#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
153pub(crate) fn cast_to_sliced_mesh<A, M>(
154    cx: &impl context::Actor,
155    actor_mesh_id: ActorMeshId,
156    comm_actor_ref: &hyperactor_reference::ActorRef<CommActor>,
157    sel_of_sliced: &Selection,
158    message: M,
159    sliced_shape: &Shape,
160    root_mesh_shape: &Shape,
161) -> Result<(), CastError>
162where
163    A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
164    M: Castable + RemoteMessage,
165{
166    let root_slice = root_mesh_shape.slice();
167
168    // Casting to `*`?
169    let sel_of_root = if selection::normalize(sel_of_sliced) == normal::NormalizedSelection::True {
170        // Reify this view into base.
171        root_slice.reify_slice(sliced_shape.slice())?
172    } else {
173        // No, fall back on `of_ranks`.
174        let ranks = sel_of_sliced
175            .eval(&EvalOpts::strict(), sliced_shape.slice())?
176            .collect::<BTreeSet<_>>();
177        Selection::of_ranks(root_slice, &ranks)?
178    };
179
180    // Cast.
181    actor_mesh_cast::<A, M>(
182        cx,
183        actor_mesh_id,
184        comm_actor_ref,
185        sel_of_root,
186        root_mesh_shape,
187        sliced_shape,
188        message,
189    )
190}
191
192/// The type of error of casting operations.
193#[derive(Debug, thiserror::Error)]
194pub enum CastError {
195    #[error("invalid selection {0}: {1}")]
196    InvalidSelection(Selection, ShapeError),
197
198    #[error("send on rank {0}: {1}")]
199    MailboxSenderError(usize, MailboxSenderError),
200
201    #[error("unsupported selection: {0}")]
202    SelectionNotSupported(String),
203
204    #[error(transparent)]
205    RootMailboxSenderError(#[from] MailboxSenderError),
206
207    #[error(transparent)]
208    ShapeError(#[from] ShapeError),
209
210    #[error(transparent)]
211    SliceError(#[from] SliceError),
212
213    #[error(transparent)]
214    SerializationError(#[from] bincode::Error),
215
216    #[error(transparent)]
217    Other(#[from] anyhow::Error),
218
219    #[error(transparent)]
220    ReshapeError(#[from] ReshapeError),
221}