1use std::sync::Arc;
13use std::sync::Mutex;
14use std::time::Duration;
15
16use tokio::sync::Notify;
17use tokio::time::Instant;
18use tokio::time::sleep_until;
19
20pub struct Alarm {
29 status: Arc<Mutex<AlarmStatus>>,
30 notify: Arc<Notify>,
31 version: usize,
32}
33enum AlarmStatus {
34 Unarmed,
35 Armed {
36 version: usize,
37 deadline: Instant,
38 armed_at: Instant,
39 },
40 Dropped,
41}
42
43impl Alarm {
44 pub fn new() -> Self {
46 Self {
47 status: Arc::new(Mutex::new(AlarmStatus::Unarmed)),
48 notify: Arc::new(Notify::new()),
49 version: 0,
50 }
51 }
52
53 pub fn arm(&mut self, duration: Duration) {
55 let mut status = self.status.lock().unwrap();
56 let armed_at = tokio::time::Instant::now();
57 *status = AlarmStatus::Armed {
58 version: self.version,
59 deadline: armed_at + duration,
60 armed_at,
61 };
62 drop(status);
63 self.notify.notify_waiters();
64 self.version += 1;
65 }
66
67 pub fn rearm(&mut self, duration: Duration) {
71 let remaining = match *self.status.lock().unwrap() {
72 AlarmStatus::Armed { armed_at, .. } => {
73 let elapsed = tokio::time::Instant::now() - armed_at;
74 duration.saturating_sub(elapsed)
75 }
76 AlarmStatus::Unarmed | AlarmStatus::Dropped => duration,
77 };
78 self.arm(remaining);
79 }
80
81 pub fn disarm(&mut self) {
83 let mut status = self.status.lock().unwrap();
84 *status = AlarmStatus::Unarmed;
85 drop(status);
86 self.notify.notify_waiters();
89 }
90
91 pub fn fire(&mut self) {
93 self.arm(Duration::from_millis(0))
94 }
95
96 pub fn sleeper(&self) -> AlarmSleeper {
99 AlarmSleeper {
100 status: Arc::clone(&self.status),
101 notify: Arc::clone(&self.notify),
102 min_version: 0,
103 }
104 }
105}
106
107impl Drop for Alarm {
108 fn drop(&mut self) {
109 let mut status = self.status.lock().unwrap();
110 *status = AlarmStatus::Dropped;
111 drop(status);
112 self.notify.notify_waiters();
113 }
114}
115
116impl Default for Alarm {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122pub struct AlarmSleeper {
124 status: Arc<Mutex<AlarmStatus>>,
125 notify: Arc<Notify>,
126 min_version: usize,
127}
128
129impl AlarmSleeper {
130 pub async fn sleep(&mut self) -> bool {
136 loop {
137 let notified = self.notify.notified();
139 let deadline = match *self.status.lock().unwrap() {
140 AlarmStatus::Dropped => return false,
141 AlarmStatus::Unarmed => None,
142 AlarmStatus::Armed { version, .. } if version < self.min_version => None,
143 AlarmStatus::Armed {
144 version, deadline, ..
145 } if tokio::time::Instant::now() >= deadline => {
146 self.min_version = version + 1;
147 return true;
148 }
149 AlarmStatus::Armed {
150 version: _,
151 deadline,
152 ..
153 } => Some(deadline),
154 };
155
156 if let Some(deadline) = deadline {
157 tokio::select! {
158 _ = sleep_until(deadline) => (),
159 _ = notified => (),
160 }
161 } else {
162 notified.await;
163 }
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use std::time::Duration;
171
172 use tokio_test::assert_pending;
173 use tokio_test::task;
174
175 use super::*;
176
177 #[tokio::test]
178 async fn test_basic() {
179 let mut alarm = Alarm::new();
180 let mut sleeper = alarm.sleeper();
181 let handle = tokio::spawn(async move { sleeper.sleep().await });
182 assert!(!handle.is_finished()); alarm.fire();
185
186 assert!(handle.await.unwrap());
187
188 let mut sleeper = alarm.sleeper();
189 alarm.arm(Duration::from_mins(10));
190 let handle = tokio::spawn(async move { sleeper.sleep().await });
191 drop(alarm);
192 assert!(!handle.await.unwrap());
194 }
195
196 #[tokio::test]
197 async fn test_sleep_once() {
198 let mut alarm = Alarm::new();
199 alarm.fire();
200 let mut sleeper = alarm.sleeper();
201 assert!(sleeper.sleep().await);
202
203 assert_pending!(task::spawn(sleeper.sleep()).poll());
205 alarm.fire();
206 assert!(sleeper.sleep().await);
207 assert_pending!(task::spawn(sleeper.sleep()).poll());
209 drop(alarm);
210 assert!(!sleeper.sleep().await);
211 }
212
213 #[tokio::test]
214 async fn test_reset() {
215 let mut alarm = Alarm::new();
216 alarm.arm(Duration::from_mins(10));
217 let mut sleeper = alarm.sleeper();
218 assert_pending!(task::spawn(sleeper.sleep()).poll());
219 alarm.arm(Duration::from_millis(10));
221 assert!(sleeper.sleep().await);
222 }
223
224 #[tokio::test]
225 async fn test_disarm() {
226 let mut alarm = Alarm::new();
227 alarm.arm(Duration::from_mins(10));
228 let mut sleeper = alarm.sleeper();
229 assert_pending!(task::spawn(sleeper.sleep()).poll());
230 alarm.disarm();
231 assert_pending!(task::spawn(sleeper.sleep()).poll());
232 alarm.arm(Duration::from_millis(10));
233 assert!(sleeper.sleep().await);
234 }
235}