1use hyperactor::Actor;
12use hyperactor::Context;
13use hyperactor::RemoteHandles;
14use hyperactor::RemoteMessage;
15use hyperactor::actor::Referable;
16use hyperactor::message::Castable;
17use hyperactor::message::ErasedUnbound;
18use hyperactor::message::IndexedErasedUnbound;
19use hyperactor::reference as hyperactor_reference;
20use hyperactor_config::Flattrs;
21use hyperactor_config::attrs::declare_attrs;
22use ndslice::Extent;
23use ndslice::Point;
24use ndslice::Region;
25use ndslice::Shape;
26use ndslice::Slice;
27use ndslice::selection::Selection;
28use ndslice::selection::routing::RoutingFrame;
29use serde::Deserialize;
30use serde::Serialize;
31use typeuri::Named;
32use uuid::Uuid;
33
34use crate::Name;
35use crate::ValueMesh;
36use crate::comm::CommMeshConfig;
37use crate::reference::ActorMeshId;
38
39pub(crate) trait CastEnvelope {
42 fn dest_port(&self) -> &DestinationPort;
43 fn headers(&self) -> &Flattrs;
44 fn sender(&self) -> &hyperactor_reference::ActorId;
45 fn cast_point(&self, config: &CommMeshConfig) -> anyhow::Result<Point>;
46 fn data(&self) -> &ErasedUnbound;
47 fn data_mut(&mut self) -> &mut ErasedUnbound;
48}
49
50#[derive(Serialize, Deserialize, Debug, Clone)]
55pub struct Uslice {
56 pub slice: Slice,
58 pub selection: Selection,
60}
61
62#[derive(Debug, Serialize, Deserialize, Clone, Named)]
64pub struct CastMessageEnvelope {
65 actor_mesh_id: ActorMeshId,
67 headers: Flattrs,
69 sender: hyperactor_reference::ActorId,
71 dest_port: DestinationPort,
74 data: ErasedUnbound,
76 shape: Shape,
78}
79wirevalue::register_type!(CastMessageEnvelope);
80
81impl CastEnvelope for CastMessageEnvelope {
82 fn sender(&self) -> &hyperactor_reference::ActorId {
83 &self.sender
84 }
85
86 fn headers(&self) -> &Flattrs {
87 &self.headers
88 }
89
90 fn dest_port(&self) -> &DestinationPort {
91 &self.dest_port
92 }
93
94 fn data(&self) -> &ErasedUnbound {
95 &self.data
96 }
97
98 fn data_mut(&mut self) -> &mut ErasedUnbound {
99 &mut self.data
100 }
101
102 fn cast_point(&self, config: &CommMeshConfig) -> anyhow::Result<Point> {
103 let rank_on_root_mesh = config.self_rank();
104 let cast_rank = self.relative_rank(rank_on_root_mesh)?;
105 let cast_shape = self.shape();
106 let cast_point = cast_shape
107 .extent()
108 .point_of_rank(cast_rank)
109 .expect("rank out of bounds");
110 Ok(cast_point)
111 }
112}
113
114impl CastMessageEnvelope {
115 pub fn new<A, M>(
117 actor_mesh_id: ActorMeshId,
118 sender: hyperactor_reference::ActorId,
119 shape: Shape,
120 headers: Flattrs,
121 message: M,
122 ) -> Result<Self, anyhow::Error>
123 where
124 A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
125 M: Castable + RemoteMessage,
126 {
127 let data = ErasedUnbound::try_from_message(message)?;
128 let actor_name = actor_mesh_id.0.to_string();
129 Ok(Self {
130 actor_mesh_id,
131 headers,
132 sender,
133 dest_port: DestinationPort::new::<A, M>(actor_name),
134 data,
135 shape,
136 })
137 }
138
139 pub fn from_serialized(
143 actor_mesh_id: ActorMeshId,
144 sender: hyperactor_reference::ActorId,
145 dest_port: DestinationPort,
146 shape: Shape,
147 headers: Flattrs,
148 data: wirevalue::Any,
149 ) -> Self {
150 Self {
151 actor_mesh_id,
152 sender,
153 headers,
154 dest_port,
155 data: ErasedUnbound::new(data),
156 shape,
157 }
158 }
159
160 pub(crate) fn shape(&self) -> &Shape {
161 &self.shape
162 }
163
164 pub(crate) fn relative_rank(&self, rank_on_root_mesh: usize) -> anyhow::Result<usize> {
167 let shape = self.shape();
168 let coords = shape.slice().coordinates(rank_on_root_mesh).map_err(|e| {
169 anyhow::anyhow!(
170 "fail to calculate coords for root rank {} due to error: {}; shape is {:?}",
171 rank_on_root_mesh,
172 e,
173 shape,
174 )
175 })?;
176 let extent =
177 Extent::new(shape.labels().to_vec(), shape.slice().sizes().to_vec()).map_err(|e| {
178 anyhow::anyhow!(
179 "fail to calculate extent for root rank {} due to error: {}; shape is {}",
180 rank_on_root_mesh,
181 e,
182 shape,
183 )
184 })?;
185 let point = extent.point(coords).map_err(|e| {
186 anyhow::anyhow!(
187 "fail to calculate point for root rank {} due to error: {}; extent is {}, shape is {}",
188 rank_on_root_mesh,
189 e,
190 extent,
191 shape,
192 )
193 })?;
194 Ok(point.rank())
195 }
196
197 pub(crate) fn stream_key(&self) -> (ActorMeshId, hyperactor_reference::ActorId) {
201 (self.actor_mesh_id.clone(), self.sender.clone())
202 }
203}
204
205#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)]
211pub struct DestinationPort {
212 actor_name: String,
214 port: u64,
217}
218wirevalue::register_type!(DestinationPort);
219
220impl DestinationPort {
221 pub fn new<A, M>(actor_name: String) -> Self
223 where
224 A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
225 M: Castable + RemoteMessage,
226 {
227 Self {
228 actor_name,
229 port: IndexedErasedUnbound::<M>::port(),
230 }
231 }
232
233 pub fn port(&self) -> u64 {
235 self.port
236 }
237
238 pub fn actor_name(&self) -> &str {
240 &self.actor_name
241 }
242}
243
244#[derive(Serialize, Deserialize, Debug, Clone, Named)]
246pub struct CastMessage {
247 pub dest: Uslice,
249 pub message: CastMessageEnvelope,
251}
252wirevalue::register_type!(CastMessage);
253
254#[derive(Serialize, Deserialize, Debug, Clone, Named)]
258pub(crate) struct ForwardMessage {
259 pub(crate) sender: hyperactor_reference::ActorId,
261 pub(crate) dests: Vec<RoutingFrame>,
263 pub(crate) seq: usize,
265 pub(crate) last_seq: usize,
267 pub(crate) message: CastMessageEnvelope,
269}
270wirevalue::register_type!(ForwardMessage);
271
272#[derive(Serialize, Deserialize, Debug, Clone, Named)]
274pub(crate) struct CastMessageV1 {
275 pub(super) headers: Flattrs,
277 pub(super) sender: hyperactor_reference::ActorId,
279 pub(super) session_id: Uuid,
281 pub(super) seqs: ValueMesh<u64>,
283 pub(super) dest_region: Region,
285 pub(super) dest_port: DestinationPort,
288 pub(super) data: ErasedUnbound,
290}
291
292impl CastEnvelope for CastMessageV1 {
293 fn sender(&self) -> &hyperactor_reference::ActorId {
294 &self.sender
295 }
296
297 fn headers(&self) -> &Flattrs {
298 &self.headers
299 }
300
301 fn dest_port(&self) -> &DestinationPort {
302 &self.dest_port
303 }
304
305 fn data(&self) -> &ErasedUnbound {
306 &self.data
307 }
308
309 fn data_mut(&mut self) -> &mut ErasedUnbound {
310 &mut self.data
311 }
312
313 fn cast_point(&self, config: &CommMeshConfig) -> anyhow::Result<Point> {
314 let rank_on_root_mesh = config.self_rank();
315 let cast_point = self.dest_region.point_of_base_rank(rank_on_root_mesh)?;
316 Ok(cast_point)
317 }
318}
319
320impl CastMessageV1 {
321 #[allow(unused)]
323 pub(crate) fn new<A, M>(
324 sender: hyperactor_reference::ActorId,
325 dest_mesh: &Name,
326 dest_region: Region,
327 headers: Flattrs,
328 message: M,
329 session_id: Uuid,
330 seqs: ValueMesh<u64>,
331 ) -> Result<Self, anyhow::Error>
332 where
333 A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
334 M: Castable + RemoteMessage,
335 {
336 let data = ErasedUnbound::try_from_message(message)?;
337 Ok(Self {
338 headers,
339 sender,
340 session_id,
341 seqs,
342 dest_region,
343 dest_port: DestinationPort::new::<A, M>(dest_mesh.to_string()),
344 data,
345 })
346 }
347}
348
349#[derive(Serialize, Deserialize, Debug, Clone, Named)]
353pub(super) struct ForwardMessageV1 {
354 pub(super) dests: Vec<RoutingFrame>,
356 pub(super) message: CastMessageV1,
358}
359
360declare_attrs! {
361 pub attr CAST_ORIGINATING_SENDER: hyperactor_reference::ActorId;
363
364 pub attr CAST_POINT: Point;
366}
367
368pub fn set_cast_info_on_headers(
369 headers: &mut Flattrs,
370 cast_point: Point,
371 sender: hyperactor_reference::ActorId,
372) {
373 headers.set(
378 hyperactor::mailbox::headers::SENDER_ACTOR_ID_HASH,
379 hyperactor_telemetry::hash_to_u64(&sender),
380 );
381 headers.set(CAST_POINT, cast_point);
382 headers.set(CAST_ORIGINATING_SENDER, sender);
383}
384
385pub trait CastInfo {
386 fn cast_point(&self) -> Point;
391 fn sender(&self) -> hyperactor_reference::ActorId;
392}
393
394impl<A: Actor> CastInfo for Context<'_, A> {
395 fn cast_point(&self) -> Point {
396 match self.headers().get(CAST_POINT) {
397 Some(point) => point,
398 None => Extent::unity().point_of_rank(0).unwrap(),
399 }
400 }
401
402 fn sender(&self) -> hyperactor_reference::ActorId {
403 self.headers()
404 .get(CAST_ORIGINATING_SENDER)
405 .expect("has sender header")
406 }
407}