hyperactor/sync/
monitor.rs1#![allow(unused_assignments)]
35
36use std::future::Future;
37use std::future::IntoFuture;
38use std::sync::Arc;
39use std::sync::Mutex;
40
41use enum_as_inner::EnumAsInner;
42use tokio::task::JoinSet;
43
44use crate::sync::flag;
45
46pub fn group() -> (Group, Handle) {
49 let (flag, guard) = flag::guarded();
50 let state = Arc::new(Mutex::new(State::Running {
51 _guard: guard,
52 tasks: JoinSet::new(),
53 }));
54
55 let group = Group(Arc::clone(&state));
56 let handle = Handle(Some((flag, state)));
57
58 (group, handle)
59}
60
61pub struct Handle(Option<(flag::Flag, Arc<Mutex<State>>)>);
64
65impl Handle {
66 pub fn status(&self) -> Status {
68 self.unwrap_state().lock().unwrap().status()
69 }
70
71 pub fn abort(&self) {
75 self.unwrap_state().lock().unwrap().stop(true)
76 }
77
78 fn unwrap_state(&self) -> &Arc<Mutex<State>> {
79 &self.0.as_ref().unwrap().1
80 }
81
82 fn take(&mut self) -> Option<(flag::Flag, Arc<Mutex<State>>)> {
83 self.0.take()
84 }
85}
86
87impl Drop for Handle {
88 fn drop(&mut self) {
89 if let Some((_, ref state)) = self.0 {
90 state.lock().unwrap().stop(true);
91 }
92 }
93}
94
95impl IntoFuture for Handle {
96 type Output = Status;
97 type IntoFuture = impl Future<Output = Self::Output>;
98 fn into_future(mut self) -> Self::IntoFuture {
99 async move {
100 let (flag, state) = self.take().unwrap();
101 flag.await;
102 #[allow(clippy::let_and_return)]
103 let status = state.lock().unwrap().status();
104 status
105 }
106 }
107}
108
109#[derive(Clone)]
114pub struct Group(Arc<Mutex<State>>);
115
116#[derive(Debug, PartialEq, Eq, Clone, Copy)]
120pub enum Status {
121 Running,
124 Failed,
127 Aborted,
130}
131
132impl Group {
133 pub fn spawn<F, T, E>(&self, fut: F)
137 where
138 F: Future<Output = Result<T, E>> + Send + 'static,
139 {
140 let state = Arc::clone(&self.0);
141 if let Some((_, tasks)) = self.0.lock().unwrap().as_running_mut() {
142 tasks.spawn(async move {
143 if fut.await.is_ok() {
144 return;
145 }
146 state.lock().unwrap().stop(false);
147 });
148 }
149 }
150
151 pub fn fail(&self) {
154 self.0.lock().unwrap().stop(false);
155 }
156}
157
158impl Drop for Group {
159 fn drop(&mut self) {
160 self.0.lock().unwrap().stop(true);
161 }
162}
163
164#[derive(EnumAsInner, Debug)]
165enum State {
166 Running {
167 _guard: flag::Guard,
168 tasks: JoinSet<()>,
169 },
170 Stopped(bool ),
171}
172
173impl State {
174 fn stop(&mut self, aborted: bool) {
175 if self.is_running() {
176 *self = State::Stopped(aborted);
179 }
180 }
181
182 pub fn status(&self) -> Status {
183 match self {
184 State::Running { .. } => Status::Running,
185 State::Stopped(false) => Status::Failed,
186 State::Stopped(true) => Status::Aborted,
187 }
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use futures::future;
194
195 use super::*;
196
197 #[tokio::test]
198 async fn test_basic() {
199 let (_group, handle) = group();
200 assert_eq!(handle.status(), Status::Running);
201 handle.abort();
202 assert_eq!(handle.status(), Status::Aborted);
203 handle.await;
204 }
205
206 #[tokio::test]
207 async fn test_group_drop() {
208 let (group, handle) = group();
209 assert_eq!(handle.status(), Status::Running);
210 drop(group);
211 assert_eq!(handle.status(), Status::Aborted);
212 handle.await;
213 }
214
215 #[tokio::test]
216 async fn test_abort_with_active_tasks() {
217 let (group, handle) = group();
218 let (flag, guard) = flag::guarded();
219
220 group.spawn(async move {
221 let _guard = guard;
222 future::pending::<Result<(), ()>>().await
223 });
224
225 assert!(!flag.signalled());
226 handle.abort();
227 assert_eq!(handle.status(), Status::Aborted);
228
229 flag.await;
230 }
231
232 #[tokio::test]
233 async fn test_fail_on_task_failure() {
234 let (group, handle) = group();
235
236 let (first_task_is_scheduled, first_task_is_scheduled_guard) = flag::guarded();
237 let (first_task_is_aborted, _first_task_is_aborted_guard) = flag::guarded();
238
239 group.spawn(async move {
240 let _guard = _first_task_is_aborted_guard;
241 first_task_is_scheduled_guard.signal();
242 future::pending::<Result<(), ()>>().await
243 });
244
245 let (second_task_should_fail, second_task_should_fail_guard) = flag::guarded();
246 group.spawn(async move {
247 second_task_should_fail.await;
248 Result::<(), ()>::Err(())
249 });
250
251 first_task_is_scheduled.await;
252 second_task_should_fail_guard.signal();
253 first_task_is_aborted.await;
254
255 assert_eq!(handle.await, Status::Failed);
256 }
257}