1use std::collections::VecDeque;
34use std::marker::PhantomData;
35
36use serde::Deserialize;
37use serde::Serialize;
38use serde::de::DeserializeOwned;
39use typeuri::Named;
40
41use crate::ActorRef;
43use crate::Mailbox;
44use crate::RemoteHandles;
45use crate::RemoteMessage;
46use crate::actor::Referable;
47use crate::context;
48use crate::endpoint::Endpoint as _;
49
50pub trait Unbind: Sized {
55 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()>;
57}
58
59pub trait Bind: Sized {
62 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()>;
64}
65
66pub trait Castable: RemoteMessage + Bind + Unbind {}
69impl<T: RemoteMessage + Bind + Unbind> Castable for T {}
70
71#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
74pub struct Bindings(VecDeque<(u64, wirevalue::Any)>);
75
76impl Bindings {
77 pub fn push_back<T: Serialize + Named>(&mut self, value: &T) -> anyhow::Result<()> {
79 let ser = wirevalue::Any::serialize(value)?;
80 self.0.push_back((T::typehash(), ser));
81 Ok(())
82 }
83
84 pub fn pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<Option<T>> {
89 match self.0.pop_front() {
90 None => Ok(None),
91 Some((t, v)) => {
92 if t != T::typehash() {
93 anyhow::bail!(
94 "type mismatch: expected {} with hash {}, found {} in binding",
95 T::typename(),
96 T::typehash(),
97 t,
98 );
99 }
100 Ok(Some(v.deserialized::<T>()?))
101 }
102 }
103 }
104
105 pub fn try_pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<T> {
107 self.pop_front::<T>()?.ok_or_else(|| {
108 anyhow::anyhow!("expect a {} binding, but none was found", T::typename())
109 })
110 }
111
112 fn visit_mut<T: Serialize + DeserializeOwned + Named>(
113 &mut self,
114 mut f: impl FnMut(&mut T) -> anyhow::Result<()>,
115 ) -> anyhow::Result<()> {
116 for v in self.0.iter_mut() {
117 if v.0 == T::typehash() {
118 let mut t = v.1.deserialized::<T>()?;
119 f(&mut t)?;
120 v.1 = wirevalue::Any::serialize(&t)?;
121 }
122 }
123 Ok(())
124 }
125}
126
127#[derive(Debug, PartialEq)]
129pub struct Unbound<M> {
130 message: M,
131 bindings: Bindings,
132}
133
134impl<M> Unbound<M> {
135 pub fn new(message: M, bindings: Bindings) -> Self {
137 Self { message, bindings }
138 }
139
140 pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
143 &mut self,
144 f: impl FnMut(&mut T) -> anyhow::Result<()>,
145 ) -> anyhow::Result<()> {
146 self.bindings.visit_mut(f)
147 }
148}
149
150impl<M: Bind> Unbound<M> {
151 pub fn bind(mut self) -> anyhow::Result<M> {
153 self.message.bind(&mut self.bindings)?;
154 anyhow::ensure!(
155 self.bindings.0.is_empty(),
156 "there are still {} elements left in bindings",
157 self.bindings.0.len()
158 );
159 Ok(self.message)
160 }
161}
162
163impl<M: Unbind> Unbound<M> {
164 pub fn try_from_message(message: M) -> anyhow::Result<Self> {
168 let mut bindings = Bindings::default();
169 message.unbind(&mut bindings)?;
170 Ok(Unbound { message, bindings })
171 }
172}
173
174#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, typeuri::Named)]
176pub struct ErasedUnbound {
177 message: wirevalue::Any,
178 bindings: Bindings,
179}
180wirevalue::register_type!(ErasedUnbound);
181
182impl ErasedUnbound {
183 pub fn new(message: wirevalue::Any) -> Self {
185 Self {
186 message,
187 bindings: Bindings::default(),
188 }
189 }
190
191 pub fn message(&self) -> &wirevalue::Any {
193 &self.message
194 }
195
196 pub fn try_from_message<T: Unbind + Serialize + Named>(msg: T) -> Result<Self, anyhow::Error> {
200 let unbound = Unbound::try_from_message(msg)?;
201 let serialized = wirevalue::Any::serialize(&unbound.message)?;
202 Ok(Self {
203 message: serialized,
204 bindings: unbound.bindings,
205 })
206 }
207
208 pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
211 &mut self,
212 f: impl FnMut(&mut T) -> anyhow::Result<()>,
213 ) -> anyhow::Result<()> {
214 self.bindings.visit_mut(f)
215 }
216
217 fn downcast<M: DeserializeOwned + Named>(self) -> anyhow::Result<Unbound<M>> {
218 let message: M = self.message.deserialized_unchecked()?;
219 Ok(Unbound {
220 message,
221 bindings: self.bindings,
222 })
223 }
224}
225
226#[derive(Debug, PartialEq, Serialize, Deserialize, typeuri::Named)]
229#[serde(from = "ErasedUnbound")]
230pub struct IndexedErasedUnbound<M>(ErasedUnbound, PhantomData<M>);
231
232impl<M: DeserializeOwned + Named> IndexedErasedUnbound<M> {
233 pub(crate) fn downcast(self) -> anyhow::Result<Unbound<M>> {
234 self.0.downcast()
235 }
236
237 pub fn inner_any(&self) -> &wirevalue::Any {
239 self.0.message()
240 }
241}
242
243impl<M: Bind> IndexedErasedUnbound<M> {
244 pub fn bind_for_test_only<A, C>(
247 actor_ref: ActorRef<A>,
248 cx: C,
249 mailbox: Mailbox,
250 ) -> anyhow::Result<()>
251 where
252 A: Referable + RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
253 M: RemoteMessage,
254 C: context::Actor + Send + Sync + 'static,
255 {
256 let port_handle = mailbox.open_enqueue_port::<IndexedErasedUnbound<M>>({
257 move |_, m| {
258 let bound_m = m.downcast()?.bind()?;
259 (&actor_ref).post(&cx, bound_m);
260 Ok(())
261 }
262 });
263 port_handle.bind_handler_port();
264 Ok(())
265 }
266}
267
268impl<M> From<ErasedUnbound> for IndexedErasedUnbound<M> {
269 fn from(erased: ErasedUnbound) -> Self {
270 Self(erased, PhantomData)
271 }
272}
273
274macro_rules! impl_bind_unbind_basic {
275 ($t:ty) => {
276 impl Bind for $t {
277 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
278 anyhow::ensure!(
279 bindings.0.is_empty(),
280 "bindings for {} should be empty, but found {} elements left",
281 stringify!($t),
282 bindings.0.len(),
283 );
284 Ok(())
285 }
286 }
287
288 impl Unbind for $t {
289 fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> {
290 Ok(())
291 }
292 }
293 };
294}
295
296impl_bind_unbind_basic!(());
297impl_bind_unbind_basic!(bool);
298impl_bind_unbind_basic!(i8);
299impl_bind_unbind_basic!(u8);
300impl_bind_unbind_basic!(i16);
301impl_bind_unbind_basic!(u16);
302impl_bind_unbind_basic!(i32);
303impl_bind_unbind_basic!(u32);
304impl_bind_unbind_basic!(i64);
305impl_bind_unbind_basic!(u64);
306impl_bind_unbind_basic!(i128);
307impl_bind_unbind_basic!(u128);
308impl_bind_unbind_basic!(isize);
309impl_bind_unbind_basic!(usize);
310impl_bind_unbind_basic!(String);
311impl_bind_unbind_basic!(std::time::Duration);
312impl_bind_unbind_basic!(std::time::SystemTime);
313
314impl<T: Unbind> Unbind for Option<T> {
315 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
316 match self {
317 Some(t) => t.unbind(bindings),
318 None => Ok(()),
319 }
320 }
321}
322
323impl<T: Bind> Bind for Option<T> {
324 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
325 match self {
326 Some(t) => t.bind(bindings),
327 None => Ok(()),
328 }
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use crate as hyperactor; use crate::Bind;
337 use crate::PortRef;
338 use crate::Unbind;
339 use crate::UnboundPort;
340 use crate::accum::ReducerSpec;
341 use crate::accum::StreamingReducerOpts;
342 use crate::testing::ids::test_port_id;
343
344 #[derive(Debug, PartialEq, Serialize, Deserialize, typeuri::Named)]
346 struct MyReply(String);
347
348 #[derive(
350 Debug,
351 Clone,
352 PartialEq,
353 Serialize,
354 Deserialize,
355 typeuri::Named,
356 Bind,
357 Unbind
358 )]
359 struct MyMessage {
360 arg0: bool,
361 arg1: u32,
362 #[binding(include)]
363 reply0: PortRef<String>,
364 #[binding(include)]
365 reply1: PortRef<MyReply>,
366 }
367
368 #[test]
369 fn test_castable() {
370 let original_port0 = PortRef::attest(test_port_id("world_0", "actor", 123));
371 let original_port1 = PortRef::attest_reducible(
372 test_port_id("world_1", "actor1", 456),
373 Some(ReducerSpec {
374 typehash: 123,
375 builder_params: None,
376 }),
377 StreamingReducerOpts::default(),
378 );
379 let my_message = MyMessage {
380 arg0: true,
381 arg1: 42,
382 reply0: original_port0.clone(),
383 reply1: original_port1.clone(),
384 };
385
386 let serialized_my_message = wirevalue::Any::serialize(&my_message).unwrap();
387
388 let mut erased = ErasedUnbound::try_from_message(my_message.clone()).unwrap();
390 assert_eq!(
391 erased,
392 ErasedUnbound {
393 message: serialized_my_message.clone(),
394 bindings: Bindings(
395 [
396 (
397 UnboundPort::typehash(),
398 wirevalue::Any::serialize(&UnboundPort::from(&original_port0)).unwrap(),
399 ),
400 (
401 UnboundPort::typehash(),
402 wirevalue::Any::serialize(&UnboundPort::from(&original_port1)).unwrap(),
403 ),
404 ]
405 .into_iter()
406 .collect()
407 ),
408 }
409 );
410
411 let new_port_id0 = test_port_id("world_0", "comm", 680);
413 assert_ne!(&new_port_id0, original_port0.port_addr());
414 let new_port_id1 = test_port_id("world_1", "comm", 257);
415 assert_ne!(&new_port_id1, original_port1.port_addr());
416
417 let mut new_ports = vec![&new_port_id0, &new_port_id1].into_iter();
418 erased
419 .visit_mut::<UnboundPort>(|b| {
420 let port = new_ports.next().unwrap();
421 b.update(port.clone());
422 Ok(())
423 })
424 .unwrap();
425
426 let new_port0 = PortRef::<String>::attest(new_port_id0);
427 let new_port1 = PortRef::<MyReply>::attest_reducible(
428 new_port_id1,
429 Some(ReducerSpec {
430 typehash: 123,
431 builder_params: None,
432 }),
433 StreamingReducerOpts::default(),
434 );
435 let new_bindings = Bindings(
436 [
437 (
438 UnboundPort::typehash(),
439 wirevalue::Any::serialize(&UnboundPort::from(&new_port0)).unwrap(),
440 ),
441 (
442 UnboundPort::typehash(),
443 wirevalue::Any::serialize(&UnboundPort::from(&new_port1)).unwrap(),
444 ),
445 ]
446 .into_iter()
447 .collect(),
448 );
449 assert_eq!(
450 erased,
451 ErasedUnbound {
452 message: serialized_my_message.clone(),
453 bindings: new_bindings.clone(),
454 }
455 );
456
457 let unbound = erased.downcast::<MyMessage>().unwrap();
459 assert_eq!(
460 unbound,
461 Unbound {
462 message: my_message,
463 bindings: new_bindings,
464 }
465 );
466 let new_my_message = unbound.bind().unwrap();
467 assert_eq!(
468 new_my_message,
469 MyMessage {
470 arg0: true,
471 arg1: 42,
472 reply0: new_port0,
473 reply1: new_port1,
474 }
475 );
476 }
477}