1use std::collections::VecDeque;
34use std::marker::PhantomData;
35
36use serde::Deserialize;
37use serde::Serialize;
38use serde::de::DeserializeOwned;
39
40use crate as hyperactor;
41use crate::ActorRef;
42use crate::Mailbox;
43use crate::Named;
44use crate::RemoteHandles;
45use crate::RemoteMessage;
46use crate::actor::RemoteActor;
47use crate::data::Serialized;
48
49pub trait Unbind: Sized {
54 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()>;
56}
57
58pub trait Bind: Sized {
61 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()>;
63}
64
65pub trait Castable: RemoteMessage + Bind + Unbind {}
68impl<T: RemoteMessage + Bind + Unbind> Castable for T {}
69
70#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
73pub struct Bindings(VecDeque<(u64, Serialized)>);
74
75impl Bindings {
76 pub fn push_back<T: Serialize + Named>(&mut self, value: &T) -> anyhow::Result<()> {
78 let ser = Serialized::serialize(value)?;
79 self.0.push_back((T::typehash(), ser));
80 Ok(())
81 }
82
83 pub fn pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<Option<T>> {
88 match self.0.pop_front() {
89 None => Ok(None),
90 Some((t, v)) => {
91 if t != T::typehash() {
92 anyhow::bail!(
93 "type mismatch: expected {} with hash {}, found {} in binding",
94 T::typename(),
95 T::typehash(),
96 t,
97 );
98 }
99 Ok(Some(v.deserialized::<T>()?))
100 }
101 }
102 }
103
104 pub fn try_pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<T> {
106 self.pop_front::<T>()?.ok_or_else(|| {
107 anyhow::anyhow!("expect a {} binding, but none was found", T::typename())
108 })
109 }
110
111 fn visit_mut<T: Serialize + DeserializeOwned + Named>(
112 &mut self,
113 mut f: impl FnMut(&mut T) -> anyhow::Result<()>,
114 ) -> anyhow::Result<()> {
115 for v in self.0.iter_mut() {
116 if v.0 == T::typehash() {
117 let mut t = v.1.deserialized::<T>()?;
118 f(&mut t)?;
119 v.1 = Serialized::serialize(&t)?;
120 }
121 }
122 Ok(())
123 }
124}
125
126#[derive(Debug, PartialEq)]
128pub struct Unbound<M> {
129 message: M,
130 bindings: Bindings,
131}
132
133impl<M> Unbound<M> {
134 pub fn new(message: M, bindings: Bindings) -> Self {
136 Self { message, bindings }
137 }
138}
139
140impl<M: Bind> Unbound<M> {
141 pub fn bind(mut self) -> anyhow::Result<M> {
143 self.message.bind(&mut self.bindings)?;
144 anyhow::ensure!(
145 self.bindings.0.is_empty(),
146 "there are still {} elements left in bindings",
147 self.bindings.0.len()
148 );
149 Ok(self.message)
150 }
151}
152
153impl<M: Unbind> Unbound<M> {
154 pub fn try_from_message(message: M) -> anyhow::Result<Self> {
158 let mut bindings = Bindings::default();
159 message.unbind(&mut bindings)?;
160 Ok(Unbound { message, bindings })
161 }
162}
163
164#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
166pub struct ErasedUnbound {
167 message: Serialized,
168 bindings: Bindings,
169}
170
171impl ErasedUnbound {
172 pub fn new(message: Serialized) -> Self {
174 Self {
175 message,
176 bindings: Bindings::default(),
177 }
178 }
179
180 pub fn try_from_message<T: Unbind + Serialize + Named>(msg: T) -> Result<Self, anyhow::Error> {
184 let unbound = Unbound::try_from_message(msg)?;
185 let serialized = Serialized::serialize(&unbound.message)?;
186 Ok(Self {
187 message: serialized,
188 bindings: unbound.bindings,
189 })
190 }
191
192 pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
195 &mut self,
196 f: impl FnMut(&mut T) -> anyhow::Result<()>,
197 ) -> anyhow::Result<()> {
198 self.bindings.visit_mut(f)
199 }
200
201 fn downcast<M: DeserializeOwned>(self) -> anyhow::Result<Unbound<M>> {
202 let message: M = self.message.deserialized()?;
203 Ok(Unbound {
204 message,
205 bindings: self.bindings,
206 })
207 }
208}
209
210#[derive(Debug, PartialEq, Serialize, Deserialize, Named)]
213#[serde(from = "ErasedUnbound")]
214pub struct IndexedErasedUnbound<M>(ErasedUnbound, PhantomData<M>);
215
216impl<M: DeserializeOwned> IndexedErasedUnbound<M> {
217 pub(crate) fn downcast(self) -> anyhow::Result<Unbound<M>> {
218 self.0.downcast()
219 }
220}
221
222impl<M: Bind> IndexedErasedUnbound<M> {
223 pub fn bind_for_test_only<A>(actor_ref: ActorRef<A>, mailbox: &Mailbox) -> anyhow::Result<()>
226 where
227 A: RemoteActor + RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
228 M: RemoteMessage,
229 {
230 let mailbox_clone = mailbox.clone();
231 let port_handle = mailbox.open_enqueue_port::<IndexedErasedUnbound<M>>(move |_, m| {
232 let bound_m = m.downcast()?.bind()?;
233 actor_ref.send(&mailbox_clone, bound_m)?;
234 Ok(())
235 });
236 port_handle.bind_to(IndexedErasedUnbound::<M>::port());
237 Ok(())
238 }
239}
240
241impl<M> From<ErasedUnbound> for IndexedErasedUnbound<M> {
242 fn from(erased: ErasedUnbound) -> Self {
243 Self(erased, PhantomData)
244 }
245}
246
247macro_rules! impl_bind_unbind_basic {
248 ($t:ty) => {
249 impl Bind for $t {
250 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
251 anyhow::ensure!(
252 bindings.0.is_empty(),
253 "bindings for {} should be empty, but found {} elements left",
254 stringify!($t),
255 bindings.0.len(),
256 );
257 Ok(())
258 }
259 }
260
261 impl Unbind for $t {
262 fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> {
263 Ok(())
264 }
265 }
266 };
267}
268
269impl_bind_unbind_basic!(());
270impl_bind_unbind_basic!(bool);
271impl_bind_unbind_basic!(i8);
272impl_bind_unbind_basic!(u8);
273impl_bind_unbind_basic!(i16);
274impl_bind_unbind_basic!(u16);
275impl_bind_unbind_basic!(i32);
276impl_bind_unbind_basic!(u32);
277impl_bind_unbind_basic!(i64);
278impl_bind_unbind_basic!(u64);
279impl_bind_unbind_basic!(i128);
280impl_bind_unbind_basic!(u128);
281impl_bind_unbind_basic!(isize);
282impl_bind_unbind_basic!(usize);
283impl_bind_unbind_basic!(String);
284
285impl<T: Unbind> Unbind for Option<T> {
286 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
287 match self {
288 Some(t) => t.unbind(bindings),
289 None => Ok(()),
290 }
291 }
292}
293
294impl<T: Bind> Bind for Option<T> {
295 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
296 match self {
297 Some(t) => t.bind(bindings),
298 None => Ok(()),
299 }
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use hyperactor::PortRef;
306 use hyperactor::id;
307
308 use super::*;
309 use crate::Bind;
310 use crate::Unbind;
311 use crate::accum::ReducerSpec;
312 use crate::reference::UnboundPort;
313
314 #[derive(Debug, PartialEq, Serialize, Deserialize, Named)]
316 struct MyReply(String);
317
318 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
320 struct MyMessage {
321 arg0: bool,
322 arg1: u32,
323 #[binding(include)]
324 reply0: PortRef<String>,
325 #[binding(include)]
326 reply1: PortRef<MyReply>,
327 }
328
329 #[test]
330 fn test_castable() {
331 let original_port0 = PortRef::attest(id!(world[0].actor[0][123]));
332 let original_port1 = PortRef::attest_reducible(
333 id!(world[1].actor1[0][456]),
334 Some(ReducerSpec {
335 typehash: 123,
336 builder_params: None,
337 }),
338 );
339 let my_message = MyMessage {
340 arg0: true,
341 arg1: 42,
342 reply0: original_port0.clone(),
343 reply1: original_port1.clone(),
344 };
345
346 let serialized_my_message = Serialized::serialize(&my_message).unwrap();
347
348 let mut erased = ErasedUnbound::try_from_message(my_message.clone()).unwrap();
350 assert_eq!(
351 erased,
352 ErasedUnbound {
353 message: serialized_my_message.clone(),
354 bindings: Bindings(
355 [
356 (
357 UnboundPort::typehash(),
358 Serialized::serialize(&UnboundPort::from(&original_port0)).unwrap(),
359 ),
360 (
361 UnboundPort::typehash(),
362 Serialized::serialize(&UnboundPort::from(&original_port1)).unwrap(),
363 ),
364 ]
365 .into_iter()
366 .collect()
367 ),
368 }
369 );
370
371 let new_port_id0 = id!(world[0].comm[0][680]);
373 assert_ne!(&new_port_id0, original_port0.port_id());
374 let new_port_id1 = id!(world[1].comm[0][257]);
375 assert_ne!(&new_port_id1, original_port1.port_id());
376
377 let mut new_ports = vec![&new_port_id0, &new_port_id1].into_iter();
378 erased
379 .visit_mut::<UnboundPort>(|b| {
380 let port = new_ports.next().unwrap();
381 b.update(port.clone());
382 Ok(())
383 })
384 .unwrap();
385
386 let new_port0 = PortRef::<String>::attest(new_port_id0);
387 let new_port1 = PortRef::<MyReply>::attest_reducible(
388 new_port_id1,
389 Some(ReducerSpec {
390 typehash: 123,
391 builder_params: None,
392 }),
393 );
394 let new_bindings = Bindings(
395 [
396 (
397 UnboundPort::typehash(),
398 Serialized::serialize(&UnboundPort::from(&new_port0)).unwrap(),
399 ),
400 (
401 UnboundPort::typehash(),
402 Serialized::serialize(&UnboundPort::from(&new_port1)).unwrap(),
403 ),
404 ]
405 .into_iter()
406 .collect(),
407 );
408 assert_eq!(
409 erased,
410 ErasedUnbound {
411 message: serialized_my_message.clone(),
412 bindings: new_bindings.clone(),
413 }
414 );
415
416 let unbound = erased.downcast::<MyMessage>().unwrap();
418 assert_eq!(
419 unbound,
420 Unbound {
421 message: my_message,
422 bindings: new_bindings,
423 }
424 );
425 let new_my_message = unbound.bind().unwrap();
426 assert_eq!(
427 new_my_message,
428 MyMessage {
429 arg0: true,
430 arg1: 42,
431 reply0: new_port0,
432 reply1: new_port1,
433 }
434 );
435 }
436}