1use std::collections::VecDeque;
34use std::marker::PhantomData;
35
36use serde::Deserialize;
37use serde::Serialize;
38use serde::de::DeserializeOwned;
39
40use crate as hyperactor;
41use crate::ActorRef;
42use crate::Mailbox;
43use crate::Named;
44use crate::RemoteHandles;
45use crate::RemoteMessage;
46use crate::actor::Referable;
47use crate::context;
48use crate::data::Serialized;
49
50pub trait Unbind: Sized {
55 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()>;
57}
58
59pub trait Bind: Sized {
62 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()>;
64}
65
66pub trait Castable: RemoteMessage + Bind + Unbind {}
69impl<T: RemoteMessage + Bind + Unbind> Castable for T {}
70
71#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
74pub struct Bindings(VecDeque<(u64, Serialized)>);
75
76impl Bindings {
77 pub fn push_back<T: Serialize + Named>(&mut self, value: &T) -> anyhow::Result<()> {
79 let ser = Serialized::serialize(value)?;
80 self.0.push_back((T::typehash(), ser));
81 Ok(())
82 }
83
84 pub fn pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<Option<T>> {
89 match self.0.pop_front() {
90 None => Ok(None),
91 Some((t, v)) => {
92 if t != T::typehash() {
93 anyhow::bail!(
94 "type mismatch: expected {} with hash {}, found {} in binding",
95 T::typename(),
96 T::typehash(),
97 t,
98 );
99 }
100 Ok(Some(v.deserialized::<T>()?))
101 }
102 }
103 }
104
105 pub fn try_pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<T> {
107 self.pop_front::<T>()?.ok_or_else(|| {
108 anyhow::anyhow!("expect a {} binding, but none was found", T::typename())
109 })
110 }
111
112 fn visit_mut<T: Serialize + DeserializeOwned + Named>(
113 &mut self,
114 mut f: impl FnMut(&mut T) -> anyhow::Result<()>,
115 ) -> anyhow::Result<()> {
116 for v in self.0.iter_mut() {
117 if v.0 == T::typehash() {
118 let mut t = v.1.deserialized::<T>()?;
119 f(&mut t)?;
120 v.1 = Serialized::serialize(&t)?;
121 }
122 }
123 Ok(())
124 }
125}
126
127#[derive(Debug, PartialEq)]
129pub struct Unbound<M> {
130 message: M,
131 bindings: Bindings,
132}
133
134impl<M> Unbound<M> {
135 pub fn new(message: M, bindings: Bindings) -> Self {
137 Self { message, bindings }
138 }
139
140 pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
143 &mut self,
144 f: impl FnMut(&mut T) -> anyhow::Result<()>,
145 ) -> anyhow::Result<()> {
146 self.bindings.visit_mut(f)
147 }
148}
149
150impl<M: Bind> Unbound<M> {
151 pub fn bind(mut self) -> anyhow::Result<M> {
153 self.message.bind(&mut self.bindings)?;
154 anyhow::ensure!(
155 self.bindings.0.is_empty(),
156 "there are still {} elements left in bindings",
157 self.bindings.0.len()
158 );
159 Ok(self.message)
160 }
161}
162
163impl<M: Unbind> Unbound<M> {
164 pub fn try_from_message(message: M) -> anyhow::Result<Self> {
168 let mut bindings = Bindings::default();
169 message.unbind(&mut bindings)?;
170 Ok(Unbound { message, bindings })
171 }
172}
173
174#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
176pub struct ErasedUnbound {
177 message: Serialized,
178 bindings: Bindings,
179}
180
181impl ErasedUnbound {
182 pub fn new(message: Serialized) -> Self {
184 Self {
185 message,
186 bindings: Bindings::default(),
187 }
188 }
189
190 pub fn try_from_message<T: Unbind + Serialize + Named>(msg: T) -> Result<Self, anyhow::Error> {
194 let unbound = Unbound::try_from_message(msg)?;
195 let serialized = Serialized::serialize(&unbound.message)?;
196 Ok(Self {
197 message: serialized,
198 bindings: unbound.bindings,
199 })
200 }
201
202 pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
205 &mut self,
206 f: impl FnMut(&mut T) -> anyhow::Result<()>,
207 ) -> anyhow::Result<()> {
208 self.bindings.visit_mut(f)
209 }
210
211 fn downcast<M: DeserializeOwned + Named>(self) -> anyhow::Result<Unbound<M>> {
212 let message: M = self.message.deserialized_unchecked()?;
213 Ok(Unbound {
214 message,
215 bindings: self.bindings,
216 })
217 }
218}
219
220#[derive(Debug, PartialEq, Serialize, Deserialize, Named)]
223#[serde(from = "ErasedUnbound")]
224pub struct IndexedErasedUnbound<M>(ErasedUnbound, PhantomData<M>);
225
226impl<M: DeserializeOwned + Named> IndexedErasedUnbound<M> {
227 pub(crate) fn downcast(self) -> anyhow::Result<Unbound<M>> {
228 self.0.downcast()
229 }
230}
231
232impl<M: Bind> IndexedErasedUnbound<M> {
233 pub fn bind_for_test_only<A, C>(
236 actor_ref: ActorRef<A>,
237 cx: C,
238 mailbox: Mailbox,
239 ) -> anyhow::Result<()>
240 where
241 A: Referable + RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
242 M: RemoteMessage,
243 C: context::Actor + Send + Sync + 'static,
244 {
245 let port_handle = mailbox.open_enqueue_port::<IndexedErasedUnbound<M>>({
246 move |_, m| {
247 let bound_m = m.downcast()?.bind()?;
248 actor_ref.send(&cx, bound_m)?;
249 Ok(())
250 }
251 });
252 port_handle.bind_to(IndexedErasedUnbound::<M>::port());
253 Ok(())
254 }
255}
256
257impl<M> From<ErasedUnbound> for IndexedErasedUnbound<M> {
258 fn from(erased: ErasedUnbound) -> Self {
259 Self(erased, PhantomData)
260 }
261}
262
263macro_rules! impl_bind_unbind_basic {
264 ($t:ty) => {
265 impl Bind for $t {
266 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
267 anyhow::ensure!(
268 bindings.0.is_empty(),
269 "bindings for {} should be empty, but found {} elements left",
270 stringify!($t),
271 bindings.0.len(),
272 );
273 Ok(())
274 }
275 }
276
277 impl Unbind for $t {
278 fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> {
279 Ok(())
280 }
281 }
282 };
283}
284
285impl_bind_unbind_basic!(());
286impl_bind_unbind_basic!(bool);
287impl_bind_unbind_basic!(i8);
288impl_bind_unbind_basic!(u8);
289impl_bind_unbind_basic!(i16);
290impl_bind_unbind_basic!(u16);
291impl_bind_unbind_basic!(i32);
292impl_bind_unbind_basic!(u32);
293impl_bind_unbind_basic!(i64);
294impl_bind_unbind_basic!(u64);
295impl_bind_unbind_basic!(i128);
296impl_bind_unbind_basic!(u128);
297impl_bind_unbind_basic!(isize);
298impl_bind_unbind_basic!(usize);
299impl_bind_unbind_basic!(String);
300
301impl<T: Unbind> Unbind for Option<T> {
302 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
303 match self {
304 Some(t) => t.unbind(bindings),
305 None => Ok(()),
306 }
307 }
308}
309
310impl<T: Bind> Bind for Option<T> {
311 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
312 match self {
313 Some(t) => t.bind(bindings),
314 None => Ok(()),
315 }
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use hyperactor::PortRef;
322 use hyperactor::id;
323
324 use super::*;
325 use crate::Bind;
326 use crate::Unbind;
327 use crate::accum::ReducerSpec;
328 use crate::reference::UnboundPort;
329
330 #[derive(Debug, PartialEq, Serialize, Deserialize, Named)]
332 struct MyReply(String);
333
334 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
336 struct MyMessage {
337 arg0: bool,
338 arg1: u32,
339 #[binding(include)]
340 reply0: PortRef<String>,
341 #[binding(include)]
342 reply1: PortRef<MyReply>,
343 }
344
345 #[test]
346 fn test_castable() {
347 let original_port0 = PortRef::attest(id!(world[0].actor[0][123]));
348 let original_port1 = PortRef::attest_reducible(
349 id!(world[1].actor1[0][456]),
350 Some(ReducerSpec {
351 typehash: 123,
352 builder_params: None,
353 }),
354 );
355 let my_message = MyMessage {
356 arg0: true,
357 arg1: 42,
358 reply0: original_port0.clone(),
359 reply1: original_port1.clone(),
360 };
361
362 let serialized_my_message = Serialized::serialize(&my_message).unwrap();
363
364 let mut erased = ErasedUnbound::try_from_message(my_message.clone()).unwrap();
366 assert_eq!(
367 erased,
368 ErasedUnbound {
369 message: serialized_my_message.clone(),
370 bindings: Bindings(
371 [
372 (
373 UnboundPort::typehash(),
374 Serialized::serialize(&UnboundPort::from(&original_port0)).unwrap(),
375 ),
376 (
377 UnboundPort::typehash(),
378 Serialized::serialize(&UnboundPort::from(&original_port1)).unwrap(),
379 ),
380 ]
381 .into_iter()
382 .collect()
383 ),
384 }
385 );
386
387 let new_port_id0 = id!(world[0].comm[0][680]);
389 assert_ne!(&new_port_id0, original_port0.port_id());
390 let new_port_id1 = id!(world[1].comm[0][257]);
391 assert_ne!(&new_port_id1, original_port1.port_id());
392
393 let mut new_ports = vec![&new_port_id0, &new_port_id1].into_iter();
394 erased
395 .visit_mut::<UnboundPort>(|b| {
396 let port = new_ports.next().unwrap();
397 b.update(port.clone());
398 Ok(())
399 })
400 .unwrap();
401
402 let new_port0 = PortRef::<String>::attest(new_port_id0);
403 let new_port1 = PortRef::<MyReply>::attest_reducible(
404 new_port_id1,
405 Some(ReducerSpec {
406 typehash: 123,
407 builder_params: None,
408 }),
409 );
410 let new_bindings = Bindings(
411 [
412 (
413 UnboundPort::typehash(),
414 Serialized::serialize(&UnboundPort::from(&new_port0)).unwrap(),
415 ),
416 (
417 UnboundPort::typehash(),
418 Serialized::serialize(&UnboundPort::from(&new_port1)).unwrap(),
419 ),
420 ]
421 .into_iter()
422 .collect(),
423 );
424 assert_eq!(
425 erased,
426 ErasedUnbound {
427 message: serialized_my_message.clone(),
428 bindings: new_bindings.clone(),
429 }
430 );
431
432 let unbound = erased.downcast::<MyMessage>().unwrap();
434 assert_eq!(
435 unbound,
436 Unbound {
437 message: my_message,
438 bindings: new_bindings,
439 }
440 );
441 let new_my_message = unbound.bind().unwrap();
442 assert_eq!(
443 new_my_message,
444 MyMessage {
445 arg0: true,
446 arg1: 42,
447 reply0: new_port0,
448 reply1: new_port1,
449 }
450 );
451 }
452}