hyperactor/
message.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module provides a framework for mutating serialized messages without
10//! the need to deserialize them. This capability is useful when sending messages
11//! to a remote destination throughout intermeidate nodes, where the intermediate
12//! nodes do not contain the message's type information.
13//!
14//! Briefly, it works by following these steps:
15//!
16//! 1. On the sender side, mutable information is extracted from the typed
17//!    message through [Unbind], and stored in a [Bindings] object. This object
18//!    is bundled with the serialized message in an [ErasedUnbound] object, which
19//!    is sent over the wire.
20//! 2. On intermediate nodes, the [ErasedUnbound] object is relayed. The
21//!    muation is applied on its bindings field, if needed.
22//! 3. One the receiver side, the [ErasedUnbound] object is received as
23//!    [IndexedErasedUnbound], where the type information is restored. Mutated
24//!    information contained in its bindings field is applied to the message
25//!    through [Bind], which results in the final typed message.
26//!
27//! One main use case of this framework is to mutate the reply ports of a
28//! muticast message, so the replies can be relayed through intermediate nodes,
29//! rather than directly sent to the original sender. Therefore, a [Castable]
30//! trait is defined, which collects requirements for message types using
31//! multicast.
32
33use std::collections::VecDeque;
34use std::marker::PhantomData;
35
36use serde::Deserialize;
37use serde::Serialize;
38use serde::de::DeserializeOwned;
39use typeuri::Named;
40
41// for macros
42use crate::Mailbox;
43use crate::RemoteHandles;
44use crate::RemoteMessage;
45use crate::actor::Referable;
46use crate::context;
47use crate::reference;
48
49/// An object `T` that is [`Unbind`] can extract a set of parameters from itself,
50/// and store in [`Bindings`]. The extracted parameters in [`Bindings`] can be
51/// independently manipulated, and then later reconstituted (rebound) into
52/// a `T`-typed object again.
53pub trait Unbind: Sized {
54    /// Extract parameters from itself and store them in bindings.
55    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()>;
56}
57
58/// An object `T` that is [`Bind`] can bind a set of externally provided
59/// parameters into itself.
60pub trait Bind: Sized {
61    /// Remove parameters from bindings, and use them to update itself.
62    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()>;
63}
64
65/// This trait collects the necessary requirements for messages that are can be
66/// cast.
67pub trait Castable: RemoteMessage + Bind + Unbind {}
68impl<T: RemoteMessage + Bind + Unbind> Castable for T {}
69
70/// Information extracted from a message through [Unbind], which can be merged
71/// back to the message through [Bind].
72#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
73pub struct Bindings(VecDeque<(u64, wirevalue::Any)>);
74
75impl Bindings {
76    /// Push a value into this bindings.
77    pub fn push_back<T: Serialize + Named>(&mut self, value: &T) -> anyhow::Result<()> {
78        let ser = wirevalue::Any::serialize(value)?;
79        self.0.push_back((T::typehash(), ser));
80        Ok(())
81    }
82
83    /// Removes the first pushed element in this bindings, deserialize it into
84    /// type T, and return it. Return [`None`] if this bindings is empty.
85    /// If the type of the first pushed element does not match T, an error is
86    /// returned.
87    pub fn pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<Option<T>> {
88        match self.0.pop_front() {
89            None => Ok(None),
90            Some((t, v)) => {
91                if t != T::typehash() {
92                    anyhow::bail!(
93                        "type mismatch: expected {} with hash {}, found {} in binding",
94                        T::typename(),
95                        T::typehash(),
96                        t,
97                    );
98                }
99                Ok(Some(v.deserialized::<T>()?))
100            }
101        }
102    }
103
104    /// Fallible version of [Bindings::pop_front].
105    pub fn try_pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<T> {
106        self.pop_front::<T>()?.ok_or_else(|| {
107            anyhow::anyhow!("expect a {} binding, but none was found", T::typename())
108        })
109    }
110
111    fn visit_mut<T: Serialize + DeserializeOwned + Named>(
112        &mut self,
113        mut f: impl FnMut(&mut T) -> anyhow::Result<()>,
114    ) -> anyhow::Result<()> {
115        for v in self.0.iter_mut() {
116            if v.0 == T::typehash() {
117                let mut t = v.1.deserialized::<T>()?;
118                f(&mut t)?;
119                v.1 = wirevalue::Any::serialize(&t)?;
120            }
121        }
122        Ok(())
123    }
124}
125
126/// An object contains a message, and its bindings extracted through [Unbind].
127#[derive(Debug, PartialEq)]
128pub struct Unbound<M> {
129    message: M,
130    bindings: Bindings,
131}
132
133impl<M> Unbound<M> {
134    /// Build a new object.
135    pub fn new(message: M, bindings: Bindings) -> Self {
136        Self { message, bindings }
137    }
138
139    /// Use the provided function to update values inside bindings in the same
140    /// order as they were pushed into bindings.
141    pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
142        &mut self,
143        f: impl FnMut(&mut T) -> anyhow::Result<()>,
144    ) -> anyhow::Result<()> {
145        self.bindings.visit_mut(f)
146    }
147}
148
149impl<M: Bind> Unbound<M> {
150    /// Bind its bindings to its message through [Bind], and return the result.
151    pub fn bind(mut self) -> anyhow::Result<M> {
152        self.message.bind(&mut self.bindings)?;
153        anyhow::ensure!(
154            self.bindings.0.is_empty(),
155            "there are still {} elements left in bindings",
156            self.bindings.0.len()
157        );
158        Ok(self.message)
159    }
160}
161
162impl<M: Unbind> Unbound<M> {
163    /// Create an object from a typed message.
164    // Note: cannot implement TryFrom<T> due to conflict with core crate's blanket impl.
165    // More can be found in this issue: https://github.com/rust-lang/rust/issues/50133
166    pub fn try_from_message(message: M) -> anyhow::Result<Self> {
167        let mut bindings = Bindings::default();
168        message.unbind(&mut bindings)?;
169        Ok(Unbound { message, bindings })
170    }
171}
172
173/// Unbound, with its message type M erased through serialization.
174#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, typeuri::Named)]
175pub struct ErasedUnbound {
176    message: wirevalue::Any,
177    bindings: Bindings,
178}
179wirevalue::register_type!(ErasedUnbound);
180
181impl ErasedUnbound {
182    /// Create an object directly from Any without binding.
183    pub fn new(message: wirevalue::Any) -> Self {
184        Self {
185            message,
186            bindings: Bindings::default(),
187        }
188    }
189
190    /// Access the inner serialized message.
191    pub fn message(&self) -> &wirevalue::Any {
192        &self.message
193    }
194
195    /// Create an object from a typed message.
196    // Note: cannot implement TryFrom<T> due to conflict with core crate's blanket impl.
197    // More can be found in this issue: https://github.com/rust-lang/rust/issues/50133
198    pub fn try_from_message<T: Unbind + Serialize + Named>(msg: T) -> Result<Self, anyhow::Error> {
199        let unbound = Unbound::try_from_message(msg)?;
200        let serialized = wirevalue::Any::serialize(&unbound.message)?;
201        Ok(Self {
202            message: serialized,
203            bindings: unbound.bindings,
204        })
205    }
206
207    /// Use the provided function to update values inside bindings in the same
208    /// order as they were pushed into bindings.
209    pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
210        &mut self,
211        f: impl FnMut(&mut T) -> anyhow::Result<()>,
212    ) -> anyhow::Result<()> {
213        self.bindings.visit_mut(f)
214    }
215
216    fn downcast<M: DeserializeOwned + Named>(self) -> anyhow::Result<Unbound<M>> {
217        let message: M = self.message.deserialized_unchecked()?;
218        Ok(Unbound {
219            message,
220            bindings: self.bindings,
221        })
222    }
223}
224
225/// Type used for indexing an erased unbound.
226/// Has the same serialized representation as `ErasedUnbound`.
227#[derive(Debug, PartialEq, Serialize, Deserialize, typeuri::Named)]
228#[serde(from = "ErasedUnbound")]
229pub struct IndexedErasedUnbound<M>(ErasedUnbound, PhantomData<M>);
230
231impl<M: DeserializeOwned + Named> IndexedErasedUnbound<M> {
232    pub(crate) fn downcast(self) -> anyhow::Result<Unbound<M>> {
233        self.0.downcast()
234    }
235
236    /// Access the inner serialized message.
237    pub fn inner_any(&self) -> &wirevalue::Any {
238        self.0.message()
239    }
240}
241
242impl<M: Bind> IndexedErasedUnbound<M> {
243    /// Used in unit tests to bind CastBlobT<M> to the given actor. Do not use in
244    /// production.
245    pub fn bind_for_test_only<A, C>(
246        actor_ref: reference::ActorRef<A>,
247        cx: C,
248        mailbox: Mailbox,
249    ) -> anyhow::Result<()>
250    where
251        A: Referable + RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
252        M: RemoteMessage,
253        C: context::Actor + Send + Sync + 'static,
254    {
255        let port_handle = mailbox.open_enqueue_port::<IndexedErasedUnbound<M>>({
256            move |_, m| {
257                let bound_m = m.downcast()?.bind()?;
258                actor_ref.send(&cx, bound_m)?;
259                Ok(())
260            }
261        });
262        port_handle.bind_actor_port();
263        Ok(())
264    }
265}
266
267impl<M> From<ErasedUnbound> for IndexedErasedUnbound<M> {
268    fn from(erased: ErasedUnbound) -> Self {
269        Self(erased, PhantomData)
270    }
271}
272
273macro_rules! impl_bind_unbind_basic {
274    ($t:ty) => {
275        impl Bind for $t {
276            fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
277                anyhow::ensure!(
278                    bindings.0.is_empty(),
279                    "bindings for {} should be empty, but found {} elements left",
280                    stringify!($t),
281                    bindings.0.len(),
282                );
283                Ok(())
284            }
285        }
286
287        impl Unbind for $t {
288            fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> {
289                Ok(())
290            }
291        }
292    };
293}
294
295impl_bind_unbind_basic!(());
296impl_bind_unbind_basic!(bool);
297impl_bind_unbind_basic!(i8);
298impl_bind_unbind_basic!(u8);
299impl_bind_unbind_basic!(i16);
300impl_bind_unbind_basic!(u16);
301impl_bind_unbind_basic!(i32);
302impl_bind_unbind_basic!(u32);
303impl_bind_unbind_basic!(i64);
304impl_bind_unbind_basic!(u64);
305impl_bind_unbind_basic!(i128);
306impl_bind_unbind_basic!(u128);
307impl_bind_unbind_basic!(isize);
308impl_bind_unbind_basic!(usize);
309impl_bind_unbind_basic!(String);
310impl_bind_unbind_basic!(std::time::Duration);
311impl_bind_unbind_basic!(std::time::SystemTime);
312
313impl<T: Unbind> Unbind for Option<T> {
314    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
315        match self {
316            Some(t) => t.unbind(bindings),
317            None => Ok(()),
318        }
319    }
320}
321
322impl<T: Bind> Bind for Option<T> {
323    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
324        match self {
325            Some(t) => t.bind(bindings),
326            None => Ok(()),
327        }
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use crate as hyperactor; // for macros
335    use crate::Bind;
336    use crate::Unbind;
337    use crate::accum::ReducerSpec;
338    use crate::accum::StreamingReducerOpts;
339    use crate::testing::ids::test_port_id_with_pid;
340
341    // Used to demonstrate a user defined reply type.
342    #[derive(Debug, PartialEq, Serialize, Deserialize, typeuri::Named)]
343    struct MyReply(String);
344
345    // Used to demonstrate a two-way message type.
346    #[derive(
347        Debug,
348        Clone,
349        PartialEq,
350        Serialize,
351        Deserialize,
352        typeuri::Named,
353        Bind,
354        Unbind
355    )]
356    struct MyMessage {
357        arg0: bool,
358        arg1: u32,
359        #[binding(include)]
360        reply0: reference::PortRef<String>,
361        #[binding(include)]
362        reply1: reference::PortRef<MyReply>,
363    }
364
365    #[test]
366    fn test_castable() {
367        let original_port0 =
368            reference::PortRef::attest(test_port_id_with_pid("world_0", "actor", 0, 123));
369        let original_port1 = reference::PortRef::attest_reducible(
370            test_port_id_with_pid("world_1", "actor1", 0, 456),
371            Some(ReducerSpec {
372                typehash: 123,
373                builder_params: None,
374            }),
375            StreamingReducerOpts::default(),
376        );
377        let my_message = MyMessage {
378            arg0: true,
379            arg1: 42,
380            reply0: original_port0.clone(),
381            reply1: original_port1.clone(),
382        };
383
384        let serialized_my_message = wirevalue::Any::serialize(&my_message).unwrap();
385
386        // convert to ErasedUnbound
387        let mut erased = ErasedUnbound::try_from_message(my_message.clone()).unwrap();
388        assert_eq!(
389            erased,
390            ErasedUnbound {
391                message: serialized_my_message.clone(),
392                bindings: Bindings(
393                    [
394                        (
395                            reference::UnboundPort::typehash(),
396                            wirevalue::Any::serialize(&reference::UnboundPort::from(
397                                &original_port0
398                            ))
399                            .unwrap(),
400                        ),
401                        (
402                            reference::UnboundPort::typehash(),
403                            wirevalue::Any::serialize(&reference::UnboundPort::from(
404                                &original_port1
405                            ))
406                            .unwrap(),
407                        ),
408                    ]
409                    .into_iter()
410                    .collect()
411                ),
412            }
413        );
414
415        // Modify the port in the erased
416        let new_port_id0 = test_port_id_with_pid("world_0", "comm", 0, 680);
417        assert_ne!(&new_port_id0, original_port0.port_id());
418        let new_port_id1 = test_port_id_with_pid("world_1", "comm", 0, 257);
419        assert_ne!(&new_port_id1, original_port1.port_id());
420
421        let mut new_ports = vec![&new_port_id0, &new_port_id1].into_iter();
422        erased
423            .visit_mut::<reference::UnboundPort>(|b| {
424                let port = new_ports.next().unwrap();
425                b.update(port.clone());
426                Ok(())
427            })
428            .unwrap();
429
430        let new_port0 = reference::PortRef::<String>::attest(new_port_id0);
431        let new_port1 = reference::PortRef::<MyReply>::attest_reducible(
432            new_port_id1,
433            Some(ReducerSpec {
434                typehash: 123,
435                builder_params: None,
436            }),
437            StreamingReducerOpts::default(),
438        );
439        let new_bindings = Bindings(
440            [
441                (
442                    reference::UnboundPort::typehash(),
443                    wirevalue::Any::serialize(&reference::UnboundPort::from(&new_port0)).unwrap(),
444                ),
445                (
446                    reference::UnboundPort::typehash(),
447                    wirevalue::Any::serialize(&reference::UnboundPort::from(&new_port1)).unwrap(),
448                ),
449            ]
450            .into_iter()
451            .collect(),
452        );
453        assert_eq!(
454            erased,
455            ErasedUnbound {
456                message: serialized_my_message.clone(),
457                bindings: new_bindings.clone(),
458            }
459        );
460
461        // convert back to MyMessage
462        let unbound = erased.downcast::<MyMessage>().unwrap();
463        assert_eq!(
464            unbound,
465            Unbound {
466                message: my_message,
467                bindings: new_bindings,
468            }
469        );
470        let new_my_message = unbound.bind().unwrap();
471        assert_eq!(
472            new_my_message,
473            MyMessage {
474                arg0: true,
475                arg1: 42,
476                reply0: new_port0,
477                reply1: new_port1,
478            }
479        );
480    }
481}