1use log::*;
10
11use super::*;
12
13#[derive(Debug)]
17struct DurableMailboxSender(Buffer<MessageEnvelope>);
18
19impl DurableMailboxSender {
20 fn new(
21 write_ahead_log: impl MessageLog<MessageEnvelope> + 'static,
22 inner: impl MailboxSender + 'static,
23 ) -> Self {
24 let write_ahead_log = Arc::new(tokio::sync::Mutex::new(write_ahead_log));
25 let inner = Arc::new(inner);
26 let sequencer =
27 Buffer::new(
28 move |envelope: MessageEnvelope,
29 return_handle: PortHandle<Undeliverable<MessageEnvelope>>| {
30 let write_ahead_log = write_ahead_log.clone();
31 let inner = inner.clone();
32 let return_handle = return_handle.clone();
33 async move {
34 let envelope_copy = envelope.clone(); let port_id = envelope.dest().clone();
36 let mut log = write_ahead_log.lock().await;
37 let append_result = log.append(envelope).await.map_err(|err| {
40 MailboxSenderError::new_bound(port_id.clone(), err.into())
41 });
42
43 let flush_result = log.flush().await.map_err(|err| {
44 MailboxSenderError::new_bound(port_id.clone(), err.into())
45 });
46
47 drop(log);
48
49 if append_result.and(flush_result).is_ok() {
50 inner.post(envelope_copy, return_handle);
51 } else {
52 envelope_copy.undeliverable(
53 DeliveryError::BrokenLink(
54 "failed to append or flush in durable sender".to_string(),
55 ),
56 return_handle,
57 );
58 }
59 }
60 },
61 );
62
63 Self(sequencer)
64 }
65
66 async fn flush(&mut self) -> Result<(), watch::error::RecvError> {
67 self.0.flush().await
68 }
69}
70
71#[async_trait]
72impl MailboxSender for DurableMailboxSender {
73 fn post_unchecked(
74 &self,
75 envelope: MessageEnvelope,
76 return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
77 ) {
78 if let Err(mpsc::error::SendError((envelope, return_handle))) =
79 self.0.send((envelope, return_handle))
80 {
81 envelope.undeliverable(
82 DeliveryError::BrokenLink("failed to post in DurableMailboxSender".to_string()),
83 return_handle,
84 );
85 }
86 }
87}
88
89pub mod log {
90
91 use std::fmt::Debug;
95
96 use async_trait::async_trait;
97 use futures::stream::Stream;
98
99 use crate::RemoteMessage;
100
101 pub type SeqId = u64;
103
104 #[derive(thiserror::Error, Debug)]
107 #[non_exhaustive]
108 pub enum MessageLogError {
109 #[error("flush: [{0}, {1})")]
111 Flush(SeqId, SeqId, #[source] anyhow::Error),
112
113 #[error("append: {0}")]
115 Append(SeqId, #[source] anyhow::Error),
116
117 #[error("read: {0}")]
119 Read(SeqId, #[source] anyhow::Error),
120
121 #[error("trim: {0}")]
123 Trim(SeqId, #[source] anyhow::Error),
124
125 #[error(transparent)]
127 Other(#[from] anyhow::Error),
128 }
129
130 #[async_trait]
136 pub trait MessageLog<M: RemoteMessage>: Sync + Send + Debug {
137 type Stream<'a>: Stream<Item = Result<(SeqId, M), MessageLogError>> + Send
139 where
140 Self: 'a;
141
142 async fn append(&mut self, message: M) -> Result<(), MessageLogError>;
145
146 async fn flush(&mut self) -> Result<SeqId, MessageLogError>;
148
149 async fn append_and_flush(&mut self, message: &M) -> Result<SeqId, MessageLogError>;
152
153 async fn trim(&mut self, new_start: SeqId) -> Result<(), MessageLogError>;
155
156 async fn read(&self, from: SeqId) -> Result<Self::Stream<'_>, MessageLogError>;
160
161 async fn read_one(&self, seq_id: SeqId) -> Result<M, MessageLogError>;
165 }
166}
167
168pub mod test_utils {
170
171 use std::collections::VecDeque;
172
173 use futures::pin_mut;
174 use log::SeqId;
175 use tokio_stream::StreamExt;
176
177 use super::*;
178
179 #[derive(Debug, Clone)]
181 pub struct TestLog<M: RemoteMessage> {
182 queue: Arc<Mutex<VecDeque<(SeqId, M)>>>,
183 current_seq_id: Arc<Mutex<SeqId>>,
184 observer: Option<mpsc::UnboundedSender<(String, M)>>,
186 }
187
188 impl<M: RemoteMessage> Default for TestLog<M> {
189 fn default() -> Self {
190 Self::new()
191 }
192 }
193
194 impl<M: RemoteMessage> TestLog<M> {
195 pub fn new() -> Self {
197 Self {
198 queue: Arc::new(Mutex::new(VecDeque::new())),
199 current_seq_id: Arc::new(Mutex::new(0)),
200 observer: None,
201 }
202 }
203
204 pub fn new_with_observer(observer: mpsc::UnboundedSender<(String, M)>) -> Self {
208 Self {
209 queue: Arc::new(Mutex::new(VecDeque::new())),
210 current_seq_id: Arc::new(Mutex::new(0)),
211 observer: Some(observer),
212 }
213 }
214 }
215
216 #[async_trait]
217 impl<M: RemoteMessage + Clone> MessageLog<M> for TestLog<M> {
218 type Stream<'a> =
219 futures::stream::Iter<std::vec::IntoIter<Result<(SeqId, M), MessageLogError>>>;
220
221 async fn append(&mut self, message: M) -> Result<(), MessageLogError> {
222 let mut seq_id = self.current_seq_id.lock().unwrap();
223 self.queue
224 .lock()
225 .unwrap()
226 .push_back((*seq_id, message.clone()));
227 *seq_id += 1;
228 if let Some(observer) = &self.observer {
229 observer.send(("append".to_string(), message)).unwrap();
230 }
231 Ok(())
232 }
233
234 async fn flush(&mut self) -> Result<SeqId, MessageLogError> {
235 let seq_id = *self.current_seq_id.lock().unwrap();
236 Ok(seq_id)
237 }
238
239 async fn append_and_flush(&mut self, message: &M) -> Result<SeqId, MessageLogError> {
240 self.append(message.clone()).await?;
241 self.flush().await
242 }
243
244 async fn trim(&mut self, new_start: SeqId) -> Result<(), MessageLogError> {
245 let mut queue = self.queue.lock().unwrap();
246 while let Some((id, _)) = queue.front() {
247 if *id < new_start {
248 queue.pop_front();
249 } else {
250 break;
251 }
252 }
253 Ok(())
254 }
255
256 async fn read(&self, seq_id: SeqId) -> Result<Self::Stream<'_>, MessageLogError> {
257 let queue = self.queue.lock().unwrap();
258 let filtered_items: Vec<_> = queue
259 .iter()
260 .filter(move |(id, _)| *id >= seq_id)
261 .map(|(seq_id, msg)| Ok((*seq_id, msg.clone())))
262 .collect();
263 for entry in filtered_items.iter() {
264 if let Some(observer) = &self.observer
265 && let Ok((_, msg)) = entry.as_ref()
266 {
267 observer.send(("read".to_string(), msg.clone())).unwrap();
268 }
269 }
270 Ok(futures::stream::iter(filtered_items.into_iter()))
271 }
272
273 async fn read_one(&self, seq_id: SeqId) -> Result<M, MessageLogError> {
274 let it = self.read(seq_id).await?;
275
276 pin_mut!(it);
277 match it.next().await {
278 Some(Ok((result_seq_id, message))) => {
279 if result_seq_id != seq_id {
280 panic!("no seq id {}", seq_id);
281 }
282 return Ok(message);
283 }
284 Some(Err(err)) => {
285 return Err(err);
286 }
287 None => {
288 return Err(MessageLogError::Read(
289 seq_id,
290 anyhow::anyhow!("failed to find message with sequence {}", seq_id),
291 ));
292 }
293 }
294 }
295 }
296}
297
298#[cfg(test)]
299mod tests {
300
301 use std::assert_matches::assert_matches;
302 use std::mem::drop;
303
304 use futures::StreamExt;
305
306 use super::test_utils::TestLog;
307 use super::*;
308 use crate::id;
309 use crate::mailbox::log::SeqId;
310
311 #[tokio::test]
312 async fn test_local_write_ahead_log_basic() {
313 let mut wal = TestLog::new();
314 wal.append(124u64).await.unwrap();
315 wal.append(56u64).await.unwrap();
316 let seq_id = wal.append_and_flush(&999u64).await.unwrap();
317 assert_eq!(seq_id, 3);
318
319 let mut it = wal.read(1).await.unwrap();
321 let (next_seq, message): (SeqId, u64) = it.next().await.unwrap().unwrap();
322 assert_eq!(next_seq, 1);
323 assert_eq!(message, 56u64);
324 let (next_seq, message) = it.next().await.unwrap().unwrap();
325 assert_eq!(next_seq, 2);
326 assert_eq!(message, 999u64);
327 assert_matches!(it.next().await, None);
328 drop(it);
330
331 wal.trim(2).await.unwrap();
333 let seq_id = wal.append_and_flush(&777u64).await.unwrap();
334 assert_eq!(seq_id, 4);
335 let mut it = wal.read(2).await.unwrap();
336 let (next_seq, message): (SeqId, u64) = it.next().await.unwrap().unwrap();
337 assert_eq!(next_seq, 2);
338 assert_eq!(message, 999u64);
339 let (next_seq, message) = it.next().await.unwrap().unwrap();
340 assert_eq!(next_seq, 3);
341 assert_eq!(message, 777u64);
342 assert_matches!(it.next().await, None);
343 }
344
345 #[tokio::test]
346 async fn test_durable_mailbox_sender() {
347 let inner = Mailbox::new_detached(id!(world0[0].actor0));
348 let write_ahead_log = TestLog::new();
349 let mut durable_mbox = DurableMailboxSender::new(write_ahead_log.clone(), inner.clone());
350
351 let (port1, mut receiver1) = inner.open_port::<u64>();
352 let (port2, mut _receiver2) = inner.open_port::<u64>();
353
354 let port1 = port1.bind();
356 let port2 = port2.bind();
357
358 durable_mbox.post(
359 MessageEnvelope::new_unknown(
360 port1.port_id().clone(),
361 Serialized::serialize(&1u64).unwrap(),
362 ),
363 monitored_return_handle(),
364 );
365 durable_mbox.post(
366 MessageEnvelope::new_unknown(
367 port2.port_id().clone(),
368 Serialized::serialize(&2u64).unwrap(),
369 ),
370 monitored_return_handle(),
371 );
372 durable_mbox.post(
373 MessageEnvelope::new_unknown(
374 port1.port_id().clone(),
375 Serialized::serialize(&3u64).unwrap(),
376 ),
377 monitored_return_handle(),
378 );
379 assert_eq!(receiver1.recv().await.unwrap(), 1u64);
380
381 durable_mbox.flush().await.unwrap();
382
383 let mut it = write_ahead_log.read(1).await.unwrap();
384 let (seq, message): (SeqId, MessageEnvelope) = it.next().await.unwrap().unwrap();
385 assert_eq!(seq, 1);
386 assert_eq!(port2.port_id(), message.dest());
387 assert_eq!(2u64, message.deserialized::<u64>().unwrap());
388 let (seq, message): (SeqId, MessageEnvelope) = it.next().await.unwrap().unwrap();
389 assert_eq!(seq, 2);
390 assert_eq!(3u64, message.deserialized::<u64>().unwrap());
391 }
392}