1use hyperactor::Actor;
12use hyperactor::Context;
13use hyperactor::RemoteHandles;
14use hyperactor::RemoteMessage;
15use hyperactor::actor::Referable;
16use hyperactor::message::Castable;
17use hyperactor::message::ErasedUnbound;
18use hyperactor::message::IndexedErasedUnbound;
19use hyperactor::reference::ActorId;
20use hyperactor_config::attrs::Attrs;
21use hyperactor_config::attrs::declare_attrs;
22use ndslice::Extent;
23use ndslice::Point;
24use ndslice::Shape;
25use ndslice::Slice;
26use ndslice::selection::Selection;
27use ndslice::selection::routing::RoutingFrame;
28use serde::Deserialize;
29use serde::Serialize;
30use typeuri::Named;
31
32use crate::reference::ActorMeshId;
33
34#[derive(Serialize, Deserialize, Debug, Clone)]
39pub struct Uslice {
40 pub slice: Slice,
42 pub selection: Selection,
44}
45
46#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)]
48pub struct CastMessageEnvelope {
49 actor_mesh_id: ActorMeshId,
51 sender: ActorId,
53 dest_port: DestinationPort,
56 data: ErasedUnbound,
58 shape: Shape,
60}
61wirevalue::register_type!(CastMessageEnvelope);
62
63impl CastMessageEnvelope {
64 pub fn new<A, M>(
66 actor_mesh_id: ActorMeshId,
67 sender: ActorId,
68 shape: Shape,
69 message: M,
70 ) -> Result<Self, anyhow::Error>
71 where
72 A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
73 M: Castable + RemoteMessage,
74 {
75 let data = ErasedUnbound::try_from_message(message)?;
76 let actor_name = match &actor_mesh_id {
77 ActorMeshId::V0(_, actor_name) => actor_name.clone(),
78 ActorMeshId::V1(name) => name.to_string(),
79 };
80 Ok(Self {
81 actor_mesh_id,
82 sender,
83 dest_port: DestinationPort::new::<A, M>(actor_name),
84 data,
85 shape,
86 })
87 }
88
89 pub fn from_serialized(
93 actor_mesh_id: ActorMeshId,
94 sender: ActorId,
95 dest_port: DestinationPort,
96 shape: Shape,
97 data: wirevalue::Any,
98 ) -> Self {
99 Self {
100 actor_mesh_id,
101 sender,
102 dest_port,
103 data: ErasedUnbound::new(data),
104 shape,
105 }
106 }
107
108 pub(crate) fn sender(&self) -> &ActorId {
109 &self.sender
110 }
111
112 pub(crate) fn dest_port(&self) -> &DestinationPort {
113 &self.dest_port
114 }
115
116 pub(crate) fn data(&self) -> &ErasedUnbound {
117 &self.data
118 }
119
120 pub(crate) fn data_mut(&mut self) -> &mut ErasedUnbound {
121 &mut self.data
122 }
123
124 pub(crate) fn shape(&self) -> &Shape {
125 &self.shape
126 }
127
128 pub(crate) fn relative_rank(&self, rank_on_root_mesh: usize) -> anyhow::Result<usize> {
131 let shape = self.shape();
132 let coords = shape.slice().coordinates(rank_on_root_mesh).map_err(|e| {
133 anyhow::anyhow!(
134 "fail to calculate coords for root rank {} due to error: {}; shape is {:?}",
135 rank_on_root_mesh,
136 e,
137 shape,
138 )
139 })?;
140 let extent =
141 Extent::new(shape.labels().to_vec(), shape.slice().sizes().to_vec()).map_err(|e| {
142 anyhow::anyhow!(
143 "fail to calculate extent for root rank {} due to error: {}; shape is {}",
144 rank_on_root_mesh,
145 e,
146 shape,
147 )
148 })?;
149 let point = extent.point(coords).map_err(|e| {
150 anyhow::anyhow!(
151 "fail to calculate point for root rank {} due to error: {}; extent is {}, shape is {}",
152 rank_on_root_mesh,
153 e,
154 extent,
155 shape,
156 )
157 })?;
158 Ok(point.rank())
159 }
160
161 pub(crate) fn stream_key(&self) -> (ActorMeshId, ActorId) {
165 (self.actor_mesh_id.clone(), self.sender.clone())
166 }
167}
168
169#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)]
175pub struct DestinationPort {
176 actor_name: String,
178 port: u64,
181}
182wirevalue::register_type!(DestinationPort);
183
184impl DestinationPort {
185 pub fn new<A, M>(actor_name: String) -> Self
187 where
188 A: Referable + RemoteHandles<IndexedErasedUnbound<M>>,
189 M: Castable + RemoteMessage,
190 {
191 Self {
192 actor_name,
193 port: IndexedErasedUnbound::<M>::port(),
194 }
195 }
196
197 pub fn port(&self) -> u64 {
199 self.port
200 }
201
202 pub fn actor_name(&self) -> &str {
204 &self.actor_name
205 }
206}
207
208#[derive(Serialize, Deserialize, Debug, Clone, Named)]
210pub struct CastMessage {
211 pub dest: Uslice,
213 pub message: CastMessageEnvelope,
215}
216wirevalue::register_type!(CastMessage);
217
218#[derive(Serialize, Deserialize, Debug, Clone, Named)]
222pub(crate) struct ForwardMessage {
223 pub(crate) sender: ActorId,
225 pub(crate) dests: Vec<RoutingFrame>,
227 pub(crate) seq: usize,
229 pub(crate) last_seq: usize,
231 pub(crate) message: CastMessageEnvelope,
233}
234wirevalue::register_type!(ForwardMessage);
235
236declare_attrs! {
237 pub attr CAST_ORIGINATING_SENDER: ActorId;
239
240 pub attr CAST_POINT: Point;
242}
243
244pub fn set_cast_info_on_headers(headers: &mut Attrs, cast_point: Point, sender: ActorId) {
245 headers.set(CAST_POINT, cast_point);
246 headers.set(CAST_ORIGINATING_SENDER, sender);
247}
248
249pub trait CastInfo {
250 fn cast_point(&self) -> Point;
255 fn sender(&self) -> &ActorId;
256}
257
258impl<A: Actor> CastInfo for Context<'_, A> {
259 fn cast_point(&self) -> Point {
260 match self.headers().get(CAST_POINT) {
261 Some(point) => point.clone(),
262 None => Extent::unity().point_of_rank(0).unwrap(),
263 }
264 }
265
266 fn sender(&self) -> &ActorId {
267 self.headers()
268 .get(CAST_ORIGINATING_SENDER)
269 .expect("has sender header")
270 }
271}