Skip to main content

monarch_hyperactor/
pytokio.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9/// `pytokio` is Monarch's Python <-> Tokio async bridge.
10///
11/// It provides a small, *non-asyncio* async world where Python code
12/// can *compose* Rust/Tokio futures using `await`.
13///
14/// ## The core idea
15///
16/// In `pytokio`:
17///
18/// - `PythonTask` = a one-shot Rust/Tokio future that produces a
19///   Python value.
20/// - `from_coroutine` = wraps a Python coroutine as a Rust future
21///   that drives it.
22/// - `Shared` = an awaitable handle to a spawned background Tokio
23///   task.
24///
25/// More concretely:
26///
27/// - Rust bindings return a Python-visible `PythonTask`
28///   (`PyPythonTask`), which wraps a Rust `PythonTask` holding a
29///   boxed Tokio future returning `PyResult<Py<PyAny>>`.
30/// - `PythonTask.from_coroutine(coro)` wraps a *Python coroutine* as
31///   a `PythonTask` by creating a Rust/Tokio future that drives
32///   `coro.__await__()` (via `send`/`throw`) and awaits the
33///   `PythonTask`s it yields.
34/// - Python code may `await` a `PythonTask` / `Shared` **only** when
35///   running under `PythonTask.from_coroutine(...)`. Awaiting
36///   arbitrary Python awaitables (e.g. `asyncio` futures) is an
37///   error.
38/// - Calling `task.spawn()` / `spawn_abortable()` returns a `Shared`
39///   (`PyShared`), which yields the result of the background Tokio
40///   task running the original `PythonTask`.
41///
42/// This is intentionally *not* a general-purpose async bridge: it’s a
43/// way to use Python syntax to drive and compose Tokio futures.
44///
45/// ## Wrapping a Python coroutine
46///
47/// ```ignore
48/// async def work():
49///     x = await some_rust_binding()      # must yield PythonTask / Shared
50///     await PythonTask.sleep(0.1)        # also a PythonTask
51///     return x
52///
53/// task = PythonTask.from_coroutine(work())
54/// result = task.block_on()              # block the calling Python thread while a
55///                                       # Tokio runtime drives the task to completion
56/// ```
57///
58/// `from_coroutine` drives the coroutine by repeatedly resuming it
59/// and awaiting the `PythonTask`s it yields, using a Tokio runtime.
60///
61/// ## Spawning
62///
63/// `spawn()` runs a `PythonTask` on a background Tokio task and
64/// returns a `Shared` handle.
65///
66/// To `await` the handle, you must still be inside a
67/// `from_coroutine`-driven coroutine:
68///
69/// ```ignore
70/// async def work():
71///     task = some_rust_binding()
72///     shared = task.spawn()
73///     # ... do other work ...
74///     result = await shared             # valid here (inside from_coroutine world)
75///     return result
76///
77/// result = PythonTask.from_coroutine(work()).block_on()
78/// ```
79///
80/// In synchronous contexts, you can wait for a spawned task without
81/// `from_coroutine`:
82///
83/// ```ignore
84/// shared = task.spawn()
85/// result = shared.block_on()            # blocks the calling Python thread
86/// ```
87///
88/// If `spawn_abortable()` is used, dropping the returned `Shared`
89/// aborts the underlying Tokio task.
90///
91/// ## Context propagation
92///
93/// `from_coroutine` preserves Monarch’s `context()` across Tokio
94/// thread hops, so code calling `context()` inside a `PythonTask`
95/// sees the same actor context as the call site that constructed the
96/// task.
97use std::error::Error;
98use std::future::Future;
99use std::pin::Pin;
100
101use hyperactor_config::CONFIG;
102use hyperactor_config::ConfigAttr;
103use hyperactor_config::attrs::declare_attrs;
104use monarch_types::SerializablePyErr;
105use monarch_types::py_global;
106use pyo3::IntoPyObjectExt;
107#[cfg(test)]
108use pyo3::PyClass;
109use pyo3::exceptions::PyRuntimeError;
110use pyo3::exceptions::PyStopIteration;
111use pyo3::exceptions::PyTimeoutError;
112use pyo3::exceptions::PyValueError;
113use pyo3::prelude::*;
114use pyo3::types::PyNone;
115use pyo3::types::PyString;
116use pyo3::types::PyTuple;
117use pyo3::types::PyType;
118use tokio::sync::Mutex;
119use tokio::sync::watch;
120use tokio::task::JoinHandle;
121
122use crate::pickle::reduce_shared;
123use crate::runtime::get_tokio_runtime;
124use crate::runtime::monarch_with_gil;
125use crate::runtime::monarch_with_gil_blocking;
126use crate::runtime::signal_safe_block_on;
127
128declare_attrs! {
129    /// If true, capture a Python stack trace at `PythonTask` creation
130    /// time and log it when a spawned task errors but is never
131    /// awaited/polled.
132    @meta(CONFIG = ConfigAttr::new(
133        Some("MONARCH_HYPERACTOR_ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK".to_string()),
134        Some("enable_unawaited_python_task_traceback".to_string()),
135    ))
136    pub attr ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK: bool = false;
137}
138
139// Import Python helpers used for actor context propagation.
140// `context()` returns the current Monarch actor context.
141// `actor_mesh` is the module that owns the `_context` contextvar we
142// must manually set/restore when driving coroutines on Tokio threads.
143py_global!(context, "monarch._src.actor.actor_mesh", "context");
144py_global!(actor_mesh_module, "monarch._src.actor", "actor_mesh");
145
146/// Capture the current Python stack trace (creation call site) if
147/// `ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK` is enabled.
148///
149/// Returns `None` when disabled to avoid the overhead of
150/// `traceback.extract_stack()`.
151fn current_traceback() -> PyResult<Option<Py<PyAny>>> {
152    if hyperactor_config::global::get(ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK) {
153        monarch_with_gil_blocking(|py| {
154            Ok(Some(
155                py.import("traceback")?
156                    .call_method0("extract_stack")?
157                    .unbind(),
158            ))
159        })
160    } else {
161        Ok(None)
162    }
163}
164
165/// Format a captured traceback (from `traceback.extract_stack()`) as
166/// a single string suitable for logging.
167fn format_traceback(py: Python<'_>, traceback: &Py<PyAny>) -> PyResult<String> {
168    let tb = py
169        .import("traceback")?
170        .call_method1("format_list", (traceback,))?;
171    PyString::new(py, "")
172        .call_method1("join", (tb,))?
173        .extract::<String>()
174}
175
176/// Helper struct to make a Rust/Tokio future (returning a Python
177/// result) passable in an actor message.
178///
179/// The future resolves to `PyResult<Py<PyAny>>` so it can return a
180/// Python value or raise a Python exception, and it is `Send +
181/// 'static` so it can cross thread/actor boundaries.
182///
183/// Also so that we don't have to write this massive type signature
184/// everywhere.
185pub(crate) struct PythonTask {
186    /// Boxed, pinned Rust/Tokio future producing a Python result,
187    /// protected so it can be taken/consumed exactly once when the
188    /// task is driven.
189    // Type decoder ring:
190    //
191    // Mutex<Pin<Box<dyn Future<Output = PyResult<Py<PyAny>>> + Send + 'static>>>
192    //   │     │   │   │                                        │      │
193    //   │     │   │   │                                        │      └─ owns all data, no dangling refs
194    //   │     │   │   │                                        └─ can cross thread boundaries
195    //   │     │   │   └─ any future type (type-erased)
196    //   │     │   └─ heap-allocated (because unsized)
197    //   │     └─ immovable (safe to poll self-referential futures)
198    //   └─ exclusive access for consumption
199    future: Mutex<Pin<Box<dyn Future<Output = PyResult<Py<PyAny>>> + Send + 'static>>>,
200
201    /// Optional Python stack trace captured at task construction
202    /// time, used to annotate logs when a spawned task errors but
203    /// nobody awaits/polls it.
204    traceback: Option<Py<PyAny>>,
205}
206
207impl PythonTask {
208    /// Construct a `PythonTask` from a Rust/Tokio future and an
209    /// optional captured Python traceback.
210    ///
211    /// The future is boxed and pinned so it can be stored in the
212    /// struct and later driven safely.
213    fn new_with_traceback(
214        fut: impl Future<Output = PyResult<Py<PyAny>>> + Send + 'static,
215        traceback: Option<Py<PyAny>>,
216    ) -> Self {
217        Self {
218            future: Mutex::new(Box::pin(fut)),
219            traceback,
220        }
221    }
222
223    /// Construct a `PythonTask`, capturing a creation-site traceback
224    /// if enabled by `ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK`.
225    pub(crate) fn new(
226        fut: impl Future<Output = PyResult<Py<PyAny>>> + Send + 'static,
227    ) -> PyResult<Self> {
228        Ok(Self::new_with_traceback(fut, current_traceback()?))
229    }
230
231    /// Return the optional captured creation-site traceback (if
232    /// enabled).
233    fn traceback(&self) -> &Option<Py<PyAny>> {
234        &self.traceback
235    }
236
237    /// Consume the task and return the boxed, pinned future.
238    ///
239    /// This is a one-shot operation: it moves the future out of the
240    /// struct so it can be driven to completion.
241    pub(crate) fn take(
242        self,
243    ) -> Pin<Box<dyn Future<Output = PyResult<Py<PyAny>>> + Send + 'static>> {
244        self.future.into_inner()
245    }
246}
247
248impl std::fmt::Debug for PythonTask {
249    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250        f.debug_struct("PythonTask")
251            .field("future", &"<PythonFuture>")
252            .finish()
253    }
254}
255
256/// Python-visible wrapper for a one-shot `PythonTask`.
257///
258/// Exposed to Python as
259/// `monarch._rust_bindings.monarch_hyperactor.pytokio.PythonTask`.
260/// This object owns the underlying Rust task and is *consumed* when
261/// it is run (e.g. via `spawn()`, `spawn_abortable()`, or
262/// `block_on()`), hence `inner: Option<_>`.
263#[pyclass(
264    name = "PythonTask",
265    module = "monarch._rust_bindings.monarch_hyperactor.pytokio"
266)]
267pub struct PyPythonTask {
268    inner: Option<PythonTask>,
269}
270
271impl From<PythonTask> for PyPythonTask {
272    fn from(task: PythonTask) -> Self {
273        Self { inner: Some(task) }
274    }
275}
276
277/// Minimal await-iterator used to implement Python's `__await__`
278/// protocol for pytokio.
279///
280/// This iterator yields the task object exactly once. The Rust-side
281/// coroutine driver (`from_coroutine`) resumes the Python coroutine
282/// and expects it to yield a `PythonTask` (or `Shared`) object back
283/// to Rust.
284#[pyclass(
285    name = "PythonTaskAwaitIterator",
286    module = "monarch._rust_bindings.monarch_hyperactor.pytokio"
287)]
288struct PythonTaskAwaitIterator {
289    value: Option<Py<PyAny>>,
290}
291
292impl PythonTaskAwaitIterator {
293    /// Create an await-iterator that will yield `task` exactly once.
294    fn new(task: Py<PyAny>) -> PythonTaskAwaitIterator {
295        PythonTaskAwaitIterator { value: Some(task) }
296    }
297}
298
299#[pymethods]
300impl PythonTaskAwaitIterator {
301    /// First `send(...)` yields the stored task; subsequent sends
302    /// raise `StopIteration`.
303    ///
304    /// Python's await machinery calls `send(None)` to advance the
305    /// iterator.
306    fn send(&mut self, value: Py<PyAny>) -> PyResult<Py<PyAny>> {
307        self.value
308            .take()
309            .ok_or_else(|| PyStopIteration::new_err((value,)))
310    }
311
312    /// Convert the thrown Python exception value into a `PyErr` and
313    /// surface it to Rust.
314    fn throw(&mut self, value: Py<PyAny>) -> PyResult<Py<PyAny>> {
315        Err(monarch_with_gil_blocking(|py| {
316            PyErr::from_value(value.into_bound(py))
317        }))
318    }
319
320    /// Iterator protocol: `next(it)` is equivalent to
321    /// `it.send(None)`.
322    fn __next__(&mut self, py: Python<'_>) -> PyResult<Py<PyAny>> {
323        self.send(py.None())
324    }
325}
326
327impl PyPythonTask {
328    /// Construct a Python-visible `PythonTask` from a Rust future,
329    /// attaching an explicit creation-site traceback (if provided).
330    ///
331    /// The input future produces a Rust value `T`; on completion we
332    /// reacquire the GIL and convert `T` into a Python object
333    /// (`Py<PyAny>`).
334    fn new_with_traceback<F, T>(fut: F, traceback: Option<Py<PyAny>>) -> PyResult<Self>
335    where
336        F: Future<Output = PyResult<T>> + Send + 'static,
337        T: for<'py> IntoPyObject<'py> + Send,
338    {
339        Ok(PythonTask::new_with_traceback(
340            async {
341                let result = fut.await?;
342                monarch_with_gil(|py| result.into_py_any(py)).await
343            },
344            traceback,
345        )
346        .into())
347    }
348
349    /// Construct a `PythonTask`, capturing a creation-site traceback
350    /// if enabled.
351    ///
352    /// See `new_with_traceback` for conversion semantics (`T` ->
353    /// Python object under the GIL).
354    pub fn new<F, T>(fut: F) -> PyResult<Self>
355    where
356        F: Future<Output = PyResult<T>> + Send + 'static,
357        T: for<'py> IntoPyObject<'py> + Send,
358    {
359        Self::new_with_traceback(fut, current_traceback()?)
360    }
361}
362
363// Helper: convert a Rust error into a generic Python ValueError.
364fn to_py_error<T>(e: T) -> PyErr
365where
366    T: Error,
367{
368    PyErr::new::<PyValueError, _>(e.to_string())
369}
370
371impl PyPythonTask {
372    /// Consume this `PythonTask` and return the underlying Rust
373    /// future.
374    ///
375    /// This is a one-shot operation: after calling `take_task`, the
376    /// `PyPythonTask` is considered *consumed* and cannot be
377    /// spawned/awaited/blocked-on again.
378    pub fn take_task(
379        &mut self,
380    ) -> PyResult<Pin<Box<dyn Future<Output = Result<Py<PyAny>, PyErr>> + Send + 'static>>> {
381        self.inner
382            .take()
383            .map(|task| task.take())
384            .ok_or_else(|| PyValueError::new_err("PythonTask already consumed"))
385    }
386
387    /// Return the captured creation-site traceback (if enabled),
388    /// cloning it under the GIL.
389    ///
390    /// Fails if the task has already been consumed.
391    fn traceback(&self) -> PyResult<Option<Py<PyAny>>> {
392        if let Some(task) = &self.inner {
393            Ok(monarch_with_gil_blocking(|py| {
394                task.traceback().as_ref().map(|t| t.clone_ref(py))
395            }))
396        } else {
397            Err(PyValueError::new_err("PythonTask already consumed"))
398        }
399    }
400
401    /// Spawn this task onto the Tokio runtime and return a `Shared`
402    /// handle that *aborts on drop*.
403    ///
404    /// Use this when the underlying future is *abort-safe*
405    /// (cancellation-safe): dropping the returned `Shared` will call
406    /// `JoinHandle::abort()`, preventing the background task from
407    /// running forever.
408    ///
409    /// This is especially useful for long-lived or periodic tasks
410    /// (e.g. timers) where "nobody is awaiting the result anymore"
411    /// should stop the work.
412    ///
413    /// Like `spawn()`, this consumes the `PyPythonTask` (it can only
414    /// be spawned once).
415    pub(crate) fn spawn_abortable(&mut self) -> PyResult<PyShared> {
416        let (tx, rx) = watch::channel(None);
417        let traceback = self.traceback()?;
418        let traceback1 = self.traceback()?;
419        let task = self.take_task()?;
420        let handle = get_tokio_runtime().spawn(async move {
421            send_result(tx, task.await, traceback1);
422        });
423        Ok(PyShared {
424            rx,
425            handle: Some(handle),
426            abort: true,
427            traceback,
428        })
429    }
430}
431
432/// Publish a completed task result to the `watch` channel.
433///
434/// If the receiver has already been dropped, `watch::Sender::send`
435/// returns the unsent value as `SendError`. We treat that as "nobody
436/// will ever observe this result".
437///
438/// In the special case where the unobserved result is an error, we
439/// log it (and include the task creation traceback when available) to
440/// avoid silently losing failures from background tasks.
441fn send_result(
442    tx: tokio::sync::watch::Sender<Option<PyResult<Py<PyAny>>>>,
443    result: PyResult<Py<PyAny>>,
444    traceback: Option<Py<PyAny>>,
445) {
446    // a SendErr just means that there are no consumers of the value left.
447    if let Err(tokio::sync::watch::error::SendError(Some(Err(pyerr)))) = tx.send(Some(result)) {
448        monarch_with_gil_blocking(|py| {
449            let tb = if let Some(tb) = traceback {
450                format_traceback(py, &tb).unwrap()
451            } else {
452                "None (run with `MONARCH_HYPERACTOR_ENABLE_UNAWAITED_PYTHON_TASK_TRACEBACK=1` to see a traceback here)\n".into()
453            };
454            tracing::error!(
455                "PythonTask errored but is not being awaited; this will not crash your program, but indicates that \
456                something went wrong.\n{}\nTraceback where the task was created (most recent call last):\n{}",
457                SerializablePyErr::from(py, &pyerr),
458                tb
459            );
460        });
461    };
462}
463
464#[pymethods]
465impl PyPythonTask {
466    /// Run this task to completion synchronously on the embedded
467    /// Tokio runtime.
468    ///
469    /// This blocks the calling Python thread until the underlying
470    /// Rust future completes. Consumes the task (like `spawn`): the
471    /// `PyPythonTask` cannot be used again.
472    fn block_on(mut slf: PyRefMut<PyPythonTask>, py: Python<'_>) -> PyResult<Py<PyAny>> {
473        let task = slf.take_task()?;
474
475        // Mutable borrows of Python objects must be dropped before
476        // releasing the GIL. `signal_safe_block_on` releases the GIL;
477        // holding `slf` across that would make other Python access
478        // throw.
479        drop(slf);
480        signal_safe_block_on(py, task)?
481    }
482
483    /// Spawn this task onto the Tokio runtime and return a `Shared`
484    /// handle.
485    ///
486    /// The returned `Shared` is awaitable *inside* the
487    /// `from_coroutine` world, or may be waited on synchronously via
488    /// `Shared.block_on()`. Consumes the task.
489    pub(crate) fn spawn(&mut self) -> PyResult<PyShared> {
490        let (tx, rx) = watch::channel(None);
491        let traceback = self.traceback()?;
492        let traceback1 = self.traceback()?;
493        let task = self.take_task()?;
494        let handle = get_tokio_runtime().spawn(async move {
495            send_result(tx, task.await, traceback1);
496        });
497        Ok(PyShared {
498            rx,
499            handle: Some(handle),
500            abort: false,
501            traceback,
502        })
503    }
504
505    /// Implement Python's `await` protocol for `PythonTask`.
506    ///
507    /// This is only supported inside the `pytokio` world driven by
508    /// `PythonTask.from_coroutine`; attempting to `await` a
509    /// `PythonTask` while an `asyncio` event loop is running is an
510    /// error.
511    fn __await__(slf: PyRef<'_, Self>) -> PyResult<PythonTaskAwaitIterator> {
512        let py = slf.py();
513        let l = pyo3_async_runtimes::get_running_loop(py);
514        if l.is_ok() {
515            return Err(PyRuntimeError::new_err(
516                "Attempting to __await__ a PythonTask when the asyncio event loop is active. PythonTask objects should only be awaited in coroutines passed to PythonTask.from_coroutine",
517            ));
518        }
519
520        Ok(PythonTaskAwaitIterator::new(slf.into_py_any(py)?))
521    }
522
523    /// Wrap a Python coroutine into a `PythonTask` that is driven by
524    /// Tokio.
525    ///
526    /// This converts `coro` into its await-iterator
527    /// (`coro.__await__()`), then repeatedly resumes it via
528    /// `send`/`throw`. Whenever the coroutine yields a
529    /// `PythonTask`/`Shared`, we extract its underlying Rust future,
530    /// `await` it on Tokio, and feed the result back into the
531    /// coroutine on the next iteration.
532    ///
533    /// Inside this coroutine, `await` is only supported for pytokio
534    /// values (`PythonTask` / `Shared`). Awaiting arbitrary Python
535    /// awaitables (e.g. `asyncio` futures) is an error.
536    ///
537    /// The current Monarch `context()` is captured at construction
538    /// time and restored while running the coroutine so `context()`
539    /// inside the task reflects the call site that created it (even
540    /// across Tokio thread hops).
541    #[staticmethod]
542    fn from_coroutine(py: Python<'_>, coro: Py<PyAny>) -> PyResult<PyPythonTask> {
543        // context() used inside a PythonTask should inherit the value of
544        // context() from the context in which the PythonTask was constructed.
545        // We need to do this manually because the value of the contextvar isn't
546        // maintained inside the tokio runtime.
547        let monarch_context = context(py).call0()?.unbind();
548        PyPythonTask::new(async move {
549            let (coroutine_iterator, none) = monarch_with_gil(|py| {
550                coro.into_bound(py)
551                    .call_method0("__await__")
552                    .map(|x| (x.unbind(), py.None()))
553            })
554            .await?;
555            let mut last: PyResult<Py<PyAny>> = Ok(none);
556            enum Action {
557                Return(Py<PyAny>),
558                Wait(Pin<Box<dyn Future<Output = Result<Py<PyAny>, PyErr>> + Send + 'static>>),
559            }
560            loop {
561                let action = monarch_with_gil(|py| -> PyResult<Action> {
562                    // We may be executing in a new thread at this point, so we need to set the value
563                    // of context().
564                    let _context = actor_mesh_module(py).getattr("_context")?;
565                    let old_context = _context.call_method1("get", (PyNone::get(py),))?;
566                    _context
567                        .call_method1("set", (monarch_context.clone_ref(py),))
568                        .expect("failed to set _context");
569
570                    let result = match last {
571                        Ok(value) => coroutine_iterator.bind(py).call_method1("send", (value,)),
572                        Err(pyerr) => coroutine_iterator
573                            .bind(py)
574                            .call_method1("throw", (pyerr.into_value(py),)),
575                    };
576
577                    // Reset context() so that when this tokio thread yields, it has its original state.
578                    _context
579                        .call_method1("set", (old_context,))
580                        .expect("failed to restore _context");
581                    match result {
582                        Ok(task) => Ok(Action::Wait(
583                            task.extract::<Py<PyPythonTask>>()
584                                .and_then(|t| t.borrow_mut(py).take_task())
585                                .unwrap_or_else(|pyerr| Box::pin(async move { Err(pyerr) })),
586                        )),
587                        Err(err) => {
588                            let err = err.into_pyobject(py)?.into_any();
589                            if err.is_instance_of::<PyStopIteration>() {
590                                Ok(Action::Return(
591                                    err.into_pyobject(py)?.getattr("value")?.unbind(),
592                                ))
593                            } else {
594                                Err(PyErr::from_value(err))
595                            }
596                        }
597                    }
598                })
599                .await?;
600                match action {
601                    Action::Return(x) => {
602                        return Ok(x);
603                    }
604                    Action::Wait(task) => {
605                        last = task.await;
606                    }
607                };
608            }
609        })
610    }
611
612    /// Wrap this task with a timeout and return a new `PythonTask`.
613    ///
614    /// Consumes the original task. If it does not complete within
615    /// `seconds`, the returned task fails with `TimeoutError`.
616    fn with_timeout(&mut self, seconds: f64) -> PyResult<PyPythonTask> {
617        let tb = self.traceback()?;
618        let task = self.take_task()?;
619        PyPythonTask::new_with_traceback(
620            async move {
621                tokio::time::timeout(std::time::Duration::from_secs_f64(seconds), task)
622                    .await
623                    .map_err(|_| PyTimeoutError::new_err(()))?
624            },
625            tb,
626        )
627    }
628
629    /// Run a Python callable on Tokio's blocking thread pool and
630    /// return a `Shared` handle.
631    ///
632    /// This is for CPU-bound or otherwise blocking Python work that
633    /// must not run on a Tokio async worker thread. The callable `f`
634    /// is executed via `tokio::spawn_blocking`, and its result (or
635    /// raised exception) is delivered through the returned `Shared`.
636    ///
637    /// The current Monarch `context()` is captured and restored while
638    /// running `f` so calls to `context()` from inside `f` see the
639    /// originating actor context.
640    #[staticmethod]
641    fn spawn_blocking(py: Python<'_>, f: Py<PyAny>) -> PyResult<PyShared> {
642        let (tx, rx) = watch::channel(None);
643        let traceback = current_traceback()?;
644        let traceback1 = traceback.as_ref().map_or_else(
645            || None,
646            |t| monarch_with_gil_blocking(|py| Some(t.clone_ref(py))),
647        );
648        let monarch_context = context(py).call0()?.unbind();
649        // The `_context` contextvar needs to be propagated through to the thread that
650        // runs the blocking tokio task. Upon completion, the original value of `_context`
651        // is restored.
652        let handle = get_tokio_runtime().spawn_blocking(move || {
653            let result = monarch_with_gil_blocking(|py| {
654                let _context = actor_mesh_module(py).getattr("_context")?;
655                let old_context = _context.call_method1("get", (PyNone::get(py),))?;
656                _context
657                    .call_method1("set", (monarch_context.clone_ref(py),))
658                    .expect("failed to set _context");
659                let result = f.call0(py);
660                _context
661                    .call_method1("set", (old_context,))
662                    .expect("failed to restore _context");
663                result
664            });
665            send_result(tx, result, traceback1);
666        });
667        Ok(PyShared {
668            rx,
669            handle: Some(handle),
670            abort: false,
671            traceback,
672        })
673    }
674
675    /// Wait for the first task to complete and return `(result,
676    /// index)`.
677    ///
678    /// This consumes all input tasks (each is `take_task()`'d). The
679    /// returned task resolves to a tuple of the winning task's result
680    /// and its index in the input list.
681    #[staticmethod]
682    fn select_one(mut tasks: Vec<PyRefMut<'_, PyPythonTask>>) -> PyResult<PyPythonTask> {
683        if tasks.is_empty() {
684            return Err(PyValueError::new_err("Cannot select from empty task list"));
685        }
686
687        let mut futures = Vec::new();
688        for task_ref in tasks.iter_mut() {
689            futures.push(task_ref.take_task()?);
690        }
691
692        PyPythonTask::new(async move {
693            let (result, index, _remaining) = futures::future::select_all(futures).await;
694            result.map(|r| (r, index))
695        })
696    }
697
698    /// Sleep for `seconds` on the Tokio runtime.
699    #[staticmethod]
700    fn sleep(seconds: f64) -> PyResult<PyPythonTask> {
701        PyPythonTask::new(async move {
702            tokio::time::sleep(tokio::time::Duration::from_secs_f64(seconds)).await;
703            Ok(())
704        })
705    }
706
707    /// Support `PythonTask[T]` type syntax on the Python side (no
708    /// runtime effect).
709    #[classmethod]
710    fn __class_getitem__(cls: &Bound<'_, PyType>, _arg: Py<PyAny>) -> Py<PyAny> {
711        cls.clone().unbind().into()
712    }
713}
714
715/// Awaitable handle to a spawned background Tokio task.
716///
717/// `Shared` is returned by `PythonTask.spawn()` /
718/// `spawn_abortable()`. It carries a `watch` receiver that is
719/// fulfilled exactly once with the task's `PyResult<Py<PyAny>>`.
720///
721/// Usage:
722///   - `await shared` inside the `PythonTask.from_coroutine(...)`
723///     world, or
724///   - `shared.block_on()` to wait synchronously.
725///
726/// If `abort` is true (from `spawn_abortable()`), dropping this
727/// object aborts the underlying Tokio task via its `JoinHandle`.
728#[pyclass(
729    name = "Shared",
730    module = "monarch._rust_bindings.monarch_hyperactor.pytokio"
731)]
732pub struct PyShared {
733    /// One-shot result channel. Starts as `None`; becomes
734    /// `Some(Ok(obj))` or `Some(Err(pyerr))` when the background task
735    /// completes.
736    rx: watch::Receiver<Option<PyResult<Py<PyAny>>>>,
737
738    /// Handle for the spawned Tokio task that is producing `rx`’s
739    /// result. `None` for `Shared.from_value(...)`.
740    handle: Option<JoinHandle<()>>,
741
742    /// If true, dropping `Shared` aborts the background task via
743    /// `handle.abort()`. This is set by `spawn_abortable()`.
744    abort: bool,
745
746    /// Optional creation-site traceback (captured when enabled) used
747    /// when logging un-awaited errors / for derived tasks.
748    traceback: Option<Py<PyAny>>,
749}
750
751/// If this `Shared` was created via `spawn_abortable()`, abort the
752/// underlying Tokio task on drop.
753///
754/// This prevents abandoned background work from running forever when
755/// no receivers remain. We guard against panics during interpreter
756/// shutdown / runtime teardown.
757impl Drop for PyShared {
758    fn drop(&mut self) {
759        if self.abort {
760            // When the PyShared is dropped, we don't want the background task to go
761            // forever, because nothing will wait on the rx.
762            if let Some(h) = self.handle.as_ref() {
763                // Guard against panics during interpreter shutdown when tokio runtime may be gone
764                let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
765                    h.abort();
766                }));
767            }
768        }
769    }
770}
771
772#[pymethods]
773impl PyShared {
774    /// Convert this `Shared` handle into a `PythonTask` that waits
775    /// for its result.
776    ///
777    /// Internally, this clones the `watch::Receiver` and returns a
778    /// new one-shot task that:
779    ///   1) waits for the sender to publish `Some(result)`, and then
780    ///   2) returns/clones the stored `Py<PyAny>` / `PyErr` under the
781    ///      GIL.
782    ///
783    /// Cloning the receiver allows multiple independent awaiters to
784    /// observe the same completion.
785    pub(crate) fn task(&self) -> PyResult<PyPythonTask> {
786        // watch channels start unchanged, and when a value is sent to them signal
787        // the receivers `changed` future.
788        // By cloning the rx before awaiting it,
789        // we can have multiple awaiters get triggered by the same change.
790        // self.rx will always be in the state where it hasn't see the change yet.
791        let mut rx = self.rx.clone();
792        PyPythonTask::new_with_traceback(
793            async move {
794                // Check if a value is already available (not None).
795                // The channel is initialized with None, and the sender sets it to Some(result).
796                // If it's still None, wait for a change. Otherwise, the value is ready.
797                if rx.borrow().is_none() {
798                    rx.changed().await.map_err(to_py_error)?;
799                }
800                // We need to hold the GIL when cloning Python objects (Py<PyAny> and PyErr).
801                monarch_with_gil(|py| {
802                    let borrowed = rx.borrow();
803                    match borrowed.as_ref().unwrap() {
804                        Ok(v) => Ok(v.bind(py).clone().unbind()),
805                        Err(err) => Err(err.clone_ref(py)),
806                    }
807                })
808                .await
809            },
810            self.traceback.as_ref().map_or_else(
811                || None,
812                |t| monarch_with_gil_blocking(|py| Some(t.clone_ref(py))),
813            ),
814        )
815    }
816
817    /// Implement Python's `await` protocol for `Shared`.
818    ///
819    /// This delegates to `self.task()` (which returns a `PythonTask`
820    /// that waits for the background result) and then returns that
821    /// task's await-iterator.
822    ///
823    /// Note: `await shared` is only supported inside the
824    /// `PythonTask.from_coroutine(...)` world (because it ultimately
825    /// awaits a `PythonTask`).
826    fn __await__(&self, py: Python<'_>) -> PyResult<PythonTaskAwaitIterator> {
827        let task = self.task()?;
828        Ok(PythonTaskAwaitIterator::new(task.into_py_any(py)?))
829    }
830
831    /// Wait synchronously for this `Shared` to resolve.
832    ///
833    /// This blocks the calling Python thread until the underlying
834    /// background task has published its result into the watch
835    /// channel, then returns that `Py<PyAny>` (or raises the stored
836    /// Python exception).
837    ///
838    /// If the value is already available, returns immediately without
839    /// blocking. This is important for cases where `block_on` is called
840    /// from within a tokio runtime (e.g., during unpickling on a worker
841    /// thread) - we can't call `runtime.block_on()` from within a runtime.
842    pub fn block_on(slf: PyRef<PyShared>, py: Python<'_>) -> PyResult<Py<PyAny>> {
843        // Check if value is already available - return immediately if so.
844        // This avoids calling into the tokio runtime when unnecessary,
845        // which is critical when called from within a tokio worker thread.
846        if let Some(value) = slf.poll()? {
847            return Ok(value);
848        }
849
850        let task = slf.task()?.take_task()?;
851        // Explicitly drop the reference so that if another thread attempts to borrow
852        // this object mutably during signal_safe_block_on, it won't throw an exception.
853        drop(slf);
854        signal_safe_block_on(py, task)?
855    }
856
857    /// Support `Shared[T]` type syntax on the Python side (no runtime
858    /// effect).
859    #[classmethod]
860    fn __class_getitem__(cls: &Bound<'_, PyType>, _arg: Py<PyAny>) -> Py<PyAny> {
861        cls.clone().unbind().into()
862    }
863
864    /// Non-blocking check for completion.
865    ///
866    /// Returns:
867    ///   - `Ok(None)` if the background task has not finished yet,
868    ///   - `Ok(Some(obj))` if it completed successfully,
869    ///   - `Err(pyerr)` if it completed with an exception.
870    ///
871    /// This does not wait; it only inspects the current watch value.
872    pub(crate) fn poll(&self) -> PyResult<Option<Py<PyAny>>> {
873        let b = self.rx.borrow();
874        let r = b.as_ref();
875        match r {
876            None => Ok(None),
877            Some(r) => Python::attach(|py| match r {
878                Ok(v) => Ok(Some(v.clone_ref(py))),
879                Err(err) => Err(err.clone_ref(py)),
880            }),
881        }
882    }
883
884    /// Construct a `Shared` that is already completed with `value`.
885    ///
886    /// This is a convenience for APIs that want to return a `Shared`
887    /// without spawning a background task. The returned handle has no
888    /// `JoinHandle` and will immediately yield `value` via `poll()`,
889    /// `await` (inside `from_coroutine`), or `block_on()`.
890    #[classmethod]
891    fn from_value(_cls: &Bound<'_, PyType>, value: Py<PyAny>) -> PyResult<Self> {
892        let (tx, rx) = watch::channel(None);
893        tx.send(Some(Ok(value))).map_err(to_py_error)?;
894        Ok(Self {
895            rx,
896            handle: None,
897            abort: false,
898            traceback: None,
899        })
900    }
901
902    /// Pickle protocol support for PyShared.
903    ///
904    /// This implements the pickle reduce protocol:
905    /// - If the shared is finished, pickle as (Shared.from_value, (value,))
906    /// - If pending pickles are allowed, defer pickling and return (pop_pending_pickle, ())
907    /// - Otherwise, block on the shared and pickle as (Shared.from_value, (value,))
908    fn __reduce__<'py>(
909        slf: &Bound<'py, Self>,
910        py: Python<'py>,
911    ) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyTuple>)> {
912        reduce_shared(py, slf)
913    }
914}
915
916/// Return true if the current thread is executing within a Tokio
917/// runtime context.
918///
919/// This checks whether `tokio::runtime::Handle::try_current()`
920/// succeeds.
921#[pyfunction]
922fn is_tokio_thread() -> bool {
923    tokio::runtime::Handle::try_current().is_ok()
924}
925
926/// Register the pytokio Python bindings into the given module.
927///
928/// This wires up the exported pyclasses (`PythonTask`, `Shared`)
929/// and module-level functions used by the Monarch Python layer.
930pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
931    hyperactor_mod.add_class::<PyPythonTask>()?;
932    hyperactor_mod.add_class::<PyShared>()?;
933    let f = wrap_pyfunction!(is_tokio_thread, hyperactor_mod)?;
934    f.setattr(
935        "__module__",
936        "monarch._rust_bindings.monarch_hyperactor.pytokio",
937    )?;
938    hyperactor_mod.add_function(f)?;
939
940    Ok(())
941}
942
943/// Ensure the embedded Python interpreter is initialized exactly
944/// once.
945///
946/// Safe to call from multiple threads, multiple times.
947#[cfg(test)]
948pub(crate) fn ensure_python() {
949    static INIT: std::sync::OnceLock<()> = std::sync::OnceLock::new();
950    INIT.get_or_init(|| {
951        pyo3::Python::initialize();
952    });
953}
954
955#[cfg(test)]
956// Helper: let us "await" a `PyPythonTask` in Rust.
957//
958// Semantics:
959//   - consume the `PyPythonTask`,
960//   - take the inner future,
961//   - `.await` it on tokio to get `Py<PyAny>`,
962//   - turn that into `Py<T>`.
963pub(crate) trait AwaitPyExt {
964    async fn await_py<T: PyClass>(self) -> Result<Py<T>, PyErr>;
965
966    // For tasks whose future just resolves to (), i.e. no object,
967    // just "did it work?"
968    async fn await_unit(self) -> Result<(), PyErr>;
969}
970
971#[cfg(test)]
972impl AwaitPyExt for PyPythonTask {
973    async fn await_py<T: PyClass>(mut self) -> Result<Py<T>, PyErr> {
974        // Take ownership of the inner future.
975        let fut = self
976            .take_task()
977            .expect("PyPythonTask already consumed in await_py");
978
979        // Await a Result<Py<PyAny>, PyErr>.
980        let py_any: Py<PyAny> = fut.await?;
981
982        // Convert Py<PyAny> -> Py<T>.
983        monarch_with_gil(|py| {
984            let bound_any = py_any.bind(py);
985
986            // Try extract a Py<T>.
987            let obj: Py<T> = bound_any
988                .extract::<Py<T>>()
989                .expect("spawn() did not return expected Python type");
990
991            Ok(obj)
992        })
993        .await
994    }
995
996    async fn await_unit(mut self) -> Result<(), PyErr> {
997        let fut = self
998            .take_task()
999            .expect("PyPythonTask already consumed in await_unit");
1000
1001        // Await it. This still gives us a Py<PyAny> because
1002        // Python-side return values are always materialized as 'some
1003        // object'. For "no value" / None, that's just a PyAny(None).
1004        let py_any: Py<PyAny> = fut.await?;
1005
1006        // We don't need to extract anything. Just drop it.
1007        drop(py_any);
1008
1009        Ok(())
1010    }
1011}