1use std::collections::VecDeque;
34use std::marker::PhantomData;
35
36use serde::Deserialize;
37use serde::Serialize;
38use serde::de::DeserializeOwned;
39use typeuri::Named;
40
41use crate::Mailbox;
43use crate::RemoteHandles;
44use crate::RemoteMessage;
45use crate::actor::Referable;
46use crate::context;
47use crate::reference;
48
49pub trait Unbind: Sized {
54 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()>;
56}
57
58pub trait Bind: Sized {
61 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()>;
63}
64
65pub trait Castable: RemoteMessage + Bind + Unbind {}
68impl<T: RemoteMessage + Bind + Unbind> Castable for T {}
69
70#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
73pub struct Bindings(VecDeque<(u64, wirevalue::Any)>);
74
75impl Bindings {
76 pub fn push_back<T: Serialize + Named>(&mut self, value: &T) -> anyhow::Result<()> {
78 let ser = wirevalue::Any::serialize(value)?;
79 self.0.push_back((T::typehash(), ser));
80 Ok(())
81 }
82
83 pub fn pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<Option<T>> {
88 match self.0.pop_front() {
89 None => Ok(None),
90 Some((t, v)) => {
91 if t != T::typehash() {
92 anyhow::bail!(
93 "type mismatch: expected {} with hash {}, found {} in binding",
94 T::typename(),
95 T::typehash(),
96 t,
97 );
98 }
99 Ok(Some(v.deserialized::<T>()?))
100 }
101 }
102 }
103
104 pub fn try_pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<T> {
106 self.pop_front::<T>()?.ok_or_else(|| {
107 anyhow::anyhow!("expect a {} binding, but none was found", T::typename())
108 })
109 }
110
111 fn visit_mut<T: Serialize + DeserializeOwned + Named>(
112 &mut self,
113 mut f: impl FnMut(&mut T) -> anyhow::Result<()>,
114 ) -> anyhow::Result<()> {
115 for v in self.0.iter_mut() {
116 if v.0 == T::typehash() {
117 let mut t = v.1.deserialized::<T>()?;
118 f(&mut t)?;
119 v.1 = wirevalue::Any::serialize(&t)?;
120 }
121 }
122 Ok(())
123 }
124}
125
126#[derive(Debug, PartialEq)]
128pub struct Unbound<M> {
129 message: M,
130 bindings: Bindings,
131}
132
133impl<M> Unbound<M> {
134 pub fn new(message: M, bindings: Bindings) -> Self {
136 Self { message, bindings }
137 }
138
139 pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
142 &mut self,
143 f: impl FnMut(&mut T) -> anyhow::Result<()>,
144 ) -> anyhow::Result<()> {
145 self.bindings.visit_mut(f)
146 }
147}
148
149impl<M: Bind> Unbound<M> {
150 pub fn bind(mut self) -> anyhow::Result<M> {
152 self.message.bind(&mut self.bindings)?;
153 anyhow::ensure!(
154 self.bindings.0.is_empty(),
155 "there are still {} elements left in bindings",
156 self.bindings.0.len()
157 );
158 Ok(self.message)
159 }
160}
161
162impl<M: Unbind> Unbound<M> {
163 pub fn try_from_message(message: M) -> anyhow::Result<Self> {
167 let mut bindings = Bindings::default();
168 message.unbind(&mut bindings)?;
169 Ok(Unbound { message, bindings })
170 }
171}
172
173#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, typeuri::Named)]
175pub struct ErasedUnbound {
176 message: wirevalue::Any,
177 bindings: Bindings,
178}
179wirevalue::register_type!(ErasedUnbound);
180
181impl ErasedUnbound {
182 pub fn new(message: wirevalue::Any) -> Self {
184 Self {
185 message,
186 bindings: Bindings::default(),
187 }
188 }
189
190 pub fn try_from_message<T: Unbind + Serialize + Named>(msg: T) -> Result<Self, anyhow::Error> {
194 let unbound = Unbound::try_from_message(msg)?;
195 let serialized = wirevalue::Any::serialize(&unbound.message)?;
196 Ok(Self {
197 message: serialized,
198 bindings: unbound.bindings,
199 })
200 }
201
202 pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
205 &mut self,
206 f: impl FnMut(&mut T) -> anyhow::Result<()>,
207 ) -> anyhow::Result<()> {
208 self.bindings.visit_mut(f)
209 }
210
211 fn downcast<M: DeserializeOwned + Named>(self) -> anyhow::Result<Unbound<M>> {
212 let message: M = self.message.deserialized_unchecked()?;
213 Ok(Unbound {
214 message,
215 bindings: self.bindings,
216 })
217 }
218}
219
220#[derive(Debug, PartialEq, Serialize, Deserialize, typeuri::Named)]
223#[serde(from = "ErasedUnbound")]
224pub struct IndexedErasedUnbound<M>(ErasedUnbound, PhantomData<M>);
225
226impl<M: DeserializeOwned + Named> IndexedErasedUnbound<M> {
227 pub(crate) fn downcast(self) -> anyhow::Result<Unbound<M>> {
228 self.0.downcast()
229 }
230}
231
232impl<M: Bind> IndexedErasedUnbound<M> {
233 pub fn bind_for_test_only<A, C>(
236 actor_ref: reference::ActorRef<A>,
237 cx: C,
238 mailbox: Mailbox,
239 ) -> anyhow::Result<()>
240 where
241 A: Referable + RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
242 M: RemoteMessage,
243 C: context::Actor + Send + Sync + 'static,
244 {
245 let port_handle = mailbox.open_enqueue_port::<IndexedErasedUnbound<M>>({
246 move |_, m| {
247 let bound_m = m.downcast()?.bind()?;
248 actor_ref.send(&cx, bound_m)?;
249 Ok(())
250 }
251 });
252 port_handle.bind_actor_port();
253 Ok(())
254 }
255}
256
257impl<M> From<ErasedUnbound> for IndexedErasedUnbound<M> {
258 fn from(erased: ErasedUnbound) -> Self {
259 Self(erased, PhantomData)
260 }
261}
262
263macro_rules! impl_bind_unbind_basic {
264 ($t:ty) => {
265 impl Bind for $t {
266 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
267 anyhow::ensure!(
268 bindings.0.is_empty(),
269 "bindings for {} should be empty, but found {} elements left",
270 stringify!($t),
271 bindings.0.len(),
272 );
273 Ok(())
274 }
275 }
276
277 impl Unbind for $t {
278 fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> {
279 Ok(())
280 }
281 }
282 };
283}
284
285impl_bind_unbind_basic!(());
286impl_bind_unbind_basic!(bool);
287impl_bind_unbind_basic!(i8);
288impl_bind_unbind_basic!(u8);
289impl_bind_unbind_basic!(i16);
290impl_bind_unbind_basic!(u16);
291impl_bind_unbind_basic!(i32);
292impl_bind_unbind_basic!(u32);
293impl_bind_unbind_basic!(i64);
294impl_bind_unbind_basic!(u64);
295impl_bind_unbind_basic!(i128);
296impl_bind_unbind_basic!(u128);
297impl_bind_unbind_basic!(isize);
298impl_bind_unbind_basic!(usize);
299impl_bind_unbind_basic!(String);
300
301impl<T: Unbind> Unbind for Option<T> {
302 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
303 match self {
304 Some(t) => t.unbind(bindings),
305 None => Ok(()),
306 }
307 }
308}
309
310impl<T: Bind> Bind for Option<T> {
311 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
312 match self {
313 Some(t) => t.bind(bindings),
314 None => Ok(()),
315 }
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use crate as hyperactor; use crate::Bind;
324 use crate::Unbind;
325 use crate::accum::ReducerSpec;
326 use crate::accum::StreamingReducerOpts;
327 use crate::testing::ids::test_port_id_with_pid;
328
329 #[derive(Debug, PartialEq, Serialize, Deserialize, typeuri::Named)]
331 struct MyReply(String);
332
333 #[derive(
335 Debug,
336 Clone,
337 PartialEq,
338 Serialize,
339 Deserialize,
340 typeuri::Named,
341 Bind,
342 Unbind
343 )]
344 struct MyMessage {
345 arg0: bool,
346 arg1: u32,
347 #[binding(include)]
348 reply0: reference::PortRef<String>,
349 #[binding(include)]
350 reply1: reference::PortRef<MyReply>,
351 }
352
353 #[test]
354 fn test_castable() {
355 let original_port0 =
356 reference::PortRef::attest(test_port_id_with_pid("world_0", "actor", 0, 123));
357 let original_port1 = reference::PortRef::attest_reducible(
358 test_port_id_with_pid("world_1", "actor1", 0, 456),
359 Some(ReducerSpec {
360 typehash: 123,
361 builder_params: None,
362 }),
363 StreamingReducerOpts::default(),
364 );
365 let my_message = MyMessage {
366 arg0: true,
367 arg1: 42,
368 reply0: original_port0.clone(),
369 reply1: original_port1.clone(),
370 };
371
372 let serialized_my_message = wirevalue::Any::serialize(&my_message).unwrap();
373
374 let mut erased = ErasedUnbound::try_from_message(my_message.clone()).unwrap();
376 assert_eq!(
377 erased,
378 ErasedUnbound {
379 message: serialized_my_message.clone(),
380 bindings: Bindings(
381 [
382 (
383 reference::UnboundPort::typehash(),
384 wirevalue::Any::serialize(&reference::UnboundPort::from(
385 &original_port0
386 ))
387 .unwrap(),
388 ),
389 (
390 reference::UnboundPort::typehash(),
391 wirevalue::Any::serialize(&reference::UnboundPort::from(
392 &original_port1
393 ))
394 .unwrap(),
395 ),
396 ]
397 .into_iter()
398 .collect()
399 ),
400 }
401 );
402
403 let new_port_id0 = test_port_id_with_pid("world_0", "comm", 0, 680);
405 assert_ne!(&new_port_id0, original_port0.port_id());
406 let new_port_id1 = test_port_id_with_pid("world_1", "comm", 0, 257);
407 assert_ne!(&new_port_id1, original_port1.port_id());
408
409 let mut new_ports = vec![&new_port_id0, &new_port_id1].into_iter();
410 erased
411 .visit_mut::<reference::UnboundPort>(|b| {
412 let port = new_ports.next().unwrap();
413 b.update(port.clone());
414 Ok(())
415 })
416 .unwrap();
417
418 let new_port0 = reference::PortRef::<String>::attest(new_port_id0);
419 let new_port1 = reference::PortRef::<MyReply>::attest_reducible(
420 new_port_id1,
421 Some(ReducerSpec {
422 typehash: 123,
423 builder_params: None,
424 }),
425 StreamingReducerOpts::default(),
426 );
427 let new_bindings = Bindings(
428 [
429 (
430 reference::UnboundPort::typehash(),
431 wirevalue::Any::serialize(&reference::UnboundPort::from(&new_port0)).unwrap(),
432 ),
433 (
434 reference::UnboundPort::typehash(),
435 wirevalue::Any::serialize(&reference::UnboundPort::from(&new_port1)).unwrap(),
436 ),
437 ]
438 .into_iter()
439 .collect(),
440 );
441 assert_eq!(
442 erased,
443 ErasedUnbound {
444 message: serialized_my_message.clone(),
445 bindings: new_bindings.clone(),
446 }
447 );
448
449 let unbound = erased.downcast::<MyMessage>().unwrap();
451 assert_eq!(
452 unbound,
453 Unbound {
454 message: my_message,
455 bindings: new_bindings,
456 }
457 );
458 let new_my_message = unbound.bind().unwrap();
459 assert_eq!(
460 new_my_message,
461 MyMessage {
462 arg0: true,
463 arg1: 42,
464 reply0: new_port0,
465 reply1: new_port1,
466 }
467 );
468 }
469}