hyperactor/sync/
flag.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! A simple flagging mechanism to coordinate between tasks.
10//!
11//! ```
12//! # use hyperactor::sync::flag;
13//!
14//! # tokio_test::block_on(async {
15//! let (flag, guard) = flag::guarded();
16//! assert!(!flag.signalled());
17//! let (flag1, guard1) = flag::guarded();
18//! tokio::spawn(async move {
19//!     let _guard = guard;
20//!     flag1.await;
21//! });
22//! drop(guard1);
23//! flag.await
24//! # })
25//! ```
26
27use std::future::Future;
28use std::future::IntoFuture;
29use std::sync::Arc;
30use std::sync::atomic::AtomicBool;
31use std::sync::atomic::Ordering;
32
33use tokio::sync::Notify;
34
35/// Create a new guarded flag. The flag obtains when the guard is dropped.
36pub fn guarded() -> (Flag, Guard) {
37    let state = Arc::new(Default::default());
38    let flag = Flag(Arc::clone(&state));
39    let guard = Guard(state);
40    (flag, guard)
41}
42
43#[derive(Debug, Default)]
44struct State {
45    flagged: AtomicBool,
46    notify: Notify,
47}
48
49impl State {
50    fn set(&self) {
51        self.flagged.store(true, Ordering::SeqCst);
52        self.notify.notify_one();
53    }
54
55    fn get(&self) -> bool {
56        self.flagged.load(Ordering::SeqCst)
57    }
58
59    async fn wait(&self) {
60        if !self.flagged.load(Ordering::SeqCst) {
61            self.notify.notified().await;
62        }
63    }
64}
65
66/// A flag indicating that an event occured. Flags can be queried and awaited.
67#[derive(Debug)]
68pub struct Flag(Arc<State>);
69
70impl Flag {
71    /// Returns true if the flag has been set.
72    pub fn signalled(&self) -> bool {
73        self.0.get()
74    }
75}
76
77impl IntoFuture for Flag {
78    type Output = ();
79    type IntoFuture = impl Future<Output = Self::Output>;
80    fn into_future(self) -> Self::IntoFuture {
81        async move { self.0.wait().await }
82    }
83}
84
85/// A guard that sets the flag when dropped.
86#[derive(Debug)]
87pub struct Guard(Arc<State>);
88
89impl Guard {
90    /// Sets the flag. This is equivalent to `drop(guard)`, but
91    /// conveys the intent more clearly.
92    pub fn signal(self) {}
93}
94
95impl Drop for Guard {
96    fn drop(&mut self) {
97        self.0.set();
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[tokio::test]
106    async fn test_basic() {
107        let (flag, guard) = guarded();
108        assert!(!flag.signalled());
109
110        guard.signal();
111        assert!(flag.signalled());
112
113        flag.await;
114    }
115
116    #[tokio::test]
117    async fn test_basic_running_await() {
118        let (flag, guard) = guarded();
119
120        let handle = tokio::spawn(async move {
121            flag.await;
122        });
123
124        #[allow(clippy::disallowed_methods)]
125        tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
126
127        guard.signal();
128        handle.await.unwrap();
129    }
130}