1use std::sync::Arc;
13use std::sync::Mutex;
14use std::time::Duration;
15
16use tokio::sync::Notify;
17use tokio::time::Instant;
18use tokio::time::sleep_until;
19
20use crate::clock::Clock;
21use crate::clock::RealClock;
22
23pub struct Alarm {
32 status: Arc<Mutex<AlarmStatus>>,
33 notify: Arc<Notify>,
34 version: usize,
35}
36enum AlarmStatus {
37 Unarmed,
38 Armed { version: usize, deadline: Instant },
39 Dropped,
40}
41
42impl Alarm {
43 pub fn new() -> Self {
45 Self {
46 status: Arc::new(Mutex::new(AlarmStatus::Unarmed)),
47 notify: Arc::new(Notify::new()),
48 version: 0,
49 }
50 }
51
52 pub fn arm(&mut self, duration: Duration) {
54 let mut status = self.status.lock().unwrap();
55 *status = AlarmStatus::Armed {
56 version: self.version,
57 deadline: RealClock.now() + duration,
58 };
59 drop(status);
60 self.notify.notify_waiters();
61 self.version += 1;
62 }
63
64 pub fn disarm(&mut self) {
66 let mut status = self.status.lock().unwrap();
67 *status = AlarmStatus::Unarmed;
68 drop(status);
69 self.notify.notify_waiters();
72 }
73
74 pub fn fire(&mut self) {
76 self.arm(Duration::from_millis(0))
77 }
78
79 pub fn sleeper(&self) -> AlarmSleeper {
82 AlarmSleeper {
83 status: Arc::clone(&self.status),
84 notify: Arc::clone(&self.notify),
85 min_version: 0,
86 }
87 }
88}
89
90impl Drop for Alarm {
91 fn drop(&mut self) {
92 let mut status = self.status.lock().unwrap();
93 *status = AlarmStatus::Dropped;
94 drop(status);
95 self.notify.notify_waiters();
96 }
97}
98
99impl Default for Alarm {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105pub struct AlarmSleeper {
107 status: Arc<Mutex<AlarmStatus>>,
108 notify: Arc<Notify>,
109 min_version: usize,
110}
111
112impl AlarmSleeper {
113 pub async fn sleep(&mut self) -> bool {
119 loop {
120 let notified = self.notify.notified();
122 let deadline = match *self.status.lock().unwrap() {
123 AlarmStatus::Dropped => return false,
124 AlarmStatus::Unarmed => None,
125 AlarmStatus::Armed { version, .. } if version < self.min_version => None,
126 AlarmStatus::Armed { version, deadline } if RealClock.now() >= deadline => {
127 self.min_version = version + 1;
128 return true;
129 }
130 AlarmStatus::Armed {
131 version: _,
132 deadline,
133 } => Some(deadline),
134 };
135
136 if let Some(deadline) = deadline {
137 tokio::select! {
138 _ = sleep_until(deadline) => (),
139 _ = notified => (),
140 }
141 } else {
142 notified.await;
143 }
144 }
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use std::time::Duration;
151
152 use tokio_test::assert_pending;
153 use tokio_test::task;
154
155 use super::*;
156
157 #[tokio::test]
158 async fn test_basic() {
159 let mut alarm = Alarm::new();
160 let mut sleeper = alarm.sleeper();
161 let handle = tokio::spawn(async move { sleeper.sleep().await });
162 assert!(!handle.is_finished()); alarm.fire();
165
166 assert!(handle.await.unwrap());
167
168 let mut sleeper = alarm.sleeper();
169 alarm.arm(Duration::from_secs(600));
170 let handle = tokio::spawn(async move { sleeper.sleep().await });
171 drop(alarm);
172 assert!(!handle.await.unwrap());
174 }
175
176 #[tokio::test]
177 async fn test_sleep_once() {
178 let mut alarm = Alarm::new();
179 alarm.fire();
180 let mut sleeper = alarm.sleeper();
181 assert!(sleeper.sleep().await);
182
183 assert_pending!(task::spawn(sleeper.sleep()).poll());
185 alarm.fire();
186 assert!(sleeper.sleep().await);
187 assert_pending!(task::spawn(sleeper.sleep()).poll());
189 drop(alarm);
190 assert!(!sleeper.sleep().await);
191 }
192
193 #[tokio::test]
194 async fn test_reset() {
195 let mut alarm = Alarm::new();
196 alarm.arm(Duration::from_secs(600));
197 let mut sleeper = alarm.sleeper();
198 assert_pending!(task::spawn(sleeper.sleep()).poll());
199 alarm.arm(Duration::from_millis(10));
201 assert!(sleeper.sleep().await);
202 }
203
204 #[tokio::test]
205 async fn test_disarm() {
206 let mut alarm = Alarm::new();
207 alarm.arm(Duration::from_secs(600));
208 let mut sleeper = alarm.sleeper();
209 assert_pending!(task::spawn(sleeper.sleep()).poll());
210 alarm.disarm();
211 assert_pending!(task::spawn(sleeper.sleep()).poll());
212 alarm.arm(Duration::from_millis(10));
213 assert!(sleeper.sleep().await);
214 }
215}