1use std::sync::Arc;
13use std::sync::Mutex;
14use std::time::Duration;
15
16use tokio::sync::Notify;
17use tokio::time::Instant;
18use tokio::time::sleep_until;
19
20use crate::clock::Clock;
21use crate::clock::RealClock;
22
23pub struct Alarm {
32 status: Arc<Mutex<AlarmStatus>>,
33 notify: Arc<Notify>,
34 version: usize,
35}
36enum AlarmStatus {
37 Unarmed,
38 Armed {
39 version: usize,
40 deadline: Instant,
41 armed_at: Instant,
42 },
43 Dropped,
44}
45
46impl Alarm {
47 pub fn new() -> Self {
49 Self {
50 status: Arc::new(Mutex::new(AlarmStatus::Unarmed)),
51 notify: Arc::new(Notify::new()),
52 version: 0,
53 }
54 }
55
56 pub fn arm(&mut self, duration: Duration) {
58 let mut status = self.status.lock().unwrap();
59 let armed_at = RealClock.now();
60 *status = AlarmStatus::Armed {
61 version: self.version,
62 deadline: armed_at + duration,
63 armed_at,
64 };
65 drop(status);
66 self.notify.notify_waiters();
67 self.version += 1;
68 }
69
70 pub fn rearm(&mut self, duration: Duration) {
74 let remaining = match *self.status.lock().unwrap() {
75 AlarmStatus::Armed { armed_at, .. } => {
76 let elapsed = RealClock.now() - armed_at;
77 duration.saturating_sub(elapsed)
78 }
79 AlarmStatus::Unarmed | AlarmStatus::Dropped => duration,
80 };
81 self.arm(remaining);
82 }
83
84 pub fn disarm(&mut self) {
86 let mut status = self.status.lock().unwrap();
87 *status = AlarmStatus::Unarmed;
88 drop(status);
89 self.notify.notify_waiters();
92 }
93
94 pub fn fire(&mut self) {
96 self.arm(Duration::from_millis(0))
97 }
98
99 pub fn sleeper(&self) -> AlarmSleeper {
102 AlarmSleeper {
103 status: Arc::clone(&self.status),
104 notify: Arc::clone(&self.notify),
105 min_version: 0,
106 }
107 }
108}
109
110impl Drop for Alarm {
111 fn drop(&mut self) {
112 let mut status = self.status.lock().unwrap();
113 *status = AlarmStatus::Dropped;
114 drop(status);
115 self.notify.notify_waiters();
116 }
117}
118
119impl Default for Alarm {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125pub struct AlarmSleeper {
127 status: Arc<Mutex<AlarmStatus>>,
128 notify: Arc<Notify>,
129 min_version: usize,
130}
131
132impl AlarmSleeper {
133 pub async fn sleep(&mut self) -> bool {
139 loop {
140 let notified = self.notify.notified();
142 let deadline = match *self.status.lock().unwrap() {
143 AlarmStatus::Dropped => return false,
144 AlarmStatus::Unarmed => None,
145 AlarmStatus::Armed { version, .. } if version < self.min_version => None,
146 AlarmStatus::Armed {
147 version, deadline, ..
148 } if RealClock.now() >= deadline => {
149 self.min_version = version + 1;
150 return true;
151 }
152 AlarmStatus::Armed {
153 version: _,
154 deadline,
155 ..
156 } => Some(deadline),
157 };
158
159 if let Some(deadline) = deadline {
160 tokio::select! {
161 _ = sleep_until(deadline) => (),
162 _ = notified => (),
163 }
164 } else {
165 notified.await;
166 }
167 }
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use std::time::Duration;
174
175 use tokio_test::assert_pending;
176 use tokio_test::task;
177
178 use super::*;
179
180 #[tokio::test]
181 async fn test_basic() {
182 let mut alarm = Alarm::new();
183 let mut sleeper = alarm.sleeper();
184 let handle = tokio::spawn(async move { sleeper.sleep().await });
185 assert!(!handle.is_finished()); alarm.fire();
188
189 assert!(handle.await.unwrap());
190
191 let mut sleeper = alarm.sleeper();
192 alarm.arm(Duration::from_mins(10));
193 let handle = tokio::spawn(async move { sleeper.sleep().await });
194 drop(alarm);
195 assert!(!handle.await.unwrap());
197 }
198
199 #[tokio::test]
200 async fn test_sleep_once() {
201 let mut alarm = Alarm::new();
202 alarm.fire();
203 let mut sleeper = alarm.sleeper();
204 assert!(sleeper.sleep().await);
205
206 assert_pending!(task::spawn(sleeper.sleep()).poll());
208 alarm.fire();
209 assert!(sleeper.sleep().await);
210 assert_pending!(task::spawn(sleeper.sleep()).poll());
212 drop(alarm);
213 assert!(!sleeper.sleep().await);
214 }
215
216 #[tokio::test]
217 async fn test_reset() {
218 let mut alarm = Alarm::new();
219 alarm.arm(Duration::from_mins(10));
220 let mut sleeper = alarm.sleeper();
221 assert_pending!(task::spawn(sleeper.sleep()).poll());
222 alarm.arm(Duration::from_millis(10));
224 assert!(sleeper.sleep().await);
225 }
226
227 #[tokio::test]
228 async fn test_disarm() {
229 let mut alarm = Alarm::new();
230 alarm.arm(Duration::from_mins(10));
231 let mut sleeper = alarm.sleeper();
232 assert_pending!(task::spawn(sleeper.sleep()).poll());
233 alarm.disarm();
234 assert_pending!(task::spawn(sleeper.sleep()).poll());
235 alarm.arm(Duration::from_millis(10));
236 assert!(sleeper.sleep().await);
237 }
238}