hyperactor/sync/
monitor.rs1use std::future::Future;
30use std::future::IntoFuture;
31use std::sync::Arc;
32use std::sync::Mutex;
33
34use enum_as_inner::EnumAsInner;
35use tokio::task::JoinSet;
36
37use crate::sync::flag;
38
39pub fn group() -> (Group, Handle) {
42 let (flag, guard) = flag::guarded();
43 let state = Arc::new(Mutex::new(State::Running {
44 _guard: guard,
45 tasks: JoinSet::new(),
46 }));
47
48 let group = Group(Arc::clone(&state));
49 let handle = Handle(Some((flag, state)));
50
51 (group, handle)
52}
53
54pub struct Handle(Option<(flag::Flag, Arc<Mutex<State>>)>);
57
58impl Handle {
59 pub fn status(&self) -> Status {
61 self.unwrap_state().lock().unwrap().status()
62 }
63
64 pub fn abort(&self) {
68 self.unwrap_state().lock().unwrap().stop(true)
69 }
70
71 fn unwrap_state(&self) -> &Arc<Mutex<State>> {
72 &self.0.as_ref().unwrap().1
73 }
74
75 fn take(&mut self) -> Option<(flag::Flag, Arc<Mutex<State>>)> {
76 self.0.take()
77 }
78}
79
80impl Drop for Handle {
81 fn drop(&mut self) {
82 if let Some((_, ref state)) = self.0 {
83 state.lock().unwrap().stop(true);
84 }
85 }
86}
87
88impl IntoFuture for Handle {
89 type Output = Status;
90 type IntoFuture = impl Future<Output = Self::Output>;
91 fn into_future(mut self) -> Self::IntoFuture {
92 async move {
93 let (flag, state) = self.take().unwrap();
94 flag.await;
95 #[allow(clippy::let_and_return)]
96 let status = state.lock().unwrap().status();
97 status
98 }
99 }
100}
101
102#[derive(Clone)]
107pub struct Group(Arc<Mutex<State>>);
108
109#[derive(Debug, PartialEq, Eq, Clone, Copy)]
113pub enum Status {
114 Running,
117 Failed,
120 Aborted,
123}
124
125impl Group {
126 pub fn spawn<F, T, E>(&self, fut: F)
130 where
131 F: Future<Output = Result<T, E>> + Send + 'static,
132 {
133 let state = Arc::clone(&self.0);
134 if let Some((_, tasks)) = self.0.lock().unwrap().as_running_mut() {
135 tasks.spawn(async move {
136 if fut.await.is_ok() {
137 return;
138 }
139 state.lock().unwrap().stop(false);
140 });
141 }
142 }
143
144 pub fn fail(&self) {
147 self.0.lock().unwrap().stop(false);
148 }
149}
150
151impl Drop for Group {
152 fn drop(&mut self) {
153 self.0.lock().unwrap().stop(true);
154 }
155}
156
157#[derive(EnumAsInner, Debug)]
158enum State {
159 Running {
160 _guard: flag::Guard,
161 tasks: JoinSet<()>,
162 },
163 Stopped(bool ),
164}
165
166impl State {
167 fn stop(&mut self, aborted: bool) {
168 if self.is_running() {
169 *self = State::Stopped(aborted);
172 }
173 }
174
175 pub fn status(&self) -> Status {
176 match self {
177 State::Running { .. } => Status::Running,
178 State::Stopped(false) => Status::Failed,
179 State::Stopped(true) => Status::Aborted,
180 }
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use futures::future;
187
188 use super::*;
189
190 #[tokio::test]
191 async fn test_basic() {
192 let (_group, handle) = group();
193 assert_eq!(handle.status(), Status::Running);
194 handle.abort();
195 assert_eq!(handle.status(), Status::Aborted);
196 handle.await;
197 }
198
199 #[tokio::test]
200 async fn test_group_drop() {
201 let (group, handle) = group();
202 assert_eq!(handle.status(), Status::Running);
203 drop(group);
204 assert_eq!(handle.status(), Status::Aborted);
205 handle.await;
206 }
207
208 #[tokio::test]
209 async fn test_abort_with_active_tasks() {
210 let (group, handle) = group();
211 let (flag, guard) = flag::guarded();
212
213 group.spawn(async move {
214 let _guard = guard;
215 future::pending::<Result<(), ()>>().await
216 });
217
218 assert!(!flag.signalled());
219 handle.abort();
220 assert_eq!(handle.status(), Status::Aborted);
221
222 flag.await;
223 }
224
225 #[tokio::test]
226 async fn test_fail_on_task_failure() {
227 let (group, handle) = group();
228
229 let (first_task_is_scheduled, first_task_is_scheduled_guard) = flag::guarded();
230 let (first_task_is_aborted, _first_task_is_aborted_guard) = flag::guarded();
231
232 group.spawn(async move {
233 let _guard = _first_task_is_aborted_guard;
234 first_task_is_scheduled_guard.signal();
235 future::pending::<Result<(), ()>>().await
236 });
237
238 let (second_task_should_fail, second_task_should_fail_guard) = flag::guarded();
239 group.spawn(async move {
240 second_task_should_fail.await;
241 Result::<(), ()>::Err(())
242 });
243
244 first_task_is_scheduled.await;
245 second_task_should_fail_guard.signal();
246 first_task_is_aborted.await;
247
248 assert_eq!(handle.await, Status::Failed);
249 }
250}