hyperactor/test_utils/
cancel_safe.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Utilities for testing **cancel safety** of futures.
10//!
11//! # What does "cancel-safe" mean?
12//!
13//! A future is *cancel-safe* if, at **any** `Poll::Pending` boundary:
14//!
15//! 1. **State remains valid** – dropping the future there does not
16//!    violate external invariants or leave shared state corrupted.
17//! 2. **Restartability holds** – from that state, constructing a
18//!    fresh future for the same logical operation can still run to
19//!    completion and produce the expected result.
20//! 3. **No partial side effects** – cancellation never leaves behind
21//!    a visible "half-done" action; effects are either not started,
22//!    or fully completed in an idempotent way.
23//!
24//! # Why cancel-safety matters
25//!
26//! Executors are free to drop futures after any `Poll::Pending`. This
27//! means that cancellation is not an exceptional path – it is *part
28//! of the normal contract*. A cancel-unsafe future can leak
29//! resources, corrupt protocol state, or leave behind truncated I/O.
30//!
31//! # What this module offers
32//!
33//! This module provides helpers (`assert_cancel_safe`,
34//! `assert_cancel_safe_async`) that:
35//!
36//! - drive a future to completion once, counting its yield points,
37//! - then for every possible cancellation boundary `k`, poll a fresh
38//!   future `k` times, drop it, and finally ensure a **new run**
39//!   still produces the expected result.
40//!
41//! # Examples
42//!
43//! - ✓ Pure/logical futures: simple state machines with no I/O (e.g.
44//!   yields twice, then return 42).
45//! - ✓ Framed writers that stage bytes internally and only commit
46//!   once the frame is fully written.
47//! - ✗ Writers that flush a partial frame before returning `Pending`.
48//! - ✗ Futures that consume from a shared queue before `Pending` and
49//!   drop without rollback.
50
51use std::fmt::Debug;
52use std::future::Future;
53use std::pin::Pin;
54use std::task::Context;
55use std::task::Poll;
56use std::task::RawWaker;
57use std::task::RawWakerVTable;
58use std::task::Waker;
59
60/// A minimal no-op waker for manual polling.
61fn noop_waker() -> Waker {
62    fn clone(_: *const ()) -> RawWaker {
63        RawWaker::new(std::ptr::null(), &VTABLE)
64    }
65    fn wake(_: *const ()) {}
66    fn wake_by_ref(_: *const ()) {}
67    fn drop(_: *const ()) {}
68    static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
69    // SAFETY: The vtable doesn't use the data pointer.
70    unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
71}
72
73/// Poll a future once.
74fn poll_once<F: Future + Unpin>(fut: &mut F, cx: &mut Context<'_>) -> Poll<F::Output> {
75    Pin::new(fut).poll(cx)
76}
77
78/// Drive a fresh future to completion, returning (`pending_count`,
79/// `out`). `pending_count` is the number of times the future returned
80/// `Poll::Pending` before it finally resolved to `Poll::Ready`.
81fn run_to_completion_count_pending<F, T>(mut mk: impl FnMut() -> F) -> (usize, T)
82where
83    F: Future<Output = T>,
84{
85    let waker = noop_waker();
86    let mut cx = Context::from_waker(&waker);
87
88    let mut fut = Box::pin(mk());
89    let mut pending_count = 0usize;
90
91    loop {
92        match poll_once(&mut fut, &mut cx) {
93            Poll::Ready(out) => return (pending_count, out),
94            Poll::Pending => {
95                pending_count += 1;
96                // Nothing else to do: we are just counting yield
97                // points.
98            }
99        }
100    }
101}
102
103/// Runtime-independent version: on each `Poll::Pending`, we just poll
104/// again. Suitable for pure/logical futures that don’t rely on
105/// timers, IO, or other external progress driven by an async runtime.
106pub fn assert_cancel_safe<F, T>(mut mk: impl FnMut() -> F, expected: &T)
107where
108    F: Future<Output = T>,
109    T: Debug + PartialEq,
110{
111    // 1) Establish ground truth and number of yield points.
112    let (pending_total, out) = run_to_completion_count_pending(&mut mk);
113    assert_eq!(&out, expected, "baseline run output mismatch");
114
115    // 2) Cancel at every poll boundary k, then ensure a fresh run
116    // still matches.
117    for k in 0..=pending_total {
118        let waker = noop_waker();
119        let mut cx = Context::from_waker(&waker);
120
121        // Poll exactly k times (dropping afterwards).
122        {
123            let mut fut = Box::pin(mk());
124            for _ in 0..k {
125                if poll_once(&mut fut, &mut cx).is_ready() {
126                    // Future completed earlier than k: no
127                    // cancellation point here. Drop and move on to
128                    // next k.
129                    break;
130                }
131            }
132            // Drop here = "cancellation".
133            drop(fut);
134        }
135
136        // 3) Now ensure we can still complete cleanly and match
137        // expected. This verifies cancelling at this boundary didn’t
138        // corrupt global state or violate invariants needed for a
139        // clean, subsequent run.
140        let (_, out2) = run_to_completion_count_pending(&mut mk);
141        assert_eq!(
142            &out2, expected,
143            "output mismatch after cancelling at poll #{k}"
144        );
145    }
146}
147
148/// Cancel-safety check for async futures. On every `Poll::Pending`,
149/// runs `on_pending().await` to drive external progress (e.g.
150/// advancing a paused clock or IO). Cancels at each yield boundary
151/// and ensures a fresh run still produces `expected`.
152pub async fn assert_cancel_safe_async<F, T, P, FutStep>(
153    mut mk: impl FnMut() -> F,
154    expected: &T,
155    mut on_pending: P,
156) where
157    F: Future<Output = T>,
158    T: Debug + PartialEq,
159    P: FnMut() -> FutStep,
160    FutStep: Future<Output = ()>,
161{
162    let waker = noop_waker();
163    let mut cx = Context::from_waker(&waker);
164
165    // 1) First, establish expected + number of pendings with the
166    // ability to drive progress.
167    let mut pending_total = 0usize;
168    {
169        let mut fut = Box::pin(mk());
170        loop {
171            match poll_once(&mut fut, &mut cx) {
172                Poll::Ready(out) => {
173                    assert_eq!(&out, expected, "baseline run output mismatch");
174                    break;
175                }
176                Poll::Pending => {
177                    pending_total += 1;
178                    on_pending().await;
179                }
180            }
181        }
182    }
183
184    // 2) Cancel at each poll boundary.
185    for k in 0..=pending_total {
186        // Poll exactly k steps, advancing external progress each
187        // time.
188        {
189            let mut fut = Box::pin(mk());
190            for _ in 0..k {
191                match poll_once(&mut fut, &mut cx) {
192                    Poll::Ready(_) => break, // Completed earlier than k
193                    Poll::Pending => on_pending().await,
194                }
195            }
196            drop(fut); // cancellation
197        }
198
199        // 3) Then ensure a clean full completion still yields
200        // expected.
201        {
202            let mut fut = Box::pin(mk());
203            loop {
204                match poll_once(&mut fut, &mut cx) {
205                    Poll::Ready(out) => {
206                        assert_eq!(
207                            &out, expected,
208                            "output mismatch after cancelling at poll #{k}"
209                        );
210                        break;
211                    }
212                    Poll::Pending => on_pending().await,
213                }
214            }
215        }
216    }
217}
218
219/// Convenience macro for `assert_cancel_safe`.
220///
221/// Example:
222/// ```ignore
223/// assert_cancel_safe!(CountToThree { step: 0 }, 42);
224/// ```
225///
226/// - `my_future_expr` is any expression that produces a fresh future
227///   when evaluated (e.g. `CountToThree { step: 0 }`).
228/// - `expected_value` is the value you expect the future to resolve
229///   to. **Pass a plain value, not a reference**. The macro will take a
230///   reference internally.
231#[macro_export]
232macro_rules! assert_cancel_safe {
233    ($make_future:expr, $expected:expr) => {{ $crate::test_utils::cancel_safe::assert_cancel_safe(|| $make_future, &$expected) }};
234}
235
236/// Async convenience macro for `assert_cancel_safe_async`.
237///
238/// Example:
239/// ```ignore
240/// assert_cancel_safe_async!(
241///     two_sleeps(),
242///     7,
243///     || async { tokio::time::advance(std::time::Duration::from_millis(1)).await }
244/// );
245/// ```
246///
247/// - `my_future_expr` is any expression that produces a fresh future
248///   when evaluated (e.g. `two_sleeps()`).
249/// - `expected_value` is the value you expect the future to resolve
250///   to. **Pass a plain value, not a reference**. The macro will take
251///   a reference internally.
252/// - `on_pending` is a closure that returns an async block, used to
253///   drive external progress each time the future yields
254///   `Poll::Pending`.
255#[macro_export]
256macro_rules! assert_cancel_safe_async {
257    ($make_future:expr, $expected:expr, $on_pending:expr) => {{
258        $crate::test_utils::cancel_safe::assert_cancel_safe_async(
259            || $make_future,
260            &$expected,
261            $on_pending,
262        )
263        .await
264    }};
265}
266
267#[cfg(test)]
268mod tests {
269    use tokio::time::Duration;
270    use tokio::time::{self};
271
272    use super::*;
273
274    // A future that yields twice, then returns a number.
275    struct CountToThree {
276        step: u8,
277    }
278
279    impl Future for CountToThree {
280        type Output = u8;
281
282        fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
283            self.step += 1;
284            match self.step {
285                1 | 2 => Poll::Pending, // yield twice...
286                3 => Poll::Ready(42),   // ... 3rd time's a charm
287                _ => panic!("polled after completion"),
288            }
289        }
290    }
291
292    // Smoke test: verify that a simple state-machine future (yields
293    // twice, then completes) passes the cancel-safety checks.
294    #[test]
295    fn test_count_to_three_cancel_safe() {
296        assert_cancel_safe!(CountToThree { step: 0 }, 42u8);
297    }
298
299    // A future that waits for two sleeps (1ms each), then returns 7.
300    #[allow(clippy::disallowed_methods)]
301    async fn two_sleeps() -> u8 {
302        time::sleep(Duration::from_millis(1)).await;
303        time::sleep(Duration::from_millis(1)).await;
304        7
305    }
306
307    // Smoke test: verify that a timer-based async future (with two
308    // sleeps) passes the async cancel-safety checks under tokio's
309    // mocked time.
310    #[tokio::test(start_paused = true)]
311    async fn test_two_sleeps_cancel_safe_async() {
312        assert_cancel_safe_async!(two_sleeps(), 7, || async {
313            time::advance(Duration::from_millis(1)).await
314        });
315    }
316}