Skip to main content

monarch_hyperactor/
runtime.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Mutex;
12use std::sync::OnceLock;
13use std::sync::atomic::AtomicUsize;
14use std::sync::atomic::Ordering;
15use std::time::Duration;
16
17use anyhow::Result;
18use hyperactor::Proc;
19use hyperactor::channel::ChannelAddr;
20use hyperactor::channel::ChannelTransport;
21use hyperactor::mailbox::BoxedMailboxSender;
22use hyperactor::mailbox::PanickingMailboxSender;
23use once_cell::sync::Lazy;
24use once_cell::unsync::OnceCell as UnsyncOnceCell;
25use pyo3::PyResult;
26use pyo3::Python;
27use pyo3::exceptions::PyRuntimeError;
28use pyo3::prelude::*;
29use pyo3::types::PyAnyMethods;
30use pyo3_async_runtimes::TaskLocals;
31use tokio::runtime::Handle;
32use tokio::task;
33
34/// Global tokio runtime container.
35///
36/// `handle` is cheap to clone and is what callers receive from
37/// `get_tokio_runtime()`. Holding a `Handle` does not lock anything, so
38/// concurrent block_on calls from different threads do not contend.
39///
40/// `runtime` exists only so the atexit handler can take ownership and
41/// call `shutdown_timeout`. Under normal operation nothing locks it; the
42/// mutex is uncontended at shutdown.
43struct GlobalRuntime {
44    handle: Handle,
45    runtime: Mutex<Option<tokio::runtime::Runtime>>,
46}
47
48static INSTANCE: OnceLock<GlobalRuntime> = OnceLock::new();
49
50fn global_runtime() -> &'static GlobalRuntime {
51    INSTANCE.get_or_init(|| {
52        let runtime = tokio::runtime::Builder::new_multi_thread()
53            .thread_name_fn(|| {
54                static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
55                let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
56                format!("monarch-pytokio-worker-{}", id)
57            })
58            .enable_all()
59            .build()
60            .unwrap();
61        let handle = runtime.handle().clone();
62        GlobalRuntime {
63            handle,
64            runtime: Mutex::new(Some(runtime)),
65        }
66    })
67}
68
69pub fn get_tokio_runtime() -> Handle {
70    global_runtime().handle.clone()
71}
72
73/// atexit handler that tears down the global Tokio runtime.
74///
75/// Callers obtain a cloned `Handle` from `get_tokio_runtime()` rather
76/// than a guard, so the `runtime` mutex is uncontended at shutdown. We
77/// can take ownership of the `Runtime` and call `shutdown_timeout`
78/// directly. If a worker thread is still inside `Handle::block_on` on a
79/// future that never resolves (e.g. a non-main thread that cannot
80/// observe SIGINT), `shutdown_timeout` aborts spawned tasks and returns
81/// after at most one second; the stuck worker is then a daemon thread
82/// that CPython kills on interpreter exit.
83#[pyfunction]
84pub fn shutdown_tokio_runtime(py: Python<'_>) {
85    // Called from Python's atexit, which holds the GIL. Release it so tokio
86    // worker threads can acquire it to complete their Python work.
87    py.detach(|| {
88        let Some(global) = INSTANCE.get() else {
89            return;
90        };
91        let Some(rt) = global.runtime.lock().unwrap().take() else {
92            return;
93        };
94        rt.shutdown_timeout(Duration::from_secs(1));
95    });
96}
97
98/// A global runtime proc used by this crate.
99pub(crate) fn get_proc_runtime() -> &'static Proc {
100    static RUNTIME_PROC: OnceLock<Proc> = OnceLock::new();
101    RUNTIME_PROC.get_or_init(|| {
102        let addr = ChannelAddr::any(ChannelTransport::Local);
103        let proc_id = hyperactor::ProcAddr::instance(addr, "monarch_hyperactor_runtime");
104        Proc::configured(proc_id, BoxedMailboxSender::new(PanickingMailboxSender))
105    })
106}
107
108/// Stores the native thread ID of the main Python thread.
109/// This is lazily initialized on first call to `is_main_thread`.
110static MAIN_THREAD_NATIVE_ID: OnceLock<i64> = OnceLock::new();
111
112/// Returns the native thread ID of the main Python thread.
113/// On first call, looks it up via `threading.main_thread().native_id`.
114fn get_main_thread_native_id() -> i64 {
115    *MAIN_THREAD_NATIVE_ID.get_or_init(|| {
116        Python::attach(|py| {
117            let threading = py.import("threading").expect("failed to import threading");
118            let main_thread = threading
119                .call_method0("main_thread")
120                .expect("failed to get main_thread");
121            main_thread
122                .getattr("native_id")
123                .expect("failed to get native_id")
124                .extract::<i64>()
125                .expect("native_id is not an i64")
126        })
127    })
128}
129
130/// Returns the current thread's native ID in a cross-platform way.
131#[cfg(target_os = "linux")]
132fn get_current_thread_id() -> i64 {
133    nix::unistd::gettid().as_raw() as i64
134}
135
136/// Returns the current thread's native ID in a cross-platform way.
137#[cfg(target_os = "macos")]
138fn get_current_thread_id() -> i64 {
139    let mut tid: u64 = 0;
140    // pthread_threadid_np with thread=0 (null pthread_t) gets the current thread's ID.
141    unsafe {
142        let ret = libc::pthread_threadid_np(0, &mut tid);
143        debug_assert_eq!(
144            ret, 0,
145            "pthread_threadid_np failed with error code: {}",
146            ret
147        );
148    }
149    // macOS thread IDs are u64 so we need to convert to i64.
150    debug_assert!(tid <= i64::MAX as u64, "thread ID {} exceeds i64::MAX", tid);
151    tid as i64
152}
153
154/// Returns the current thread's native ID in a cross-platform way.
155#[cfg(not(any(target_os = "linux", target_os = "macos")))]
156compile_error!("get_current_thread_id is only implemented for Linux and macOS");
157
158/// Returns true if the current thread is the main Python thread.
159/// Compares the current thread's native ID against the main Python thread's native ID.
160pub fn is_main_thread() -> bool {
161    let current_tid = get_current_thread_id();
162    current_tid == get_main_thread_native_id()
163}
164
165pub fn initialize(py: Python) -> Result<()> {
166    // Eagerly initialize the main thread ID while we're on the main thread
167    // with the GIL held. If this were lazily initialized on a background
168    // tokio thread during shutdown, the `py.import("threading")` call inside
169    // get_main_thread_native_id() would trigger module_from_spec on a
170    // partially-finalized interpreter, causing a segfault.
171    let _ = get_main_thread_native_id();
172
173    let atexit = py.import("atexit")?;
174    let shutdown_fn = wrap_pyfunction!(shutdown_tokio_runtime, py)?;
175    atexit.call_method1("register", (shutdown_fn,))?;
176    Ok(())
177}
178
179/// Block the current thread on a future, but make sure to check for signals
180/// originating from the Python signal handler.
181///
182/// Python's signal handler just sets a flag that it expects the Python
183/// interpreter to handle later via a call to `PyErr_CheckSignals`. When we
184/// enter into potentially long-running native code, we need to make sure to be
185/// checking for signals frequently, otherwise we will ignore them. This will
186/// manifest as `ctrl-C` not doing anything.
187///
188/// One additional wrinkle is that `PyErr_CheckSignals` only works on the main
189/// Python thread; if it's called on any other thread it silently does nothing.
190/// So, we check if we're on the main thread by comparing native thread IDs.
191pub fn signal_safe_block_on<F>(py: Python, future: F) -> PyResult<F::Output>
192where
193    F: Future + Send + 'static,
194    F::Output: Send + 'static,
195{
196    let runtime = get_tokio_runtime();
197    // Release the GIL, otherwise the work in `future` that tries to acquire the
198    // GIL on another thread may deadlock.
199    py.detach(|| {
200        if is_main_thread() {
201            // Spawn the future onto the tokio runtime
202            let handle = runtime.spawn(future);
203            // Block the current thread on waiting for *either* the future to
204            // complete or a signal.
205            runtime.block_on(async {
206                tokio::select! {
207                    result = handle => result.map_err(|e| PyRuntimeError::new_err(format!("JoinErr: {:?}", e))),
208                    signal = async {
209                        let sleep_for = std::time::Duration::from_millis(100);
210                        loop {
211                            // Acquiring the GIL in a loop is sad, hopefully once
212                            // every 100ms is fine.
213                            Python::attach(|py| py.check_signals())?;
214                            tokio::time::sleep(sleep_for).await;
215                        }
216                    } => signal
217                }
218            })
219        } else {
220            // If we're not on the main thread, we can just block it. We've
221            // released the GIL, so the Python main thread will continue on, and
222            // `PyErr_CheckSignals` doesn't do anything anyway.
223            Ok(runtime.block_on(future))
224        }
225    })
226}
227
228/// A test function that sleeps indefinitely in a loop.
229/// This is used for testing signal handling in signal_safe_block_on.
230/// The function will sleep forever until interrupted by a signal.
231#[pyfunction]
232pub fn sleep_indefinitely_for_unit_tests(py: Python) -> PyResult<()> {
233    // Create a future that sleeps indefinitely
234    let future = async {
235        loop {
236            tracing::info!("idef sleeping for 100ms");
237            tokio::time::sleep(Duration::from_millis(100)).await;
238        }
239    };
240
241    // Use signal_safe_block_on to run the future, which should make it
242    // interruptible by signals like SIGINT
243    signal_safe_block_on(py, future)
244}
245
246/// Initialize the runtime module and expose Python functions
247pub fn register_python_bindings(runtime_mod: &Bound<'_, PyModule>) -> PyResult<()> {
248    let sleep_indefinitely_fn =
249        wrap_pyfunction!(sleep_indefinitely_for_unit_tests, runtime_mod.py())?;
250    sleep_indefinitely_fn.setattr(
251        "__module__",
252        "monarch._rust_bindings.monarch_hyperactor.runtime",
253    )?;
254    runtime_mod.add_function(sleep_indefinitely_fn)?;
255    Ok(())
256}
257
258struct SimpleRuntime;
259
260impl pyo3_async_runtimes::generic::Runtime for SimpleRuntime {
261    type JoinError = task::JoinError;
262    type JoinHandle = task::JoinHandle<()>;
263
264    fn spawn<F>(fut: F) -> Self::JoinHandle
265    where
266        F: Future<Output = ()> + Send + 'static,
267    {
268        get_tokio_runtime().spawn(async move {
269            fut.await;
270        })
271    }
272}
273
274tokio::task_local! {
275    static TASK_LOCALS: UnsyncOnceCell<TaskLocals>;
276}
277
278impl pyo3_async_runtimes::generic::ContextExt for SimpleRuntime {
279    fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
280    where
281        F: Future<Output = R> + Send + 'static,
282    {
283        let cell = UnsyncOnceCell::new();
284        cell.set(locals).unwrap();
285
286        Box::pin(TASK_LOCALS.scope(cell, fut))
287    }
288
289    fn get_task_locals() -> Option<TaskLocals> {
290        TASK_LOCALS
291            .try_with(|c| {
292                c.get()
293                    .map(|locals| monarch_with_gil_blocking(|py| locals.clone_ref(py)))
294            })
295            .unwrap_or_default()
296    }
297}
298
299pub fn future_into_py<F, T>(py: Python, fut: F) -> PyResult<Bound<PyAny>>
300where
301    F: Future<Output = PyResult<T>> + Send + 'static,
302    T: for<'py> IntoPyObject<'py>,
303{
304    pyo3_async_runtimes::generic::future_into_py::<SimpleRuntime, F, T>(py, fut)
305}
306
307/// Global lock to serialize GIL acquisition from Rust threads in async contexts.
308///
309/// Under high concurrency, many async tasks can simultaneously try to acquire the GIL.
310/// Each call blocks the current tokio worker thread, which can cause runtime starvation
311/// and apparent deadlocks (nothing else gets polled).
312///
313/// This wrapper serializes GIL acquisition among callers that opt in, so at most one
314/// tokio task is blocked in `Python::attach` at a time, improving fairness under
315/// contention.
316///
317/// Note: this does not globally prevent other sync code from calling `Python::attach`
318/// directly. Use `monarch_with_gil` or `monarch_with_gil_blocking` for Python interaction
319/// that occurs on async hot paths.
320static GIL_LOCK: Lazy<tokio::sync::Mutex<()>> = Lazy::new(|| tokio::sync::Mutex::new(()));
321
322// Thread-local depth counter for re-entrant GIL acquisition.
323//
324// This tracks when we're already inside a `monarch_with_gil` or `monarch_with_gil_blocking`
325// call. On re-entry (e.g., when Python calls back into Rust while we're already executing
326// under `Python::attach`), we bypass the `GIL_LOCK` to avoid deadlocks.
327//
328// Without this, the following scenario would deadlock:
329// 1. Rust async code calls `monarch_with_gil`, acquires `GIL_LOCK`
330// 2. Inside the closure, Python code is executed
331// 3. Python code calls back into Rust (e.g., via a PyO3 callback)
332// 4. The callback tries to call `monarch_with_gil` again
333// 5. DEADLOCK: waiting for `GIL_LOCK` which is held by the same logical call chain
334thread_local! {
335    static GIL_DEPTH: std::cell::Cell<u32> = const { std::cell::Cell::new(0) };
336}
337
338/// RAII guard that decrements the GIL depth counter when dropped.
339struct GilDepthGuard {
340    prev_depth: u32,
341}
342
343impl Drop for GilDepthGuard {
344    fn drop(&mut self) {
345        GIL_DEPTH.with(|d| d.set(self.prev_depth));
346    }
347}
348
349/// Increments the GIL depth counter and returns a guard that restores it on drop.
350fn increment_gil_depth() -> GilDepthGuard {
351    let prev_depth = GIL_DEPTH.with(|d| {
352        let current = d.get();
353        d.set(current + 1);
354        current
355    });
356    GilDepthGuard { prev_depth }
357}
358
359/// Returns true if we're already inside a `monarch_with_gil` call (re-entrant).
360fn is_reentrant() -> bool {
361    GIL_DEPTH.with(|d| d.get() > 0)
362}
363
364/// Async wrapper around `Python::attach` intended for async call sites.
365///
366/// Why: under high concurrency, many async tasks can simultaneously
367/// try to acquire the GIL. Each call blocks the current tokio worker
368/// thread, which can cause runtime starvation / apparent deadlocks
369/// (nothing else gets polled).
370///
371/// This wrapper serializes GIL acquisition among async callers so at most one tokio
372/// task is blocked in `Python::attach` at a time, preventing runtime starvation
373/// under GIL contention.
374///
375/// Note: this does not globally prevent other sync code from calling
376/// `Python::attach` directly. Use this wrapper for Python
377/// interaction that occurs on async hot paths.
378///
379/// # Re-entrancy Safety
380///
381/// This function is re-entrant safe. If called while already inside a `monarch_with_gil`
382/// or `monarch_with_gil_blocking` call (e.g., from a Python→Rust callback), it bypasses
383/// the `GIL_LOCK` to avoid deadlocks.
384///
385/// # Example
386/// ```ignore
387/// let result = monarch_with_gil(|py| {
388///     // Do work with Python GIL
389///     Ok(42)
390/// })
391/// .await?;
392/// ```
393pub async fn monarch_with_gil<F, R>(f: F) -> R
394where
395    F: for<'py> FnOnce(Python<'py>) -> R + Send,
396{
397    // If we're already inside a monarch_with_gil call (re-entrant), skip the lock
398    // to avoid deadlock from Python→Rust callbacks
399    if is_reentrant() {
400        let _depth_guard = increment_gil_depth();
401        return Python::attach(f);
402    }
403
404    // Not re-entrant: acquire the serialization lock
405    let _lock_guard = GIL_LOCK.lock().await;
406    let _depth_guard = increment_gil_depth();
407    Python::attach(f)
408}
409
410/// Blocking wrapper around `Python::with_gil` for use in synchronous contexts.
411///
412/// Unlike `monarch_with_gil`, this function does NOT use the `GIL_LOCK` async mutex.
413/// Since it is blocking call, it simply acquires the GIL and releases it when the
414/// closure returns.
415///
416/// # Example
417/// ```ignore
418/// let result = monarch_with_gil_blocking(|py| {
419///     // Do work with Python GIL
420///     Ok(42)
421/// })?;
422/// ```
423pub fn monarch_with_gil_blocking<F, R>(f: F) -> R
424where
425    F: for<'py> FnOnce(Python<'py>) -> R + Send,
426{
427    let _depth_guard = increment_gil_depth();
428    Python::attach(f)
429}