Skip to main content

hyperactor/
message.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module provides a framework for mutating serialized messages without
10//! the need to deserialize them. This capability is useful when sending messages
11//! to a remote destination throughout intermeidate nodes, where the intermediate
12//! nodes do not contain the message's type information.
13//!
14//! Briefly, it works by following these steps:
15//!
16//! 1. On the sender side, mutable information is extracted from the typed
17//!    message through [Unbind], and stored in a [Bindings] object. This object
18//!    is bundled with the serialized message in an [ErasedUnbound] object, which
19//!    is sent over the wire.
20//! 2. On intermediate nodes, the [ErasedUnbound] object is relayed. The
21//!    muation is applied on its bindings field, if needed.
22//! 3. One the receiver side, the [ErasedUnbound] object is received as
23//!    [IndexedErasedUnbound], where the type information is restored. Mutated
24//!    information contained in its bindings field is applied to the message
25//!    through [Bind], which results in the final typed message.
26//!
27//! One main use case of this framework is to mutate the reply ports of a
28//! muticast message, so the replies can be relayed through intermediate nodes,
29//! rather than directly sent to the original sender. Therefore, a [Castable]
30//! trait is defined, which collects requirements for message types using
31//! multicast.
32
33use std::collections::VecDeque;
34use std::marker::PhantomData;
35
36use serde::Deserialize;
37use serde::Serialize;
38use serde::de::DeserializeOwned;
39use typeuri::Named;
40
41// for macros
42use crate::ActorRef;
43use crate::Mailbox;
44use crate::RemoteHandles;
45use crate::RemoteMessage;
46use crate::actor::Referable;
47use crate::context;
48use crate::endpoint::Endpoint as _;
49
50/// An object `T` that is [`Unbind`] can extract a set of parameters from itself,
51/// and store in [`Bindings`]. The extracted parameters in [`Bindings`] can be
52/// independently manipulated, and then later reconstituted (rebound) into
53/// a `T`-typed object again.
54pub trait Unbind: Sized {
55    /// Extract parameters from itself and store them in bindings.
56    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()>;
57}
58
59/// An object `T` that is [`Bind`] can bind a set of externally provided
60/// parameters into itself.
61pub trait Bind: Sized {
62    /// Remove parameters from bindings, and use them to update itself.
63    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()>;
64}
65
66/// This trait collects the necessary requirements for messages that are can be
67/// cast.
68pub trait Castable: RemoteMessage + Bind + Unbind {}
69impl<T: RemoteMessage + Bind + Unbind> Castable for T {}
70
71/// Information extracted from a message through [Unbind], which can be merged
72/// back to the message through [Bind].
73#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
74pub struct Bindings(VecDeque<(u64, wirevalue::Any)>);
75
76impl Bindings {
77    /// Push a value into this bindings.
78    pub fn push_back<T: Serialize + Named>(&mut self, value: &T) -> anyhow::Result<()> {
79        let ser = wirevalue::Any::serialize(value)?;
80        self.0.push_back((T::typehash(), ser));
81        Ok(())
82    }
83
84    /// Removes the first pushed element in this bindings, deserialize it into
85    /// type T, and return it. Return [`None`] if this bindings is empty.
86    /// If the type of the first pushed element does not match T, an error is
87    /// returned.
88    pub fn pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<Option<T>> {
89        match self.0.pop_front() {
90            None => Ok(None),
91            Some((t, v)) => {
92                if t != T::typehash() {
93                    anyhow::bail!(
94                        "type mismatch: expected {} with hash {}, found {} in binding",
95                        T::typename(),
96                        T::typehash(),
97                        t,
98                    );
99                }
100                Ok(Some(v.deserialized::<T>()?))
101            }
102        }
103    }
104
105    /// Fallible version of [Bindings::pop_front].
106    pub fn try_pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<T> {
107        self.pop_front::<T>()?.ok_or_else(|| {
108            anyhow::anyhow!("expect a {} binding, but none was found", T::typename())
109        })
110    }
111
112    fn visit_mut<T: Serialize + DeserializeOwned + Named>(
113        &mut self,
114        mut f: impl FnMut(&mut T) -> anyhow::Result<()>,
115    ) -> anyhow::Result<()> {
116        for v in self.0.iter_mut() {
117            if v.0 == T::typehash() {
118                let mut t = v.1.deserialized::<T>()?;
119                f(&mut t)?;
120                v.1 = wirevalue::Any::serialize(&t)?;
121            }
122        }
123        Ok(())
124    }
125}
126
127/// An object contains a message, and its bindings extracted through [Unbind].
128#[derive(Debug, PartialEq)]
129pub struct Unbound<M> {
130    message: M,
131    bindings: Bindings,
132}
133
134impl<M> Unbound<M> {
135    /// Build a new object.
136    pub fn new(message: M, bindings: Bindings) -> Self {
137        Self { message, bindings }
138    }
139
140    /// Use the provided function to update values inside bindings in the same
141    /// order as they were pushed into bindings.
142    pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
143        &mut self,
144        f: impl FnMut(&mut T) -> anyhow::Result<()>,
145    ) -> anyhow::Result<()> {
146        self.bindings.visit_mut(f)
147    }
148}
149
150impl<M: Bind> Unbound<M> {
151    /// Bind its bindings to its message through [Bind], and return the result.
152    pub fn bind(mut self) -> anyhow::Result<M> {
153        self.message.bind(&mut self.bindings)?;
154        anyhow::ensure!(
155            self.bindings.0.is_empty(),
156            "there are still {} elements left in bindings",
157            self.bindings.0.len()
158        );
159        Ok(self.message)
160    }
161}
162
163impl<M: Unbind> Unbound<M> {
164    /// Create an object from a typed message.
165    // Note: cannot implement TryFrom<T> due to conflict with core crate's blanket impl.
166    // More can be found in this issue: https://github.com/rust-lang/rust/issues/50133
167    pub fn try_from_message(message: M) -> anyhow::Result<Self> {
168        let mut bindings = Bindings::default();
169        message.unbind(&mut bindings)?;
170        Ok(Unbound { message, bindings })
171    }
172}
173
174/// Unbound, with its message type M erased through serialization.
175#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, typeuri::Named)]
176pub struct ErasedUnbound {
177    message: wirevalue::Any,
178    bindings: Bindings,
179}
180wirevalue::register_type!(ErasedUnbound);
181
182impl ErasedUnbound {
183    /// Create an object directly from Any without binding.
184    pub fn new(message: wirevalue::Any) -> Self {
185        Self {
186            message,
187            bindings: Bindings::default(),
188        }
189    }
190
191    /// Access the inner serialized message.
192    pub fn message(&self) -> &wirevalue::Any {
193        &self.message
194    }
195
196    /// Create an object from a typed message.
197    // Note: cannot implement TryFrom<T> due to conflict with core crate's blanket impl.
198    // More can be found in this issue: https://github.com/rust-lang/rust/issues/50133
199    pub fn try_from_message<T: Unbind + Serialize + Named>(msg: T) -> Result<Self, anyhow::Error> {
200        let unbound = Unbound::try_from_message(msg)?;
201        let serialized = wirevalue::Any::serialize(&unbound.message)?;
202        Ok(Self {
203            message: serialized,
204            bindings: unbound.bindings,
205        })
206    }
207
208    /// Use the provided function to update values inside bindings in the same
209    /// order as they were pushed into bindings.
210    pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
211        &mut self,
212        f: impl FnMut(&mut T) -> anyhow::Result<()>,
213    ) -> anyhow::Result<()> {
214        self.bindings.visit_mut(f)
215    }
216
217    fn downcast<M: DeserializeOwned + Named>(self) -> anyhow::Result<Unbound<M>> {
218        let message: M = self.message.deserialized_unchecked()?;
219        Ok(Unbound {
220            message,
221            bindings: self.bindings,
222        })
223    }
224}
225
226/// Type used for indexing an erased unbound.
227/// Has the same serialized representation as `ErasedUnbound`.
228#[derive(Debug, PartialEq, Serialize, Deserialize, typeuri::Named)]
229#[serde(from = "ErasedUnbound")]
230pub struct IndexedErasedUnbound<M>(ErasedUnbound, PhantomData<M>);
231
232impl<M: DeserializeOwned + Named> IndexedErasedUnbound<M> {
233    pub(crate) fn downcast(self) -> anyhow::Result<Unbound<M>> {
234        self.0.downcast()
235    }
236
237    /// Access the inner serialized message.
238    pub fn inner_any(&self) -> &wirevalue::Any {
239        self.0.message()
240    }
241}
242
243impl<M: Bind> IndexedErasedUnbound<M> {
244    /// Used in unit tests to bind CastBlobT<M> to the given actor. Do not use in
245    /// production.
246    pub fn bind_for_test_only<A, C>(
247        actor_ref: ActorRef<A>,
248        cx: C,
249        mailbox: Mailbox,
250    ) -> anyhow::Result<()>
251    where
252        A: Referable + RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
253        M: RemoteMessage,
254        C: context::Actor + Send + Sync + 'static,
255    {
256        let port_handle = mailbox.open_enqueue_port::<IndexedErasedUnbound<M>>({
257            move |_, m| {
258                let bound_m = m.downcast()?.bind()?;
259                (&actor_ref).post(&cx, bound_m);
260                Ok(())
261            }
262        });
263        port_handle.bind_handler_port();
264        Ok(())
265    }
266}
267
268impl<M> From<ErasedUnbound> for IndexedErasedUnbound<M> {
269    fn from(erased: ErasedUnbound) -> Self {
270        Self(erased, PhantomData)
271    }
272}
273
274macro_rules! impl_bind_unbind_basic {
275    ($t:ty) => {
276        impl Bind for $t {
277            fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
278                anyhow::ensure!(
279                    bindings.0.is_empty(),
280                    "bindings for {} should be empty, but found {} elements left",
281                    stringify!($t),
282                    bindings.0.len(),
283                );
284                Ok(())
285            }
286        }
287
288        impl Unbind for $t {
289            fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> {
290                Ok(())
291            }
292        }
293    };
294}
295
296impl_bind_unbind_basic!(());
297impl_bind_unbind_basic!(bool);
298impl_bind_unbind_basic!(i8);
299impl_bind_unbind_basic!(u8);
300impl_bind_unbind_basic!(i16);
301impl_bind_unbind_basic!(u16);
302impl_bind_unbind_basic!(i32);
303impl_bind_unbind_basic!(u32);
304impl_bind_unbind_basic!(i64);
305impl_bind_unbind_basic!(u64);
306impl_bind_unbind_basic!(i128);
307impl_bind_unbind_basic!(u128);
308impl_bind_unbind_basic!(isize);
309impl_bind_unbind_basic!(usize);
310impl_bind_unbind_basic!(String);
311impl_bind_unbind_basic!(std::time::Duration);
312impl_bind_unbind_basic!(std::time::SystemTime);
313
314impl<T: Unbind> Unbind for Option<T> {
315    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
316        match self {
317            Some(t) => t.unbind(bindings),
318            None => Ok(()),
319        }
320    }
321}
322
323impl<T: Bind> Bind for Option<T> {
324    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
325        match self {
326            Some(t) => t.bind(bindings),
327            None => Ok(()),
328        }
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use crate as hyperactor; // for macros
336    use crate::Bind;
337    use crate::PortRef;
338    use crate::Unbind;
339    use crate::UnboundPort;
340    use crate::accum::ReducerSpec;
341    use crate::accum::StreamingReducerOpts;
342    use crate::testing::ids::test_port_id;
343
344    // Used to demonstrate a user defined reply type.
345    #[derive(Debug, PartialEq, Serialize, Deserialize, typeuri::Named)]
346    struct MyReply(String);
347
348    // Used to demonstrate a two-way message type.
349    #[derive(
350        Debug,
351        Clone,
352        PartialEq,
353        Serialize,
354        Deserialize,
355        typeuri::Named,
356        Bind,
357        Unbind
358    )]
359    struct MyMessage {
360        arg0: bool,
361        arg1: u32,
362        #[binding(include)]
363        reply0: PortRef<String>,
364        #[binding(include)]
365        reply1: PortRef<MyReply>,
366    }
367
368    #[test]
369    fn test_castable() {
370        let original_port0 = PortRef::attest(test_port_id("world_0", "actor", 123));
371        let original_port1 = PortRef::attest_reducible(
372            test_port_id("world_1", "actor1", 456),
373            Some(ReducerSpec {
374                typehash: 123,
375                builder_params: None,
376            }),
377            StreamingReducerOpts::default(),
378        );
379        let my_message = MyMessage {
380            arg0: true,
381            arg1: 42,
382            reply0: original_port0.clone(),
383            reply1: original_port1.clone(),
384        };
385
386        let serialized_my_message = wirevalue::Any::serialize(&my_message).unwrap();
387
388        // convert to ErasedUnbound
389        let mut erased = ErasedUnbound::try_from_message(my_message.clone()).unwrap();
390        assert_eq!(
391            erased,
392            ErasedUnbound {
393                message: serialized_my_message.clone(),
394                bindings: Bindings(
395                    [
396                        (
397                            UnboundPort::typehash(),
398                            wirevalue::Any::serialize(&UnboundPort::from(&original_port0)).unwrap(),
399                        ),
400                        (
401                            UnboundPort::typehash(),
402                            wirevalue::Any::serialize(&UnboundPort::from(&original_port1)).unwrap(),
403                        ),
404                    ]
405                    .into_iter()
406                    .collect()
407                ),
408            }
409        );
410
411        // Modify the port in the erased
412        let new_port_id0 = test_port_id("world_0", "comm", 680);
413        assert_ne!(&new_port_id0, original_port0.port_addr());
414        let new_port_id1 = test_port_id("world_1", "comm", 257);
415        assert_ne!(&new_port_id1, original_port1.port_addr());
416
417        let mut new_ports = vec![&new_port_id0, &new_port_id1].into_iter();
418        erased
419            .visit_mut::<UnboundPort>(|b| {
420                let port = new_ports.next().unwrap();
421                b.update(port.clone());
422                Ok(())
423            })
424            .unwrap();
425
426        let new_port0 = PortRef::<String>::attest(new_port_id0);
427        let new_port1 = PortRef::<MyReply>::attest_reducible(
428            new_port_id1,
429            Some(ReducerSpec {
430                typehash: 123,
431                builder_params: None,
432            }),
433            StreamingReducerOpts::default(),
434        );
435        let new_bindings = Bindings(
436            [
437                (
438                    UnboundPort::typehash(),
439                    wirevalue::Any::serialize(&UnboundPort::from(&new_port0)).unwrap(),
440                ),
441                (
442                    UnboundPort::typehash(),
443                    wirevalue::Any::serialize(&UnboundPort::from(&new_port1)).unwrap(),
444                ),
445            ]
446            .into_iter()
447            .collect(),
448        );
449        assert_eq!(
450            erased,
451            ErasedUnbound {
452                message: serialized_my_message.clone(),
453                bindings: new_bindings.clone(),
454            }
455        );
456
457        // convert back to MyMessage
458        let unbound = erased.downcast::<MyMessage>().unwrap();
459        assert_eq!(
460            unbound,
461            Unbound {
462                message: my_message,
463                bindings: new_bindings,
464            }
465        );
466        let new_my_message = unbound.bind().unwrap();
467        assert_eq!(
468            new_my_message,
469            MyMessage {
470                arg0: true,
471                arg1: 42,
472                reply0: new_port0,
473                reply1: new_port1,
474            }
475        );
476    }
477}