1use std::collections::VecDeque;
34use std::marker::PhantomData;
35
36use serde::Deserialize;
37use serde::Serialize;
38use serde::de::DeserializeOwned;
39use typeuri::Named;
40
41use crate::Mailbox;
43use crate::RemoteHandles;
44use crate::RemoteMessage;
45use crate::actor::Referable;
46use crate::context;
47use crate::reference;
48
49pub trait Unbind: Sized {
54 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()>;
56}
57
58pub trait Bind: Sized {
61 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()>;
63}
64
65pub trait Castable: RemoteMessage + Bind + Unbind {}
68impl<T: RemoteMessage + Bind + Unbind> Castable for T {}
69
70#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
73pub struct Bindings(VecDeque<(u64, wirevalue::Any)>);
74
75impl Bindings {
76 pub fn push_back<T: Serialize + Named>(&mut self, value: &T) -> anyhow::Result<()> {
78 let ser = wirevalue::Any::serialize(value)?;
79 self.0.push_back((T::typehash(), ser));
80 Ok(())
81 }
82
83 pub fn pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<Option<T>> {
88 match self.0.pop_front() {
89 None => Ok(None),
90 Some((t, v)) => {
91 if t != T::typehash() {
92 anyhow::bail!(
93 "type mismatch: expected {} with hash {}, found {} in binding",
94 T::typename(),
95 T::typehash(),
96 t,
97 );
98 }
99 Ok(Some(v.deserialized::<T>()?))
100 }
101 }
102 }
103
104 pub fn try_pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<T> {
106 self.pop_front::<T>()?.ok_or_else(|| {
107 anyhow::anyhow!("expect a {} binding, but none was found", T::typename())
108 })
109 }
110
111 fn visit_mut<T: Serialize + DeserializeOwned + Named>(
112 &mut self,
113 mut f: impl FnMut(&mut T) -> anyhow::Result<()>,
114 ) -> anyhow::Result<()> {
115 for v in self.0.iter_mut() {
116 if v.0 == T::typehash() {
117 let mut t = v.1.deserialized::<T>()?;
118 f(&mut t)?;
119 v.1 = wirevalue::Any::serialize(&t)?;
120 }
121 }
122 Ok(())
123 }
124}
125
126#[derive(Debug, PartialEq)]
128pub struct Unbound<M> {
129 message: M,
130 bindings: Bindings,
131}
132
133impl<M> Unbound<M> {
134 pub fn new(message: M, bindings: Bindings) -> Self {
136 Self { message, bindings }
137 }
138
139 pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
142 &mut self,
143 f: impl FnMut(&mut T) -> anyhow::Result<()>,
144 ) -> anyhow::Result<()> {
145 self.bindings.visit_mut(f)
146 }
147}
148
149impl<M: Bind> Unbound<M> {
150 pub fn bind(mut self) -> anyhow::Result<M> {
152 self.message.bind(&mut self.bindings)?;
153 anyhow::ensure!(
154 self.bindings.0.is_empty(),
155 "there are still {} elements left in bindings",
156 self.bindings.0.len()
157 );
158 Ok(self.message)
159 }
160}
161
162impl<M: Unbind> Unbound<M> {
163 pub fn try_from_message(message: M) -> anyhow::Result<Self> {
167 let mut bindings = Bindings::default();
168 message.unbind(&mut bindings)?;
169 Ok(Unbound { message, bindings })
170 }
171}
172
173#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, typeuri::Named)]
175pub struct ErasedUnbound {
176 message: wirevalue::Any,
177 bindings: Bindings,
178}
179wirevalue::register_type!(ErasedUnbound);
180
181impl ErasedUnbound {
182 pub fn new(message: wirevalue::Any) -> Self {
184 Self {
185 message,
186 bindings: Bindings::default(),
187 }
188 }
189
190 pub fn message(&self) -> &wirevalue::Any {
192 &self.message
193 }
194
195 pub fn try_from_message<T: Unbind + Serialize + Named>(msg: T) -> Result<Self, anyhow::Error> {
199 let unbound = Unbound::try_from_message(msg)?;
200 let serialized = wirevalue::Any::serialize(&unbound.message)?;
201 Ok(Self {
202 message: serialized,
203 bindings: unbound.bindings,
204 })
205 }
206
207 pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
210 &mut self,
211 f: impl FnMut(&mut T) -> anyhow::Result<()>,
212 ) -> anyhow::Result<()> {
213 self.bindings.visit_mut(f)
214 }
215
216 fn downcast<M: DeserializeOwned + Named>(self) -> anyhow::Result<Unbound<M>> {
217 let message: M = self.message.deserialized_unchecked()?;
218 Ok(Unbound {
219 message,
220 bindings: self.bindings,
221 })
222 }
223}
224
225#[derive(Debug, PartialEq, Serialize, Deserialize, typeuri::Named)]
228#[serde(from = "ErasedUnbound")]
229pub struct IndexedErasedUnbound<M>(ErasedUnbound, PhantomData<M>);
230
231impl<M: DeserializeOwned + Named> IndexedErasedUnbound<M> {
232 pub(crate) fn downcast(self) -> anyhow::Result<Unbound<M>> {
233 self.0.downcast()
234 }
235
236 pub fn inner_any(&self) -> &wirevalue::Any {
238 self.0.message()
239 }
240}
241
242impl<M: Bind> IndexedErasedUnbound<M> {
243 pub fn bind_for_test_only<A, C>(
246 actor_ref: reference::ActorRef<A>,
247 cx: C,
248 mailbox: Mailbox,
249 ) -> anyhow::Result<()>
250 where
251 A: Referable + RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
252 M: RemoteMessage,
253 C: context::Actor + Send + Sync + 'static,
254 {
255 let port_handle = mailbox.open_enqueue_port::<IndexedErasedUnbound<M>>({
256 move |_, m| {
257 let bound_m = m.downcast()?.bind()?;
258 actor_ref.send(&cx, bound_m)?;
259 Ok(())
260 }
261 });
262 port_handle.bind_actor_port();
263 Ok(())
264 }
265}
266
267impl<M> From<ErasedUnbound> for IndexedErasedUnbound<M> {
268 fn from(erased: ErasedUnbound) -> Self {
269 Self(erased, PhantomData)
270 }
271}
272
273macro_rules! impl_bind_unbind_basic {
274 ($t:ty) => {
275 impl Bind for $t {
276 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
277 anyhow::ensure!(
278 bindings.0.is_empty(),
279 "bindings for {} should be empty, but found {} elements left",
280 stringify!($t),
281 bindings.0.len(),
282 );
283 Ok(())
284 }
285 }
286
287 impl Unbind for $t {
288 fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> {
289 Ok(())
290 }
291 }
292 };
293}
294
295impl_bind_unbind_basic!(());
296impl_bind_unbind_basic!(bool);
297impl_bind_unbind_basic!(i8);
298impl_bind_unbind_basic!(u8);
299impl_bind_unbind_basic!(i16);
300impl_bind_unbind_basic!(u16);
301impl_bind_unbind_basic!(i32);
302impl_bind_unbind_basic!(u32);
303impl_bind_unbind_basic!(i64);
304impl_bind_unbind_basic!(u64);
305impl_bind_unbind_basic!(i128);
306impl_bind_unbind_basic!(u128);
307impl_bind_unbind_basic!(isize);
308impl_bind_unbind_basic!(usize);
309impl_bind_unbind_basic!(String);
310impl_bind_unbind_basic!(std::time::Duration);
311impl_bind_unbind_basic!(std::time::SystemTime);
312
313impl<T: Unbind> Unbind for Option<T> {
314 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
315 match self {
316 Some(t) => t.unbind(bindings),
317 None => Ok(()),
318 }
319 }
320}
321
322impl<T: Bind> Bind for Option<T> {
323 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
324 match self {
325 Some(t) => t.bind(bindings),
326 None => Ok(()),
327 }
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use crate as hyperactor; use crate::Bind;
336 use crate::Unbind;
337 use crate::accum::ReducerSpec;
338 use crate::accum::StreamingReducerOpts;
339 use crate::testing::ids::test_port_id_with_pid;
340
341 #[derive(Debug, PartialEq, Serialize, Deserialize, typeuri::Named)]
343 struct MyReply(String);
344
345 #[derive(
347 Debug,
348 Clone,
349 PartialEq,
350 Serialize,
351 Deserialize,
352 typeuri::Named,
353 Bind,
354 Unbind
355 )]
356 struct MyMessage {
357 arg0: bool,
358 arg1: u32,
359 #[binding(include)]
360 reply0: reference::PortRef<String>,
361 #[binding(include)]
362 reply1: reference::PortRef<MyReply>,
363 }
364
365 #[test]
366 fn test_castable() {
367 let original_port0 =
368 reference::PortRef::attest(test_port_id_with_pid("world_0", "actor", 0, 123));
369 let original_port1 = reference::PortRef::attest_reducible(
370 test_port_id_with_pid("world_1", "actor1", 0, 456),
371 Some(ReducerSpec {
372 typehash: 123,
373 builder_params: None,
374 }),
375 StreamingReducerOpts::default(),
376 );
377 let my_message = MyMessage {
378 arg0: true,
379 arg1: 42,
380 reply0: original_port0.clone(),
381 reply1: original_port1.clone(),
382 };
383
384 let serialized_my_message = wirevalue::Any::serialize(&my_message).unwrap();
385
386 let mut erased = ErasedUnbound::try_from_message(my_message.clone()).unwrap();
388 assert_eq!(
389 erased,
390 ErasedUnbound {
391 message: serialized_my_message.clone(),
392 bindings: Bindings(
393 [
394 (
395 reference::UnboundPort::typehash(),
396 wirevalue::Any::serialize(&reference::UnboundPort::from(
397 &original_port0
398 ))
399 .unwrap(),
400 ),
401 (
402 reference::UnboundPort::typehash(),
403 wirevalue::Any::serialize(&reference::UnboundPort::from(
404 &original_port1
405 ))
406 .unwrap(),
407 ),
408 ]
409 .into_iter()
410 .collect()
411 ),
412 }
413 );
414
415 let new_port_id0 = test_port_id_with_pid("world_0", "comm", 0, 680);
417 assert_ne!(&new_port_id0, original_port0.port_id());
418 let new_port_id1 = test_port_id_with_pid("world_1", "comm", 0, 257);
419 assert_ne!(&new_port_id1, original_port1.port_id());
420
421 let mut new_ports = vec![&new_port_id0, &new_port_id1].into_iter();
422 erased
423 .visit_mut::<reference::UnboundPort>(|b| {
424 let port = new_ports.next().unwrap();
425 b.update(port.clone());
426 Ok(())
427 })
428 .unwrap();
429
430 let new_port0 = reference::PortRef::<String>::attest(new_port_id0);
431 let new_port1 = reference::PortRef::<MyReply>::attest_reducible(
432 new_port_id1,
433 Some(ReducerSpec {
434 typehash: 123,
435 builder_params: None,
436 }),
437 StreamingReducerOpts::default(),
438 );
439 let new_bindings = Bindings(
440 [
441 (
442 reference::UnboundPort::typehash(),
443 wirevalue::Any::serialize(&reference::UnboundPort::from(&new_port0)).unwrap(),
444 ),
445 (
446 reference::UnboundPort::typehash(),
447 wirevalue::Any::serialize(&reference::UnboundPort::from(&new_port1)).unwrap(),
448 ),
449 ]
450 .into_iter()
451 .collect(),
452 );
453 assert_eq!(
454 erased,
455 ErasedUnbound {
456 message: serialized_my_message.clone(),
457 bindings: new_bindings.clone(),
458 }
459 );
460
461 let unbound = erased.downcast::<MyMessage>().unwrap();
463 assert_eq!(
464 unbound,
465 Unbound {
466 message: my_message,
467 bindings: new_bindings,
468 }
469 );
470 let new_my_message = unbound.bind().unwrap();
471 assert_eq!(
472 new_my_message,
473 MyMessage {
474 arg0: true,
475 arg1: 42,
476 reply0: new_port0,
477 reply1: new_port1,
478 }
479 );
480 }
481}