hyperactor_mesh/
casting.rs1use std::collections::BTreeSet;
12
13use hyperactor::RemoteHandles;
14use hyperactor::RemoteMessage;
15use hyperactor::actor::Referable;
16use hyperactor::context;
17use hyperactor::mailbox;
18use hyperactor::mailbox::MailboxSenderError;
19use hyperactor::mailbox::MessageEnvelope;
20use hyperactor::mailbox::Undeliverable;
21use hyperactor::message::Castable;
22use hyperactor::message::IndexedErasedUnbound;
23use hyperactor::reference as hyperactor_reference;
24use hyperactor_config::Flattrs;
25use hyperactor_config::attrs::declare_attrs;
26use ndslice::Selection;
27use ndslice::Shape;
28use ndslice::ShapeError;
29use ndslice::SliceError;
30use ndslice::reshape::Limit;
31use ndslice::reshape::ReshapeError;
32use ndslice::reshape::ReshapeSliceExt;
33use ndslice::reshape::reshape_selection;
34use ndslice::selection;
35use ndslice::selection::EvalOpts;
36use ndslice::selection::ReifySlice;
37use ndslice::selection::normal;
38
39use crate::CommActor;
40use crate::comm::multicast::CAST_ORIGINATING_SENDER;
41use crate::comm::multicast::CastMessage;
42use crate::comm::multicast::CastMessageEnvelope;
43use crate::comm::multicast::Uslice;
44use crate::config::MAX_CAST_DIMENSION_SIZE;
45use crate::metrics;
46use crate::reference::ActorMeshId;
47
48declare_attrs! {
49 pub attr CAST_ACTOR_MESH_ID: ActorMeshId;
53}
54
55pub fn update_undeliverable_envelope_for_casting(
59 mut envelope: Undeliverable<MessageEnvelope>,
60) -> Undeliverable<MessageEnvelope> {
61 let old_actor = envelope.0.sender().clone();
62 if let Some(actor_id) = envelope.0.headers().get(CAST_ORIGINATING_SENDER) {
63 tracing::debug!(
64 actor_id = %old_actor,
65 "remapped comm-actor id to id from CAST_ORIGINATING_SENDER {}", actor_id
66 );
67 envelope.0.update_sender(actor_id);
68 }
69 envelope
71}
72
73#[allow(clippy::result_large_err)] #[tracing::instrument(level = "debug", skip_all)]
77pub(crate) fn actor_mesh_cast<A, M>(
78 cx: &impl context::Actor,
79 actor_mesh_id: ActorMeshId,
80 comm_actor_ref: &hyperactor_reference::ActorRef<CommActor>,
81 selection_of_root: Selection,
82 root_mesh_shape: &Shape,
83 cast_mesh_shape: &Shape,
84 message: M,
85) -> Result<(), CastError>
86where
87 A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
88 M: Castable + RemoteMessage,
89{
90 let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
91 "message_type" => M::typename(),
92 "message_variant" => message.arm().unwrap_or_default(),
93 ));
94
95 let mut headers = Flattrs::new();
96 mailbox::headers::set_send_timestamp(&mut headers);
97 mailbox::headers::set_rust_message_type::<M>(&mut headers);
98 headers.set(CAST_ACTOR_MESH_ID, actor_mesh_id.clone());
99 let message = CastMessageEnvelope::new::<A, M>(
100 actor_mesh_id.clone(),
101 cx.mailbox().actor_id().clone(),
102 cast_mesh_shape.clone(),
103 headers,
104 message,
105 )?;
106
107 let slice_of_root = root_mesh_shape.slice();
125
126 let max_cast_dimension_size = hyperactor_config::global::get(MAX_CAST_DIMENSION_SIZE);
127
128 let slice_of_cast = slice_of_root.reshape_with_limit(Limit::from(max_cast_dimension_size));
129
130 let selection_of_cast =
131 reshape_selection(selection_of_root, root_mesh_shape.slice(), &slice_of_cast)?;
132
133 let cast_message = CastMessage {
134 dest: Uslice {
135 slice: slice_of_cast,
136 selection: selection_of_cast,
137 },
138 message,
139 };
140
141 let mut headers = Flattrs::new();
143 headers.set(CAST_ACTOR_MESH_ID, actor_mesh_id);
144
145 comm_actor_ref
146 .port()
147 .send_with_headers(cx, headers, cast_message)?;
148
149 Ok(())
150}
151
152#[allow(clippy::result_large_err)] pub(crate) fn cast_to_sliced_mesh<A, M>(
154 cx: &impl context::Actor,
155 actor_mesh_id: ActorMeshId,
156 comm_actor_ref: &hyperactor_reference::ActorRef<CommActor>,
157 sel_of_sliced: &Selection,
158 message: M,
159 sliced_shape: &Shape,
160 root_mesh_shape: &Shape,
161) -> Result<(), CastError>
162where
163 A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
164 M: Castable + RemoteMessage,
165{
166 let root_slice = root_mesh_shape.slice();
167
168 let sel_of_root = if selection::normalize(sel_of_sliced) == normal::NormalizedSelection::True {
170 root_slice.reify_slice(sliced_shape.slice())?
172 } else {
173 let ranks = sel_of_sliced
175 .eval(&EvalOpts::strict(), sliced_shape.slice())?
176 .collect::<BTreeSet<_>>();
177 Selection::of_ranks(root_slice, &ranks)?
178 };
179
180 actor_mesh_cast::<A, M>(
182 cx,
183 actor_mesh_id,
184 comm_actor_ref,
185 sel_of_root,
186 root_mesh_shape,
187 sliced_shape,
188 message,
189 )
190}
191
192#[derive(Debug, thiserror::Error)]
194pub enum CastError {
195 #[error("invalid selection {0}: {1}")]
196 InvalidSelection(Selection, ShapeError),
197
198 #[error("send on rank {0}: {1}")]
199 MailboxSenderError(usize, MailboxSenderError),
200
201 #[error("unsupported selection: {0}")]
202 SelectionNotSupported(String),
203
204 #[error(transparent)]
205 RootMailboxSenderError(#[from] MailboxSenderError),
206
207 #[error(transparent)]
208 ShapeError(#[from] ShapeError),
209
210 #[error(transparent)]
211 SliceError(#[from] SliceError),
212
213 #[error(transparent)]
214 SerializationError(#[from] bincode::Error),
215
216 #[error(transparent)]
217 Other(#[from] anyhow::Error),
218
219 #[error(transparent)]
220 ReshapeError(#[from] ReshapeError),
221}