hyperactor/mailbox/
durable_mailbox_sender.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use log::*;
10
11use super::*;
12
13/// A [`DurableMailboxSender`] is a [`MailboxSender`] that writes messages to a write-ahead log
14/// before the receiver consume any of them. It allows the receiver to recover from crashes by
15/// replaying the log. It supports any implementation of [`MailboxSender`].
16struct DurableMailboxSender(Buffer<MessageEnvelope>);
17
18impl DurableMailboxSender {
19    fn new(
20        write_ahead_log: impl MessageLog<MessageEnvelope> + 'static,
21        inner: impl MailboxSender + 'static,
22    ) -> Self {
23        let write_ahead_log = Arc::new(tokio::sync::Mutex::new(write_ahead_log));
24        let inner = Arc::new(inner);
25        let sequencer =
26            Buffer::new(
27                move |envelope: MessageEnvelope,
28                      return_handle: PortHandle<Undeliverable<MessageEnvelope>>| {
29                    let write_ahead_log = write_ahead_log.clone();
30                    let inner = inner.clone();
31                    let return_handle = return_handle.clone();
32                    async move {
33                        let envelope_copy = envelope.clone(); // we maintain a copy in case we have to mark it failed
34                        let port_id = envelope.dest().clone();
35                        let mut log = write_ahead_log.lock().await;
36                        // TODO: There are potentially two ways to avoid copy; both require interface change.
37                        // (1) use Rc or Arc and (2) implement our own CopyOnDrop struct
38                        let append_result = log.append(envelope).await.map_err(|err| {
39                            MailboxSenderError::new_bound(port_id.clone(), err.into())
40                        });
41
42                        let flush_result = log.flush().await.map_err(|err| {
43                            MailboxSenderError::new_bound(port_id.clone(), err.into())
44                        });
45
46                        drop(log);
47
48                        if append_result.and(flush_result).is_ok() {
49                            inner.post(envelope_copy, return_handle);
50                        } else {
51                            envelope_copy.undeliverable(
52                                DeliveryError::BrokenLink(
53                                    "failed to append or flush in durable sender".to_string(),
54                                ),
55                                return_handle,
56                            );
57                        }
58                    }
59                },
60            );
61
62        Self(sequencer)
63    }
64
65    async fn flush(&mut self) -> Result<(), watch::error::RecvError> {
66        self.0.flush().await
67    }
68}
69
70#[async_trait]
71impl MailboxSender for DurableMailboxSender {
72    fn post_unchecked(
73        &self,
74        envelope: MessageEnvelope,
75        return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
76    ) {
77        if let Err(mpsc::error::SendError((envelope, return_handle))) =
78            self.0.send((envelope, return_handle))
79        {
80            envelope.undeliverable(
81                DeliveryError::BrokenLink("failed to post in DurableMailboxSender".to_string()),
82                return_handle,
83            );
84        }
85    }
86}
87
88pub mod log {
89
90    //! This module implements a write-ahead log for mailboxes. This can be used to provide
91    //! durable messaging facilities for actors.
92
93    use std::fmt::Debug;
94
95    use async_trait::async_trait;
96    use futures::stream::Stream;
97
98    use crate::RemoteMessage;
99
100    /// A sequence id is a unique identifier for a message.
101    pub type SeqId = u64;
102
103    /// Errors that occur during message log operations.
104    /// This enum is marked non-exhaustive to allow for extensibility.
105    #[derive(thiserror::Error, Debug)]
106    #[non_exhaustive]
107    pub enum MessageLogError {
108        /// An error occured during flushing messages with a sequence id range.
109        #[error("flush: [{0}, {1})")]
110        Flush(SeqId, SeqId, #[source] anyhow::Error),
111
112        /// An error occured during appending a message with an assigned sequence id.
113        #[error("append: {0}")]
114        Append(SeqId, #[source] anyhow::Error),
115
116        /// An error occured during reading a message with the persistent sequence id.
117        #[error("read: {0}")]
118        Read(SeqId, #[source] anyhow::Error),
119
120        /// An error occured during trimming a message with the persistent sequence id.
121        #[error("trim: {0}")]
122        Trim(SeqId, #[source] anyhow::Error),
123
124        /// An other error.
125        #[error(transparent)]
126        Other(#[from] anyhow::Error),
127    }
128
129    /// This [`MessageLog`] is a log that serves as a building block to persist data before the
130    /// consumer can process it. One typical example is to persist messages before an actor handles it.
131    /// In such a case, it can be used as a white-ahead log. It allows the actor to recover from a
132    /// crash without requesting resending the messages. The log is append-only and the messages are
133    /// persisted in order with sequence ids.
134    #[async_trait]
135    pub trait MessageLog<M: RemoteMessage>: Sync + Send {
136        /// The type of the stream returned from read operations on this log.
137        type Stream<'a>: Stream<Item = Result<(SeqId, M), MessageLogError>> + Send
138        where
139            Self: 'a;
140
141        /// Append a message to a buffer. The appended messages will only be persisted and available to
142        /// read after calling [`flush`].
143        async fn append(&mut self, message: M) -> Result<(), MessageLogError>;
144
145        /// Flush the appended messages. Return the next sequence id of the last persistent message.
146        async fn flush(&mut self) -> Result<SeqId, MessageLogError>;
147
148        /// Directly flush the message. All previously buffered messages will be flushed as well.
149        /// This convenience method can prevent an additional copy of the message by directly writing to the log.
150        async fn append_and_flush(&mut self, message: &M) -> Result<SeqId, MessageLogError>;
151
152        /// Trim the persistent logs before the given [`new_start`] non-inclusively.
153        async fn trim(&mut self, new_start: SeqId) -> Result<(), MessageLogError>;
154
155        /// Given a sequence id, return a stream of message and sequence id tuples that are persisted
156        /// after the given sequence id inclusively. The stream will yield errors when streaming
157        /// messages back if any. It will also yield errors if creating the stream itself fails.
158        async fn read(&self, from: SeqId) -> Result<Self::Stream<'_>, MessageLogError>;
159
160        /// Read exactly one message from the log. If the log is empty, return an error.
161        // Ideally, this method can have a default implmentation. But the compiler complains
162        // about `self` does not live long enough.
163        async fn read_one(&self, seq_id: SeqId) -> Result<M, MessageLogError>;
164    }
165}
166
167/// A test util mod so that it can be used beyond the crate
168pub mod test_utils {
169
170    use std::collections::VecDeque;
171
172    use futures::pin_mut;
173    use log::SeqId;
174    use tokio_stream::StreamExt;
175
176    use super::*;
177
178    /// An in-memory log for testing.
179    #[derive(Clone)]
180    pub struct TestLog<M: RemoteMessage> {
181        queue: Arc<Mutex<VecDeque<(SeqId, M)>>>,
182        current_seq_id: Arc<Mutex<SeqId>>,
183        // For outside to validate the values of saved messages.
184        observer: Option<mpsc::UnboundedSender<(String, M)>>,
185    }
186
187    impl<M: RemoteMessage> Default for TestLog<M> {
188        fn default() -> Self {
189            Self::new()
190        }
191    }
192
193    impl<M: RemoteMessage> TestLog<M> {
194        /// Create a new, empty [`TestLog`].
195        pub fn new() -> Self {
196            Self {
197                queue: Arc::new(Mutex::new(VecDeque::new())),
198                current_seq_id: Arc::new(Mutex::new(0)),
199                observer: None,
200            }
201        }
202
203        /// Create a new test log that sends all log operations to the provided
204        /// observer. The observer is sent tuples of `(op, message)`, where `op` is
205        /// either "append" or "read".
206        pub fn new_with_observer(observer: mpsc::UnboundedSender<(String, M)>) -> Self {
207            Self {
208                queue: Arc::new(Mutex::new(VecDeque::new())),
209                current_seq_id: Arc::new(Mutex::new(0)),
210                observer: Some(observer),
211            }
212        }
213    }
214
215    #[async_trait]
216    impl<M: RemoteMessage + Clone> MessageLog<M> for TestLog<M> {
217        type Stream<'a> =
218            futures::stream::Iter<std::vec::IntoIter<Result<(SeqId, M), MessageLogError>>>;
219
220        async fn append(&mut self, message: M) -> Result<(), MessageLogError> {
221            let mut seq_id = self.current_seq_id.lock().unwrap();
222            self.queue
223                .lock()
224                .unwrap()
225                .push_back((*seq_id, message.clone()));
226            *seq_id += 1;
227            if let Some(observer) = &self.observer {
228                observer.send(("append".to_string(), message)).unwrap();
229            }
230            Ok(())
231        }
232
233        async fn flush(&mut self) -> Result<SeqId, MessageLogError> {
234            let seq_id = *self.current_seq_id.lock().unwrap();
235            Ok(seq_id)
236        }
237
238        async fn append_and_flush(&mut self, message: &M) -> Result<SeqId, MessageLogError> {
239            self.append(message.clone()).await?;
240            self.flush().await
241        }
242
243        async fn trim(&mut self, new_start: SeqId) -> Result<(), MessageLogError> {
244            let mut queue = self.queue.lock().unwrap();
245            while let Some((id, _)) = queue.front() {
246                if *id < new_start {
247                    queue.pop_front();
248                } else {
249                    break;
250                }
251            }
252            Ok(())
253        }
254
255        async fn read(&self, seq_id: SeqId) -> Result<Self::Stream<'_>, MessageLogError> {
256            let queue = self.queue.lock().unwrap();
257            let filtered_items: Vec<_> = queue
258                .iter()
259                .filter(move |(id, _)| *id >= seq_id)
260                .map(|(seq_id, msg)| Ok((*seq_id, msg.clone())))
261                .collect();
262            for entry in filtered_items.iter() {
263                if let Some(observer) = &self.observer
264                    && let Ok((_, msg)) = entry.as_ref()
265                {
266                    observer.send(("read".to_string(), msg.clone())).unwrap();
267                }
268            }
269            Ok(futures::stream::iter(filtered_items))
270        }
271
272        async fn read_one(&self, seq_id: SeqId) -> Result<M, MessageLogError> {
273            let it = self.read(seq_id).await?;
274
275            pin_mut!(it);
276            match it.next().await {
277                Some(Ok((result_seq_id, message))) => {
278                    if result_seq_id != seq_id {
279                        panic!("no seq id {}", seq_id);
280                    }
281                    return Ok(message);
282                }
283                Some(Err(err)) => {
284                    return Err(err);
285                }
286                None => {
287                    return Err(MessageLogError::Read(
288                        seq_id,
289                        anyhow::anyhow!("failed to find message with sequence {}", seq_id),
290                    ));
291                }
292            }
293        }
294    }
295}
296
297#[cfg(test)]
298mod tests {
299
300    use std::assert_matches::assert_matches;
301    use std::mem::drop;
302
303    use futures::StreamExt;
304
305    use super::test_utils::TestLog;
306    use super::*;
307    use crate::id;
308    use crate::mailbox::log::SeqId;
309
310    #[tokio::test]
311    async fn test_local_write_ahead_log_basic() {
312        let mut wal = TestLog::new();
313        wal.append(124u64).await.unwrap();
314        wal.append(56u64).await.unwrap();
315        let seq_id = wal.append_and_flush(&999u64).await.unwrap();
316        assert_eq!(seq_id, 3);
317
318        // Simple read given a sequence id
319        let mut it = wal.read(1).await.unwrap();
320        let (next_seq, message): (SeqId, u64) = it.next().await.unwrap().unwrap();
321        assert_eq!(next_seq, 1);
322        assert_eq!(message, 56u64);
323        let (next_seq, message) = it.next().await.unwrap().unwrap();
324        assert_eq!(next_seq, 2);
325        assert_eq!(message, 999u64);
326        assert_matches!(it.next().await, None);
327        // Drop the iterator to release borrow from wal
328        drop(it);
329
330        // Trim then append
331        wal.trim(2).await.unwrap();
332        let seq_id = wal.append_and_flush(&777u64).await.unwrap();
333        assert_eq!(seq_id, 4);
334        let mut it = wal.read(2).await.unwrap();
335        let (next_seq, message): (SeqId, u64) = it.next().await.unwrap().unwrap();
336        assert_eq!(next_seq, 2);
337        assert_eq!(message, 999u64);
338        let (next_seq, message) = it.next().await.unwrap().unwrap();
339        assert_eq!(next_seq, 3);
340        assert_eq!(message, 777u64);
341        assert_matches!(it.next().await, None);
342    }
343
344    #[tokio::test]
345    async fn test_durable_mailbox_sender() {
346        let inner = Mailbox::new_detached(id!(world0[0].actor0));
347        let write_ahead_log = TestLog::new();
348        let mut durable_mbox = DurableMailboxSender::new(write_ahead_log.clone(), inner.clone());
349
350        let (port1, mut receiver1) = inner.open_port::<u64>();
351        let (port2, mut _receiver2) = inner.open_port::<u64>();
352
353        // Convert to references so that the ports are registered.
354        let port1 = port1.bind();
355        let port2 = port2.bind();
356
357        durable_mbox.post(
358            MessageEnvelope::new_unknown(
359                port1.port_id().clone(),
360                wirevalue::Any::serialize(&1u64).unwrap(),
361            ),
362            monitored_return_handle(),
363        );
364        durable_mbox.post(
365            MessageEnvelope::new_unknown(
366                port2.port_id().clone(),
367                wirevalue::Any::serialize(&2u64).unwrap(),
368            ),
369            monitored_return_handle(),
370        );
371        durable_mbox.post(
372            MessageEnvelope::new_unknown(
373                port1.port_id().clone(),
374                wirevalue::Any::serialize(&3u64).unwrap(),
375            ),
376            monitored_return_handle(),
377        );
378        assert_eq!(receiver1.recv().await.unwrap(), 1u64);
379
380        durable_mbox.flush().await.unwrap();
381
382        let mut it = write_ahead_log.read(1).await.unwrap();
383        let (seq, message): (SeqId, MessageEnvelope) = it.next().await.unwrap().unwrap();
384        assert_eq!(seq, 1);
385        assert_eq!(port2.port_id(), message.dest());
386        assert_eq!(2u64, message.deserialized::<u64>().unwrap());
387        let (seq, message): (SeqId, MessageEnvelope) = it.next().await.unwrap().unwrap();
388        assert_eq!(seq, 2);
389        assert_eq!(3u64, message.deserialized::<u64>().unwrap());
390    }
391}