1use std::collections::VecDeque;
34use std::marker::PhantomData;
35
36use serde::Deserialize;
37use serde::Serialize;
38use serde::de::DeserializeOwned;
39use typeuri::Named;
40
41use crate::ActorRef;
43use crate::Mailbox;
44use crate::RemoteHandles;
45use crate::RemoteMessage;
46use crate::actor::Referable;
47use crate::context;
48
49pub trait Unbind: Sized {
54 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()>;
56}
57
58pub trait Bind: Sized {
61 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()>;
63}
64
65pub trait Castable: RemoteMessage + Bind + Unbind {}
68impl<T: RemoteMessage + Bind + Unbind> Castable for T {}
69
70#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
73pub struct Bindings(VecDeque<(u64, wirevalue::Any)>);
74
75impl Bindings {
76 pub fn push_back<T: Serialize + Named>(&mut self, value: &T) -> anyhow::Result<()> {
78 let ser = wirevalue::Any::serialize(value)?;
79 self.0.push_back((T::typehash(), ser));
80 Ok(())
81 }
82
83 pub fn pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<Option<T>> {
88 match self.0.pop_front() {
89 None => Ok(None),
90 Some((t, v)) => {
91 if t != T::typehash() {
92 anyhow::bail!(
93 "type mismatch: expected {} with hash {}, found {} in binding",
94 T::typename(),
95 T::typehash(),
96 t,
97 );
98 }
99 Ok(Some(v.deserialized::<T>()?))
100 }
101 }
102 }
103
104 pub fn try_pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<T> {
106 self.pop_front::<T>()?.ok_or_else(|| {
107 anyhow::anyhow!("expect a {} binding, but none was found", T::typename())
108 })
109 }
110
111 fn visit_mut<T: Serialize + DeserializeOwned + Named>(
112 &mut self,
113 mut f: impl FnMut(&mut T) -> anyhow::Result<()>,
114 ) -> anyhow::Result<()> {
115 for v in self.0.iter_mut() {
116 if v.0 == T::typehash() {
117 let mut t = v.1.deserialized::<T>()?;
118 f(&mut t)?;
119 v.1 = wirevalue::Any::serialize(&t)?;
120 }
121 }
122 Ok(())
123 }
124}
125
126#[derive(Debug, PartialEq)]
128pub struct Unbound<M> {
129 message: M,
130 bindings: Bindings,
131}
132
133impl<M> Unbound<M> {
134 pub fn new(message: M, bindings: Bindings) -> Self {
136 Self { message, bindings }
137 }
138
139 pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
142 &mut self,
143 f: impl FnMut(&mut T) -> anyhow::Result<()>,
144 ) -> anyhow::Result<()> {
145 self.bindings.visit_mut(f)
146 }
147}
148
149impl<M: Bind> Unbound<M> {
150 pub fn bind(mut self) -> anyhow::Result<M> {
152 self.message.bind(&mut self.bindings)?;
153 anyhow::ensure!(
154 self.bindings.0.is_empty(),
155 "there are still {} elements left in bindings",
156 self.bindings.0.len()
157 );
158 Ok(self.message)
159 }
160}
161
162impl<M: Unbind> Unbound<M> {
163 pub fn try_from_message(message: M) -> anyhow::Result<Self> {
167 let mut bindings = Bindings::default();
168 message.unbind(&mut bindings)?;
169 Ok(Unbound { message, bindings })
170 }
171}
172
173#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, typeuri::Named)]
175pub struct ErasedUnbound {
176 message: wirevalue::Any,
177 bindings: Bindings,
178}
179wirevalue::register_type!(ErasedUnbound);
180
181impl ErasedUnbound {
182 pub fn new(message: wirevalue::Any) -> Self {
184 Self {
185 message,
186 bindings: Bindings::default(),
187 }
188 }
189
190 pub fn try_from_message<T: Unbind + Serialize + Named>(msg: T) -> Result<Self, anyhow::Error> {
194 let unbound = Unbound::try_from_message(msg)?;
195 let serialized = wirevalue::Any::serialize(&unbound.message)?;
196 Ok(Self {
197 message: serialized,
198 bindings: unbound.bindings,
199 })
200 }
201
202 pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
205 &mut self,
206 f: impl FnMut(&mut T) -> anyhow::Result<()>,
207 ) -> anyhow::Result<()> {
208 self.bindings.visit_mut(f)
209 }
210
211 fn downcast<M: DeserializeOwned + Named>(self) -> anyhow::Result<Unbound<M>> {
212 let message: M = self.message.deserialized_unchecked()?;
213 Ok(Unbound {
214 message,
215 bindings: self.bindings,
216 })
217 }
218}
219
220#[derive(Debug, PartialEq, Serialize, Deserialize, typeuri::Named)]
223#[serde(from = "ErasedUnbound")]
224pub struct IndexedErasedUnbound<M>(ErasedUnbound, PhantomData<M>);
225
226impl<M: DeserializeOwned + Named> IndexedErasedUnbound<M> {
227 pub(crate) fn downcast(self) -> anyhow::Result<Unbound<M>> {
228 self.0.downcast()
229 }
230}
231
232impl<M: Bind> IndexedErasedUnbound<M> {
233 pub fn bind_for_test_only<A, C>(
236 actor_ref: ActorRef<A>,
237 cx: C,
238 mailbox: Mailbox,
239 ) -> anyhow::Result<()>
240 where
241 A: Referable + RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
242 M: RemoteMessage,
243 C: context::Actor + Send + Sync + 'static,
244 {
245 let port_handle = mailbox.open_enqueue_port::<IndexedErasedUnbound<M>>({
246 move |_, m| {
247 let bound_m = m.downcast()?.bind()?;
248 actor_ref.send(&cx, bound_m)?;
249 Ok(())
250 }
251 });
252 port_handle.bind_actor_port();
253 Ok(())
254 }
255}
256
257impl<M> From<ErasedUnbound> for IndexedErasedUnbound<M> {
258 fn from(erased: ErasedUnbound) -> Self {
259 Self(erased, PhantomData)
260 }
261}
262
263macro_rules! impl_bind_unbind_basic {
264 ($t:ty) => {
265 impl Bind for $t {
266 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
267 anyhow::ensure!(
268 bindings.0.is_empty(),
269 "bindings for {} should be empty, but found {} elements left",
270 stringify!($t),
271 bindings.0.len(),
272 );
273 Ok(())
274 }
275 }
276
277 impl Unbind for $t {
278 fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> {
279 Ok(())
280 }
281 }
282 };
283}
284
285impl_bind_unbind_basic!(());
286impl_bind_unbind_basic!(bool);
287impl_bind_unbind_basic!(i8);
288impl_bind_unbind_basic!(u8);
289impl_bind_unbind_basic!(i16);
290impl_bind_unbind_basic!(u16);
291impl_bind_unbind_basic!(i32);
292impl_bind_unbind_basic!(u32);
293impl_bind_unbind_basic!(i64);
294impl_bind_unbind_basic!(u64);
295impl_bind_unbind_basic!(i128);
296impl_bind_unbind_basic!(u128);
297impl_bind_unbind_basic!(isize);
298impl_bind_unbind_basic!(usize);
299impl_bind_unbind_basic!(String);
300
301impl<T: Unbind> Unbind for Option<T> {
302 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
303 match self {
304 Some(t) => t.unbind(bindings),
305 None => Ok(()),
306 }
307 }
308}
309
310impl<T: Bind> Bind for Option<T> {
311 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
312 match self {
313 Some(t) => t.bind(bindings),
314 None => Ok(()),
315 }
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use crate as hyperactor; use crate::Bind;
324 use crate::PortRef;
325 use crate::Unbind;
326 use crate::accum::ReducerSpec;
327 use crate::id;
328 use crate::reference::UnboundPort;
329
330 #[derive(Debug, PartialEq, Serialize, Deserialize, typeuri::Named)]
332 struct MyReply(String);
333
334 #[derive(
336 Debug,
337 Clone,
338 PartialEq,
339 Serialize,
340 Deserialize,
341 typeuri::Named,
342 Bind,
343 Unbind
344 )]
345 struct MyMessage {
346 arg0: bool,
347 arg1: u32,
348 #[binding(include)]
349 reply0: PortRef<String>,
350 #[binding(include)]
351 reply1: PortRef<MyReply>,
352 }
353
354 #[test]
355 fn test_castable() {
356 let original_port0 = PortRef::attest(id!(world[0].actor[0][123]));
357 let original_port1 = PortRef::attest_reducible(
358 id!(world[1].actor1[0][456]),
359 Some(ReducerSpec {
360 typehash: 123,
361 builder_params: None,
362 }),
363 None,
364 );
365 let my_message = MyMessage {
366 arg0: true,
367 arg1: 42,
368 reply0: original_port0.clone(),
369 reply1: original_port1.clone(),
370 };
371
372 let serialized_my_message = wirevalue::Any::serialize(&my_message).unwrap();
373
374 let mut erased = ErasedUnbound::try_from_message(my_message.clone()).unwrap();
376 assert_eq!(
377 erased,
378 ErasedUnbound {
379 message: serialized_my_message.clone(),
380 bindings: Bindings(
381 [
382 (
383 UnboundPort::typehash(),
384 wirevalue::Any::serialize(&UnboundPort::from(&original_port0)).unwrap(),
385 ),
386 (
387 UnboundPort::typehash(),
388 wirevalue::Any::serialize(&UnboundPort::from(&original_port1)).unwrap(),
389 ),
390 ]
391 .into_iter()
392 .collect()
393 ),
394 }
395 );
396
397 let new_port_id0 = id!(world[0].comm[0][680]);
399 assert_ne!(&new_port_id0, original_port0.port_id());
400 let new_port_id1 = id!(world[1].comm[0][257]);
401 assert_ne!(&new_port_id1, original_port1.port_id());
402
403 let mut new_ports = vec![&new_port_id0, &new_port_id1].into_iter();
404 erased
405 .visit_mut::<UnboundPort>(|b| {
406 let port = new_ports.next().unwrap();
407 b.update(port.clone());
408 Ok(())
409 })
410 .unwrap();
411
412 let new_port0 = PortRef::<String>::attest(new_port_id0);
413 let new_port1 = PortRef::<MyReply>::attest_reducible(
414 new_port_id1,
415 Some(ReducerSpec {
416 typehash: 123,
417 builder_params: None,
418 }),
419 None,
420 );
421 let new_bindings = Bindings(
422 [
423 (
424 UnboundPort::typehash(),
425 wirevalue::Any::serialize(&UnboundPort::from(&new_port0)).unwrap(),
426 ),
427 (
428 UnboundPort::typehash(),
429 wirevalue::Any::serialize(&UnboundPort::from(&new_port1)).unwrap(),
430 ),
431 ]
432 .into_iter()
433 .collect(),
434 );
435 assert_eq!(
436 erased,
437 ErasedUnbound {
438 message: serialized_my_message.clone(),
439 bindings: new_bindings.clone(),
440 }
441 );
442
443 let unbound = erased.downcast::<MyMessage>().unwrap();
445 assert_eq!(
446 unbound,
447 Unbound {
448 message: my_message,
449 bindings: new_bindings,
450 }
451 );
452 let new_my_message = unbound.bind().unwrap();
453 assert_eq!(
454 new_my_message,
455 MyMessage {
456 arg0: true,
457 arg1: 42,
458 reply0: new_port0,
459 reply1: new_port1,
460 }
461 );
462 }
463}