hyperactor/mailbox/
durable_mailbox_sender.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use log::*;
10
11use super::*;
12
13/// A [`DurableMailboxSender`] is a [`MailboxSender`] that writes messages to a write-ahead log
14/// before the receiver consume any of them. It allows the receiver to recover from crashes by
15/// replaying the log. It supports any implementation of [`MailboxSender`].
16#[derive(Debug)]
17struct DurableMailboxSender(Buffer<MessageEnvelope>);
18
19impl DurableMailboxSender {
20    fn new(
21        write_ahead_log: impl MessageLog<MessageEnvelope> + 'static,
22        inner: impl MailboxSender + 'static,
23    ) -> Self {
24        let write_ahead_log = Arc::new(tokio::sync::Mutex::new(write_ahead_log));
25        let inner = Arc::new(inner);
26        let sequencer =
27            Buffer::new(
28                move |envelope: MessageEnvelope,
29                      return_handle: PortHandle<Undeliverable<MessageEnvelope>>| {
30                    let write_ahead_log = write_ahead_log.clone();
31                    let inner = inner.clone();
32                    let return_handle = return_handle.clone();
33                    async move {
34                        let envelope_copy = envelope.clone(); // we maintain a copy in case we have to mark it failed
35                        let port_id = envelope.dest().clone();
36                        let mut log = write_ahead_log.lock().await;
37                        // TODO: There are potentially two ways to avoid copy; both require interface change.
38                        // (1) use Rc or Arc and (2) implement our own CopyOnDrop struct
39                        let append_result = log.append(envelope).await.map_err(|err| {
40                            MailboxSenderError::new_bound(port_id.clone(), err.into())
41                        });
42
43                        let flush_result = log.flush().await.map_err(|err| {
44                            MailboxSenderError::new_bound(port_id.clone(), err.into())
45                        });
46
47                        drop(log);
48
49                        if append_result.and(flush_result).is_ok() {
50                            inner.post(envelope_copy, return_handle);
51                        } else {
52                            envelope_copy.undeliverable(
53                                DeliveryError::BrokenLink(
54                                    "failed to append or flush in durable sender".to_string(),
55                                ),
56                                return_handle,
57                            );
58                        }
59                    }
60                },
61            );
62
63        Self(sequencer)
64    }
65
66    async fn flush(&mut self) -> Result<(), watch::error::RecvError> {
67        self.0.flush().await
68    }
69}
70
71#[async_trait]
72impl MailboxSender for DurableMailboxSender {
73    fn post(
74        &self,
75        envelope: MessageEnvelope,
76        return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
77    ) {
78        if let Err(mpsc::error::SendError((envelope, return_handle))) =
79            self.0.send((envelope, return_handle))
80        {
81            envelope.undeliverable(
82                DeliveryError::BrokenLink("failed to post in DurableMailboxSender".to_string()),
83                return_handle,
84            );
85        }
86    }
87}
88
89pub mod log {
90
91    //! This module implements a write-ahead log for mailboxes. This can be used to provide
92    //! durable messaging facilities for actors.
93
94    use std::fmt::Debug;
95
96    use async_trait::async_trait;
97    use futures::stream::Stream;
98
99    use crate::RemoteMessage;
100
101    /// A sequence id is a unique identifier for a message.
102    pub type SeqId = u64;
103
104    /// Errors that occur during message log operations.
105    /// This enum is marked non-exhaustive to allow for extensibility.
106    #[derive(thiserror::Error, Debug)]
107    #[non_exhaustive]
108    pub enum MessageLogError {
109        /// An error occured during flushing messages with a sequence id range.
110        #[error("flush: [{0}, {1})")]
111        Flush(SeqId, SeqId, #[source] anyhow::Error),
112
113        /// An error occured during appending a message with an assigned sequence id.
114        #[error("append: {0}")]
115        Append(SeqId, #[source] anyhow::Error),
116
117        /// An error occured during reading a message with the persistent sequence id.
118        #[error("read: {0}")]
119        Read(SeqId, #[source] anyhow::Error),
120
121        /// An error occured during trimming a message with the persistent sequence id.
122        #[error("trim: {0}")]
123        Trim(SeqId, #[source] anyhow::Error),
124
125        /// An other error.
126        #[error(transparent)]
127        Other(#[from] anyhow::Error),
128    }
129
130    /// This [`MessageLog`] is a log that serves as a building block to persist data before the
131    /// consumer can process it. One typical example is to persist messages before an actor handles it.
132    /// In such a case, it can be used as a white-ahead log. It allows the actor to recover from a
133    /// crash without requesting resending the messages. The log is append-only and the messages are
134    /// persisted in order with sequence ids.
135    #[async_trait]
136    pub trait MessageLog<M: RemoteMessage>: Sync + Send + Debug {
137        /// The type of the stream returned from read operations on this log.
138        type Stream<'a>: Stream<Item = Result<(SeqId, M), MessageLogError>> + Send
139        where
140            Self: 'a;
141
142        /// Append a message to a buffer. The appended messages will only be persisted and available to
143        /// read after calling [`flush`].
144        async fn append(&mut self, message: M) -> Result<(), MessageLogError>;
145
146        /// Flush the appended messages. Return the next sequence id of the last persistent message.
147        async fn flush(&mut self) -> Result<SeqId, MessageLogError>;
148
149        /// Directly flush the message. All previously buffered messages will be flushed as well.
150        /// This convenience method can prevent an additional copy of the message by directly writing to the log.
151        async fn append_and_flush(&mut self, message: &M) -> Result<SeqId, MessageLogError>;
152
153        /// Trim the persistent logs before the given [`new_start`] non-inclusively.
154        async fn trim(&mut self, new_start: SeqId) -> Result<(), MessageLogError>;
155
156        /// Given a sequence id, return a stream of message and sequence id tuples that are persisted
157        /// after the given sequence id inclusively. The stream will yield errors when streaming
158        /// messages back if any. It will also yield errors if creating the stream itself fails.
159        async fn read(&self, from: SeqId) -> Result<Self::Stream<'_>, MessageLogError>;
160
161        /// Read exactly one message from the log. If the log is empty, return an error.
162        // Ideally, this method can have a default implmentation. But the compiler complains
163        // about `self` does not live long enough.
164        async fn read_one(&self, seq_id: SeqId) -> Result<M, MessageLogError>;
165    }
166}
167
168/// A test util mod so that it can be used beyond the crate
169pub mod test_utils {
170
171    use std::collections::VecDeque;
172
173    use futures::pin_mut;
174    use log::SeqId;
175    use tokio_stream::StreamExt;
176
177    use super::*;
178
179    /// An in-memory log for testing.
180    #[derive(Debug, Clone)]
181    pub struct TestLog<M: RemoteMessage> {
182        queue: Arc<Mutex<VecDeque<(SeqId, M)>>>,
183        current_seq_id: Arc<Mutex<SeqId>>,
184        // For outside to validate the values of saved messages.
185        observer: Option<mpsc::UnboundedSender<(String, M)>>,
186    }
187
188    impl<M: RemoteMessage> TestLog<M> {
189        /// Create a new, empty [`TestLog`].
190        pub fn new() -> Self {
191            Self {
192                queue: Arc::new(Mutex::new(VecDeque::new())),
193                current_seq_id: Arc::new(Mutex::new(0)),
194                observer: None,
195            }
196        }
197
198        /// Create a new test log that sends all log operations to the provided
199        /// observer. The observer is sent tuples of `(op, message)`, where `op` is
200        /// either "append" or "read".
201        pub fn new_with_observer(observer: mpsc::UnboundedSender<(String, M)>) -> Self {
202            Self {
203                queue: Arc::new(Mutex::new(VecDeque::new())),
204                current_seq_id: Arc::new(Mutex::new(0)),
205                observer: Some(observer),
206            }
207        }
208    }
209
210    #[async_trait]
211    impl<M: RemoteMessage + Clone> MessageLog<M> for TestLog<M> {
212        type Stream<'a> =
213            futures::stream::Iter<std::vec::IntoIter<Result<(SeqId, M), MessageLogError>>>;
214
215        async fn append(&mut self, message: M) -> Result<(), MessageLogError> {
216            let mut seq_id = self.current_seq_id.lock().unwrap();
217            self.queue
218                .lock()
219                .unwrap()
220                .push_back((*seq_id, message.clone()));
221            *seq_id += 1;
222            if let Some(observer) = &self.observer {
223                observer.send(("append".to_string(), message)).unwrap();
224            }
225            Ok(())
226        }
227
228        async fn flush(&mut self) -> Result<SeqId, MessageLogError> {
229            let seq_id = *self.current_seq_id.lock().unwrap();
230            Ok(seq_id)
231        }
232
233        async fn append_and_flush(&mut self, message: &M) -> Result<SeqId, MessageLogError> {
234            self.append(message.clone()).await?;
235            self.flush().await
236        }
237
238        async fn trim(&mut self, new_start: SeqId) -> Result<(), MessageLogError> {
239            let mut queue = self.queue.lock().unwrap();
240            while let Some((id, _)) = queue.front() {
241                if *id < new_start {
242                    queue.pop_front();
243                } else {
244                    break;
245                }
246            }
247            Ok(())
248        }
249
250        async fn read(&self, seq_id: SeqId) -> Result<Self::Stream<'_>, MessageLogError> {
251            let queue = self.queue.lock().unwrap();
252            let filtered_items: Vec<_> = queue
253                .iter()
254                .filter(move |(id, _)| *id >= seq_id)
255                .map(|(seq_id, msg)| Ok((*seq_id, msg.clone())))
256                .collect();
257            for entry in filtered_items.iter() {
258                if let Some(observer) = &self.observer {
259                    if let Ok((_, msg)) = entry.as_ref() {
260                        observer.send(("read".to_string(), msg.clone())).unwrap();
261                    }
262                }
263            }
264            Ok(futures::stream::iter(filtered_items.into_iter()))
265        }
266
267        async fn read_one(&self, seq_id: SeqId) -> Result<M, MessageLogError> {
268            let it = self.read(seq_id).await?;
269
270            pin_mut!(it);
271            match it.next().await {
272                Some(Ok((result_seq_id, message))) => {
273                    if result_seq_id != seq_id {
274                        panic!("no seq id {}", seq_id);
275                    }
276                    return Ok(message);
277                }
278                Some(Err(err)) => {
279                    return Err(err);
280                }
281                None => {
282                    return Err(MessageLogError::Read(
283                        seq_id,
284                        anyhow::anyhow!("failed to find message with sequence {}", seq_id),
285                    ));
286                }
287            }
288        }
289    }
290}
291
292#[cfg(test)]
293mod tests {
294
295    use std::assert_matches::assert_matches;
296    use std::mem::drop;
297
298    use futures::StreamExt;
299
300    use super::test_utils::TestLog;
301    use super::*;
302    use crate::id;
303    use crate::mailbox::log::SeqId;
304
305    #[tokio::test]
306    async fn test_local_write_ahead_log_basic() {
307        let mut wal = TestLog::new();
308        wal.append(124u64).await.unwrap();
309        wal.append(56u64).await.unwrap();
310        let seq_id = wal.append_and_flush(&999u64).await.unwrap();
311        assert_eq!(seq_id, 3);
312
313        // Simple read given a sequence id
314        let mut it = wal.read(1).await.unwrap();
315        let (next_seq, message): (SeqId, u64) = it.next().await.unwrap().unwrap();
316        assert_eq!(next_seq, 1);
317        assert_eq!(message, 56u64);
318        let (next_seq, message) = it.next().await.unwrap().unwrap();
319        assert_eq!(next_seq, 2);
320        assert_eq!(message, 999u64);
321        assert_matches!(it.next().await, None);
322        // Drop the iterator to release borrow from wal
323        drop(it);
324
325        // Trim then append
326        wal.trim(2).await.unwrap();
327        let seq_id = wal.append_and_flush(&777u64).await.unwrap();
328        assert_eq!(seq_id, 4);
329        let mut it = wal.read(2).await.unwrap();
330        let (next_seq, message): (SeqId, u64) = it.next().await.unwrap().unwrap();
331        assert_eq!(next_seq, 2);
332        assert_eq!(message, 999u64);
333        let (next_seq, message) = it.next().await.unwrap().unwrap();
334        assert_eq!(next_seq, 3);
335        assert_eq!(message, 777u64);
336        assert_matches!(it.next().await, None);
337    }
338
339    #[tokio::test]
340    async fn test_durable_mailbox_sender() {
341        let inner = Mailbox::new_detached(id!(world0[0].actor0));
342        let write_ahead_log = TestLog::new();
343        let mut durable_mbox = DurableMailboxSender::new(write_ahead_log.clone(), inner.clone());
344
345        let (port1, mut receiver1) = inner.open_port::<u64>();
346        let (port2, mut _receiver2) = inner.open_port::<u64>();
347
348        // Convert to references so that the ports are registered.
349        let port1 = port1.bind();
350        let port2 = port2.bind();
351
352        durable_mbox.post(
353            MessageEnvelope::new_unknown(
354                port1.port_id().clone(),
355                Serialized::serialize(&1u64).unwrap(),
356            ),
357            monitored_return_handle(),
358        );
359        durable_mbox.post(
360            MessageEnvelope::new_unknown(
361                port2.port_id().clone(),
362                Serialized::serialize(&2u64).unwrap(),
363            ),
364            monitored_return_handle(),
365        );
366        durable_mbox.post(
367            MessageEnvelope::new_unknown(
368                port1.port_id().clone(),
369                Serialized::serialize(&3u64).unwrap(),
370            ),
371            monitored_return_handle(),
372        );
373        assert_eq!(receiver1.recv().await.unwrap(), 1u64);
374
375        durable_mbox.flush().await.unwrap();
376
377        let mut it = write_ahead_log.read(1).await.unwrap();
378        let (seq, message): (SeqId, MessageEnvelope) = it.next().await.unwrap().unwrap();
379        assert_eq!(seq, 1);
380        assert_eq!(port2.port_id(), message.dest());
381        assert_eq!(2u64, message.deserialized::<u64>().unwrap());
382        let (seq, message): (SeqId, MessageEnvelope) = it.next().await.unwrap().unwrap();
383        assert_eq!(seq, 2);
384        assert_eq!(3u64, message.deserialized::<u64>().unwrap());
385    }
386}