hyperactor/sync/
mvar.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module contains an implementation of the MVar synchronization
10//! primitive.
11
12use std::mem::take;
13use std::sync::Arc;
14
15use tokio::sync::Mutex;
16use tokio::sync::MutexGuard;
17use tokio::sync::watch;
18
19/// An MVar is a primitive that combines synchronization and the exchange
20/// of a value. Its semantics are analogous to a synchronous channel of
21/// size 1: if the MVar is full, then `put` blocks until it is emptied;
22/// if the MVar is empty, then `take` blocks until it is filled.
23///
24/// MVars, first introduced in "[Concurrent Haskell](https://www.microsoft.com/en-us/research/wp-content/uploads/1996/01/concurrent-haskell.pdf)"
25/// are surprisingly versatile in use. They can be used as:
26/// - a communication channel (with `put` and `take` corresponding to `send` and `recv`);
27/// - a semaphore (with `put` and `take` corresponding to `signal` and `wait`);
28/// - a mutex (with `put` and `take` corresponding to `lock` and `unlock`);
29#[derive(Clone, Debug)]
30pub struct MVar<T> {
31    seq: watch::Sender<usize>,
32    value: Arc<Mutex<Option<T>>>,
33}
34
35impl<T> MVar<T> {
36    /// Create a new MVar with an optional initial value; if no value is
37    /// provided the MVar starts empty.
38    fn new(init: Option<T>) -> Self {
39        let (seq, _) = watch::channel(0);
40        Self {
41            seq,
42            value: Arc::new(Mutex::new(init)),
43        }
44    }
45
46    /// Create a new full MVar with the provided value.
47    pub fn full(value: T) -> Self {
48        Self::new(Some(value))
49    }
50
51    /// Create a new empty MVar.
52    pub fn empty() -> Self {
53        Self::new(None)
54    }
55
56    async fn waitseq(&self, seq: usize) -> (MutexGuard<'_, Option<T>>, usize) {
57        let mut sub = self.seq.subscribe();
58        while *sub.borrow_and_update() < seq {
59            sub.changed().await.unwrap();
60        }
61        let locked = self.value.lock().await;
62        let seq = *sub.borrow_and_update();
63        (locked, seq + 1)
64    }
65
66    fn notify(&self, seq: usize) {
67        self.seq.send_replace(seq);
68    }
69
70    /// Wait until the MVar is full and take its value.
71    /// This method is cancellation safe.
72    pub async fn take(&self) -> T {
73        let mut seq = 0;
74        loop {
75            let mut value;
76            (value, seq) = self.waitseq(seq).await;
77
78            if let Some(current_value) = take(&mut *value) {
79                self.notify(seq);
80                break current_value;
81            }
82            drop(value);
83        }
84    }
85
86    /// Wait until the MVar is empty and put a new value.
87    /// This method is cancellation safe.
88    pub async fn put(&self, new_value: T) {
89        let mut seq = 0;
90        loop {
91            let mut value;
92            (value, seq) = self.waitseq(seq).await;
93
94            if value.is_none() {
95                *value = Some(new_value);
96                self.notify(seq);
97                break;
98            }
99            drop(value);
100        }
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107
108    #[tokio::test]
109    async fn test_mvar() {
110        let mv0 = MVar::full(0);
111        let mv1 = MVar::empty();
112
113        assert_eq!(mv0.take().await, 0);
114
115        tokio::spawn({
116            let mv0 = mv0.clone();
117            let mv1 = mv1.clone();
118            async move { mv1.put(mv0.take().await).await }
119        });
120
121        mv0.put(1).await;
122        assert_eq!(mv1.take().await, 1);
123    }
124}