hyperactor/sync/
monitor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Monitors supervise a set of related tasks, aborting them on any failure.
10//!
11//! ```
12//! # use hyperactor::sync::monitor;
13//! # use hyperactor::sync::flag;
14//!
15//! # tokio_test::block_on(async {
16//! let (group, handle) = monitor::group();
17//! let (flag, guard) = flag::guarded();
18//! group.spawn(async move {
19//!     flag.await;
20//!     Result::<(), ()>::Err(())
21//! });
22//! group.spawn(async move {
23//!     guard.signal();
24//!     Result::<(), ()>::Ok(())
25//! });
26//! assert_eq!(handle.await, monitor::Status::Failed);
27//! # })
28//! ```
29use std::future::Future;
30use std::future::IntoFuture;
31use std::sync::Arc;
32use std::sync::Mutex;
33
34use enum_as_inner::EnumAsInner;
35use tokio::task::JoinSet;
36
37use crate::sync::flag;
38
39/// Create a new monitored group and handle. The group is aborted
40/// if either group or its handle are dropped.
41pub fn group() -> (Group, Handle) {
42    let (flag, guard) = flag::guarded();
43    let state = Arc::new(Mutex::new(State::Running {
44        _guard: guard,
45        tasks: JoinSet::new(),
46    }));
47
48    let group = Group(Arc::clone(&state));
49    let handle = Handle(Some((flag, state)));
50
51    (group, handle)
52}
53
54/// A handle to a monitored task group. Handles may be awaited to
55/// wait for the completion of the group (failure or abortion).
56pub struct Handle(Option<(flag::Flag, Arc<Mutex<State>>)>);
57
58impl Handle {
59    /// The current status of the group.
60    pub fn status(&self) -> Status {
61        self.unwrap_state().lock().unwrap().status()
62    }
63
64    /// Abort the group. This aborts all tasks and returns immediately.
65    /// Note that the group status is not guaranteed to converge to
66    /// [`Status::Aborted`] as this call may race with failing tasks.
67    pub fn abort(&self) {
68        self.unwrap_state().lock().unwrap().stop(true)
69    }
70
71    fn unwrap_state(&self) -> &Arc<Mutex<State>> {
72        &self.0.as_ref().unwrap().1
73    }
74
75    fn take(&mut self) -> Option<(flag::Flag, Arc<Mutex<State>>)> {
76        self.0.take()
77    }
78}
79
80impl Drop for Handle {
81    fn drop(&mut self) {
82        if let Some((_, ref state)) = self.0 {
83            state.lock().unwrap().stop(true);
84        }
85    }
86}
87
88impl IntoFuture for Handle {
89    type Output = Status;
90    type IntoFuture = impl Future<Output = Self::Output>;
91    fn into_future(mut self) -> Self::IntoFuture {
92        async move {
93            let (flag, state) = self.take().unwrap();
94            flag.await;
95            #[allow(clippy::let_and_return)]
96            let status = state.lock().unwrap().status();
97            status
98        }
99    }
100}
101
102/// A group of tasks that share a common fate. Any tasks that are spawned onto
103/// the group will be aborted if any task fails or if the group is aborted.
104///
105/// The group is also aborted if the group itself is dropped.
106#[derive(Clone)]
107pub struct Group(Arc<Mutex<State>>);
108
109/// The status of a group. Groups start out in [`Status::Running`]
110/// and transition exactly zero or one time to either [`Status::Failed`]
111/// or [`Status::Aborted`].
112#[derive(Debug, PartialEq, Eq, Clone, Copy)]
113pub enum Status {
114    /// The group is running zero or more tasks,
115    /// none of which have failed.
116    Running,
117    /// One of the group's tasks has failed, and
118    /// the remaining tasks have been canceled.
119    Failed,
120    /// The group was aborted by calling [`Group::abort`],
121    /// and any running tasks have been canceled.
122    Aborted,
123}
124
125impl Group {
126    /// Spawn a new task onto this group. If the task fails, the group is
127    /// aborted. If tasks are spawned onto an already stopped group, they are
128    /// simply never run.
129    pub fn spawn<F, T, E>(&self, fut: F)
130    where
131        F: Future<Output = Result<T, E>> + Send + 'static,
132    {
133        let state = Arc::clone(&self.0);
134        if let Some((_, tasks)) = self.0.lock().unwrap().as_running_mut() {
135            tasks.spawn(async move {
136                if fut.await.is_ok() {
137                    return;
138                }
139                state.lock().unwrap().stop(false);
140            });
141        }
142    }
143
144    /// Fail the group. This is equivalent to spawning a task that
145    /// immediately fails.
146    pub fn fail(&self) {
147        self.0.lock().unwrap().stop(false);
148    }
149}
150
151impl Drop for Group {
152    fn drop(&mut self) {
153        self.0.lock().unwrap().stop(true);
154    }
155}
156
157#[derive(EnumAsInner, Debug)]
158enum State {
159    Running {
160        _guard: flag::Guard,
161        tasks: JoinSet<()>,
162    },
163    Stopped(bool /*aborted*/),
164}
165
166impl State {
167    fn stop(&mut self, aborted: bool) {
168        if self.is_running() {
169            // This drops both `tasks` and `_guard` which will
170            // abort tasks and notify any waiters.
171            *self = State::Stopped(aborted);
172        }
173    }
174
175    pub fn status(&self) -> Status {
176        match self {
177            State::Running { .. } => Status::Running,
178            State::Stopped(false) => Status::Failed,
179            State::Stopped(true) => Status::Aborted,
180        }
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use futures::future;
187
188    use super::*;
189
190    #[tokio::test]
191    async fn test_basic() {
192        let (_group, handle) = group();
193        assert_eq!(handle.status(), Status::Running);
194        handle.abort();
195        assert_eq!(handle.status(), Status::Aborted);
196        handle.await;
197    }
198
199    #[tokio::test]
200    async fn test_group_drop() {
201        let (group, handle) = group();
202        assert_eq!(handle.status(), Status::Running);
203        drop(group);
204        assert_eq!(handle.status(), Status::Aborted);
205        handle.await;
206    }
207
208    #[tokio::test]
209    async fn test_abort_with_active_tasks() {
210        let (group, handle) = group();
211        let (flag, guard) = flag::guarded();
212
213        group.spawn(async move {
214            let _guard = guard;
215            future::pending::<Result<(), ()>>().await
216        });
217
218        assert!(!flag.signalled());
219        handle.abort();
220        assert_eq!(handle.status(), Status::Aborted);
221
222        flag.await;
223    }
224
225    #[tokio::test]
226    async fn test_fail_on_task_failure() {
227        let (group, handle) = group();
228
229        let (first_task_is_scheduled, first_task_is_scheduled_guard) = flag::guarded();
230        let (first_task_is_aborted, _first_task_is_aborted_guard) = flag::guarded();
231
232        group.spawn(async move {
233            let _guard = _first_task_is_aborted_guard;
234            first_task_is_scheduled_guard.signal();
235            future::pending::<Result<(), ()>>().await
236        });
237
238        let (second_task_should_fail, second_task_should_fail_guard) = flag::guarded();
239        group.spawn(async move {
240            second_task_should_fail.await;
241            Result::<(), ()>::Err(())
242        });
243
244        first_task_is_scheduled.await;
245        second_task_should_fail_guard.signal();
246        first_task_is_aborted.await;
247
248        assert_eq!(handle.await, Status::Failed);
249    }
250}