hyperactor/
time.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module contains various utilities for dealing with time.
10//! (This probably belongs in a separate crate.)
11
12use std::sync::Arc;
13use std::sync::Mutex;
14use std::time::Duration;
15
16use tokio::sync::Notify;
17use tokio::time::Instant;
18use tokio::time::sleep_until;
19
20use crate::clock::Clock;
21use crate::clock::RealClock;
22
23/// An alarm that can be Armed to fire at some future time.
24///
25/// Alarm is itself owned, and may have multiple sleepers attached
26/// to it. Each sleeper is awoken at most once for each alarm that has
27/// been set.
28///
29/// When instances of `Alarm` are dropped, sleepers are awoken,
30/// returning `false`, indicating that the alarm is defunct.
31pub struct Alarm {
32    status: Arc<Mutex<AlarmStatus>>,
33    notify: Arc<Notify>,
34    version: usize,
35}
36enum AlarmStatus {
37    Unarmed,
38    Armed { version: usize, deadline: Instant },
39    Dropped,
40}
41
42impl Alarm {
43    /// Create a new, unset alarm.
44    pub fn new() -> Self {
45        Self {
46            status: Arc::new(Mutex::new(AlarmStatus::Unarmed)),
47            notify: Arc::new(Notify::new()),
48            version: 0,
49        }
50    }
51
52    /// Arm the alarm to fire after the provided duration.
53    pub fn arm(&mut self, duration: Duration) {
54        let mut status = self.status.lock().unwrap();
55        *status = AlarmStatus::Armed {
56            version: self.version,
57            deadline: RealClock.now() + duration,
58        };
59        drop(status);
60        self.notify.notify_waiters();
61        self.version += 1;
62    }
63
64    /// Disarm the alarm, canceling any pending alarms.
65    pub fn disarm(&mut self) {
66        let mut status = self.status.lock().unwrap();
67        *status = AlarmStatus::Unarmed;
68        drop(status);
69        // Not technically needed (sleepers will still converge),
70        // but this clears up the timers:
71        self.notify.notify_waiters();
72    }
73
74    /// Fire the alarm immediately.
75    pub fn fire(&mut self) {
76        self.arm(Duration::from_millis(0))
77    }
78
79    /// Create a new sleeper for this alarm. Many sleepers can wait for the alarm
80    /// to fire at any given time.
81    pub fn sleeper(&self) -> AlarmSleeper {
82        AlarmSleeper {
83            status: Arc::clone(&self.status),
84            notify: Arc::clone(&self.notify),
85            min_version: 0,
86        }
87    }
88}
89
90impl Drop for Alarm {
91    fn drop(&mut self) {
92        let mut status = self.status.lock().unwrap();
93        *status = AlarmStatus::Dropped;
94        drop(status);
95        self.notify.notify_waiters();
96    }
97}
98
99impl Default for Alarm {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105/// A single alarm sleeper.
106pub struct AlarmSleeper {
107    status: Arc<Mutex<AlarmStatus>>,
108    notify: Arc<Notify>,
109    min_version: usize,
110}
111
112impl AlarmSleeper {
113    /// Sleep until the alarm fires. Returns true if the alarm fired,
114    /// and false if the alarm has been dropped.
115    ///
116    /// Sleep will fire (return true) at most once for each time the
117    /// alarm is set.
118    pub async fn sleep(&mut self) -> bool {
119        loop {
120            // Obtain a notifier before checking the state, to avoid the unlock-notify race.
121            let notified = self.notify.notified();
122            let deadline = match *self.status.lock().unwrap() {
123                AlarmStatus::Dropped => return false,
124                AlarmStatus::Unarmed => None,
125                AlarmStatus::Armed { version, .. } if version < self.min_version => None,
126                AlarmStatus::Armed { version, deadline } if RealClock.now() >= deadline => {
127                    self.min_version = version + 1;
128                    return true;
129                }
130                AlarmStatus::Armed {
131                    version: _,
132                    deadline,
133                } => Some(deadline),
134            };
135
136            if let Some(deadline) = deadline {
137                tokio::select! {
138                    _ = sleep_until(deadline) => (),
139                    _ = notified => (),
140                }
141            } else {
142                notified.await;
143            }
144        }
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use std::time::Duration;
151
152    use tokio_test::assert_pending;
153    use tokio_test::task;
154
155    use super::*;
156
157    #[tokio::test]
158    async fn test_basic() {
159        let mut alarm = Alarm::new();
160        let mut sleeper = alarm.sleeper();
161        let handle = tokio::spawn(async move { sleeper.sleep().await });
162        assert!(!handle.is_finished()); // not super meaningful..
163
164        alarm.fire();
165
166        assert!(handle.await.unwrap());
167
168        let mut sleeper = alarm.sleeper();
169        alarm.arm(Duration::from_secs(600));
170        let handle = tokio::spawn(async move { sleeper.sleep().await });
171        drop(alarm);
172        // Dropped:
173        assert!(!handle.await.unwrap());
174    }
175
176    #[tokio::test]
177    async fn test_sleep_once() {
178        let mut alarm = Alarm::new();
179        alarm.fire();
180        let mut sleeper = alarm.sleeper();
181        assert!(sleeper.sleep().await);
182
183        // Don't wake up again:
184        assert_pending!(task::spawn(sleeper.sleep()).poll());
185        alarm.fire();
186        assert!(sleeper.sleep().await);
187        // Don't wake up again:
188        assert_pending!(task::spawn(sleeper.sleep()).poll());
189        drop(alarm);
190        assert!(!sleeper.sleep().await);
191    }
192
193    #[tokio::test]
194    async fn test_reset() {
195        let mut alarm = Alarm::new();
196        alarm.arm(Duration::from_secs(600));
197        let mut sleeper = alarm.sleeper();
198        assert_pending!(task::spawn(sleeper.sleep()).poll());
199        // Should reset after setting to an earlier time:
200        alarm.arm(Duration::from_millis(10));
201        assert!(sleeper.sleep().await);
202    }
203
204    #[tokio::test]
205    async fn test_disarm() {
206        let mut alarm = Alarm::new();
207        alarm.arm(Duration::from_secs(600));
208        let mut sleeper = alarm.sleeper();
209        assert_pending!(task::spawn(sleeper.sleep()).poll());
210        alarm.disarm();
211        assert_pending!(task::spawn(sleeper.sleep()).poll());
212        alarm.arm(Duration::from_millis(10));
213        assert!(sleeper.sleep().await);
214    }
215}