hyperactor/mailbox/
durable_mailbox_sender.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use log::*;
10
11use super::*;
12
13/// A [`DurableMailboxSender`] is a [`MailboxSender`] that writes messages to a write-ahead log
14/// before the receiver consume any of them. It allows the receiver to recover from crashes by
15/// replaying the log. It supports any implementation of [`MailboxSender`].
16#[derive(Debug)]
17struct DurableMailboxSender(Buffer<MessageEnvelope>);
18
19impl DurableMailboxSender {
20    fn new(
21        write_ahead_log: impl MessageLog<MessageEnvelope> + 'static,
22        inner: impl MailboxSender + 'static,
23    ) -> Self {
24        let write_ahead_log = Arc::new(tokio::sync::Mutex::new(write_ahead_log));
25        let inner = Arc::new(inner);
26        let sequencer =
27            Buffer::new(
28                move |envelope: MessageEnvelope,
29                      return_handle: PortHandle<Undeliverable<MessageEnvelope>>| {
30                    let write_ahead_log = write_ahead_log.clone();
31                    let inner = inner.clone();
32                    let return_handle = return_handle.clone();
33                    async move {
34                        let envelope_copy = envelope.clone(); // we maintain a copy in case we have to mark it failed
35                        let port_id = envelope.dest().clone();
36                        let mut log = write_ahead_log.lock().await;
37                        // TODO: There are potentially two ways to avoid copy; both require interface change.
38                        // (1) use Rc or Arc and (2) implement our own CopyOnDrop struct
39                        let append_result = log.append(envelope).await.map_err(|err| {
40                            MailboxSenderError::new_bound(port_id.clone(), err.into())
41                        });
42
43                        let flush_result = log.flush().await.map_err(|err| {
44                            MailboxSenderError::new_bound(port_id.clone(), err.into())
45                        });
46
47                        drop(log);
48
49                        if append_result.and(flush_result).is_ok() {
50                            inner.post(envelope_copy, return_handle);
51                        } else {
52                            envelope_copy.undeliverable(
53                                DeliveryError::BrokenLink(
54                                    "failed to append or flush in durable sender".to_string(),
55                                ),
56                                return_handle,
57                            );
58                        }
59                    }
60                },
61            );
62
63        Self(sequencer)
64    }
65
66    async fn flush(&mut self) -> Result<(), watch::error::RecvError> {
67        self.0.flush().await
68    }
69}
70
71#[async_trait]
72impl MailboxSender for DurableMailboxSender {
73    fn post_unchecked(
74        &self,
75        envelope: MessageEnvelope,
76        return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
77    ) {
78        if let Err(mpsc::error::SendError((envelope, return_handle))) =
79            self.0.send((envelope, return_handle))
80        {
81            envelope.undeliverable(
82                DeliveryError::BrokenLink("failed to post in DurableMailboxSender".to_string()),
83                return_handle,
84            );
85        }
86    }
87}
88
89pub mod log {
90
91    //! This module implements a write-ahead log for mailboxes. This can be used to provide
92    //! durable messaging facilities for actors.
93
94    use std::fmt::Debug;
95
96    use async_trait::async_trait;
97    use futures::stream::Stream;
98
99    use crate::RemoteMessage;
100
101    /// A sequence id is a unique identifier for a message.
102    pub type SeqId = u64;
103
104    /// Errors that occur during message log operations.
105    /// This enum is marked non-exhaustive to allow for extensibility.
106    #[derive(thiserror::Error, Debug)]
107    #[non_exhaustive]
108    pub enum MessageLogError {
109        /// An error occured during flushing messages with a sequence id range.
110        #[error("flush: [{0}, {1})")]
111        Flush(SeqId, SeqId, #[source] anyhow::Error),
112
113        /// An error occured during appending a message with an assigned sequence id.
114        #[error("append: {0}")]
115        Append(SeqId, #[source] anyhow::Error),
116
117        /// An error occured during reading a message with the persistent sequence id.
118        #[error("read: {0}")]
119        Read(SeqId, #[source] anyhow::Error),
120
121        /// An error occured during trimming a message with the persistent sequence id.
122        #[error("trim: {0}")]
123        Trim(SeqId, #[source] anyhow::Error),
124
125        /// An other error.
126        #[error(transparent)]
127        Other(#[from] anyhow::Error),
128    }
129
130    /// This [`MessageLog`] is a log that serves as a building block to persist data before the
131    /// consumer can process it. One typical example is to persist messages before an actor handles it.
132    /// In such a case, it can be used as a white-ahead log. It allows the actor to recover from a
133    /// crash without requesting resending the messages. The log is append-only and the messages are
134    /// persisted in order with sequence ids.
135    #[async_trait]
136    pub trait MessageLog<M: RemoteMessage>: Sync + Send + Debug {
137        /// The type of the stream returned from read operations on this log.
138        type Stream<'a>: Stream<Item = Result<(SeqId, M), MessageLogError>> + Send
139        where
140            Self: 'a;
141
142        /// Append a message to a buffer. The appended messages will only be persisted and available to
143        /// read after calling [`flush`].
144        async fn append(&mut self, message: M) -> Result<(), MessageLogError>;
145
146        /// Flush the appended messages. Return the next sequence id of the last persistent message.
147        async fn flush(&mut self) -> Result<SeqId, MessageLogError>;
148
149        /// Directly flush the message. All previously buffered messages will be flushed as well.
150        /// This convenience method can prevent an additional copy of the message by directly writing to the log.
151        async fn append_and_flush(&mut self, message: &M) -> Result<SeqId, MessageLogError>;
152
153        /// Trim the persistent logs before the given [`new_start`] non-inclusively.
154        async fn trim(&mut self, new_start: SeqId) -> Result<(), MessageLogError>;
155
156        /// Given a sequence id, return a stream of message and sequence id tuples that are persisted
157        /// after the given sequence id inclusively. The stream will yield errors when streaming
158        /// messages back if any. It will also yield errors if creating the stream itself fails.
159        async fn read(&self, from: SeqId) -> Result<Self::Stream<'_>, MessageLogError>;
160
161        /// Read exactly one message from the log. If the log is empty, return an error.
162        // Ideally, this method can have a default implmentation. But the compiler complains
163        // about `self` does not live long enough.
164        async fn read_one(&self, seq_id: SeqId) -> Result<M, MessageLogError>;
165    }
166}
167
168/// A test util mod so that it can be used beyond the crate
169pub mod test_utils {
170
171    use std::collections::VecDeque;
172
173    use futures::pin_mut;
174    use log::SeqId;
175    use tokio_stream::StreamExt;
176
177    use super::*;
178
179    /// An in-memory log for testing.
180    #[derive(Debug, Clone)]
181    pub struct TestLog<M: RemoteMessage> {
182        queue: Arc<Mutex<VecDeque<(SeqId, M)>>>,
183        current_seq_id: Arc<Mutex<SeqId>>,
184        // For outside to validate the values of saved messages.
185        observer: Option<mpsc::UnboundedSender<(String, M)>>,
186    }
187
188    impl<M: RemoteMessage> Default for TestLog<M> {
189        fn default() -> Self {
190            Self::new()
191        }
192    }
193
194    impl<M: RemoteMessage> TestLog<M> {
195        /// Create a new, empty [`TestLog`].
196        pub fn new() -> Self {
197            Self {
198                queue: Arc::new(Mutex::new(VecDeque::new())),
199                current_seq_id: Arc::new(Mutex::new(0)),
200                observer: None,
201            }
202        }
203
204        /// Create a new test log that sends all log operations to the provided
205        /// observer. The observer is sent tuples of `(op, message)`, where `op` is
206        /// either "append" or "read".
207        pub fn new_with_observer(observer: mpsc::UnboundedSender<(String, M)>) -> Self {
208            Self {
209                queue: Arc::new(Mutex::new(VecDeque::new())),
210                current_seq_id: Arc::new(Mutex::new(0)),
211                observer: Some(observer),
212            }
213        }
214    }
215
216    #[async_trait]
217    impl<M: RemoteMessage + Clone> MessageLog<M> for TestLog<M> {
218        type Stream<'a> =
219            futures::stream::Iter<std::vec::IntoIter<Result<(SeqId, M), MessageLogError>>>;
220
221        async fn append(&mut self, message: M) -> Result<(), MessageLogError> {
222            let mut seq_id = self.current_seq_id.lock().unwrap();
223            self.queue
224                .lock()
225                .unwrap()
226                .push_back((*seq_id, message.clone()));
227            *seq_id += 1;
228            if let Some(observer) = &self.observer {
229                observer.send(("append".to_string(), message)).unwrap();
230            }
231            Ok(())
232        }
233
234        async fn flush(&mut self) -> Result<SeqId, MessageLogError> {
235            let seq_id = *self.current_seq_id.lock().unwrap();
236            Ok(seq_id)
237        }
238
239        async fn append_and_flush(&mut self, message: &M) -> Result<SeqId, MessageLogError> {
240            self.append(message.clone()).await?;
241            self.flush().await
242        }
243
244        async fn trim(&mut self, new_start: SeqId) -> Result<(), MessageLogError> {
245            let mut queue = self.queue.lock().unwrap();
246            while let Some((id, _)) = queue.front() {
247                if *id < new_start {
248                    queue.pop_front();
249                } else {
250                    break;
251                }
252            }
253            Ok(())
254        }
255
256        async fn read(&self, seq_id: SeqId) -> Result<Self::Stream<'_>, MessageLogError> {
257            let queue = self.queue.lock().unwrap();
258            let filtered_items: Vec<_> = queue
259                .iter()
260                .filter(move |(id, _)| *id >= seq_id)
261                .map(|(seq_id, msg)| Ok((*seq_id, msg.clone())))
262                .collect();
263            for entry in filtered_items.iter() {
264                if let Some(observer) = &self.observer
265                    && let Ok((_, msg)) = entry.as_ref()
266                {
267                    observer.send(("read".to_string(), msg.clone())).unwrap();
268                }
269            }
270            Ok(futures::stream::iter(filtered_items.into_iter()))
271        }
272
273        async fn read_one(&self, seq_id: SeqId) -> Result<M, MessageLogError> {
274            let it = self.read(seq_id).await?;
275
276            pin_mut!(it);
277            match it.next().await {
278                Some(Ok((result_seq_id, message))) => {
279                    if result_seq_id != seq_id {
280                        panic!("no seq id {}", seq_id);
281                    }
282                    return Ok(message);
283                }
284                Some(Err(err)) => {
285                    return Err(err);
286                }
287                None => {
288                    return Err(MessageLogError::Read(
289                        seq_id,
290                        anyhow::anyhow!("failed to find message with sequence {}", seq_id),
291                    ));
292                }
293            }
294        }
295    }
296}
297
298#[cfg(test)]
299mod tests {
300
301    use std::assert_matches::assert_matches;
302    use std::mem::drop;
303
304    use futures::StreamExt;
305
306    use super::test_utils::TestLog;
307    use super::*;
308    use crate::id;
309    use crate::mailbox::log::SeqId;
310
311    #[tokio::test]
312    async fn test_local_write_ahead_log_basic() {
313        let mut wal = TestLog::new();
314        wal.append(124u64).await.unwrap();
315        wal.append(56u64).await.unwrap();
316        let seq_id = wal.append_and_flush(&999u64).await.unwrap();
317        assert_eq!(seq_id, 3);
318
319        // Simple read given a sequence id
320        let mut it = wal.read(1).await.unwrap();
321        let (next_seq, message): (SeqId, u64) = it.next().await.unwrap().unwrap();
322        assert_eq!(next_seq, 1);
323        assert_eq!(message, 56u64);
324        let (next_seq, message) = it.next().await.unwrap().unwrap();
325        assert_eq!(next_seq, 2);
326        assert_eq!(message, 999u64);
327        assert_matches!(it.next().await, None);
328        // Drop the iterator to release borrow from wal
329        drop(it);
330
331        // Trim then append
332        wal.trim(2).await.unwrap();
333        let seq_id = wal.append_and_flush(&777u64).await.unwrap();
334        assert_eq!(seq_id, 4);
335        let mut it = wal.read(2).await.unwrap();
336        let (next_seq, message): (SeqId, u64) = it.next().await.unwrap().unwrap();
337        assert_eq!(next_seq, 2);
338        assert_eq!(message, 999u64);
339        let (next_seq, message) = it.next().await.unwrap().unwrap();
340        assert_eq!(next_seq, 3);
341        assert_eq!(message, 777u64);
342        assert_matches!(it.next().await, None);
343    }
344
345    #[tokio::test]
346    async fn test_durable_mailbox_sender() {
347        let inner = Mailbox::new_detached(id!(world0[0].actor0));
348        let write_ahead_log = TestLog::new();
349        let mut durable_mbox = DurableMailboxSender::new(write_ahead_log.clone(), inner.clone());
350
351        let (port1, mut receiver1) = inner.open_port::<u64>();
352        let (port2, mut _receiver2) = inner.open_port::<u64>();
353
354        // Convert to references so that the ports are registered.
355        let port1 = port1.bind();
356        let port2 = port2.bind();
357
358        durable_mbox.post(
359            MessageEnvelope::new_unknown(
360                port1.port_id().clone(),
361                Serialized::serialize(&1u64).unwrap(),
362            ),
363            monitored_return_handle(),
364        );
365        durable_mbox.post(
366            MessageEnvelope::new_unknown(
367                port2.port_id().clone(),
368                Serialized::serialize(&2u64).unwrap(),
369            ),
370            monitored_return_handle(),
371        );
372        durable_mbox.post(
373            MessageEnvelope::new_unknown(
374                port1.port_id().clone(),
375                Serialized::serialize(&3u64).unwrap(),
376            ),
377            monitored_return_handle(),
378        );
379        assert_eq!(receiver1.recv().await.unwrap(), 1u64);
380
381        durable_mbox.flush().await.unwrap();
382
383        let mut it = write_ahead_log.read(1).await.unwrap();
384        let (seq, message): (SeqId, MessageEnvelope) = it.next().await.unwrap().unwrap();
385        assert_eq!(seq, 1);
386        assert_eq!(port2.port_id(), message.dest());
387        assert_eq!(2u64, message.deserialized::<u64>().unwrap());
388        let (seq, message): (SeqId, MessageEnvelope) = it.next().await.unwrap().unwrap();
389        assert_eq!(seq, 2);
390        assert_eq!(3u64, message.deserialized::<u64>().unwrap());
391    }
392}