hyperactor/sync/mvar.rs
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module contains an implementation of the MVar synchronization
10//! primitive.
11
12use std::mem::take;
13use std::sync::Arc;
14
15use tokio::sync::Mutex;
16use tokio::sync::MutexGuard;
17use tokio::sync::watch;
18
19/// An MVar is a primitive that combines synchronization and the exchange
20/// of a value. Its semantics are analogous to a synchronous channel of
21/// size 1: if the MVar is full, then `put` blocks until it is emptied;
22/// if the MVar is empty, then `take` blocks until it is filled.
23///
24/// MVars, first introduced in "[Concurrent Haskell](https://www.microsoft.com/en-us/research/wp-content/uploads/1996/01/concurrent-haskell.pdf)"
25/// are surprisingly versatile in use. They can be used as:
26/// - a communication channel (with `put` and `take` corresponding to `send` and `recv`);
27/// - a semaphore (with `put` and `take` corresponding to `signal` and `wait`);
28/// - a mutex (with `put` and `take` corresponding to `lock` and `unlock`);
29#[derive(Debug)]
30pub struct MVar<T> {
31 seq: watch::Sender<usize>,
32 value: Arc<Mutex<Option<T>>>,
33}
34
35// Manual Clone impl: cloning an MVar clones the Arc, not the inner value,
36// so T does not need to be Clone.
37impl<T> Clone for MVar<T> {
38 fn clone(&self) -> Self {
39 Self {
40 seq: self.seq.clone(),
41 value: self.value.clone(),
42 }
43 }
44}
45
46impl<T> MVar<T> {
47 /// Create a new MVar with an optional initial value; if no value is
48 /// provided the MVar starts empty.
49 fn new(init: Option<T>) -> Self {
50 let (seq, _) = watch::channel(0);
51 Self {
52 seq,
53 value: Arc::new(Mutex::new(init)),
54 }
55 }
56
57 /// Create a new full MVar with the provided value.
58 pub fn full(value: T) -> Self {
59 Self::new(Some(value))
60 }
61
62 /// Create a new empty MVar.
63 pub fn empty() -> Self {
64 Self::new(None)
65 }
66
67 async fn waitseq(&self, seq: usize) -> (MutexGuard<'_, Option<T>>, usize) {
68 let mut sub = self.seq.subscribe();
69 while *sub.borrow_and_update() < seq {
70 sub.changed().await.unwrap();
71 }
72 let locked = self.value.lock().await;
73 let seq = *sub.borrow_and_update();
74 (locked, seq + 1)
75 }
76
77 fn notify(&self, seq: usize) {
78 self.seq.send_replace(seq);
79 }
80
81 /// Wait until the MVar is full and take its value.
82 /// This method is cancellation safe.
83 pub async fn take(&self) -> T {
84 let mut seq = 0;
85 loop {
86 let mut value;
87 (value, seq) = self.waitseq(seq).await;
88
89 if let Some(current_value) = take(&mut *value) {
90 self.notify(seq);
91 break current_value;
92 }
93 drop(value);
94 }
95 }
96
97 /// Wait until the MVar is empty and put a new value.
98 /// This method is cancellation safe.
99 pub async fn put(&self, new_value: T) {
100 let mut seq = 0;
101 loop {
102 let mut value;
103 (value, seq) = self.waitseq(seq).await;
104
105 if value.is_none() {
106 *value = Some(new_value);
107 self.notify(seq);
108 break;
109 }
110 drop(value);
111 }
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[tokio::test]
120 async fn test_mvar() {
121 let mv0 = MVar::full(0);
122 let mv1 = MVar::empty();
123
124 assert_eq!(mv0.take().await, 0);
125
126 tokio::spawn({
127 let mv0 = mv0.clone();
128 let mv1 = mv1.clone();
129 async move { mv1.put(mv0.take().await).await }
130 });
131
132 mv0.put(1).await;
133 assert_eq!(mv1.take().await, 1);
134 }
135}