hyperactor/sync/
monitor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Monitors supervise a set of related tasks, aborting them on any failure.
10//!
11//! ```
12//! # use hyperactor::sync::monitor;
13//! # use hyperactor::sync::flag;
14//!
15//! # tokio_test::block_on(async {
16//! let (group, handle) = monitor::group();
17//! let (flag, guard) = flag::guarded();
18//! group.spawn(async move {
19//!     flag.await;
20//!     Result::<(), ()>::Err(())
21//! });
22//! group.spawn(async move {
23//!     guard.signal();
24//!     Result::<(), ()>::Ok(())
25//! });
26//! assert_eq!(handle.await, monitor::Status::Failed);
27//! # })
28//! ```
29
30// EnumAsInner generates code that triggers a false positive
31// unused_assignments lint on struct variant fields. #[allow] on the
32// enum itself doesn't propagate into derive-macro-generated code, so
33// the suppression must be at module scope.
34#![allow(unused_assignments)]
35
36use std::future::Future;
37use std::future::IntoFuture;
38use std::sync::Arc;
39use std::sync::Mutex;
40
41use enum_as_inner::EnumAsInner;
42use tokio::task::JoinSet;
43
44use crate::sync::flag;
45
46/// Create a new monitored group and handle. The group is aborted
47/// if either group or its handle are dropped.
48pub fn group() -> (Group, Handle) {
49    let (flag, guard) = flag::guarded();
50    let state = Arc::new(Mutex::new(State::Running {
51        _guard: guard,
52        tasks: JoinSet::new(),
53    }));
54
55    let group = Group(Arc::clone(&state));
56    let handle = Handle(Some((flag, state)));
57
58    (group, handle)
59}
60
61/// A handle to a monitored task group. Handles may be awaited to
62/// wait for the completion of the group (failure or abortion).
63pub struct Handle(Option<(flag::Flag, Arc<Mutex<State>>)>);
64
65impl Handle {
66    /// The current status of the group.
67    pub fn status(&self) -> Status {
68        self.unwrap_state().lock().unwrap().status()
69    }
70
71    /// Abort the group. This aborts all tasks and returns immediately.
72    /// Note that the group status is not guaranteed to converge to
73    /// [`Status::Aborted`] as this call may race with failing tasks.
74    pub fn abort(&self) {
75        self.unwrap_state().lock().unwrap().stop(true)
76    }
77
78    fn unwrap_state(&self) -> &Arc<Mutex<State>> {
79        &self.0.as_ref().unwrap().1
80    }
81
82    fn take(&mut self) -> Option<(flag::Flag, Arc<Mutex<State>>)> {
83        self.0.take()
84    }
85}
86
87impl Drop for Handle {
88    fn drop(&mut self) {
89        if let Some((_, ref state)) = self.0 {
90            state.lock().unwrap().stop(true);
91        }
92    }
93}
94
95impl IntoFuture for Handle {
96    type Output = Status;
97    type IntoFuture = impl Future<Output = Self::Output>;
98    fn into_future(mut self) -> Self::IntoFuture {
99        async move {
100            let (flag, state) = self.take().unwrap();
101            flag.await;
102            #[allow(clippy::let_and_return)]
103            let status = state.lock().unwrap().status();
104            status
105        }
106    }
107}
108
109/// A group of tasks that share a common fate. Any tasks that are spawned onto
110/// the group will be aborted if any task fails or if the group is aborted.
111///
112/// The group is also aborted if the group itself is dropped.
113#[derive(Clone)]
114pub struct Group(Arc<Mutex<State>>);
115
116/// The status of a group. Groups start out in [`Status::Running`]
117/// and transition exactly zero or one time to either [`Status::Failed`]
118/// or [`Status::Aborted`].
119#[derive(Debug, PartialEq, Eq, Clone, Copy)]
120pub enum Status {
121    /// The group is running zero or more tasks,
122    /// none of which have failed.
123    Running,
124    /// One of the group's tasks has failed, and
125    /// the remaining tasks have been canceled.
126    Failed,
127    /// The group was aborted by calling [`Group::abort`],
128    /// and any running tasks have been canceled.
129    Aborted,
130}
131
132impl Group {
133    /// Spawn a new task onto this group. If the task fails, the group is
134    /// aborted. If tasks are spawned onto an already stopped group, they are
135    /// simply never run.
136    pub fn spawn<F, T, E>(&self, fut: F)
137    where
138        F: Future<Output = Result<T, E>> + Send + 'static,
139    {
140        let state = Arc::clone(&self.0);
141        if let Some((_, tasks)) = self.0.lock().unwrap().as_running_mut() {
142            tasks.spawn(async move {
143                if fut.await.is_ok() {
144                    return;
145                }
146                state.lock().unwrap().stop(false);
147            });
148        }
149    }
150
151    /// Fail the group. This is equivalent to spawning a task that
152    /// immediately fails.
153    pub fn fail(&self) {
154        self.0.lock().unwrap().stop(false);
155    }
156}
157
158impl Drop for Group {
159    fn drop(&mut self) {
160        self.0.lock().unwrap().stop(true);
161    }
162}
163
164#[derive(EnumAsInner, Debug)]
165enum State {
166    Running {
167        _guard: flag::Guard,
168        tasks: JoinSet<()>,
169    },
170    Stopped(bool /*aborted*/),
171}
172
173impl State {
174    fn stop(&mut self, aborted: bool) {
175        if self.is_running() {
176            // This drops both `tasks` and `_guard` which will
177            // abort tasks and notify any waiters.
178            *self = State::Stopped(aborted);
179        }
180    }
181
182    pub fn status(&self) -> Status {
183        match self {
184            State::Running { .. } => Status::Running,
185            State::Stopped(false) => Status::Failed,
186            State::Stopped(true) => Status::Aborted,
187        }
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use futures::future;
194
195    use super::*;
196
197    #[tokio::test]
198    async fn test_basic() {
199        let (_group, handle) = group();
200        assert_eq!(handle.status(), Status::Running);
201        handle.abort();
202        assert_eq!(handle.status(), Status::Aborted);
203        handle.await;
204    }
205
206    #[tokio::test]
207    async fn test_group_drop() {
208        let (group, handle) = group();
209        assert_eq!(handle.status(), Status::Running);
210        drop(group);
211        assert_eq!(handle.status(), Status::Aborted);
212        handle.await;
213    }
214
215    #[tokio::test]
216    async fn test_abort_with_active_tasks() {
217        let (group, handle) = group();
218        let (flag, guard) = flag::guarded();
219
220        group.spawn(async move {
221            let _guard = guard;
222            future::pending::<Result<(), ()>>().await
223        });
224
225        assert!(!flag.signalled());
226        handle.abort();
227        assert_eq!(handle.status(), Status::Aborted);
228
229        flag.await;
230    }
231
232    #[tokio::test]
233    async fn test_fail_on_task_failure() {
234        let (group, handle) = group();
235
236        let (first_task_is_scheduled, first_task_is_scheduled_guard) = flag::guarded();
237        let (first_task_is_aborted, _first_task_is_aborted_guard) = flag::guarded();
238
239        group.spawn(async move {
240            let _guard = _first_task_is_aborted_guard;
241            first_task_is_scheduled_guard.signal();
242            future::pending::<Result<(), ()>>().await
243        });
244
245        let (second_task_should_fail, second_task_should_fail_guard) = flag::guarded();
246        group.spawn(async move {
247            second_task_should_fail.await;
248            Result::<(), ()>::Err(())
249        });
250
251        first_task_is_scheduled.await;
252        second_task_should_fail_guard.signal();
253        first_task_is_aborted.await;
254
255        assert_eq!(handle.await, Status::Failed);
256    }
257}