monarch_hyperactor/
pytokio.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9/// `pytokio` is Monarch's Python <-> Tokio async bridge.
10///
11/// It provides a small, *non-asyncio* async world where Python code
12/// can *compose* Rust/Tokio futures using `await`.
13///
14/// ## The core idea
15///
16/// In `pytokio`:
17///
18/// - `PythonTask` = a one-shot Rust/Tokio future that produces a
19///   Python value.
20/// - `from_coroutine` = wraps a Python coroutine as a Rust future
21///   that drives it.
22/// - `Shared` = an awaitable handle to a spawned background Tokio
23///   task.
24///
25/// More concretely:
26///
27/// - Rust bindings return a Python-visible `PythonTask`
28///   (`PyPythonTask`), which wraps a Rust `PythonTask` holding a
29///   boxed Tokio future returning `PyResult<Py<PyAny>>`.
30/// - `PythonTask.from_coroutine(coro)` wraps a *Python coroutine* as
31///   a `PythonTask` by creating a Rust/Tokio future that drives
32///   `coro.__await__()` (via `send`/`throw`) and awaits the
33///   `PythonTask`s it yields.
34/// - Python code may `await` a `PythonTask` / `Shared` **only** when
35///   running under `PythonTask.from_coroutine(...)`. Awaiting
36///   arbitrary Python awaitables (e.g. `asyncio` futures) is an
37///   error.
38/// - Calling `task.spawn()` / `spawn_abortable()` returns a `Shared`
39///   (`PyShared`), which yields the result of the background Tokio
40///   task running the original `PythonTask`.
41///
42/// This is intentionally *not* a general-purpose async bridge: it’s a
43/// way to use Python syntax to drive and compose Tokio futures.
44///
45/// ## Wrapping a Python coroutine
46///
47/// ```ignore
48/// async def work():
49///     x = await some_rust_binding()      # must yield PythonTask / Shared
50///     await PythonTask.sleep(0.1)        # also a PythonTask
51///     return x
52///
53/// task = PythonTask.from_coroutine(work())
54/// result = task.block_on()              # block the calling Python thread while a
55///                                       # Tokio runtime drives the task to completion
56/// ```
57///
58/// `from_coroutine` drives the coroutine by repeatedly resuming it
59/// and awaiting the `PythonTask`s it yields, using a Tokio runtime.
60///
61/// ## Spawning
62///
63/// `spawn()` runs a `PythonTask` on a background Tokio task and
64/// returns a `Shared` handle.
65///
66/// To `await` the handle, you must still be inside a
67/// `from_coroutine`-driven coroutine:
68///
69/// ```ignore
70/// async def work():
71///     task = some_rust_binding()
72///     shared = task.spawn()
73///     # ... do other work ...
74///     result = await shared             # valid here (inside from_coroutine world)
75///     return result
76///
77/// result = PythonTask.from_coroutine(work()).block_on()
78/// ```
79///
80/// In synchronous contexts, you can wait for a spawned task without
81/// `from_coroutine`:
82///
83/// ```ignore
84/// shared = task.spawn()
85/// result = shared.block_on()            # blocks the calling Python thread
86/// ```
87///
88/// If `spawn_abortable()` is used, dropping the returned `Shared`
89/// aborts the underlying Tokio task.
90///
91/// ## Context propagation
92///
93/// `from_coroutine` preserves Monarch’s `context()` across Tokio
94/// thread hops, so code calling `context()` inside a `PythonTask`
95/// sees the same actor context as the call site that constructed the
96/// task.
97use std::error::Error;
98use std::future::Future;
99use std::pin::Pin;
100
101use hyperactor_config::CONFIG;
102use hyperactor_config::ConfigAttr;
103use hyperactor_config::attrs::declare_attrs;
104use monarch_types::SerializablePyErr;
105use monarch_types::py_global;
106use pyo3::IntoPyObjectExt;
107#[cfg(test)]
108use pyo3::PyClass;
109use pyo3::exceptions::PyRuntimeError;
110use pyo3::exceptions::PyStopIteration;
111use pyo3::exceptions::PyTimeoutError;
112use pyo3::exceptions::PyValueError;
113use pyo3::prelude::*;
114use pyo3::types::PyNone;
115use pyo3::types::PyString;
116use pyo3::types::PyTuple;
117use pyo3::types::PyType;
118use tokio::sync::Mutex;
119use tokio::sync::watch;
120use tokio::task::JoinHandle;
121
122use crate::pickle::reduce_shared;
123use crate::runtime::get_tokio_runtime;
124use crate::runtime::monarch_with_gil;
125use crate::runtime::monarch_with_gil_blocking;
126use crate::runtime::signal_safe_block_on;
127
128declare_attrs! {
129    /// If true, capture a Python stack trace at `PythonTask` creation
130    /// time and log it when a spawned task errors but is never
131    /// awaited/polled.
132    @meta(CONFIG = ConfigAttr::new(
133        Some("MONARCH_HYPERACTOR_ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK".to_string()),
134        Some("enable_unawaited_python_task_traceback".to_string()),
135    ))
136    pub attr ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK: bool = false;
137}
138
139// Import Python helpers used for actor context propagation.
140// `context()` returns the current Monarch actor context.
141// `actor_mesh` is the module that owns the `_context` contextvar we
142// must manually set/restore when driving coroutines on Tokio threads.
143py_global!(context, "monarch._src.actor.actor_mesh", "context");
144py_global!(actor_mesh_module, "monarch._src.actor", "actor_mesh");
145
146/// Capture the current Python stack trace (creation call site) if
147/// `ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK` is enabled.
148///
149/// Returns `None` when disabled to avoid the overhead of
150/// `traceback.extract_stack()`.
151fn current_traceback() -> PyResult<Option<Py<PyAny>>> {
152    if hyperactor_config::global::get(ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK) {
153        monarch_with_gil_blocking(|py| {
154            Ok(Some(
155                py.import("traceback")?
156                    .call_method0("extract_stack")?
157                    .unbind(),
158            ))
159        })
160    } else {
161        Ok(None)
162    }
163}
164
165/// Format a captured traceback (from `traceback.extract_stack()`) as
166/// a single string suitable for logging.
167fn format_traceback(py: Python<'_>, traceback: &Py<PyAny>) -> PyResult<String> {
168    let tb = py
169        .import("traceback")?
170        .call_method1("format_list", (traceback,))?;
171    PyString::new(py, "")
172        .call_method1("join", (tb,))?
173        .extract::<String>()
174}
175
176/// Helper struct to make a Rust/Tokio future (returning a Python
177/// result) passable in an actor message.
178///
179/// The future resolves to `PyResult<Py<PyAny>>` so it can return a
180/// Python value or raise a Python exception, and it is `Send +
181/// 'static` so it can cross thread/actor boundaries.
182///
183/// Also so that we don't have to write this massive type signature
184/// everywhere.
185pub(crate) struct PythonTask {
186    /// Boxed, pinned Rust/Tokio future producing a Python result,
187    /// protected so it can be taken/consumed exactly once when the
188    /// task is driven.
189    // Type decoder ring:
190    //
191    // Mutex<Pin<Box<dyn Future<Output = PyResult<Py<PyAny>>> + Send + 'static>>>
192    //   │     │   │   │                                        │      │
193    //   │     │   │   │                                        │      └─ owns all data, no dangling refs
194    //   │     │   │   │                                        └─ can cross thread boundaries
195    //   │     │   │   └─ any future type (type-erased)
196    //   │     │   └─ heap-allocated (because unsized)
197    //   │     └─ immovable (safe to poll self-referential futures)
198    //   └─ exclusive access for consumption
199    future: Mutex<Pin<Box<dyn Future<Output = PyResult<Py<PyAny>>> + Send + 'static>>>,
200
201    /// Optional Python stack trace captured at task construction
202    /// time, used to annotate logs when a spawned task errors but
203    /// nobody awaits/polls it.
204    traceback: Option<Py<PyAny>>,
205}
206
207impl PythonTask {
208    /// Construct a `PythonTask` from a Rust/Tokio future and an
209    /// optional captured Python traceback.
210    ///
211    /// The future is boxed and pinned so it can be stored in the
212    /// struct and later driven safely.
213    fn new_with_traceback(
214        fut: impl Future<Output = PyResult<Py<PyAny>>> + Send + 'static,
215        traceback: Option<Py<PyAny>>,
216    ) -> Self {
217        Self {
218            future: Mutex::new(Box::pin(fut)),
219            traceback,
220        }
221    }
222
223    /// Construct a `PythonTask`, capturing a creation-site traceback
224    /// if enabled by `ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK`.
225    pub(crate) fn new(
226        fut: impl Future<Output = PyResult<Py<PyAny>>> + Send + 'static,
227    ) -> PyResult<Self> {
228        Ok(Self::new_with_traceback(fut, current_traceback()?))
229    }
230
231    /// Return the optional captured creation-site traceback (if
232    /// enabled).
233    fn traceback(&self) -> &Option<Py<PyAny>> {
234        &self.traceback
235    }
236
237    /// Consume the task and return the boxed, pinned future.
238    ///
239    /// This is a one-shot operation: it moves the future out of the
240    /// struct so it can be driven to completion.
241    pub(crate) fn take(
242        self,
243    ) -> Pin<Box<dyn Future<Output = PyResult<Py<PyAny>>> + Send + 'static>> {
244        self.future.into_inner()
245    }
246}
247
248impl std::fmt::Debug for PythonTask {
249    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250        f.debug_struct("PythonTask")
251            .field("future", &"<PythonFuture>")
252            .finish()
253    }
254}
255
256/// Python-visible wrapper for a one-shot `PythonTask`.
257///
258/// Exposed to Python as
259/// `monarch._rust_bindings.monarch_hyperactor.pytokio.PythonTask`.
260/// This object owns the underlying Rust task and is *consumed* when
261/// it is run (e.g. via `spawn()`, `spawn_abortable()`, or
262/// `block_on()`), hence `inner: Option<_>`.
263#[pyclass(
264    name = "PythonTask",
265    module = "monarch._rust_bindings.monarch_hyperactor.pytokio"
266)]
267pub struct PyPythonTask {
268    inner: Option<PythonTask>,
269}
270
271impl From<PythonTask> for PyPythonTask {
272    fn from(task: PythonTask) -> Self {
273        Self { inner: Some(task) }
274    }
275}
276
277/// Minimal await-iterator used to implement Python's `__await__`
278/// protocol for pytokio.
279///
280/// This iterator yields the task object exactly once. The Rust-side
281/// coroutine driver (`from_coroutine`) resumes the Python coroutine
282/// and expects it to yield a `PythonTask` (or `Shared`) object back
283/// to Rust.
284#[pyclass(
285    name = "PythonTaskAwaitIterator",
286    module = "monarch._rust_bindings.monarch_hyperactor.pytokio"
287)]
288struct PythonTaskAwaitIterator {
289    value: Option<Py<PyAny>>,
290}
291
292impl PythonTaskAwaitIterator {
293    /// Create an await-iterator that will yield `task` exactly once.
294    fn new(task: Py<PyAny>) -> PythonTaskAwaitIterator {
295        PythonTaskAwaitIterator { value: Some(task) }
296    }
297}
298
299#[pymethods]
300impl PythonTaskAwaitIterator {
301    /// First `send(...)` yields the stored task; subsequent sends
302    /// raise `StopIteration`.
303    ///
304    /// Python's await machinery calls `send(None)` to advance the
305    /// iterator.
306    fn send(&mut self, value: Py<PyAny>) -> PyResult<Py<PyAny>> {
307        self.value
308            .take()
309            .ok_or_else(|| PyStopIteration::new_err((value,)))
310    }
311
312    /// Convert the thrown Python exception value into a `PyErr` and
313    /// surface it to Rust.
314    fn throw(&mut self, value: Py<PyAny>) -> PyResult<Py<PyAny>> {
315        Err(monarch_with_gil_blocking(|py| {
316            PyErr::from_value(value.into_bound(py))
317        }))
318    }
319
320    /// Iterator protocol: `next(it)` is equivalent to
321    /// `it.send(None)`.
322    fn __next__(&mut self, py: Python<'_>) -> PyResult<Py<PyAny>> {
323        self.send(py.None())
324    }
325}
326
327impl PyPythonTask {
328    /// Construct a Python-visible `PythonTask` from a Rust future,
329    /// attaching an explicit creation-site traceback (if provided).
330    ///
331    /// The input future produces a Rust value `T`; on completion we
332    /// reacquire the GIL and convert `T` into a Python object
333    /// (`Py<PyAny>`).
334    fn new_with_traceback<F, T>(fut: F, traceback: Option<Py<PyAny>>) -> PyResult<Self>
335    where
336        F: Future<Output = PyResult<T>> + Send + 'static,
337        T: for<'py> IntoPyObject<'py> + Send,
338    {
339        Ok(PythonTask::new_with_traceback(
340            async {
341                let result = fut.await?;
342                monarch_with_gil(|py| result.into_py_any(py)).await
343            },
344            traceback,
345        )
346        .into())
347    }
348
349    /// Construct a `PythonTask`, capturing a creation-site traceback
350    /// if enabled.
351    ///
352    /// See `new_with_traceback` for conversion semantics (`T` ->
353    /// Python object under the GIL).
354    pub fn new<F, T>(fut: F) -> PyResult<Self>
355    where
356        F: Future<Output = PyResult<T>> + Send + 'static,
357        T: for<'py> IntoPyObject<'py> + Send,
358    {
359        Self::new_with_traceback(fut, current_traceback()?)
360    }
361}
362
363// Helper: convert a Rust error into a generic Python ValueError.
364fn to_py_error<T>(e: T) -> PyErr
365where
366    T: Error,
367{
368    PyErr::new::<PyValueError, _>(e.to_string())
369}
370
371impl PyPythonTask {
372    /// Consume this `PythonTask` and return the underlying Rust
373    /// future.
374    ///
375    /// This is a one-shot operation: after calling `take_task`, the
376    /// `PyPythonTask` is considered *consumed* and cannot be
377    /// spawned/awaited/blocked-on again.
378    pub fn take_task(
379        &mut self,
380    ) -> PyResult<Pin<Box<dyn Future<Output = Result<Py<PyAny>, PyErr>> + Send + 'static>>> {
381        self.inner
382            .take()
383            .map(|task| task.take())
384            .ok_or_else(|| PyValueError::new_err("PythonTask already consumed"))
385    }
386
387    /// Return the captured creation-site traceback (if enabled),
388    /// cloning it under the GIL.
389    ///
390    /// Fails if the task has already been consumed.
391    fn traceback(&self) -> PyResult<Option<Py<PyAny>>> {
392        if let Some(task) = &self.inner {
393            Ok(monarch_with_gil_blocking(|py| {
394                task.traceback().as_ref().map(|t| t.clone_ref(py))
395            }))
396        } else {
397            Err(PyValueError::new_err("PythonTask already consumed"))
398        }
399    }
400
401    /// Spawn this task onto the Tokio runtime and return a `Shared`
402    /// handle that *aborts on drop*.
403    ///
404    /// Use this when the underlying future is *abort-safe*
405    /// (cancellation-safe): dropping the returned `Shared` will call
406    /// `JoinHandle::abort()`, preventing the background task from
407    /// running forever.
408    ///
409    /// This is especially useful for long-lived or periodic tasks
410    /// (e.g. timers) where "nobody is awaiting the result anymore"
411    /// should stop the work.
412    ///
413    /// Like `spawn()`, this consumes the `PyPythonTask` (it can only
414    /// be spawned once).
415    pub(crate) fn spawn_abortable(&mut self) -> PyResult<PyShared> {
416        let (tx, rx) = watch::channel(None);
417        let traceback = self.traceback()?;
418        let traceback1 = self.traceback()?;
419        let task = self.take_task()?;
420        let handle = get_tokio_runtime().spawn(async move {
421            send_result(tx, task.await, traceback1);
422        });
423        Ok(PyShared {
424            rx,
425            handle: Some(handle),
426            abort: true,
427            traceback,
428        })
429    }
430}
431
432/// Publish a completed task result to the `watch` channel.
433///
434/// If the receiver has already been dropped, `watch::Sender::send`
435/// returns the unsent value as `SendError`. We treat that as "nobody
436/// will ever observe this result".
437///
438/// In the special case where the unobserved result is an error, we
439/// log it (and include the task creation traceback when available) to
440/// avoid silently losing failures from background tasks.
441fn send_result(
442    tx: tokio::sync::watch::Sender<Option<PyResult<Py<PyAny>>>>,
443    result: PyResult<Py<PyAny>>,
444    traceback: Option<Py<PyAny>>,
445) {
446    // a SendErr just means that there are no consumers of the value left.
447    match tx.send(Some(result)) {
448        Err(tokio::sync::watch::error::SendError(Some(Err(pyerr)))) => {
449            monarch_with_gil_blocking(|py| {
450                let tb = if let Some(tb) = traceback {
451                    format_traceback(py, &tb).unwrap()
452                } else {
453                    "None (run with `MONARCH_HYPERACTOR_ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK=1` to see a traceback here)\n".into()
454                };
455                tracing::error!(
456                    "PythonTask errored but is not being awaited; this will not crash your program, but indicates that \
457                    something went wrong.\n{}\nTraceback where the task was created (most recent call last):\n{}",
458                    SerializablePyErr::from(py, &pyerr),
459                    tb
460                );
461            });
462        }
463        _ => {}
464    };
465}
466
467#[pymethods]
468impl PyPythonTask {
469    /// Run this task to completion synchronously on the embedded
470    /// Tokio runtime.
471    ///
472    /// This blocks the calling Python thread until the underlying
473    /// Rust future completes. Consumes the task (like `spawn`): the
474    /// `PyPythonTask` cannot be used again.
475    fn block_on(mut slf: PyRefMut<PyPythonTask>, py: Python<'_>) -> PyResult<Py<PyAny>> {
476        let task = slf.take_task()?;
477
478        // Mutable borrows of Python objects must be dropped before
479        // releasing the GIL. `signal_safe_block_on` releases the GIL;
480        // holding `slf` across that would make other Python access
481        // throw.
482        drop(slf);
483        signal_safe_block_on(py, task)?
484    }
485
486    /// Spawn this task onto the Tokio runtime and return a `Shared`
487    /// handle.
488    ///
489    /// The returned `Shared` is awaitable *inside* the
490    /// `from_coroutine` world, or may be waited on synchronously via
491    /// `Shared.block_on()`. Consumes the task.
492    pub(crate) fn spawn(&mut self) -> PyResult<PyShared> {
493        let (tx, rx) = watch::channel(None);
494        let traceback = self.traceback()?;
495        let traceback1 = self.traceback()?;
496        let task = self.take_task()?;
497        let handle = get_tokio_runtime().spawn(async move {
498            send_result(tx, task.await, traceback1);
499        });
500        Ok(PyShared {
501            rx,
502            handle: Some(handle),
503            abort: false,
504            traceback,
505        })
506    }
507
508    /// Implement Python's `await` protocol for `PythonTask`.
509    ///
510    /// This is only supported inside the `pytokio` world driven by
511    /// `PythonTask.from_coroutine`; attempting to `await` a
512    /// `PythonTask` while an `asyncio` event loop is running is an
513    /// error.
514    fn __await__(slf: PyRef<'_, Self>) -> PyResult<PythonTaskAwaitIterator> {
515        let py = slf.py();
516        let l = pyo3_async_runtimes::get_running_loop(py);
517        if l.is_ok() {
518            return Err(PyRuntimeError::new_err(
519                "Attempting to __await__ a PythonTask when the asyncio event loop is active. PythonTask objects should only be awaited in coroutines passed to PythonTask.from_coroutine",
520            ));
521        }
522
523        Ok(PythonTaskAwaitIterator::new(slf.into_py_any(py)?))
524    }
525
526    /// Wrap a Python coroutine into a `PythonTask` that is driven by
527    /// Tokio.
528    ///
529    /// This converts `coro` into its await-iterator
530    /// (`coro.__await__()`), then repeatedly resumes it via
531    /// `send`/`throw`. Whenever the coroutine yields a
532    /// `PythonTask`/`Shared`, we extract its underlying Rust future,
533    /// `await` it on Tokio, and feed the result back into the
534    /// coroutine on the next iteration.
535    ///
536    /// Inside this coroutine, `await` is only supported for pytokio
537    /// values (`PythonTask` / `Shared`). Awaiting arbitrary Python
538    /// awaitables (e.g. `asyncio` futures) is an error.
539    ///
540    /// The current Monarch `context()` is captured at construction
541    /// time and restored while running the coroutine so `context()`
542    /// inside the task reflects the call site that created it (even
543    /// across Tokio thread hops).
544    #[staticmethod]
545    fn from_coroutine(py: Python<'_>, coro: Py<PyAny>) -> PyResult<PyPythonTask> {
546        // context() used inside a PythonTask should inherit the value of
547        // context() from the context in which the PythonTask was constructed.
548        // We need to do this manually because the value of the contextvar isn't
549        // maintained inside the tokio runtime.
550        let monarch_context = context(py).call0()?.unbind();
551        PyPythonTask::new(async move {
552            let (coroutine_iterator, none) = monarch_with_gil(|py| {
553                coro.into_bound(py)
554                    .call_method0("__await__")
555                    .map(|x| (x.unbind(), py.None()))
556            })
557            .await?;
558            let mut last: PyResult<Py<PyAny>> = Ok(none);
559            enum Action {
560                Return(Py<PyAny>),
561                Wait(Pin<Box<dyn Future<Output = Result<Py<PyAny>, PyErr>> + Send + 'static>>),
562            }
563            loop {
564                let action = monarch_with_gil(|py| -> PyResult<Action> {
565                    // We may be executing in a new thread at this point, so we need to set the value
566                    // of context().
567                    let _context = actor_mesh_module(py).getattr("_context")?;
568                    let old_context = _context.call_method1("get", (PyNone::get(py),))?;
569                    _context
570                        .call_method1("set", (monarch_context.clone_ref(py),))
571                        .expect("failed to set _context");
572
573                    let result = match last {
574                        Ok(value) => coroutine_iterator.bind(py).call_method1("send", (value,)),
575                        Err(pyerr) => coroutine_iterator
576                            .bind(py)
577                            .call_method1("throw", (pyerr.into_value(py),)),
578                    };
579
580                    // Reset context() so that when this tokio thread yields, it has its original state.
581                    _context
582                        .call_method1("set", (old_context,))
583                        .expect("failed to restore _context");
584                    match result {
585                        Ok(task) => Ok(Action::Wait(
586                            task.extract::<Py<PyPythonTask>>()
587                                .and_then(|t| t.borrow_mut(py).take_task())
588                                .unwrap_or_else(|pyerr| Box::pin(async move { Err(pyerr) })),
589                        )),
590                        Err(err) => {
591                            let err = err.into_pyobject(py)?.into_any();
592                            if err.is_instance_of::<PyStopIteration>() {
593                                Ok(Action::Return(
594                                    err.into_pyobject(py)?.getattr("value")?.unbind(),
595                                ))
596                            } else {
597                                Err(PyErr::from_value(err))
598                            }
599                        }
600                    }
601                })
602                .await?;
603                match action {
604                    Action::Return(x) => {
605                        return Ok(x);
606                    }
607                    Action::Wait(task) => {
608                        last = task.await;
609                    }
610                };
611            }
612        })
613    }
614
615    /// Wrap this task with a timeout and return a new `PythonTask`.
616    ///
617    /// Consumes the original task. If it does not complete within
618    /// `seconds`, the returned task fails with `TimeoutError`.
619    fn with_timeout(&mut self, seconds: f64) -> PyResult<PyPythonTask> {
620        let tb = self.traceback()?;
621        let task = self.take_task()?;
622        PyPythonTask::new_with_traceback(
623            async move {
624                tokio::time::timeout(std::time::Duration::from_secs_f64(seconds), task)
625                    .await
626                    .map_err(|_| PyTimeoutError::new_err(()))?
627            },
628            tb,
629        )
630    }
631
632    /// Run a Python callable on Tokio's blocking thread pool and
633    /// return a `Shared` handle.
634    ///
635    /// This is for CPU-bound or otherwise blocking Python work that
636    /// must not run on a Tokio async worker thread. The callable `f`
637    /// is executed via `tokio::spawn_blocking`, and its result (or
638    /// raised exception) is delivered through the returned `Shared`.
639    ///
640    /// The current Monarch `context()` is captured and restored while
641    /// running `f` so calls to `context()` from inside `f` see the
642    /// originating actor context.
643    #[staticmethod]
644    fn spawn_blocking(py: Python<'_>, f: Py<PyAny>) -> PyResult<PyShared> {
645        let (tx, rx) = watch::channel(None);
646        let traceback = current_traceback()?;
647        let traceback1 = traceback.as_ref().map_or_else(
648            || None,
649            |t| monarch_with_gil_blocking(|py| Some(t.clone_ref(py))),
650        );
651        let monarch_context = context(py).call0()?.unbind();
652        // The `_context` contextvar needs to be propagated through to the thread that
653        // runs the blocking tokio task. Upon completion, the original value of `_context`
654        // is restored.
655        let handle = get_tokio_runtime().spawn_blocking(move || {
656            let result = monarch_with_gil_blocking(|py| {
657                let _context = actor_mesh_module(py).getattr("_context")?;
658                let old_context = _context.call_method1("get", (PyNone::get(py),))?;
659                _context
660                    .call_method1("set", (monarch_context.clone_ref(py),))
661                    .expect("failed to set _context");
662                let result = f.call0(py);
663                _context
664                    .call_method1("set", (old_context,))
665                    .expect("failed to restore _context");
666                result
667            });
668            send_result(tx, result, traceback1);
669        });
670        Ok(PyShared {
671            rx,
672            handle: Some(handle),
673            abort: false,
674            traceback,
675        })
676    }
677
678    /// Wait for the first task to complete and return `(result,
679    /// index)`.
680    ///
681    /// This consumes all input tasks (each is `take_task()`'d). The
682    /// returned task resolves to a tuple of the winning task's result
683    /// and its index in the input list.
684    #[staticmethod]
685    fn select_one(mut tasks: Vec<PyRefMut<'_, PyPythonTask>>) -> PyResult<PyPythonTask> {
686        if tasks.is_empty() {
687            return Err(PyValueError::new_err("Cannot select from empty task list"));
688        }
689
690        let mut futures = Vec::new();
691        for task_ref in tasks.iter_mut() {
692            futures.push(task_ref.take_task()?);
693        }
694
695        PyPythonTask::new(async move {
696            let (result, index, _remaining) = futures::future::select_all(futures).await;
697            result.map(|r| (r, index))
698        })
699    }
700
701    /// Sleep for `seconds` on the Tokio runtime.
702    #[staticmethod]
703    fn sleep(seconds: f64) -> PyResult<PyPythonTask> {
704        PyPythonTask::new(async move {
705            tokio::time::sleep(tokio::time::Duration::from_secs_f64(seconds)).await;
706            Ok(())
707        })
708    }
709
710    /// Support `PythonTask[T]` type syntax on the Python side (no
711    /// runtime effect).
712    #[classmethod]
713    fn __class_getitem__(cls: &Bound<'_, PyType>, _arg: Py<PyAny>) -> Py<PyAny> {
714        cls.clone().unbind().into()
715    }
716}
717
718/// Awaitable handle to a spawned background Tokio task.
719///
720/// `Shared` is returned by `PythonTask.spawn()` /
721/// `spawn_abortable()`. It carries a `watch` receiver that is
722/// fulfilled exactly once with the task's `PyResult<Py<PyAny>>`.
723///
724/// Usage:
725///   - `await shared` inside the `PythonTask.from_coroutine(...)`
726///     world, or
727///   - `shared.block_on()` to wait synchronously.
728///
729/// If `abort` is true (from `spawn_abortable()`), dropping this
730/// object aborts the underlying Tokio task via its `JoinHandle`.
731#[pyclass(
732    name = "Shared",
733    module = "monarch._rust_bindings.monarch_hyperactor.pytokio"
734)]
735pub struct PyShared {
736    /// One-shot result channel. Starts as `None`; becomes
737    /// `Some(Ok(obj))` or `Some(Err(pyerr))` when the background task
738    /// completes.
739    rx: watch::Receiver<Option<PyResult<Py<PyAny>>>>,
740
741    /// Handle for the spawned Tokio task that is producing `rx`’s
742    /// result. `None` for `Shared.from_value(...)`.
743    handle: Option<JoinHandle<()>>,
744
745    /// If true, dropping `Shared` aborts the background task via
746    /// `handle.abort()`. This is set by `spawn_abortable()`.
747    abort: bool,
748
749    /// Optional creation-site traceback (captured when enabled) used
750    /// when logging un-awaited errors / for derived tasks.
751    traceback: Option<Py<PyAny>>,
752}
753
754/// If this `Shared` was created via `spawn_abortable()`, abort the
755/// underlying Tokio task on drop.
756///
757/// This prevents abandoned background work from running forever when
758/// no receivers remain. We guard against panics during interpreter
759/// shutdown / runtime teardown.
760impl Drop for PyShared {
761    fn drop(&mut self) {
762        if self.abort {
763            // When the PyShared is dropped, we don't want the background task to go
764            // forever, because nothing will wait on the rx.
765            if let Some(h) = self.handle.as_ref() {
766                // Guard against panics during interpreter shutdown when tokio runtime may be gone
767                let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
768                    h.abort();
769                }));
770            }
771        }
772    }
773}
774
775#[pymethods]
776impl PyShared {
777    /// Convert this `Shared` handle into a `PythonTask` that waits
778    /// for its result.
779    ///
780    /// Internally, this clones the `watch::Receiver` and returns a
781    /// new one-shot task that:
782    ///   1) waits for the sender to publish `Some(result)`, and then
783    ///   2) returns/clones the stored `Py<PyAny>` / `PyErr` under the
784    ///      GIL.
785    ///
786    /// Cloning the receiver allows multiple independent awaiters to
787    /// observe the same completion.
788    pub(crate) fn task(&self) -> PyResult<PyPythonTask> {
789        // watch channels start unchanged, and when a value is sent to them signal
790        // the receivers `changed` future.
791        // By cloning the rx before awaiting it,
792        // we can have multiple awaiters get triggered by the same change.
793        // self.rx will always be in the state where it hasn't see the change yet.
794        let mut rx = self.rx.clone();
795        PyPythonTask::new_with_traceback(
796            async move {
797                // Check if a value is already available (not None).
798                // The channel is initialized with None, and the sender sets it to Some(result).
799                // If it's still None, wait for a change. Otherwise, the value is ready.
800                if rx.borrow().is_none() {
801                    rx.changed().await.map_err(to_py_error)?;
802                }
803                // We need to hold the GIL when cloning Python objects (Py<PyAny> and PyErr).
804                monarch_with_gil(|py| {
805                    let borrowed = rx.borrow();
806                    match borrowed.as_ref().unwrap() {
807                        Ok(v) => Ok(v.bind(py).clone().unbind()),
808                        Err(err) => Err(err.clone_ref(py)),
809                    }
810                })
811                .await
812            },
813            self.traceback.as_ref().map_or_else(
814                || None,
815                |t| monarch_with_gil_blocking(|py| Some(t.clone_ref(py))),
816            ),
817        )
818    }
819
820    /// Implement Python's `await` protocol for `Shared`.
821    ///
822    /// This delegates to `self.task()` (which returns a `PythonTask`
823    /// that waits for the background result) and then returns that
824    /// task's await-iterator.
825    ///
826    /// Note: `await shared` is only supported inside the
827    /// `PythonTask.from_coroutine(...)` world (because it ultimately
828    /// awaits a `PythonTask`).
829    fn __await__(&mut self, py: Python<'_>) -> PyResult<PythonTaskAwaitIterator> {
830        let task = self.task()?;
831        Ok(PythonTaskAwaitIterator::new(task.into_py_any(py)?))
832    }
833
834    /// Wait synchronously for this `Shared` to resolve.
835    ///
836    /// This blocks the calling Python thread until the underlying
837    /// background task has published its result into the watch
838    /// channel, then returns that `Py<PyAny>` (or raises the stored
839    /// Python exception).
840    ///
841    /// If the value is already available, returns immediately without
842    /// blocking. This is important for cases where `block_on` is called
843    /// from within a tokio runtime (e.g., during unpickling on a worker
844    /// thread) - we can't call `runtime.block_on()` from within a runtime.
845    pub fn block_on(slf: PyRef<PyShared>, py: Python<'_>) -> PyResult<Py<PyAny>> {
846        // Check if value is already available - return immediately if so.
847        // This avoids calling into the tokio runtime when unnecessary,
848        // which is critical when called from within a tokio worker thread.
849        if let Some(value) = slf.poll()? {
850            return Ok(value);
851        }
852
853        let task = slf.task()?.take_task()?;
854        // Explicitly drop the reference so that if another thread attempts to borrow
855        // this object mutably during signal_safe_block_on, it won't throw an exception.
856        drop(slf);
857        signal_safe_block_on(py, task)?
858    }
859
860    /// Support `Shared[T]` type syntax on the Python side (no runtime
861    /// effect).
862    #[classmethod]
863    fn __class_getitem__(cls: &Bound<'_, PyType>, _arg: Py<PyAny>) -> Py<PyAny> {
864        cls.clone().unbind().into()
865    }
866
867    /// Non-blocking check for completion.
868    ///
869    /// Returns:
870    ///   - `Ok(None)` if the background task has not finished yet,
871    ///   - `Ok(Some(obj))` if it completed successfully,
872    ///   - `Err(pyerr)` if it completed with an exception.
873    ///
874    /// This does not wait; it only inspects the current watch value.
875    pub(crate) fn poll(&self) -> PyResult<Option<Py<PyAny>>> {
876        let b = self.rx.borrow();
877        let r = b.as_ref();
878        match r {
879            None => Ok(None),
880            Some(r) => Python::attach(|py| match r {
881                Ok(v) => Ok(Some(v.clone_ref(py))),
882                Err(err) => Err(err.clone_ref(py)),
883            }),
884        }
885    }
886
887    /// Construct a `Shared` that is already completed with `value`.
888    ///
889    /// This is a convenience for APIs that want to return a `Shared`
890    /// without spawning a background task. The returned handle has no
891    /// `JoinHandle` and will immediately yield `value` via `poll()`,
892    /// `await` (inside `from_coroutine`), or `block_on()`.
893    #[classmethod]
894    fn from_value(_cls: &Bound<'_, PyType>, value: Py<PyAny>) -> PyResult<Self> {
895        let (tx, rx) = watch::channel(None);
896        tx.send(Some(Ok(value))).map_err(to_py_error)?;
897        Ok(Self {
898            rx,
899            handle: None,
900            abort: false,
901            traceback: None,
902        })
903    }
904
905    /// Pickle protocol support for PyShared.
906    ///
907    /// This implements the pickle reduce protocol:
908    /// - If the shared is finished, pickle as (Shared.from_value, (value,))
909    /// - If pending pickles are allowed, defer pickling and return (pop_pending_pickle, ())
910    /// - Otherwise, block on the shared and pickle as (Shared.from_value, (value,))
911    fn __reduce__<'py>(
912        slf: &Bound<'py, Self>,
913        py: Python<'py>,
914    ) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyTuple>)> {
915        reduce_shared(py, slf)
916    }
917}
918
919/// Return true if the current thread is executing within a Tokio
920/// runtime context.
921///
922/// This checks whether `tokio::runtime::Handle::try_current()`
923/// succeeds.
924#[pyfunction]
925fn is_tokio_thread() -> bool {
926    tokio::runtime::Handle::try_current().is_ok()
927}
928
929/// Register the pytokio Python bindings into the given module.
930///
931/// This wires up the exported pyclasses (`PythonTask`, `Shared`)
932/// and module-level functions used by the Monarch Python layer.
933pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
934    hyperactor_mod.add_class::<PyPythonTask>()?;
935    hyperactor_mod.add_class::<PyShared>()?;
936    let f = wrap_pyfunction!(is_tokio_thread, hyperactor_mod)?;
937    f.setattr(
938        "__module__",
939        "monarch._rust_bindings.monarch_hyperactor.pytokio",
940    )?;
941    hyperactor_mod.add_function(f)?;
942
943    Ok(())
944}
945
946/// Ensure the embedded Python interpreter is initialized exactly
947/// once.
948///
949/// Safe to call from multiple threads, multiple times.
950#[cfg(test)]
951pub(crate) fn ensure_python() {
952    static INIT: std::sync::OnceLock<()> = std::sync::OnceLock::new();
953    INIT.get_or_init(|| {
954        pyo3::Python::initialize();
955    });
956}
957
958#[cfg(test)]
959// Helper: let us "await" a `PyPythonTask` in Rust.
960//
961// Semantics:
962//   - consume the `PyPythonTask`,
963//   - take the inner future,
964//   - `.await` it on tokio to get `Py<PyAny>`,
965//   - turn that into `Py<T>`.
966pub(crate) trait AwaitPyExt {
967    async fn await_py<T: PyClass>(self) -> Result<Py<T>, PyErr>;
968
969    // For tasks whose future just resolves to (), i.e. no object,
970    // just "did it work?"
971    async fn await_unit(self) -> Result<(), PyErr>;
972}
973
974#[cfg(test)]
975impl AwaitPyExt for PyPythonTask {
976    async fn await_py<T: PyClass>(mut self) -> Result<Py<T>, PyErr> {
977        // Take ownership of the inner future.
978        let fut = self
979            .take_task()
980            .expect("PyPythonTask already consumed in await_py");
981
982        // Await a Result<Py<PyAny>, PyErr>.
983        let py_any: Py<PyAny> = fut.await?;
984
985        // Convert Py<PyAny> -> Py<T>.
986        monarch_with_gil(|py| {
987            let bound_any = py_any.bind(py);
988
989            // Try extract a Py<T>.
990            let obj: Py<T> = bound_any
991                .extract::<Py<T>>()
992                .expect("spawn() did not return expected Python type");
993
994            Ok(obj)
995        })
996        .await
997    }
998
999    async fn await_unit(mut self) -> Result<(), PyErr> {
1000        let fut = self
1001            .take_task()
1002            .expect("PyPythonTask already consumed in await_unit");
1003
1004        // Await it. This still gives us a Py<PyAny> because
1005        // Python-side return values are always materialized as 'some
1006        // object'. For "no value" / None, that's just a PyAny(None).
1007        let py_any: Py<PyAny> = fut.await?;
1008
1009        // We don't need to extract anything. Just drop it.
1010        drop(py_any);
1011
1012        Ok(())
1013    }
1014}