hyperactor/mailbox/
durable_mailbox_sender.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use log::*;
10
11use super::*;
12
13/// A [`DurableMailboxSender`] is a [`MailboxSender`] that writes messages to a write-ahead log
14/// before the receiver consume any of them. It allows the receiver to recover from crashes by
15/// replaying the log. It supports any implementation of [`MailboxSender`].
16struct DurableMailboxSender(Buffer<MessageEnvelope>);
17
18impl DurableMailboxSender {
19    #[allow(dead_code)]
20    fn new(
21        write_ahead_log: impl MessageLog<MessageEnvelope> + 'static,
22        inner: impl MailboxSender + 'static,
23    ) -> Self {
24        let write_ahead_log = Arc::new(tokio::sync::Mutex::new(write_ahead_log));
25        let inner = Arc::new(inner);
26        let sequencer =
27            Buffer::new(
28                move |envelope: MessageEnvelope,
29                      return_handle: PortHandle<Undeliverable<MessageEnvelope>>| {
30                    let write_ahead_log = write_ahead_log.clone();
31                    let inner = inner.clone();
32                    let return_handle = return_handle.clone();
33                    async move {
34                        let envelope_copy = envelope.clone(); // we maintain a copy in case we have to mark it failed
35                        let port_id = envelope.dest().clone();
36                        let mut log = write_ahead_log.lock().await;
37                        // TODO: There are potentially two ways to avoid copy; both require interface change.
38                        // (1) use Rc or Arc and (2) implement our own CopyOnDrop struct
39                        let append_result = log.append(envelope).await.map_err(|err| {
40                            MailboxSenderError::new_bound(port_id.clone(), err.into())
41                        });
42
43                        let flush_result = log.flush().await.map_err(|err| {
44                            MailboxSenderError::new_bound(port_id.clone(), err.into())
45                        });
46
47                        drop(log);
48
49                        if append_result.and(flush_result).is_ok() {
50                            inner.post(envelope_copy, return_handle);
51                        } else {
52                            envelope_copy.undeliverable(
53                                DeliveryError::BrokenLink(
54                                    "failed to append or flush in durable sender".to_string(),
55                                ),
56                                return_handle,
57                            );
58                        }
59                    }
60                },
61            );
62
63        Self(sequencer)
64    }
65
66    #[allow(dead_code)]
67    async fn flush(&mut self) -> Result<(), watch::error::RecvError> {
68        self.0.flush().await
69    }
70}
71
72#[async_trait]
73impl MailboxSender for DurableMailboxSender {
74    fn post_unchecked(
75        &self,
76        envelope: MessageEnvelope,
77        return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
78    ) {
79        if let Err(mpsc::error::SendError((envelope, return_handle))) =
80            self.0.send((envelope, return_handle))
81        {
82            envelope.undeliverable(
83                DeliveryError::BrokenLink("failed to post in DurableMailboxSender".to_string()),
84                return_handle,
85            );
86        }
87    }
88}
89
90pub mod log {
91
92    //! This module implements a write-ahead log for mailboxes. This can be used to provide
93    //! durable messaging facilities for actors.
94
95    use std::fmt::Debug;
96
97    use async_trait::async_trait;
98    use futures::stream::Stream;
99
100    use crate::RemoteMessage;
101
102    /// A sequence id is a unique identifier for a message.
103    pub type SeqId = u64;
104
105    /// Errors that occur during message log operations.
106    /// This enum is marked non-exhaustive to allow for extensibility.
107    #[derive(thiserror::Error, Debug)]
108    #[non_exhaustive]
109    pub enum MessageLogError {
110        /// An error occured during flushing messages with a sequence id range.
111        #[error("flush: [{0}, {1})")]
112        Flush(SeqId, SeqId, #[source] anyhow::Error),
113
114        /// An error occured during appending a message with an assigned sequence id.
115        #[error("append: {0}")]
116        Append(SeqId, #[source] anyhow::Error),
117
118        /// An error occured during reading a message with the persistent sequence id.
119        #[error("read: {0}")]
120        Read(SeqId, #[source] anyhow::Error),
121
122        /// An error occured during trimming a message with the persistent sequence id.
123        #[error("trim: {0}")]
124        Trim(SeqId, #[source] anyhow::Error),
125
126        /// An other error.
127        #[error(transparent)]
128        Other(#[from] anyhow::Error),
129    }
130
131    /// This [`MessageLog`] is a log that serves as a building block to persist data before the
132    /// consumer can process it. One typical example is to persist messages before an actor handles it.
133    /// In such a case, it can be used as a white-ahead log. It allows the actor to recover from a
134    /// crash without requesting resending the messages. The log is append-only and the messages are
135    /// persisted in order with sequence ids.
136    #[async_trait]
137    pub trait MessageLog<M: RemoteMessage>: Sync + Send {
138        /// The type of the stream returned from read operations on this log.
139        type Stream<'a>: Stream<Item = Result<(SeqId, M), MessageLogError>> + Send
140        where
141            Self: 'a;
142
143        /// Append a message to a buffer. The appended messages will only be persisted and available to
144        /// read after calling [`flush`].
145        async fn append(&mut self, message: M) -> Result<(), MessageLogError>;
146
147        /// Flush the appended messages. Return the next sequence id of the last persistent message.
148        async fn flush(&mut self) -> Result<SeqId, MessageLogError>;
149
150        /// Directly flush the message. All previously buffered messages will be flushed as well.
151        /// This convenience method can prevent an additional copy of the message by directly writing to the log.
152        async fn append_and_flush(&mut self, message: &M) -> Result<SeqId, MessageLogError>;
153
154        /// Trim the persistent logs before the given [`new_start`] non-inclusively.
155        async fn trim(&mut self, new_start: SeqId) -> Result<(), MessageLogError>;
156
157        /// Given a sequence id, return a stream of message and sequence id tuples that are persisted
158        /// after the given sequence id inclusively. The stream will yield errors when streaming
159        /// messages back if any. It will also yield errors if creating the stream itself fails.
160        async fn read(&self, from: SeqId) -> Result<Self::Stream<'_>, MessageLogError>;
161
162        /// Read exactly one message from the log. If the log is empty, return an error.
163        // Ideally, this method can have a default implmentation. But the compiler complains
164        // about `self` does not live long enough.
165        async fn read_one(&self, seq_id: SeqId) -> Result<M, MessageLogError>;
166    }
167}
168
169/// A test util mod so that it can be used beyond the crate
170pub mod test_utils {
171
172    use std::collections::VecDeque;
173
174    use futures::pin_mut;
175    use log::SeqId;
176    use tokio_stream::StreamExt;
177
178    use super::*;
179
180    /// An in-memory log for testing.
181    #[derive(Clone)]
182    pub struct TestLog<M: RemoteMessage> {
183        queue: Arc<Mutex<VecDeque<(SeqId, M)>>>,
184        current_seq_id: Arc<Mutex<SeqId>>,
185        // For outside to validate the values of saved messages.
186        observer: Option<mpsc::UnboundedSender<(String, M)>>,
187    }
188
189    impl<M: RemoteMessage> Default for TestLog<M> {
190        fn default() -> Self {
191            Self::new()
192        }
193    }
194
195    impl<M: RemoteMessage> TestLog<M> {
196        /// Create a new, empty [`TestLog`].
197        pub fn new() -> Self {
198            Self {
199                queue: Arc::new(Mutex::new(VecDeque::new())),
200                current_seq_id: Arc::new(Mutex::new(0)),
201                observer: None,
202            }
203        }
204
205        /// Create a new test log that sends all log operations to the provided
206        /// observer. The observer is sent tuples of `(op, message)`, where `op` is
207        /// either "append" or "read".
208        pub fn new_with_observer(observer: mpsc::UnboundedSender<(String, M)>) -> Self {
209            Self {
210                queue: Arc::new(Mutex::new(VecDeque::new())),
211                current_seq_id: Arc::new(Mutex::new(0)),
212                observer: Some(observer),
213            }
214        }
215    }
216
217    #[async_trait]
218    impl<M: RemoteMessage + Clone> MessageLog<M> for TestLog<M> {
219        type Stream<'a> =
220            futures::stream::Iter<std::vec::IntoIter<Result<(SeqId, M), MessageLogError>>>;
221
222        async fn append(&mut self, message: M) -> Result<(), MessageLogError> {
223            let mut seq_id = self.current_seq_id.lock().unwrap();
224            self.queue
225                .lock()
226                .unwrap()
227                .push_back((*seq_id, message.clone()));
228            *seq_id += 1;
229            if let Some(observer) = &self.observer {
230                observer.send(("append".to_string(), message)).unwrap();
231            }
232            Ok(())
233        }
234
235        async fn flush(&mut self) -> Result<SeqId, MessageLogError> {
236            let seq_id = *self.current_seq_id.lock().unwrap();
237            Ok(seq_id)
238        }
239
240        async fn append_and_flush(&mut self, message: &M) -> Result<SeqId, MessageLogError> {
241            self.append(message.clone()).await?;
242            self.flush().await
243        }
244
245        async fn trim(&mut self, new_start: SeqId) -> Result<(), MessageLogError> {
246            let mut queue = self.queue.lock().unwrap();
247            while let Some((id, _)) = queue.front() {
248                if *id < new_start {
249                    queue.pop_front();
250                } else {
251                    break;
252                }
253            }
254            Ok(())
255        }
256
257        async fn read(&self, seq_id: SeqId) -> Result<Self::Stream<'_>, MessageLogError> {
258            let queue = self.queue.lock().unwrap();
259            let filtered_items: Vec<_> = queue
260                .iter()
261                .filter(move |(id, _)| *id >= seq_id)
262                .map(|(seq_id, msg)| Ok((*seq_id, msg.clone())))
263                .collect();
264            for entry in filtered_items.iter() {
265                if let Some(observer) = &self.observer
266                    && let Ok((_, msg)) = entry.as_ref()
267                {
268                    observer.send(("read".to_string(), msg.clone())).unwrap();
269                }
270            }
271            Ok(futures::stream::iter(filtered_items))
272        }
273
274        async fn read_one(&self, seq_id: SeqId) -> Result<M, MessageLogError> {
275            let it = self.read(seq_id).await?;
276
277            pin_mut!(it);
278            match it.next().await {
279                Some(Ok((result_seq_id, message))) => {
280                    if result_seq_id != seq_id {
281                        panic!("no seq id {}", seq_id);
282                    }
283                    return Ok(message);
284                }
285                Some(Err(err)) => {
286                    return Err(err);
287                }
288                None => {
289                    return Err(MessageLogError::Read(
290                        seq_id,
291                        anyhow::anyhow!("failed to find message with sequence {}", seq_id),
292                    ));
293                }
294            }
295        }
296    }
297}
298
299#[cfg(test)]
300mod tests {
301
302    use std::assert_matches::assert_matches;
303    use std::mem::drop;
304
305    use futures::StreamExt;
306
307    use super::test_utils::TestLog;
308    use super::*;
309    use crate::mailbox::log::SeqId;
310    use crate::testing::ids::test_actor_id;
311
312    #[tokio::test]
313    async fn test_local_write_ahead_log_basic() {
314        let mut wal = TestLog::new();
315        wal.append(124u64).await.unwrap();
316        wal.append(56u64).await.unwrap();
317        let seq_id = wal.append_and_flush(&999u64).await.unwrap();
318        assert_eq!(seq_id, 3);
319
320        // Simple read given a sequence id
321        let mut it = wal.read(1).await.unwrap();
322        let (next_seq, message): (SeqId, u64) = it.next().await.unwrap().unwrap();
323        assert_eq!(next_seq, 1);
324        assert_eq!(message, 56u64);
325        let (next_seq, message) = it.next().await.unwrap().unwrap();
326        assert_eq!(next_seq, 2);
327        assert_eq!(message, 999u64);
328        assert_matches!(it.next().await, None);
329        // Drop the iterator to release borrow from wal
330        drop(it);
331
332        // Trim then append
333        wal.trim(2).await.unwrap();
334        let seq_id = wal.append_and_flush(&777u64).await.unwrap();
335        assert_eq!(seq_id, 4);
336        let mut it = wal.read(2).await.unwrap();
337        let (next_seq, message): (SeqId, u64) = it.next().await.unwrap().unwrap();
338        assert_eq!(next_seq, 2);
339        assert_eq!(message, 999u64);
340        let (next_seq, message) = it.next().await.unwrap().unwrap();
341        assert_eq!(next_seq, 3);
342        assert_eq!(message, 777u64);
343        assert_matches!(it.next().await, None);
344    }
345
346    #[tokio::test]
347    async fn test_durable_mailbox_sender() {
348        let inner = Mailbox::new_detached(test_actor_id("world0_0", "actor0"));
349        let write_ahead_log = TestLog::new();
350        let mut durable_mbox = DurableMailboxSender::new(write_ahead_log.clone(), inner.clone());
351
352        let (port1, mut receiver1) = inner.open_port::<u64>();
353        let (port2, mut _receiver2) = inner.open_port::<u64>();
354
355        // Convert to references so that the ports are registered.
356        let port1 = port1.bind();
357        let port2 = port2.bind();
358
359        durable_mbox.post(
360            MessageEnvelope::new_unknown(
361                port1.port_id().clone(),
362                wirevalue::Any::serialize(&1u64).unwrap(),
363            ),
364            monitored_return_handle(),
365        );
366        durable_mbox.post(
367            MessageEnvelope::new_unknown(
368                port2.port_id().clone(),
369                wirevalue::Any::serialize(&2u64).unwrap(),
370            ),
371            monitored_return_handle(),
372        );
373        durable_mbox.post(
374            MessageEnvelope::new_unknown(
375                port1.port_id().clone(),
376                wirevalue::Any::serialize(&3u64).unwrap(),
377            ),
378            monitored_return_handle(),
379        );
380        assert_eq!(receiver1.recv().await.unwrap(), 1u64);
381
382        durable_mbox.flush().await.unwrap();
383
384        let mut it = write_ahead_log.read(1).await.unwrap();
385        let (seq, message): (SeqId, MessageEnvelope) = it.next().await.unwrap().unwrap();
386        assert_eq!(seq, 1);
387        assert_eq!(port2.port_id(), message.dest());
388        assert_eq!(2u64, message.deserialized::<u64>().unwrap());
389        let (seq, message): (SeqId, MessageEnvelope) = it.next().await.unwrap().unwrap();
390        assert_eq!(seq, 2);
391        assert_eq!(3u64, message.deserialized::<u64>().unwrap());
392    }
393}