hyperactor/
message.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module provides a framework for mutating serialized messages without
10//! the need to deserialize them. This capability is useful when sending messages
11//! to a remote destination throughout intermeidate nodes, where the intermediate
12//! nodes do not contain the message's type information.
13//!
14//! Briefly, it works by following these steps:
15//!
16//! 1. On the sender side, mutable information is extracted from the typed
17//!    message through [Unbind], and stored in a [Bindings] object. This object
18//!    is bundled with the serialized message in an [ErasedUnbound] object, which
19//!    is sent over the wire.
20//! 2. On intermediate nodes, the [ErasedUnbound] object is relayed. The
21//!    muation is applied on its bindings field, if needed.
22//! 3. One the receiver side, the [ErasedUnbound] object is received as
23//!    [IndexedErasedUnbound], where the type information is restored. Mutated
24//!    information contained in its bindings field is applied to the message
25//!    through [Bind], which results in the final typed message.
26//!
27//! One main use case of this framework is to mutate the reply ports of a
28//! muticast message, so the replies can be relayed through intermediate nodes,
29//! rather than directly sent to the original sender. Therefore, a [Castable]
30//! trait is defined, which collects requirements for message types using
31//! multicast.
32
33use std::collections::VecDeque;
34use std::marker::PhantomData;
35
36use serde::Deserialize;
37use serde::Serialize;
38use serde::de::DeserializeOwned;
39use typeuri::Named;
40
41// for macros
42use crate::ActorRef;
43use crate::Mailbox;
44use crate::RemoteHandles;
45use crate::RemoteMessage;
46use crate::actor::Referable;
47use crate::context;
48
49/// An object `T` that is [`Unbind`] can extract a set of parameters from itself,
50/// and store in [`Bindings`]. The extracted parameters in [`Bindings`] can be
51/// independently manipulated, and then later reconstituted (rebound) into
52/// a `T`-typed object again.
53pub trait Unbind: Sized {
54    /// Extract parameters from itself and store them in bindings.
55    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()>;
56}
57
58/// An object `T` that is [`Bind`] can bind a set of externally provided
59/// parameters into itself.
60pub trait Bind: Sized {
61    /// Remove parameters from bindings, and use them to update itself.
62    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()>;
63}
64
65/// This trait collects the necessary requirements for messages that are can be
66/// cast.
67pub trait Castable: RemoteMessage + Bind + Unbind {}
68impl<T: RemoteMessage + Bind + Unbind> Castable for T {}
69
70/// Information extracted from a message through [Unbind], which can be merged
71/// back to the message through [Bind].
72#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
73pub struct Bindings(VecDeque<(u64, wirevalue::Any)>);
74
75impl Bindings {
76    /// Push a value into this bindings.
77    pub fn push_back<T: Serialize + Named>(&mut self, value: &T) -> anyhow::Result<()> {
78        let ser = wirevalue::Any::serialize(value)?;
79        self.0.push_back((T::typehash(), ser));
80        Ok(())
81    }
82
83    /// Removes the first pushed element in this bindings, deserialize it into
84    /// type T, and return it. Return [`None`] if this bindings is empty.
85    /// If the type of the first pushed element does not match T, an error is
86    /// returned.
87    pub fn pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<Option<T>> {
88        match self.0.pop_front() {
89            None => Ok(None),
90            Some((t, v)) => {
91                if t != T::typehash() {
92                    anyhow::bail!(
93                        "type mismatch: expected {} with hash {}, found {} in binding",
94                        T::typename(),
95                        T::typehash(),
96                        t,
97                    );
98                }
99                Ok(Some(v.deserialized::<T>()?))
100            }
101        }
102    }
103
104    /// Fallible version of [Bindings::pop_front].
105    pub fn try_pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<T> {
106        self.pop_front::<T>()?.ok_or_else(|| {
107            anyhow::anyhow!("expect a {} binding, but none was found", T::typename())
108        })
109    }
110
111    fn visit_mut<T: Serialize + DeserializeOwned + Named>(
112        &mut self,
113        mut f: impl FnMut(&mut T) -> anyhow::Result<()>,
114    ) -> anyhow::Result<()> {
115        for v in self.0.iter_mut() {
116            if v.0 == T::typehash() {
117                let mut t = v.1.deserialized::<T>()?;
118                f(&mut t)?;
119                v.1 = wirevalue::Any::serialize(&t)?;
120            }
121        }
122        Ok(())
123    }
124}
125
126/// An object contains a message, and its bindings extracted through [Unbind].
127#[derive(Debug, PartialEq)]
128pub struct Unbound<M> {
129    message: M,
130    bindings: Bindings,
131}
132
133impl<M> Unbound<M> {
134    /// Build a new object.
135    pub fn new(message: M, bindings: Bindings) -> Self {
136        Self { message, bindings }
137    }
138
139    /// Use the provided function to update values inside bindings in the same
140    /// order as they were pushed into bindings.
141    pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
142        &mut self,
143        f: impl FnMut(&mut T) -> anyhow::Result<()>,
144    ) -> anyhow::Result<()> {
145        self.bindings.visit_mut(f)
146    }
147}
148
149impl<M: Bind> Unbound<M> {
150    /// Bind its bindings to its message through [Bind], and return the result.
151    pub fn bind(mut self) -> anyhow::Result<M> {
152        self.message.bind(&mut self.bindings)?;
153        anyhow::ensure!(
154            self.bindings.0.is_empty(),
155            "there are still {} elements left in bindings",
156            self.bindings.0.len()
157        );
158        Ok(self.message)
159    }
160}
161
162impl<M: Unbind> Unbound<M> {
163    /// Create an object from a typed message.
164    // Note: cannot implement TryFrom<T> due to conflict with core crate's blanket impl.
165    // More can be found in this issue: https://github.com/rust-lang/rust/issues/50133
166    pub fn try_from_message(message: M) -> anyhow::Result<Self> {
167        let mut bindings = Bindings::default();
168        message.unbind(&mut bindings)?;
169        Ok(Unbound { message, bindings })
170    }
171}
172
173/// Unbound, with its message type M erased through serialization.
174#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, typeuri::Named)]
175pub struct ErasedUnbound {
176    message: wirevalue::Any,
177    bindings: Bindings,
178}
179wirevalue::register_type!(ErasedUnbound);
180
181impl ErasedUnbound {
182    /// Create an object directly from Any without binding.
183    pub fn new(message: wirevalue::Any) -> Self {
184        Self {
185            message,
186            bindings: Bindings::default(),
187        }
188    }
189
190    /// Create an object from a typed message.
191    // Note: cannot implement TryFrom<T> due to conflict with core crate's blanket impl.
192    // More can be found in this issue: https://github.com/rust-lang/rust/issues/50133
193    pub fn try_from_message<T: Unbind + Serialize + Named>(msg: T) -> Result<Self, anyhow::Error> {
194        let unbound = Unbound::try_from_message(msg)?;
195        let serialized = wirevalue::Any::serialize(&unbound.message)?;
196        Ok(Self {
197            message: serialized,
198            bindings: unbound.bindings,
199        })
200    }
201
202    /// Use the provided function to update values inside bindings in the same
203    /// order as they were pushed into bindings.
204    pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
205        &mut self,
206        f: impl FnMut(&mut T) -> anyhow::Result<()>,
207    ) -> anyhow::Result<()> {
208        self.bindings.visit_mut(f)
209    }
210
211    fn downcast<M: DeserializeOwned + Named>(self) -> anyhow::Result<Unbound<M>> {
212        let message: M = self.message.deserialized_unchecked()?;
213        Ok(Unbound {
214            message,
215            bindings: self.bindings,
216        })
217    }
218}
219
220/// Type used for indexing an erased unbound.
221/// Has the same serialized representation as `ErasedUnbound`.
222#[derive(Debug, PartialEq, Serialize, Deserialize, typeuri::Named)]
223#[serde(from = "ErasedUnbound")]
224pub struct IndexedErasedUnbound<M>(ErasedUnbound, PhantomData<M>);
225
226impl<M: DeserializeOwned + Named> IndexedErasedUnbound<M> {
227    pub(crate) fn downcast(self) -> anyhow::Result<Unbound<M>> {
228        self.0.downcast()
229    }
230}
231
232impl<M: Bind> IndexedErasedUnbound<M> {
233    /// Used in unit tests to bind CastBlobT<M> to the given actor. Do not use in
234    /// production.
235    pub fn bind_for_test_only<A, C>(
236        actor_ref: ActorRef<A>,
237        cx: C,
238        mailbox: Mailbox,
239    ) -> anyhow::Result<()>
240    where
241        A: Referable + RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
242        M: RemoteMessage,
243        C: context::Actor + Send + Sync + 'static,
244    {
245        let port_handle = mailbox.open_enqueue_port::<IndexedErasedUnbound<M>>({
246            move |_, m| {
247                let bound_m = m.downcast()?.bind()?;
248                actor_ref.send(&cx, bound_m)?;
249                Ok(())
250            }
251        });
252        port_handle.bind_actor_port();
253        Ok(())
254    }
255}
256
257impl<M> From<ErasedUnbound> for IndexedErasedUnbound<M> {
258    fn from(erased: ErasedUnbound) -> Self {
259        Self(erased, PhantomData)
260    }
261}
262
263macro_rules! impl_bind_unbind_basic {
264    ($t:ty) => {
265        impl Bind for $t {
266            fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
267                anyhow::ensure!(
268                    bindings.0.is_empty(),
269                    "bindings for {} should be empty, but found {} elements left",
270                    stringify!($t),
271                    bindings.0.len(),
272                );
273                Ok(())
274            }
275        }
276
277        impl Unbind for $t {
278            fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> {
279                Ok(())
280            }
281        }
282    };
283}
284
285impl_bind_unbind_basic!(());
286impl_bind_unbind_basic!(bool);
287impl_bind_unbind_basic!(i8);
288impl_bind_unbind_basic!(u8);
289impl_bind_unbind_basic!(i16);
290impl_bind_unbind_basic!(u16);
291impl_bind_unbind_basic!(i32);
292impl_bind_unbind_basic!(u32);
293impl_bind_unbind_basic!(i64);
294impl_bind_unbind_basic!(u64);
295impl_bind_unbind_basic!(i128);
296impl_bind_unbind_basic!(u128);
297impl_bind_unbind_basic!(isize);
298impl_bind_unbind_basic!(usize);
299impl_bind_unbind_basic!(String);
300
301impl<T: Unbind> Unbind for Option<T> {
302    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
303        match self {
304            Some(t) => t.unbind(bindings),
305            None => Ok(()),
306        }
307    }
308}
309
310impl<T: Bind> Bind for Option<T> {
311    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
312        match self {
313            Some(t) => t.bind(bindings),
314            None => Ok(()),
315        }
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use crate as hyperactor; // for macros
323    use crate::Bind;
324    use crate::PortRef;
325    use crate::Unbind;
326    use crate::accum::ReducerSpec;
327    use crate::id;
328    use crate::reference::UnboundPort;
329
330    // Used to demonstrate a user defined reply type.
331    #[derive(Debug, PartialEq, Serialize, Deserialize, typeuri::Named)]
332    struct MyReply(String);
333
334    // Used to demonstrate a two-way message type.
335    #[derive(
336        Debug,
337        Clone,
338        PartialEq,
339        Serialize,
340        Deserialize,
341        typeuri::Named,
342        Bind,
343        Unbind
344    )]
345    struct MyMessage {
346        arg0: bool,
347        arg1: u32,
348        #[binding(include)]
349        reply0: PortRef<String>,
350        #[binding(include)]
351        reply1: PortRef<MyReply>,
352    }
353
354    #[test]
355    fn test_castable() {
356        let original_port0 = PortRef::attest(id!(world[0].actor[0][123]));
357        let original_port1 = PortRef::attest_reducible(
358            id!(world[1].actor1[0][456]),
359            Some(ReducerSpec {
360                typehash: 123,
361                builder_params: None,
362            }),
363            None,
364        );
365        let my_message = MyMessage {
366            arg0: true,
367            arg1: 42,
368            reply0: original_port0.clone(),
369            reply1: original_port1.clone(),
370        };
371
372        let serialized_my_message = wirevalue::Any::serialize(&my_message).unwrap();
373
374        // convert to ErasedUnbound
375        let mut erased = ErasedUnbound::try_from_message(my_message.clone()).unwrap();
376        assert_eq!(
377            erased,
378            ErasedUnbound {
379                message: serialized_my_message.clone(),
380                bindings: Bindings(
381                    [
382                        (
383                            UnboundPort::typehash(),
384                            wirevalue::Any::serialize(&UnboundPort::from(&original_port0)).unwrap(),
385                        ),
386                        (
387                            UnboundPort::typehash(),
388                            wirevalue::Any::serialize(&UnboundPort::from(&original_port1)).unwrap(),
389                        ),
390                    ]
391                    .into_iter()
392                    .collect()
393                ),
394            }
395        );
396
397        // Modify the port in the erased
398        let new_port_id0 = id!(world[0].comm[0][680]);
399        assert_ne!(&new_port_id0, original_port0.port_id());
400        let new_port_id1 = id!(world[1].comm[0][257]);
401        assert_ne!(&new_port_id1, original_port1.port_id());
402
403        let mut new_ports = vec![&new_port_id0, &new_port_id1].into_iter();
404        erased
405            .visit_mut::<UnboundPort>(|b| {
406                let port = new_ports.next().unwrap();
407                b.update(port.clone());
408                Ok(())
409            })
410            .unwrap();
411
412        let new_port0 = PortRef::<String>::attest(new_port_id0);
413        let new_port1 = PortRef::<MyReply>::attest_reducible(
414            new_port_id1,
415            Some(ReducerSpec {
416                typehash: 123,
417                builder_params: None,
418            }),
419            None,
420        );
421        let new_bindings = Bindings(
422            [
423                (
424                    UnboundPort::typehash(),
425                    wirevalue::Any::serialize(&UnboundPort::from(&new_port0)).unwrap(),
426                ),
427                (
428                    UnboundPort::typehash(),
429                    wirevalue::Any::serialize(&UnboundPort::from(&new_port1)).unwrap(),
430                ),
431            ]
432            .into_iter()
433            .collect(),
434        );
435        assert_eq!(
436            erased,
437            ErasedUnbound {
438                message: serialized_my_message.clone(),
439                bindings: new_bindings.clone(),
440            }
441        );
442
443        // convert back to MyMessage
444        let unbound = erased.downcast::<MyMessage>().unwrap();
445        assert_eq!(
446            unbound,
447            Unbound {
448                message: my_message,
449                bindings: new_bindings,
450            }
451        );
452        let new_my_message = unbound.bind().unwrap();
453        assert_eq!(
454            new_my_message,
455            MyMessage {
456                arg0: true,
457                arg1: 42,
458                reply0: new_port0,
459                reply1: new_port1,
460            }
461        );
462    }
463}