monarch_hyperactor/pytokio.rs
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9/// `pytokio` is Monarch's Python <-> Tokio async bridge.
10///
11/// It provides a small, *non-asyncio* async world where Python code
12/// can *compose* Rust/Tokio futures using `await`.
13///
14/// ## The core idea
15///
16/// In `pytokio`:
17///
18/// - `PythonTask` = a one-shot Rust/Tokio future that produces a
19/// Python value.
20/// - `from_coroutine` = wraps a Python coroutine as a Rust future
21/// that drives it.
22/// - `Shared` = an awaitable handle to a spawned background Tokio
23/// task.
24///
25/// More concretely:
26///
27/// - Rust bindings return a Python-visible `PythonTask`
28/// (`PyPythonTask`), which wraps a Rust `PythonTask` holding a
29/// boxed Tokio future returning `PyResult<Py<PyAny>>`.
30/// - `PythonTask.from_coroutine(coro)` wraps a *Python coroutine* as
31/// a `PythonTask` by creating a Rust/Tokio future that drives
32/// `coro.__await__()` (via `send`/`throw`) and awaits the
33/// `PythonTask`s it yields.
34/// - Python code may `await` a `PythonTask` / `Shared` **only** when
35/// running under `PythonTask.from_coroutine(...)`. Awaiting
36/// arbitrary Python awaitables (e.g. `asyncio` futures) is an
37/// error.
38/// - Calling `task.spawn()` / `spawn_abortable()` returns a `Shared`
39/// (`PyShared`), which yields the result of the background Tokio
40/// task running the original `PythonTask`.
41///
42/// This is intentionally *not* a general-purpose async bridge: it’s a
43/// way to use Python syntax to drive and compose Tokio futures.
44///
45/// ## Wrapping a Python coroutine
46///
47/// ```ignore
48/// async def work():
49/// x = await some_rust_binding() # must yield PythonTask / Shared
50/// await PythonTask.sleep(0.1) # also a PythonTask
51/// return x
52///
53/// task = PythonTask.from_coroutine(work())
54/// result = task.block_on() # block the calling Python thread while a
55/// # Tokio runtime drives the task to completion
56/// ```
57///
58/// `from_coroutine` drives the coroutine by repeatedly resuming it
59/// and awaiting the `PythonTask`s it yields, using a Tokio runtime.
60///
61/// ## Spawning
62///
63/// `spawn()` runs a `PythonTask` on a background Tokio task and
64/// returns a `Shared` handle.
65///
66/// To `await` the handle, you must still be inside a
67/// `from_coroutine`-driven coroutine:
68///
69/// ```ignore
70/// async def work():
71/// task = some_rust_binding()
72/// shared = task.spawn()
73/// # ... do other work ...
74/// result = await shared # valid here (inside from_coroutine world)
75/// return result
76///
77/// result = PythonTask.from_coroutine(work()).block_on()
78/// ```
79///
80/// In synchronous contexts, you can wait for a spawned task without
81/// `from_coroutine`:
82///
83/// ```ignore
84/// shared = task.spawn()
85/// result = shared.block_on() # blocks the calling Python thread
86/// ```
87///
88/// If `spawn_abortable()` is used, dropping the returned `Shared`
89/// aborts the underlying Tokio task.
90///
91/// ## Context propagation
92///
93/// `from_coroutine` preserves Monarch’s `context()` across Tokio
94/// thread hops, so code calling `context()` inside a `PythonTask`
95/// sees the same actor context as the call site that constructed the
96/// task.
97use std::error::Error;
98use std::future::Future;
99use std::pin::Pin;
100
101use hyperactor_config::CONFIG;
102use hyperactor_config::ConfigAttr;
103use hyperactor_config::attrs::declare_attrs;
104use monarch_types::SerializablePyErr;
105use monarch_types::py_global;
106use pyo3::IntoPyObjectExt;
107#[cfg(test)]
108use pyo3::PyClass;
109use pyo3::exceptions::PyRuntimeError;
110use pyo3::exceptions::PyStopIteration;
111use pyo3::exceptions::PyTimeoutError;
112use pyo3::exceptions::PyValueError;
113use pyo3::prelude::*;
114use pyo3::types::PyNone;
115use pyo3::types::PyString;
116use pyo3::types::PyTuple;
117use pyo3::types::PyType;
118use tokio::sync::Mutex;
119use tokio::sync::watch;
120use tokio::task::JoinHandle;
121
122use crate::pickle::reduce_shared;
123use crate::runtime::get_tokio_runtime;
124use crate::runtime::monarch_with_gil;
125use crate::runtime::monarch_with_gil_blocking;
126use crate::runtime::signal_safe_block_on;
127
128declare_attrs! {
129 /// If true, capture a Python stack trace at `PythonTask` creation
130 /// time and log it when a spawned task errors but is never
131 /// awaited/polled.
132 @meta(CONFIG = ConfigAttr::new(
133 Some("MONARCH_HYPERACTOR_ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK".to_string()),
134 Some("enable_unawaited_python_task_traceback".to_string()),
135 ))
136 pub attr ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK: bool = false;
137}
138
139// Import Python helpers used for actor context propagation.
140// `context()` returns the current Monarch actor context.
141// `actor_mesh` is the module that owns the `_context` contextvar we
142// must manually set/restore when driving coroutines on Tokio threads.
143py_global!(context, "monarch._src.actor.actor_mesh", "context");
144py_global!(actor_mesh_module, "monarch._src.actor", "actor_mesh");
145
146/// Capture the current Python stack trace (creation call site) if
147/// `ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK` is enabled.
148///
149/// Returns `None` when disabled to avoid the overhead of
150/// `traceback.extract_stack()`.
151fn current_traceback() -> PyResult<Option<Py<PyAny>>> {
152 if hyperactor_config::global::get(ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK) {
153 monarch_with_gil_blocking(|py| {
154 Ok(Some(
155 py.import("traceback")?
156 .call_method0("extract_stack")?
157 .unbind(),
158 ))
159 })
160 } else {
161 Ok(None)
162 }
163}
164
165/// Format a captured traceback (from `traceback.extract_stack()`) as
166/// a single string suitable for logging.
167fn format_traceback(py: Python<'_>, traceback: &Py<PyAny>) -> PyResult<String> {
168 let tb = py
169 .import("traceback")?
170 .call_method1("format_list", (traceback,))?;
171 PyString::new(py, "")
172 .call_method1("join", (tb,))?
173 .extract::<String>()
174}
175
176/// Helper struct to make a Rust/Tokio future (returning a Python
177/// result) passable in an actor message.
178///
179/// The future resolves to `PyResult<Py<PyAny>>` so it can return a
180/// Python value or raise a Python exception, and it is `Send +
181/// 'static` so it can cross thread/actor boundaries.
182///
183/// Also so that we don't have to write this massive type signature
184/// everywhere.
185pub(crate) struct PythonTask {
186 /// Boxed, pinned Rust/Tokio future producing a Python result,
187 /// protected so it can be taken/consumed exactly once when the
188 /// task is driven.
189 // Type decoder ring:
190 //
191 // Mutex<Pin<Box<dyn Future<Output = PyResult<Py<PyAny>>> + Send + 'static>>>
192 // │ │ │ │ │ │
193 // │ │ │ │ │ └─ owns all data, no dangling refs
194 // │ │ │ │ └─ can cross thread boundaries
195 // │ │ │ └─ any future type (type-erased)
196 // │ │ └─ heap-allocated (because unsized)
197 // │ └─ immovable (safe to poll self-referential futures)
198 // └─ exclusive access for consumption
199 future: Mutex<Pin<Box<dyn Future<Output = PyResult<Py<PyAny>>> + Send + 'static>>>,
200
201 /// Optional Python stack trace captured at task construction
202 /// time, used to annotate logs when a spawned task errors but
203 /// nobody awaits/polls it.
204 traceback: Option<Py<PyAny>>,
205}
206
207impl PythonTask {
208 /// Construct a `PythonTask` from a Rust/Tokio future and an
209 /// optional captured Python traceback.
210 ///
211 /// The future is boxed and pinned so it can be stored in the
212 /// struct and later driven safely.
213 fn new_with_traceback(
214 fut: impl Future<Output = PyResult<Py<PyAny>>> + Send + 'static,
215 traceback: Option<Py<PyAny>>,
216 ) -> Self {
217 Self {
218 future: Mutex::new(Box::pin(fut)),
219 traceback,
220 }
221 }
222
223 /// Construct a `PythonTask`, capturing a creation-site traceback
224 /// if enabled by `ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK`.
225 pub(crate) fn new(
226 fut: impl Future<Output = PyResult<Py<PyAny>>> + Send + 'static,
227 ) -> PyResult<Self> {
228 Ok(Self::new_with_traceback(fut, current_traceback()?))
229 }
230
231 /// Return the optional captured creation-site traceback (if
232 /// enabled).
233 fn traceback(&self) -> &Option<Py<PyAny>> {
234 &self.traceback
235 }
236
237 /// Consume the task and return the boxed, pinned future.
238 ///
239 /// This is a one-shot operation: it moves the future out of the
240 /// struct so it can be driven to completion.
241 pub(crate) fn take(
242 self,
243 ) -> Pin<Box<dyn Future<Output = PyResult<Py<PyAny>>> + Send + 'static>> {
244 self.future.into_inner()
245 }
246}
247
248impl std::fmt::Debug for PythonTask {
249 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250 f.debug_struct("PythonTask")
251 .field("future", &"<PythonFuture>")
252 .finish()
253 }
254}
255
256/// Python-visible wrapper for a one-shot `PythonTask`.
257///
258/// Exposed to Python as
259/// `monarch._rust_bindings.monarch_hyperactor.pytokio.PythonTask`.
260/// This object owns the underlying Rust task and is *consumed* when
261/// it is run (e.g. via `spawn()`, `spawn_abortable()`, or
262/// `block_on()`), hence `inner: Option<_>`.
263#[pyclass(
264 name = "PythonTask",
265 module = "monarch._rust_bindings.monarch_hyperactor.pytokio"
266)]
267pub struct PyPythonTask {
268 inner: Option<PythonTask>,
269}
270
271impl From<PythonTask> for PyPythonTask {
272 fn from(task: PythonTask) -> Self {
273 Self { inner: Some(task) }
274 }
275}
276
277/// Minimal await-iterator used to implement Python's `__await__`
278/// protocol for pytokio.
279///
280/// This iterator yields the task object exactly once. The Rust-side
281/// coroutine driver (`from_coroutine`) resumes the Python coroutine
282/// and expects it to yield a `PythonTask` (or `Shared`) object back
283/// to Rust.
284#[pyclass(
285 name = "PythonTaskAwaitIterator",
286 module = "monarch._rust_bindings.monarch_hyperactor.pytokio"
287)]
288struct PythonTaskAwaitIterator {
289 value: Option<Py<PyAny>>,
290}
291
292impl PythonTaskAwaitIterator {
293 /// Create an await-iterator that will yield `task` exactly once.
294 fn new(task: Py<PyAny>) -> PythonTaskAwaitIterator {
295 PythonTaskAwaitIterator { value: Some(task) }
296 }
297}
298
299#[pymethods]
300impl PythonTaskAwaitIterator {
301 /// First `send(...)` yields the stored task; subsequent sends
302 /// raise `StopIteration`.
303 ///
304 /// Python's await machinery calls `send(None)` to advance the
305 /// iterator.
306 fn send(&mut self, value: Py<PyAny>) -> PyResult<Py<PyAny>> {
307 self.value
308 .take()
309 .ok_or_else(|| PyStopIteration::new_err((value,)))
310 }
311
312 /// Convert the thrown Python exception value into a `PyErr` and
313 /// surface it to Rust.
314 fn throw(&mut self, value: Py<PyAny>) -> PyResult<Py<PyAny>> {
315 Err(monarch_with_gil_blocking(|py| {
316 PyErr::from_value(value.into_bound(py))
317 }))
318 }
319
320 /// Iterator protocol: `next(it)` is equivalent to
321 /// `it.send(None)`.
322 fn __next__(&mut self, py: Python<'_>) -> PyResult<Py<PyAny>> {
323 self.send(py.None())
324 }
325}
326
327impl PyPythonTask {
328 /// Construct a Python-visible `PythonTask` from a Rust future,
329 /// attaching an explicit creation-site traceback (if provided).
330 ///
331 /// The input future produces a Rust value `T`; on completion we
332 /// reacquire the GIL and convert `T` into a Python object
333 /// (`Py<PyAny>`).
334 fn new_with_traceback<F, T>(fut: F, traceback: Option<Py<PyAny>>) -> PyResult<Self>
335 where
336 F: Future<Output = PyResult<T>> + Send + 'static,
337 T: for<'py> IntoPyObject<'py> + Send,
338 {
339 Ok(PythonTask::new_with_traceback(
340 async {
341 let result = fut.await?;
342 monarch_with_gil(|py| result.into_py_any(py)).await
343 },
344 traceback,
345 )
346 .into())
347 }
348
349 /// Construct a `PythonTask`, capturing a creation-site traceback
350 /// if enabled.
351 ///
352 /// See `new_with_traceback` for conversion semantics (`T` ->
353 /// Python object under the GIL).
354 pub fn new<F, T>(fut: F) -> PyResult<Self>
355 where
356 F: Future<Output = PyResult<T>> + Send + 'static,
357 T: for<'py> IntoPyObject<'py> + Send,
358 {
359 Self::new_with_traceback(fut, current_traceback()?)
360 }
361}
362
363// Helper: convert a Rust error into a generic Python ValueError.
364fn to_py_error<T>(e: T) -> PyErr
365where
366 T: Error,
367{
368 PyErr::new::<PyValueError, _>(e.to_string())
369}
370
371impl PyPythonTask {
372 /// Consume this `PythonTask` and return the underlying Rust
373 /// future.
374 ///
375 /// This is a one-shot operation: after calling `take_task`, the
376 /// `PyPythonTask` is considered *consumed* and cannot be
377 /// spawned/awaited/blocked-on again.
378 pub fn take_task(
379 &mut self,
380 ) -> PyResult<Pin<Box<dyn Future<Output = Result<Py<PyAny>, PyErr>> + Send + 'static>>> {
381 self.inner
382 .take()
383 .map(|task| task.take())
384 .ok_or_else(|| PyValueError::new_err("PythonTask already consumed"))
385 }
386
387 /// Return the captured creation-site traceback (if enabled),
388 /// cloning it under the GIL.
389 ///
390 /// Fails if the task has already been consumed.
391 fn traceback(&self) -> PyResult<Option<Py<PyAny>>> {
392 if let Some(task) = &self.inner {
393 Ok(monarch_with_gil_blocking(|py| {
394 task.traceback().as_ref().map(|t| t.clone_ref(py))
395 }))
396 } else {
397 Err(PyValueError::new_err("PythonTask already consumed"))
398 }
399 }
400
401 /// Spawn this task onto the Tokio runtime and return a `Shared`
402 /// handle that *aborts on drop*.
403 ///
404 /// Use this when the underlying future is *abort-safe*
405 /// (cancellation-safe): dropping the returned `Shared` will call
406 /// `JoinHandle::abort()`, preventing the background task from
407 /// running forever.
408 ///
409 /// This is especially useful for long-lived or periodic tasks
410 /// (e.g. timers) where "nobody is awaiting the result anymore"
411 /// should stop the work.
412 ///
413 /// Like `spawn()`, this consumes the `PyPythonTask` (it can only
414 /// be spawned once).
415 pub(crate) fn spawn_abortable(&mut self) -> PyResult<PyShared> {
416 let (tx, rx) = watch::channel(None);
417 let traceback = self.traceback()?;
418 let traceback1 = self.traceback()?;
419 let task = self.take_task()?;
420 let handle = get_tokio_runtime().spawn(async move {
421 send_result(tx, task.await, traceback1);
422 });
423 Ok(PyShared {
424 rx,
425 handle: Some(handle),
426 abort: true,
427 traceback,
428 })
429 }
430}
431
432/// Publish a completed task result to the `watch` channel.
433///
434/// If the receiver has already been dropped, `watch::Sender::send`
435/// returns the unsent value as `SendError`. We treat that as "nobody
436/// will ever observe this result".
437///
438/// In the special case where the unobserved result is an error, we
439/// log it (and include the task creation traceback when available) to
440/// avoid silently losing failures from background tasks.
441fn send_result(
442 tx: tokio::sync::watch::Sender<Option<PyResult<Py<PyAny>>>>,
443 result: PyResult<Py<PyAny>>,
444 traceback: Option<Py<PyAny>>,
445) {
446 // a SendErr just means that there are no consumers of the value left.
447 match tx.send(Some(result)) {
448 Err(tokio::sync::watch::error::SendError(Some(Err(pyerr)))) => {
449 monarch_with_gil_blocking(|py| {
450 let tb = if let Some(tb) = traceback {
451 format_traceback(py, &tb).unwrap()
452 } else {
453 "None (run with `MONARCH_HYPERACTOR_ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK=1` to see a traceback here)\n".into()
454 };
455 tracing::error!(
456 "PythonTask errored but is not being awaited; this will not crash your program, but indicates that \
457 something went wrong.\n{}\nTraceback where the task was created (most recent call last):\n{}",
458 SerializablePyErr::from(py, &pyerr),
459 tb
460 );
461 });
462 }
463 _ => {}
464 };
465}
466
467#[pymethods]
468impl PyPythonTask {
469 /// Run this task to completion synchronously on the embedded
470 /// Tokio runtime.
471 ///
472 /// This blocks the calling Python thread until the underlying
473 /// Rust future completes. Consumes the task (like `spawn`): the
474 /// `PyPythonTask` cannot be used again.
475 fn block_on(mut slf: PyRefMut<PyPythonTask>, py: Python<'_>) -> PyResult<Py<PyAny>> {
476 let task = slf.take_task()?;
477
478 // Mutable borrows of Python objects must be dropped before
479 // releasing the GIL. `signal_safe_block_on` releases the GIL;
480 // holding `slf` across that would make other Python access
481 // throw.
482 drop(slf);
483 signal_safe_block_on(py, task)?
484 }
485
486 /// Spawn this task onto the Tokio runtime and return a `Shared`
487 /// handle.
488 ///
489 /// The returned `Shared` is awaitable *inside* the
490 /// `from_coroutine` world, or may be waited on synchronously via
491 /// `Shared.block_on()`. Consumes the task.
492 pub(crate) fn spawn(&mut self) -> PyResult<PyShared> {
493 let (tx, rx) = watch::channel(None);
494 let traceback = self.traceback()?;
495 let traceback1 = self.traceback()?;
496 let task = self.take_task()?;
497 let handle = get_tokio_runtime().spawn(async move {
498 send_result(tx, task.await, traceback1);
499 });
500 Ok(PyShared {
501 rx,
502 handle: Some(handle),
503 abort: false,
504 traceback,
505 })
506 }
507
508 /// Implement Python's `await` protocol for `PythonTask`.
509 ///
510 /// This is only supported inside the `pytokio` world driven by
511 /// `PythonTask.from_coroutine`; attempting to `await` a
512 /// `PythonTask` while an `asyncio` event loop is running is an
513 /// error.
514 fn __await__(slf: PyRef<'_, Self>) -> PyResult<PythonTaskAwaitIterator> {
515 let py = slf.py();
516 let l = pyo3_async_runtimes::get_running_loop(py);
517 if l.is_ok() {
518 return Err(PyRuntimeError::new_err(
519 "Attempting to __await__ a PythonTask when the asyncio event loop is active. PythonTask objects should only be awaited in coroutines passed to PythonTask.from_coroutine",
520 ));
521 }
522
523 Ok(PythonTaskAwaitIterator::new(slf.into_py_any(py)?))
524 }
525
526 /// Wrap a Python coroutine into a `PythonTask` that is driven by
527 /// Tokio.
528 ///
529 /// This converts `coro` into its await-iterator
530 /// (`coro.__await__()`), then repeatedly resumes it via
531 /// `send`/`throw`. Whenever the coroutine yields a
532 /// `PythonTask`/`Shared`, we extract its underlying Rust future,
533 /// `await` it on Tokio, and feed the result back into the
534 /// coroutine on the next iteration.
535 ///
536 /// Inside this coroutine, `await` is only supported for pytokio
537 /// values (`PythonTask` / `Shared`). Awaiting arbitrary Python
538 /// awaitables (e.g. `asyncio` futures) is an error.
539 ///
540 /// The current Monarch `context()` is captured at construction
541 /// time and restored while running the coroutine so `context()`
542 /// inside the task reflects the call site that created it (even
543 /// across Tokio thread hops).
544 #[staticmethod]
545 fn from_coroutine(py: Python<'_>, coro: Py<PyAny>) -> PyResult<PyPythonTask> {
546 // context() used inside a PythonTask should inherit the value of
547 // context() from the context in which the PythonTask was constructed.
548 // We need to do this manually because the value of the contextvar isn't
549 // maintained inside the tokio runtime.
550 let monarch_context = context(py).call0()?.unbind();
551 PyPythonTask::new(async move {
552 let (coroutine_iterator, none) = monarch_with_gil(|py| {
553 coro.into_bound(py)
554 .call_method0("__await__")
555 .map(|x| (x.unbind(), py.None()))
556 })
557 .await?;
558 let mut last: PyResult<Py<PyAny>> = Ok(none);
559 enum Action {
560 Return(Py<PyAny>),
561 Wait(Pin<Box<dyn Future<Output = Result<Py<PyAny>, PyErr>> + Send + 'static>>),
562 }
563 loop {
564 let action = monarch_with_gil(|py| -> PyResult<Action> {
565 // We may be executing in a new thread at this point, so we need to set the value
566 // of context().
567 let _context = actor_mesh_module(py).getattr("_context")?;
568 let old_context = _context.call_method1("get", (PyNone::get(py),))?;
569 _context
570 .call_method1("set", (monarch_context.clone_ref(py),))
571 .expect("failed to set _context");
572
573 let result = match last {
574 Ok(value) => coroutine_iterator.bind(py).call_method1("send", (value,)),
575 Err(pyerr) => coroutine_iterator
576 .bind(py)
577 .call_method1("throw", (pyerr.into_value(py),)),
578 };
579
580 // Reset context() so that when this tokio thread yields, it has its original state.
581 _context
582 .call_method1("set", (old_context,))
583 .expect("failed to restore _context");
584 match result {
585 Ok(task) => Ok(Action::Wait(
586 task.extract::<Py<PyPythonTask>>()
587 .and_then(|t| t.borrow_mut(py).take_task())
588 .unwrap_or_else(|pyerr| Box::pin(async move { Err(pyerr) })),
589 )),
590 Err(err) => {
591 let err = err.into_pyobject(py)?.into_any();
592 if err.is_instance_of::<PyStopIteration>() {
593 Ok(Action::Return(
594 err.into_pyobject(py)?.getattr("value")?.unbind(),
595 ))
596 } else {
597 Err(PyErr::from_value(err))
598 }
599 }
600 }
601 })
602 .await?;
603 match action {
604 Action::Return(x) => {
605 return Ok(x);
606 }
607 Action::Wait(task) => {
608 last = task.await;
609 }
610 };
611 }
612 })
613 }
614
615 /// Wrap this task with a timeout and return a new `PythonTask`.
616 ///
617 /// Consumes the original task. If it does not complete within
618 /// `seconds`, the returned task fails with `TimeoutError`.
619 fn with_timeout(&mut self, seconds: f64) -> PyResult<PyPythonTask> {
620 let tb = self.traceback()?;
621 let task = self.take_task()?;
622 PyPythonTask::new_with_traceback(
623 async move {
624 tokio::time::timeout(std::time::Duration::from_secs_f64(seconds), task)
625 .await
626 .map_err(|_| PyTimeoutError::new_err(()))?
627 },
628 tb,
629 )
630 }
631
632 /// Run a Python callable on Tokio's blocking thread pool and
633 /// return a `Shared` handle.
634 ///
635 /// This is for CPU-bound or otherwise blocking Python work that
636 /// must not run on a Tokio async worker thread. The callable `f`
637 /// is executed via `tokio::spawn_blocking`, and its result (or
638 /// raised exception) is delivered through the returned `Shared`.
639 ///
640 /// The current Monarch `context()` is captured and restored while
641 /// running `f` so calls to `context()` from inside `f` see the
642 /// originating actor context.
643 #[staticmethod]
644 fn spawn_blocking(py: Python<'_>, f: Py<PyAny>) -> PyResult<PyShared> {
645 let (tx, rx) = watch::channel(None);
646 let traceback = current_traceback()?;
647 let traceback1 = traceback.as_ref().map_or_else(
648 || None,
649 |t| monarch_with_gil_blocking(|py| Some(t.clone_ref(py))),
650 );
651 let monarch_context = context(py).call0()?.unbind();
652 // The `_context` contextvar needs to be propagated through to the thread that
653 // runs the blocking tokio task. Upon completion, the original value of `_context`
654 // is restored.
655 let handle = get_tokio_runtime().spawn_blocking(move || {
656 let result = monarch_with_gil_blocking(|py| {
657 let _context = actor_mesh_module(py).getattr("_context")?;
658 let old_context = _context.call_method1("get", (PyNone::get(py),))?;
659 _context
660 .call_method1("set", (monarch_context.clone_ref(py),))
661 .expect("failed to set _context");
662 let result = f.call0(py);
663 _context
664 .call_method1("set", (old_context,))
665 .expect("failed to restore _context");
666 result
667 });
668 send_result(tx, result, traceback1);
669 });
670 Ok(PyShared {
671 rx,
672 handle: Some(handle),
673 abort: false,
674 traceback,
675 })
676 }
677
678 /// Wait for the first task to complete and return `(result,
679 /// index)`.
680 ///
681 /// This consumes all input tasks (each is `take_task()`'d). The
682 /// returned task resolves to a tuple of the winning task's result
683 /// and its index in the input list.
684 #[staticmethod]
685 fn select_one(mut tasks: Vec<PyRefMut<'_, PyPythonTask>>) -> PyResult<PyPythonTask> {
686 if tasks.is_empty() {
687 return Err(PyValueError::new_err("Cannot select from empty task list"));
688 }
689
690 let mut futures = Vec::new();
691 for task_ref in tasks.iter_mut() {
692 futures.push(task_ref.take_task()?);
693 }
694
695 PyPythonTask::new(async move {
696 let (result, index, _remaining) = futures::future::select_all(futures).await;
697 result.map(|r| (r, index))
698 })
699 }
700
701 /// Sleep for `seconds` on the Tokio runtime.
702 #[staticmethod]
703 fn sleep(seconds: f64) -> PyResult<PyPythonTask> {
704 PyPythonTask::new(async move {
705 tokio::time::sleep(tokio::time::Duration::from_secs_f64(seconds)).await;
706 Ok(())
707 })
708 }
709
710 /// Support `PythonTask[T]` type syntax on the Python side (no
711 /// runtime effect).
712 #[classmethod]
713 fn __class_getitem__(cls: &Bound<'_, PyType>, _arg: Py<PyAny>) -> Py<PyAny> {
714 cls.clone().unbind().into()
715 }
716}
717
718/// Awaitable handle to a spawned background Tokio task.
719///
720/// `Shared` is returned by `PythonTask.spawn()` /
721/// `spawn_abortable()`. It carries a `watch` receiver that is
722/// fulfilled exactly once with the task's `PyResult<Py<PyAny>>`.
723///
724/// Usage:
725/// - `await shared` inside the `PythonTask.from_coroutine(...)`
726/// world, or
727/// - `shared.block_on()` to wait synchronously.
728///
729/// If `abort` is true (from `spawn_abortable()`), dropping this
730/// object aborts the underlying Tokio task via its `JoinHandle`.
731#[pyclass(
732 name = "Shared",
733 module = "monarch._rust_bindings.monarch_hyperactor.pytokio"
734)]
735pub struct PyShared {
736 /// One-shot result channel. Starts as `None`; becomes
737 /// `Some(Ok(obj))` or `Some(Err(pyerr))` when the background task
738 /// completes.
739 rx: watch::Receiver<Option<PyResult<Py<PyAny>>>>,
740
741 /// Handle for the spawned Tokio task that is producing `rx`’s
742 /// result. `None` for `Shared.from_value(...)`.
743 handle: Option<JoinHandle<()>>,
744
745 /// If true, dropping `Shared` aborts the background task via
746 /// `handle.abort()`. This is set by `spawn_abortable()`.
747 abort: bool,
748
749 /// Optional creation-site traceback (captured when enabled) used
750 /// when logging un-awaited errors / for derived tasks.
751 traceback: Option<Py<PyAny>>,
752}
753
754/// If this `Shared` was created via `spawn_abortable()`, abort the
755/// underlying Tokio task on drop.
756///
757/// This prevents abandoned background work from running forever when
758/// no receivers remain. We guard against panics during interpreter
759/// shutdown / runtime teardown.
760impl Drop for PyShared {
761 fn drop(&mut self) {
762 if self.abort {
763 // When the PyShared is dropped, we don't want the background task to go
764 // forever, because nothing will wait on the rx.
765 if let Some(h) = self.handle.as_ref() {
766 // Guard against panics during interpreter shutdown when tokio runtime may be gone
767 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
768 h.abort();
769 }));
770 }
771 }
772 }
773}
774
775#[pymethods]
776impl PyShared {
777 /// Convert this `Shared` handle into a `PythonTask` that waits
778 /// for its result.
779 ///
780 /// Internally, this clones the `watch::Receiver` and returns a
781 /// new one-shot task that:
782 /// 1) waits for the sender to publish `Some(result)`, and then
783 /// 2) returns/clones the stored `Py<PyAny>` / `PyErr` under the
784 /// GIL.
785 ///
786 /// Cloning the receiver allows multiple independent awaiters to
787 /// observe the same completion.
788 pub(crate) fn task(&self) -> PyResult<PyPythonTask> {
789 // watch channels start unchanged, and when a value is sent to them signal
790 // the receivers `changed` future.
791 // By cloning the rx before awaiting it,
792 // we can have multiple awaiters get triggered by the same change.
793 // self.rx will always be in the state where it hasn't see the change yet.
794 let mut rx = self.rx.clone();
795 PyPythonTask::new_with_traceback(
796 async move {
797 // Check if a value is already available (not None).
798 // The channel is initialized with None, and the sender sets it to Some(result).
799 // If it's still None, wait for a change. Otherwise, the value is ready.
800 if rx.borrow().is_none() {
801 rx.changed().await.map_err(to_py_error)?;
802 }
803 // We need to hold the GIL when cloning Python objects (Py<PyAny> and PyErr).
804 monarch_with_gil(|py| {
805 let borrowed = rx.borrow();
806 match borrowed.as_ref().unwrap() {
807 Ok(v) => Ok(v.bind(py).clone().unbind()),
808 Err(err) => Err(err.clone_ref(py)),
809 }
810 })
811 .await
812 },
813 self.traceback.as_ref().map_or_else(
814 || None,
815 |t| monarch_with_gil_blocking(|py| Some(t.clone_ref(py))),
816 ),
817 )
818 }
819
820 /// Implement Python's `await` protocol for `Shared`.
821 ///
822 /// This delegates to `self.task()` (which returns a `PythonTask`
823 /// that waits for the background result) and then returns that
824 /// task's await-iterator.
825 ///
826 /// Note: `await shared` is only supported inside the
827 /// `PythonTask.from_coroutine(...)` world (because it ultimately
828 /// awaits a `PythonTask`).
829 fn __await__(&mut self, py: Python<'_>) -> PyResult<PythonTaskAwaitIterator> {
830 let task = self.task()?;
831 Ok(PythonTaskAwaitIterator::new(task.into_py_any(py)?))
832 }
833
834 /// Wait synchronously for this `Shared` to resolve.
835 ///
836 /// This blocks the calling Python thread until the underlying
837 /// background task has published its result into the watch
838 /// channel, then returns that `Py<PyAny>` (or raises the stored
839 /// Python exception).
840 ///
841 /// If the value is already available, returns immediately without
842 /// blocking. This is important for cases where `block_on` is called
843 /// from within a tokio runtime (e.g., during unpickling on a worker
844 /// thread) - we can't call `runtime.block_on()` from within a runtime.
845 pub fn block_on(slf: PyRef<PyShared>, py: Python<'_>) -> PyResult<Py<PyAny>> {
846 // Check if value is already available - return immediately if so.
847 // This avoids calling into the tokio runtime when unnecessary,
848 // which is critical when called from within a tokio worker thread.
849 if let Some(value) = slf.poll()? {
850 return Ok(value);
851 }
852
853 let task = slf.task()?.take_task()?;
854 // Explicitly drop the reference so that if another thread attempts to borrow
855 // this object mutably during signal_safe_block_on, it won't throw an exception.
856 drop(slf);
857 signal_safe_block_on(py, task)?
858 }
859
860 /// Support `Shared[T]` type syntax on the Python side (no runtime
861 /// effect).
862 #[classmethod]
863 fn __class_getitem__(cls: &Bound<'_, PyType>, _arg: Py<PyAny>) -> Py<PyAny> {
864 cls.clone().unbind().into()
865 }
866
867 /// Non-blocking check for completion.
868 ///
869 /// Returns:
870 /// - `Ok(None)` if the background task has not finished yet,
871 /// - `Ok(Some(obj))` if it completed successfully,
872 /// - `Err(pyerr)` if it completed with an exception.
873 ///
874 /// This does not wait; it only inspects the current watch value.
875 pub(crate) fn poll(&self) -> PyResult<Option<Py<PyAny>>> {
876 let b = self.rx.borrow();
877 let r = b.as_ref();
878 match r {
879 None => Ok(None),
880 Some(r) => Python::attach(|py| match r {
881 Ok(v) => Ok(Some(v.clone_ref(py))),
882 Err(err) => Err(err.clone_ref(py)),
883 }),
884 }
885 }
886
887 /// Construct a `Shared` that is already completed with `value`.
888 ///
889 /// This is a convenience for APIs that want to return a `Shared`
890 /// without spawning a background task. The returned handle has no
891 /// `JoinHandle` and will immediately yield `value` via `poll()`,
892 /// `await` (inside `from_coroutine`), or `block_on()`.
893 #[classmethod]
894 fn from_value(_cls: &Bound<'_, PyType>, value: Py<PyAny>) -> PyResult<Self> {
895 let (tx, rx) = watch::channel(None);
896 tx.send(Some(Ok(value))).map_err(to_py_error)?;
897 Ok(Self {
898 rx,
899 handle: None,
900 abort: false,
901 traceback: None,
902 })
903 }
904
905 /// Pickle protocol support for PyShared.
906 ///
907 /// This implements the pickle reduce protocol:
908 /// - If the shared is finished, pickle as (Shared.from_value, (value,))
909 /// - If pending pickles are allowed, defer pickling and return (pop_pending_pickle, ())
910 /// - Otherwise, block on the shared and pickle as (Shared.from_value, (value,))
911 fn __reduce__<'py>(
912 slf: &Bound<'py, Self>,
913 py: Python<'py>,
914 ) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyTuple>)> {
915 reduce_shared(py, slf)
916 }
917}
918
919/// Return true if the current thread is executing within a Tokio
920/// runtime context.
921///
922/// This checks whether `tokio::runtime::Handle::try_current()`
923/// succeeds.
924#[pyfunction]
925fn is_tokio_thread() -> bool {
926 tokio::runtime::Handle::try_current().is_ok()
927}
928
929/// Register the pytokio Python bindings into the given module.
930///
931/// This wires up the exported pyclasses (`PythonTask`, `Shared`)
932/// and module-level functions used by the Monarch Python layer.
933pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
934 hyperactor_mod.add_class::<PyPythonTask>()?;
935 hyperactor_mod.add_class::<PyShared>()?;
936 let f = wrap_pyfunction!(is_tokio_thread, hyperactor_mod)?;
937 f.setattr(
938 "__module__",
939 "monarch._rust_bindings.monarch_hyperactor.pytokio",
940 )?;
941 hyperactor_mod.add_function(f)?;
942
943 Ok(())
944}
945
946/// Ensure the embedded Python interpreter is initialized exactly
947/// once.
948///
949/// Safe to call from multiple threads, multiple times.
950#[cfg(test)]
951pub(crate) fn ensure_python() {
952 static INIT: std::sync::OnceLock<()> = std::sync::OnceLock::new();
953 INIT.get_or_init(|| {
954 pyo3::Python::initialize();
955 });
956}
957
958#[cfg(test)]
959// Helper: let us "await" a `PyPythonTask` in Rust.
960//
961// Semantics:
962// - consume the `PyPythonTask`,
963// - take the inner future,
964// - `.await` it on tokio to get `Py<PyAny>`,
965// - turn that into `Py<T>`.
966pub(crate) trait AwaitPyExt {
967 async fn await_py<T: PyClass>(self) -> Result<Py<T>, PyErr>;
968
969 // For tasks whose future just resolves to (), i.e. no object,
970 // just "did it work?"
971 async fn await_unit(self) -> Result<(), PyErr>;
972}
973
974#[cfg(test)]
975impl AwaitPyExt for PyPythonTask {
976 async fn await_py<T: PyClass>(mut self) -> Result<Py<T>, PyErr> {
977 // Take ownership of the inner future.
978 let fut = self
979 .take_task()
980 .expect("PyPythonTask already consumed in await_py");
981
982 // Await a Result<Py<PyAny>, PyErr>.
983 let py_any: Py<PyAny> = fut.await?;
984
985 // Convert Py<PyAny> -> Py<T>.
986 monarch_with_gil(|py| {
987 let bound_any = py_any.bind(py);
988
989 // Try extract a Py<T>.
990 let obj: Py<T> = bound_any
991 .extract::<Py<T>>()
992 .expect("spawn() did not return expected Python type");
993
994 Ok(obj)
995 })
996 .await
997 }
998
999 async fn await_unit(mut self) -> Result<(), PyErr> {
1000 let fut = self
1001 .take_task()
1002 .expect("PyPythonTask already consumed in await_unit");
1003
1004 // Await it. This still gives us a Py<PyAny> because
1005 // Python-side return values are always materialized as 'some
1006 // object'. For "no value" / None, that's just a PyAny(None).
1007 let py_any: Py<PyAny> = fut.await?;
1008
1009 // We don't need to extract anything. Just drop it.
1010 drop(py_any);
1011
1012 Ok(())
1013 }
1014}