hyperactor/sync/
mvar.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module contains an implementation of the MVar synchronization
10//! primitive.
11
12use std::mem::take;
13use std::sync::Arc;
14
15use tokio::sync::Mutex;
16use tokio::sync::MutexGuard;
17use tokio::sync::watch;
18
19/// An MVar is a primitive that combines synchronization and the exchange
20/// of a value. Its semantics are analogous to a synchronous channel of
21/// size 1: if the MVar is full, then `put` blocks until it is emptied;
22/// if the MVar is empty, then `take` blocks until it is filled.
23///
24/// MVars, first introduced in "[Concurrent Haskell](https://www.microsoft.com/en-us/research/wp-content/uploads/1996/01/concurrent-haskell.pdf)"
25/// are surprisingly versatile in use. They can be used as:
26/// - a communication channel (with `put` and `take` corresponding to `send` and `recv`);
27/// - a semaphore (with `put` and `take` corresponding to `signal` and `wait`);
28/// - a mutex (with `put` and `take` corresponding to `lock` and `unlock`);
29#[derive(Debug)]
30pub struct MVar<T> {
31    seq: watch::Sender<usize>,
32    value: Arc<Mutex<Option<T>>>,
33}
34
35// Manual Clone impl: cloning an MVar clones the Arc, not the inner value,
36// so T does not need to be Clone.
37impl<T> Clone for MVar<T> {
38    fn clone(&self) -> Self {
39        Self {
40            seq: self.seq.clone(),
41            value: self.value.clone(),
42        }
43    }
44}
45
46impl<T> MVar<T> {
47    /// Create a new MVar with an optional initial value; if no value is
48    /// provided the MVar starts empty.
49    fn new(init: Option<T>) -> Self {
50        let (seq, _) = watch::channel(0);
51        Self {
52            seq,
53            value: Arc::new(Mutex::new(init)),
54        }
55    }
56
57    /// Create a new full MVar with the provided value.
58    pub fn full(value: T) -> Self {
59        Self::new(Some(value))
60    }
61
62    /// Create a new empty MVar.
63    pub fn empty() -> Self {
64        Self::new(None)
65    }
66
67    async fn waitseq(&self, seq: usize) -> (MutexGuard<'_, Option<T>>, usize) {
68        let mut sub = self.seq.subscribe();
69        while *sub.borrow_and_update() < seq {
70            sub.changed().await.unwrap();
71        }
72        let locked = self.value.lock().await;
73        let seq = *sub.borrow_and_update();
74        (locked, seq + 1)
75    }
76
77    fn notify(&self, seq: usize) {
78        self.seq.send_replace(seq);
79    }
80
81    /// Wait until the MVar is full and take its value.
82    /// This method is cancellation safe.
83    pub async fn take(&self) -> T {
84        let mut seq = 0;
85        loop {
86            let mut value;
87            (value, seq) = self.waitseq(seq).await;
88
89            if let Some(current_value) = take(&mut *value) {
90                self.notify(seq);
91                break current_value;
92            }
93            drop(value);
94        }
95    }
96
97    /// Wait until the MVar is empty and put a new value.
98    /// This method is cancellation safe.
99    pub async fn put(&self, new_value: T) {
100        let mut seq = 0;
101        loop {
102            let mut value;
103            (value, seq) = self.waitseq(seq).await;
104
105            if value.is_none() {
106                *value = Some(new_value);
107                self.notify(seq);
108                break;
109            }
110            drop(value);
111        }
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[tokio::test]
120    async fn test_mvar() {
121        let mv0 = MVar::full(0);
122        let mv1 = MVar::empty();
123
124        assert_eq!(mv0.take().await, 0);
125
126        tokio::spawn({
127            let mv0 = mv0.clone();
128            let mv1 = mv1.clone();
129            async move { mv1.put(mv0.take().await).await }
130        });
131
132        mv0.put(1).await;
133        assert_eq!(mv1.take().await, 1);
134    }
135}