hyperactor/test_utils/cancel_safe.rs
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Utilities for testing **cancel safety** of futures.
10//!
11//! # What does "cancel-safe" mean?
12//!
13//! A future is *cancel-safe* if, at **any** `Poll::Pending` boundary:
14//!
15//! 1. **State remains valid** – dropping the future there does not
16//! violate external invariants or leave shared state corrupted.
17//! 2. **Restartability holds** – from that state, constructing a
18//! fresh future for the same logical operation can still run to
19//! completion and produce the expected result.
20//! 3. **No partial side effects** – cancellation never leaves behind
21//! a visible "half-done" action; effects are either not started,
22//! or fully completed in an idempotent way.
23//!
24//! # Why cancel-safety matters
25//!
26//! Executors are free to drop futures after any `Poll::Pending`. This
27//! means that cancellation is not an exceptional path – it is *part
28//! of the normal contract*. A cancel-unsafe future can leak
29//! resources, corrupt protocol state, or leave behind truncated I/O.
30//!
31//! # What this module offers
32//!
33//! This module provides helpers (`assert_cancel_safe`,
34//! `assert_cancel_safe_async`) that:
35//!
36//! - drive a future to completion once, counting its yield points,
37//! - then for every possible cancellation boundary `k`, poll a fresh
38//! future `k` times, drop it, and finally ensure a **new run**
39//! still produces the expected result.
40//!
41//! # Examples
42//!
43//! - ✓ Pure/logical futures: simple state machines with no I/O (e.g.
44//! yields twice, then return 42).
45//! - ✓ Framed writers that stage bytes internally and only commit
46//! once the frame is fully written.
47//! - ✗ Writers that flush a partial frame before returning `Pending`.
48//! - ✗ Futures that consume from a shared queue before `Pending` and
49//! drop without rollback.
50
51use std::fmt::Debug;
52use std::future::Future;
53use std::pin::Pin;
54use std::task::Context;
55use std::task::Poll;
56use std::task::RawWaker;
57use std::task::RawWakerVTable;
58use std::task::Waker;
59
60/// A minimal no-op waker for manual polling.
61fn noop_waker() -> Waker {
62 fn clone(_: *const ()) -> RawWaker {
63 RawWaker::new(std::ptr::null(), &VTABLE)
64 }
65 fn wake(_: *const ()) {}
66 fn wake_by_ref(_: *const ()) {}
67 fn drop(_: *const ()) {}
68 static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
69 // SAFETY: The vtable doesn't use the data pointer.
70 unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
71}
72
73/// Poll a future once.
74fn poll_once<F: Future + Unpin>(fut: &mut F, cx: &mut Context<'_>) -> Poll<F::Output> {
75 Pin::new(fut).poll(cx)
76}
77
78/// Drive a fresh future to completion, returning (`pending_count`,
79/// `out`). `pending_count` is the number of times the future returned
80/// `Poll::Pending` before it finally resolved to `Poll::Ready`.
81fn run_to_completion_count_pending<F, T>(mut mk: impl FnMut() -> F) -> (usize, T)
82where
83 F: Future<Output = T>,
84{
85 let waker = noop_waker();
86 let mut cx = Context::from_waker(&waker);
87
88 let mut fut = Box::pin(mk());
89 let mut pending_count = 0usize;
90
91 loop {
92 match poll_once(&mut fut, &mut cx) {
93 Poll::Ready(out) => return (pending_count, out),
94 Poll::Pending => {
95 pending_count += 1;
96 // Nothing else to do: we are just counting yield
97 // points.
98 }
99 }
100 }
101}
102
103/// Runtime-independent version: on each `Poll::Pending`, we just poll
104/// again. Suitable for pure/logical futures that don’t rely on
105/// timers, IO, or other external progress driven by an async runtime.
106pub fn assert_cancel_safe<F, T>(mut mk: impl FnMut() -> F, expected: &T)
107where
108 F: Future<Output = T>,
109 T: Debug + PartialEq,
110{
111 // 1) Establish ground truth and number of yield points.
112 let (pending_total, out) = run_to_completion_count_pending(&mut mk);
113 assert_eq!(&out, expected, "baseline run output mismatch");
114
115 // 2) Cancel at every poll boundary k, then ensure a fresh run
116 // still matches.
117 for k in 0..=pending_total {
118 let waker = noop_waker();
119 let mut cx = Context::from_waker(&waker);
120
121 // Poll exactly k times (dropping afterwards).
122 {
123 let mut fut = Box::pin(mk());
124 for _ in 0..k {
125 if poll_once(&mut fut, &mut cx).is_ready() {
126 // Future completed earlier than k: no
127 // cancellation point here. Drop and move on to
128 // next k.
129 break;
130 }
131 }
132 // Drop here = "cancellation".
133 drop(fut);
134 }
135
136 // 3) Now ensure we can still complete cleanly and match
137 // expected. This verifies cancelling at this boundary didn’t
138 // corrupt global state or violate invariants needed for a
139 // clean, subsequent run.
140 let (_, out2) = run_to_completion_count_pending(&mut mk);
141 assert_eq!(
142 &out2, expected,
143 "output mismatch after cancelling at poll #{k}"
144 );
145 }
146}
147
148/// Cancel-safety check for async futures. On every `Poll::Pending`,
149/// runs `on_pending().await` to drive external progress (e.g.
150/// advancing a paused clock or IO). Cancels at each yield boundary
151/// and ensures a fresh run still produces `expected`.
152pub async fn assert_cancel_safe_async<F, T, P, FutStep>(
153 mut mk: impl FnMut() -> F,
154 expected: &T,
155 mut on_pending: P,
156) where
157 F: Future<Output = T>,
158 T: Debug + PartialEq,
159 P: FnMut() -> FutStep,
160 FutStep: Future<Output = ()>,
161{
162 let waker = noop_waker();
163 let mut cx = Context::from_waker(&waker);
164
165 // 1) First, establish expected + number of pendings with the
166 // ability to drive progress.
167 let mut pending_total = 0usize;
168 {
169 let mut fut = Box::pin(mk());
170 loop {
171 match poll_once(&mut fut, &mut cx) {
172 Poll::Ready(out) => {
173 assert_eq!(&out, expected, "baseline run output mismatch");
174 break;
175 }
176 Poll::Pending => {
177 pending_total += 1;
178 on_pending().await;
179 }
180 }
181 }
182 }
183
184 // 2) Cancel at each poll boundary.
185 for k in 0..=pending_total {
186 // Poll exactly k steps, advancing external progress each
187 // time.
188 {
189 let mut fut = Box::pin(mk());
190 for _ in 0..k {
191 match poll_once(&mut fut, &mut cx) {
192 Poll::Ready(_) => break, // Completed earlier than k
193 Poll::Pending => on_pending().await,
194 }
195 }
196 drop(fut); // cancellation
197 }
198
199 // 3) Then ensure a clean full completion still yields
200 // expected.
201 {
202 let mut fut = Box::pin(mk());
203 loop {
204 match poll_once(&mut fut, &mut cx) {
205 Poll::Ready(out) => {
206 assert_eq!(
207 &out, expected,
208 "output mismatch after cancelling at poll #{k}"
209 );
210 break;
211 }
212 Poll::Pending => on_pending().await,
213 }
214 }
215 }
216 }
217}
218
219/// Convenience macro for `assert_cancel_safe`.
220///
221/// Example:
222/// ```ignore
223/// assert_cancel_safe!(CountToThree { step: 0 }, 42);
224/// ```
225///
226/// - `my_future_expr` is any expression that produces a fresh future
227/// when evaluated (e.g. `CountToThree { step: 0 }`).
228/// - `expected_value` is the value you expect the future to resolve
229/// to. **Pass a plain value, not a reference**. The macro will take a
230/// reference internally.
231#[macro_export]
232macro_rules! assert_cancel_safe {
233 ($make_future:expr, $expected:expr) => {{ $crate::test_utils::cancel_safe::assert_cancel_safe(|| $make_future, &$expected) }};
234}
235
236/// Async convenience macro for `assert_cancel_safe_async`.
237///
238/// Example:
239/// ```ignore
240/// assert_cancel_safe_async!(
241/// two_sleeps(),
242/// 7,
243/// || async { tokio::time::advance(std::time::Duration::from_millis(1)).await }
244/// );
245/// ```
246///
247/// - `my_future_expr` is any expression that produces a fresh future
248/// when evaluated (e.g. `two_sleeps()`).
249/// - `expected_value` is the value you expect the future to resolve
250/// to. **Pass a plain value, not a reference**. The macro will take
251/// a reference internally.
252/// - `on_pending` is a closure that returns an async block, used to
253/// drive external progress each time the future yields
254/// `Poll::Pending`.
255#[macro_export]
256macro_rules! assert_cancel_safe_async {
257 ($make_future:expr, $expected:expr, $on_pending:expr) => {{
258 $crate::test_utils::cancel_safe::assert_cancel_safe_async(
259 || $make_future,
260 &$expected,
261 $on_pending,
262 )
263 .await
264 }};
265}
266
267#[cfg(test)]
268mod tests {
269 use tokio::time::Duration;
270 use tokio::time::{self};
271
272 use super::*;
273
274 // A future that yields twice, then returns a number.
275 struct CountToThree {
276 step: u8,
277 }
278
279 impl Future for CountToThree {
280 type Output = u8;
281
282 fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
283 self.step += 1;
284 match self.step {
285 1 | 2 => Poll::Pending, // yield twice...
286 3 => Poll::Ready(42), // ... 3rd time's a charm
287 _ => panic!("polled after completion"),
288 }
289 }
290 }
291
292 // Smoke test: verify that a simple state-machine future (yields
293 // twice, then completes) passes the cancel-safety checks.
294 #[test]
295 fn test_count_to_three_cancel_safe() {
296 assert_cancel_safe!(CountToThree { step: 0 }, 42u8);
297 }
298
299 // A future that waits for two sleeps (1ms each), then returns 7.
300 #[allow(clippy::disallowed_methods)]
301 async fn two_sleeps() -> u8 {
302 time::sleep(Duration::from_millis(1)).await;
303 time::sleep(Duration::from_millis(1)).await;
304 7
305 }
306
307 // Smoke test: verify that a timer-based async future (with two
308 // sleeps) passes the async cancel-safety checks under tokio's
309 // mocked time.
310 #[tokio::test(start_paused = true)]
311 async fn test_two_sleeps_cancel_safe_async() {
312 assert_cancel_safe_async!(two_sleeps(), 7, || async {
313 time::advance(Duration::from_millis(1)).await
314 });
315 }
316}