hyperactor/sync/mvar.rs
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module contains an implementation of the MVar synchronization
10//! primitive.
11
12use std::mem::take;
13use std::sync::Arc;
14
15use tokio::sync::Mutex;
16use tokio::sync::MutexGuard;
17use tokio::sync::watch;
18
19/// An MVar is a primitive that combines synchronization and the exchange
20/// of a value. Its semantics are analogous to a synchronous channel of
21/// size 1: if the MVar is full, then `put` blocks until it is emptied;
22/// if the MVar is empty, then `take` blocks until it is filled.
23///
24/// MVars, first introduced in "[Concurrent Haskell](https://www.microsoft.com/en-us/research/wp-content/uploads/1996/01/concurrent-haskell.pdf)"
25/// are surprisingly versatile in use. They can be used as:
26/// - a communication channel (with `put` and `take` corresponding to `send` and `recv`);
27/// - a semaphore (with `put` and `take` corresponding to `signal` and `wait`);
28/// - a mutex (with `put` and `take` corresponding to `lock` and `unlock`);
29#[derive(Clone, Debug)]
30pub struct MVar<T> {
31 seq: watch::Sender<usize>,
32 value: Arc<Mutex<Option<T>>>,
33}
34
35impl<T> MVar<T> {
36 /// Create a new MVar with an optional initial value; if no value is
37 /// provided the MVar starts empty.
38 fn new(init: Option<T>) -> Self {
39 let (seq, _) = watch::channel(0);
40 Self {
41 seq,
42 value: Arc::new(Mutex::new(init)),
43 }
44 }
45
46 /// Create a new full MVar with the provided value.
47 pub fn full(value: T) -> Self {
48 Self::new(Some(value))
49 }
50
51 /// Create a new empty MVar.
52 pub fn empty() -> Self {
53 Self::new(None)
54 }
55
56 async fn waitseq(&self, seq: usize) -> (MutexGuard<'_, Option<T>>, usize) {
57 let mut sub = self.seq.subscribe();
58 while *sub.borrow_and_update() < seq {
59 sub.changed().await.unwrap();
60 }
61 let locked = self.value.lock().await;
62 let seq = *sub.borrow_and_update();
63 (locked, seq + 1)
64 }
65
66 fn notify(&self, seq: usize) {
67 self.seq.send_replace(seq);
68 }
69
70 /// Wait until the MVar is full and take its value.
71 /// This method is cancellation safe.
72 pub async fn take(&self) -> T {
73 let mut seq = 0;
74 loop {
75 let mut value;
76 (value, seq) = self.waitseq(seq).await;
77
78 if let Some(current_value) = take(&mut *value) {
79 self.notify(seq);
80 break current_value;
81 }
82 drop(value);
83 }
84 }
85
86 /// Wait until the MVar is empty and put a new value.
87 /// This method is cancellation safe.
88 pub async fn put(&self, new_value: T) {
89 let mut seq = 0;
90 loop {
91 let mut value;
92 (value, seq) = self.waitseq(seq).await;
93
94 if value.is_none() {
95 *value = Some(new_value);
96 self.notify(seq);
97 break;
98 }
99 drop(value);
100 }
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[tokio::test]
109 async fn test_mvar() {
110 let mv0 = MVar::full(0);
111 let mv1 = MVar::empty();
112
113 assert_eq!(mv0.take().await, 0);
114
115 tokio::spawn({
116 let mv0 = mv0.clone();
117 let mv1 = mv1.clone();
118 async move { mv1.put(mv0.take().await).await }
119 });
120
121 mv0.put(1).await;
122 assert_eq!(mv1.take().await, 1);
123 }
124}