hyperactor_mesh/
casting.rs1use std::collections::BTreeSet;
12
13use hyperactor::ActorRef;
14use hyperactor::RemoteEndpoint as _;
15use hyperactor::RemoteHandles;
16use hyperactor::RemoteMessage;
17use hyperactor::actor::Referable;
18use hyperactor::config::ENABLE_DEST_ACTOR_REORDERING_BUFFER;
19use hyperactor::context;
20use hyperactor::mailbox;
21use hyperactor::mailbox::MailboxSenderError;
22use hyperactor::mailbox::MessageEnvelope;
23use hyperactor::mailbox::Undeliverable;
24use hyperactor::message::Castable;
25use hyperactor::message::IndexedErasedUnbound;
26use hyperactor_config::Flattrs;
27use hyperactor_config::attrs::declare_attrs;
28use ndslice::Selection;
29use ndslice::Shape;
30use ndslice::ShapeError;
31use ndslice::SliceError;
32use ndslice::reshape::Limit;
33use ndslice::reshape::ReshapeError;
34use ndslice::reshape::ReshapeSliceExt;
35use ndslice::reshape::reshape_selection;
36use ndslice::selection;
37use ndslice::selection::EvalOpts;
38use ndslice::selection::ReifySlice;
39use ndslice::selection::normal;
40
41use crate::CommActor;
42use crate::comm::ENABLE_NATIVE_V1_CASTING;
43use crate::comm::multicast::CAST_ORIGINATING_SENDER;
44use crate::comm::multicast::CastMessage;
45use crate::comm::multicast::CastMessageEnvelope;
46use crate::comm::multicast::Uslice;
47use crate::config::MAX_CAST_DIMENSION_SIZE;
48use crate::mesh_id::ActorMeshId;
49use crate::metrics;
50
51pub(crate) fn v1_casting_enabled() -> bool {
54 let enabled = hyperactor_config::global::get(ENABLE_NATIVE_V1_CASTING);
55 if enabled {
56 assert!(
57 hyperactor_config::global::get(ENABLE_DEST_ACTOR_REORDERING_BUFFER),
58 "native V1 casting requires ENABLE_DEST_ACTOR_REORDERING_BUFFER to be enabled",
59 );
60 }
61 enabled
62}
63
64declare_attrs! {
65 pub attr CAST_ACTOR_MESH_ID: ActorMeshId;
69}
70
71pub fn update_undeliverable_envelope_for_casting(
75 mut envelope: Undeliverable<MessageEnvelope>,
76) -> Undeliverable<MessageEnvelope> {
77 let Some(message) = envelope.as_message_mut() else {
78 return envelope;
79 };
80 let old_actor = message.sender().clone();
81 if let Some(actor_id) = message.headers().get(CAST_ORIGINATING_SENDER) {
82 tracing::debug!(
83 actor_id = %old_actor,
84 "remapped comm-actor id to id from CAST_ORIGINATING_SENDER {}", actor_id
85 );
86 message.update_sender(actor_id);
87 }
88 envelope
90}
91
92#[allow(clippy::result_large_err)] #[tracing::instrument(level = "debug", skip_all)]
100pub(crate) fn actor_mesh_cast<A, M>(
101 cx: &impl context::Actor,
102 actor_mesh_id: ActorMeshId,
103 comm_actor_ref: &ActorRef<CommActor>,
104 selection_of_root: Selection,
105 root_mesh_shape: &Shape,
106 cast_mesh_shape: &Shape,
107 message: M,
108 caller_headers: &Flattrs,
109) -> Result<(), CastError>
110where
111 A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
112 M: Castable + RemoteMessage,
113{
114 let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
115 "message_type" => M::typename(),
116 "message_variant" => message.arm().unwrap_or_default(),
117 ));
118
119 let mut headers = caller_headers.clone();
123 mailbox::headers::set_send_timestamp(&mut headers);
124 mailbox::headers::set_rust_message_type::<M>(&mut headers);
125 headers.set(CAST_ACTOR_MESH_ID, actor_mesh_id.clone());
126 let message = CastMessageEnvelope::new::<A, M>(
127 actor_mesh_id.clone(),
128 cx.mailbox().actor_addr().clone(),
129 cast_mesh_shape.clone(),
130 headers,
131 message,
132 )?;
133
134 let slice_of_root = root_mesh_shape.slice();
152
153 let max_cast_dimension_size = hyperactor_config::global::get(MAX_CAST_DIMENSION_SIZE);
154
155 let slice_of_cast = slice_of_root.reshape_with_limit(Limit::from(max_cast_dimension_size));
156
157 let selection_of_cast =
158 reshape_selection(selection_of_root, root_mesh_shape.slice(), &slice_of_cast)?;
159
160 let cast_message = CastMessage {
161 dest: Uslice {
162 slice: slice_of_cast,
163 selection: selection_of_cast,
164 },
165 message,
166 };
167
168 let mut headers = caller_headers.clone();
172 headers.set(CAST_ACTOR_MESH_ID, actor_mesh_id);
173
174 comm_actor_ref
175 .port()
176 .post_with_headers(cx, headers, cast_message);
177
178 Ok(())
179}
180
181#[allow(clippy::result_large_err)] pub(crate) fn cast_to_sliced_mesh<A, M>(
183 cx: &impl context::Actor,
184 actor_mesh_id: ActorMeshId,
185 comm_actor_ref: &ActorRef<CommActor>,
186 sel_of_sliced: &Selection,
187 message: M,
188 sliced_shape: &Shape,
189 root_mesh_shape: &Shape,
190 caller_headers: &Flattrs,
191) -> Result<(), CastError>
192where
193 A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
194 M: Castable + RemoteMessage,
195{
196 let root_slice = root_mesh_shape.slice();
197
198 let sel_of_root = if selection::normalize(sel_of_sliced) == normal::NormalizedSelection::True {
200 root_slice.reify_slice(sliced_shape.slice())?
202 } else {
203 let ranks = sel_of_sliced
205 .eval(&EvalOpts::strict(), sliced_shape.slice())?
206 .collect::<BTreeSet<_>>();
207 Selection::of_ranks(root_slice, &ranks)?
208 };
209
210 actor_mesh_cast::<A, M>(
212 cx,
213 actor_mesh_id,
214 comm_actor_ref,
215 sel_of_root,
216 root_mesh_shape,
217 sliced_shape,
218 message,
219 caller_headers,
220 )
221}
222
223#[derive(Debug, thiserror::Error)]
225pub enum CastError {
226 #[error("invalid selection {0}: {1}")]
227 InvalidSelection(Selection, ShapeError),
228
229 #[error("send on rank {0}: {1}")]
230 MailboxSenderError(usize, MailboxSenderError),
231
232 #[error("unsupported selection: {0}")]
233 SelectionNotSupported(String),
234
235 #[error(transparent)]
236 RootMailboxSenderError(#[from] MailboxSenderError),
237
238 #[error(transparent)]
239 ShapeError(#[from] ShapeError),
240
241 #[error(transparent)]
242 SliceError(#[from] SliceError),
243
244 #[error(transparent)]
245 SerializationEncodeError(#[from] bincode::error::EncodeError),
246
247 #[error(transparent)]
248 SerializationDecodeError(#[from] bincode::error::DecodeError),
249
250 #[error(transparent)]
251 Other(#[from] anyhow::Error),
252
253 #[error(transparent)]
254 ReshapeError(#[from] ReshapeError),
255}