hyperactor/
message.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module provides a framework for mutating serialized messages without
10//! the need to deserialize them. This capability is useful when sending messages
11//! to a remote destination throughout intermeidate nodes, where the intermediate
12//! nodes do not contain the message's type information.
13//!
14//! Briefly, it works by following these steps:
15//!
16//! 1. On the sender side, mutable information is extracted from the typed
17//!    message through [Unbind], and stored in a [Bindings] object. This object
18//!    is bundled with the serialized message in an [ErasedUnbound] object, which
19//!    is sent over the wire.
20//! 2. On intermediate nodes, the [ErasedUnbound] object is relayed. The
21//!    muation is applied on its bindings field, if needed.
22//! 3. One the receiver side, the [ErasedUnbound] object is received as
23//!    [IndexedErasedUnbound], where the type information is restored. Mutated
24//!    information contained in its bindings field is applied to the message
25//!    through [Bind], which results in the final typed message.
26//!
27//! One main use case of this framework is to mutate the reply ports of a
28//! muticast message, so the replies can be relayed through intermediate nodes,
29//! rather than directly sent to the original sender. Therefore, a [Castable]
30//! trait is defined, which collects requirements for message types using
31//! multicast.
32
33use std::collections::VecDeque;
34use std::marker::PhantomData;
35
36use serde::Deserialize;
37use serde::Serialize;
38use serde::de::DeserializeOwned;
39
40use crate as hyperactor;
41use crate::ActorRef;
42use crate::Mailbox;
43use crate::Named;
44use crate::RemoteHandles;
45use crate::RemoteMessage;
46use crate::actor::Referable;
47use crate::context;
48use crate::data::Serialized;
49
50/// An object `T` that is [`Unbind`] can extract a set of parameters from itself,
51/// and store in [`Bindings`]. The extracted parameters in [`Bindings`] can be
52/// independently manipulated, and then later reconstituted (rebound) into
53/// a `T`-typed object again.
54pub trait Unbind: Sized {
55    /// Extract parameters from itself and store them in bindings.
56    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()>;
57}
58
59/// An object `T` that is [`Bind`] can bind a set of externally provided
60/// parameters into itself.
61pub trait Bind: Sized {
62    /// Remove parameters from bindings, and use them to update itself.
63    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()>;
64}
65
66/// This trait collects the necessary requirements for messages that are can be
67/// cast.
68pub trait Castable: RemoteMessage + Bind + Unbind {}
69impl<T: RemoteMessage + Bind + Unbind> Castable for T {}
70
71/// Information extracted from a message through [Unbind], which can be merged
72/// back to the message through [Bind].
73#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
74pub struct Bindings(VecDeque<(u64, Serialized)>);
75
76impl Bindings {
77    /// Push a value into this bindings.
78    pub fn push_back<T: Serialize + Named>(&mut self, value: &T) -> anyhow::Result<()> {
79        let ser = Serialized::serialize(value)?;
80        self.0.push_back((T::typehash(), ser));
81        Ok(())
82    }
83
84    /// Removes the first pushed element in this bindings, deserialize it into
85    /// type T, and return it. Return [`None`] if this bindings is empty.
86    /// If the type of the first pushed element does not match T, an error is
87    /// returned.
88    pub fn pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<Option<T>> {
89        match self.0.pop_front() {
90            None => Ok(None),
91            Some((t, v)) => {
92                if t != T::typehash() {
93                    anyhow::bail!(
94                        "type mismatch: expected {} with hash {}, found {} in binding",
95                        T::typename(),
96                        T::typehash(),
97                        t,
98                    );
99                }
100                Ok(Some(v.deserialized::<T>()?))
101            }
102        }
103    }
104
105    /// Fallible version of [Bindings::pop_front].
106    pub fn try_pop_front<T: DeserializeOwned + Named>(&mut self) -> anyhow::Result<T> {
107        self.pop_front::<T>()?.ok_or_else(|| {
108            anyhow::anyhow!("expect a {} binding, but none was found", T::typename())
109        })
110    }
111
112    fn visit_mut<T: Serialize + DeserializeOwned + Named>(
113        &mut self,
114        mut f: impl FnMut(&mut T) -> anyhow::Result<()>,
115    ) -> anyhow::Result<()> {
116        for v in self.0.iter_mut() {
117            if v.0 == T::typehash() {
118                let mut t = v.1.deserialized::<T>()?;
119                f(&mut t)?;
120                v.1 = Serialized::serialize(&t)?;
121            }
122        }
123        Ok(())
124    }
125}
126
127/// An object contains a message, and its bindings extracted through [Unbind].
128#[derive(Debug, PartialEq)]
129pub struct Unbound<M> {
130    message: M,
131    bindings: Bindings,
132}
133
134impl<M> Unbound<M> {
135    /// Build a new object.
136    pub fn new(message: M, bindings: Bindings) -> Self {
137        Self { message, bindings }
138    }
139
140    /// Use the provided function to update values inside bindings in the same
141    /// order as they were pushed into bindings.
142    pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
143        &mut self,
144        f: impl FnMut(&mut T) -> anyhow::Result<()>,
145    ) -> anyhow::Result<()> {
146        self.bindings.visit_mut(f)
147    }
148}
149
150impl<M: Bind> Unbound<M> {
151    /// Bind its bindings to its message through [Bind], and return the result.
152    pub fn bind(mut self) -> anyhow::Result<M> {
153        self.message.bind(&mut self.bindings)?;
154        anyhow::ensure!(
155            self.bindings.0.is_empty(),
156            "there are still {} elements left in bindings",
157            self.bindings.0.len()
158        );
159        Ok(self.message)
160    }
161}
162
163impl<M: Unbind> Unbound<M> {
164    /// Create an object from a typed message.
165    // Note: cannot implement TryFrom<T> due to conflict with core crate's blanket impl.
166    // More can be found in this issue: https://github.com/rust-lang/rust/issues/50133
167    pub fn try_from_message(message: M) -> anyhow::Result<Self> {
168        let mut bindings = Bindings::default();
169        message.unbind(&mut bindings)?;
170        Ok(Unbound { message, bindings })
171    }
172}
173
174/// Unbound, with its message type M erased through serialization.
175#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
176pub struct ErasedUnbound {
177    message: Serialized,
178    bindings: Bindings,
179}
180
181impl ErasedUnbound {
182    /// Create an object directly from Serialized without binding.
183    pub fn new(message: Serialized) -> Self {
184        Self {
185            message,
186            bindings: Bindings::default(),
187        }
188    }
189
190    /// Create an object from a typed message.
191    // Note: cannot implement TryFrom<T> due to conflict with core crate's blanket impl.
192    // More can be found in this issue: https://github.com/rust-lang/rust/issues/50133
193    pub fn try_from_message<T: Unbind + Serialize + Named>(msg: T) -> Result<Self, anyhow::Error> {
194        let unbound = Unbound::try_from_message(msg)?;
195        let serialized = Serialized::serialize(&unbound.message)?;
196        Ok(Self {
197            message: serialized,
198            bindings: unbound.bindings,
199        })
200    }
201
202    /// Use the provided function to update values inside bindings in the same
203    /// order as they were pushed into bindings.
204    pub fn visit_mut<T: Serialize + DeserializeOwned + Named>(
205        &mut self,
206        f: impl FnMut(&mut T) -> anyhow::Result<()>,
207    ) -> anyhow::Result<()> {
208        self.bindings.visit_mut(f)
209    }
210
211    fn downcast<M: DeserializeOwned + Named>(self) -> anyhow::Result<Unbound<M>> {
212        let message: M = self.message.deserialized_unchecked()?;
213        Ok(Unbound {
214            message,
215            bindings: self.bindings,
216        })
217    }
218}
219
220/// Type used for indexing an erased unbound.
221/// Has the same serialized representation as `ErasedUnbound`.
222#[derive(Debug, PartialEq, Serialize, Deserialize, Named)]
223#[serde(from = "ErasedUnbound")]
224pub struct IndexedErasedUnbound<M>(ErasedUnbound, PhantomData<M>);
225
226impl<M: DeserializeOwned + Named> IndexedErasedUnbound<M> {
227    pub(crate) fn downcast(self) -> anyhow::Result<Unbound<M>> {
228        self.0.downcast()
229    }
230}
231
232impl<M: Bind> IndexedErasedUnbound<M> {
233    /// Used in unit tests to bind CastBlobT<M> to the given actor. Do not use in
234    /// production.
235    pub fn bind_for_test_only<A, C>(
236        actor_ref: ActorRef<A>,
237        cx: C,
238        mailbox: Mailbox,
239    ) -> anyhow::Result<()>
240    where
241        A: Referable + RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
242        M: RemoteMessage,
243        C: context::Actor + Send + Sync + 'static,
244    {
245        let port_handle = mailbox.open_enqueue_port::<IndexedErasedUnbound<M>>({
246            move |_, m| {
247                let bound_m = m.downcast()?.bind()?;
248                actor_ref.send(&cx, bound_m)?;
249                Ok(())
250            }
251        });
252        port_handle.bind_to(IndexedErasedUnbound::<M>::port());
253        Ok(())
254    }
255}
256
257impl<M> From<ErasedUnbound> for IndexedErasedUnbound<M> {
258    fn from(erased: ErasedUnbound) -> Self {
259        Self(erased, PhantomData)
260    }
261}
262
263macro_rules! impl_bind_unbind_basic {
264    ($t:ty) => {
265        impl Bind for $t {
266            fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
267                anyhow::ensure!(
268                    bindings.0.is_empty(),
269                    "bindings for {} should be empty, but found {} elements left",
270                    stringify!($t),
271                    bindings.0.len(),
272                );
273                Ok(())
274            }
275        }
276
277        impl Unbind for $t {
278            fn unbind(&self, _bindings: &mut Bindings) -> anyhow::Result<()> {
279                Ok(())
280            }
281        }
282    };
283}
284
285impl_bind_unbind_basic!(());
286impl_bind_unbind_basic!(bool);
287impl_bind_unbind_basic!(i8);
288impl_bind_unbind_basic!(u8);
289impl_bind_unbind_basic!(i16);
290impl_bind_unbind_basic!(u16);
291impl_bind_unbind_basic!(i32);
292impl_bind_unbind_basic!(u32);
293impl_bind_unbind_basic!(i64);
294impl_bind_unbind_basic!(u64);
295impl_bind_unbind_basic!(i128);
296impl_bind_unbind_basic!(u128);
297impl_bind_unbind_basic!(isize);
298impl_bind_unbind_basic!(usize);
299impl_bind_unbind_basic!(String);
300
301impl<T: Unbind> Unbind for Option<T> {
302    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
303        match self {
304            Some(t) => t.unbind(bindings),
305            None => Ok(()),
306        }
307    }
308}
309
310impl<T: Bind> Bind for Option<T> {
311    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
312        match self {
313            Some(t) => t.bind(bindings),
314            None => Ok(()),
315        }
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use hyperactor::PortRef;
322    use hyperactor::id;
323
324    use super::*;
325    use crate::Bind;
326    use crate::Unbind;
327    use crate::accum::ReducerSpec;
328    use crate::reference::UnboundPort;
329
330    // Used to demonstrate a user defined reply type.
331    #[derive(Debug, PartialEq, Serialize, Deserialize, Named)]
332    struct MyReply(String);
333
334    // Used to demonstrate a two-way message type.
335    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
336    struct MyMessage {
337        arg0: bool,
338        arg1: u32,
339        #[binding(include)]
340        reply0: PortRef<String>,
341        #[binding(include)]
342        reply1: PortRef<MyReply>,
343    }
344
345    #[test]
346    fn test_castable() {
347        let original_port0 = PortRef::attest(id!(world[0].actor[0][123]));
348        let original_port1 = PortRef::attest_reducible(
349            id!(world[1].actor1[0][456]),
350            Some(ReducerSpec {
351                typehash: 123,
352                builder_params: None,
353            }),
354        );
355        let my_message = MyMessage {
356            arg0: true,
357            arg1: 42,
358            reply0: original_port0.clone(),
359            reply1: original_port1.clone(),
360        };
361
362        let serialized_my_message = Serialized::serialize(&my_message).unwrap();
363
364        // convert to ErasedUnbound
365        let mut erased = ErasedUnbound::try_from_message(my_message.clone()).unwrap();
366        assert_eq!(
367            erased,
368            ErasedUnbound {
369                message: serialized_my_message.clone(),
370                bindings: Bindings(
371                    [
372                        (
373                            UnboundPort::typehash(),
374                            Serialized::serialize(&UnboundPort::from(&original_port0)).unwrap(),
375                        ),
376                        (
377                            UnboundPort::typehash(),
378                            Serialized::serialize(&UnboundPort::from(&original_port1)).unwrap(),
379                        ),
380                    ]
381                    .into_iter()
382                    .collect()
383                ),
384            }
385        );
386
387        // Modify the port in the erased
388        let new_port_id0 = id!(world[0].comm[0][680]);
389        assert_ne!(&new_port_id0, original_port0.port_id());
390        let new_port_id1 = id!(world[1].comm[0][257]);
391        assert_ne!(&new_port_id1, original_port1.port_id());
392
393        let mut new_ports = vec![&new_port_id0, &new_port_id1].into_iter();
394        erased
395            .visit_mut::<UnboundPort>(|b| {
396                let port = new_ports.next().unwrap();
397                b.update(port.clone());
398                Ok(())
399            })
400            .unwrap();
401
402        let new_port0 = PortRef::<String>::attest(new_port_id0);
403        let new_port1 = PortRef::<MyReply>::attest_reducible(
404            new_port_id1,
405            Some(ReducerSpec {
406                typehash: 123,
407                builder_params: None,
408            }),
409        );
410        let new_bindings = Bindings(
411            [
412                (
413                    UnboundPort::typehash(),
414                    Serialized::serialize(&UnboundPort::from(&new_port0)).unwrap(),
415                ),
416                (
417                    UnboundPort::typehash(),
418                    Serialized::serialize(&UnboundPort::from(&new_port1)).unwrap(),
419                ),
420            ]
421            .into_iter()
422            .collect(),
423        );
424        assert_eq!(
425            erased,
426            ErasedUnbound {
427                message: serialized_my_message.clone(),
428                bindings: new_bindings.clone(),
429            }
430        );
431
432        // convert back to MyMessage
433        let unbound = erased.downcast::<MyMessage>().unwrap();
434        assert_eq!(
435            unbound,
436            Unbound {
437                message: my_message,
438                bindings: new_bindings,
439            }
440        );
441        let new_my_message = unbound.bind().unwrap();
442        assert_eq!(
443            new_my_message,
444            MyMessage {
445                arg0: true,
446                arg1: 42,
447                reply0: new_port0,
448                reply1: new_port1,
449            }
450        );
451    }
452}