monarch_hyperactor/
runtime.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::OnceLock;
12use std::sync::RwLock;
13use std::sync::RwLockReadGuard;
14use std::sync::atomic::AtomicUsize;
15use std::sync::atomic::Ordering;
16use std::time::Duration;
17
18use anyhow::Result;
19use hyperactor::Proc;
20use hyperactor::channel::ChannelAddr;
21use hyperactor::channel::ChannelTransport;
22use hyperactor::mailbox::BoxedMailboxSender;
23use hyperactor::mailbox::PanickingMailboxSender;
24use hyperactor::reference;
25use once_cell::sync::Lazy;
26use once_cell::unsync::OnceCell as UnsyncOnceCell;
27use pyo3::PyResult;
28use pyo3::Python;
29use pyo3::exceptions::PyRuntimeError;
30use pyo3::prelude::*;
31use pyo3::types::PyAnyMethods;
32use pyo3_async_runtimes::TaskLocals;
33use tokio::task;
34
35// this must be a RwLock and only return a guard for reading the runtime.
36// Otherwise multiple threads can deadlock fighting for the Runtime object if they hold it
37// while blocking on something.
38static INSTANCE: std::sync::LazyLock<RwLock<Option<tokio::runtime::Runtime>>> =
39    std::sync::LazyLock::new(|| RwLock::new(None));
40
41pub fn get_tokio_runtime<'l>() -> std::sync::MappedRwLockReadGuard<'l, tokio::runtime::Runtime> {
42    // First try to get a read lock and check if runtime exists
43    {
44        let read_guard = INSTANCE.read().unwrap();
45        if read_guard.is_some() {
46            return RwLockReadGuard::map(read_guard, |lock: &Option<tokio::runtime::Runtime>| {
47                lock.as_ref().unwrap()
48            });
49        }
50        // Drop the read lock by letting it go out of scope
51    }
52
53    // Runtime doesn't exist, upgrade to write lock to initialize
54    let mut write_guard = INSTANCE.write().unwrap();
55    if write_guard.is_none() {
56        *write_guard = Some(
57            tokio::runtime::Builder::new_multi_thread()
58                .thread_name_fn(|| {
59                    static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
60                    let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
61                    format!("monarch-pytokio-worker-{}", id)
62                })
63                .enable_all()
64                .build()
65                .unwrap(),
66        );
67    }
68
69    // Downgrade write lock to read lock and return the reference
70    let read_guard = std::sync::RwLockWriteGuard::downgrade(write_guard);
71    RwLockReadGuard::map(read_guard, |lock: &Option<tokio::runtime::Runtime>| {
72        lock.as_ref().unwrap()
73    })
74}
75
76#[pyfunction]
77pub fn shutdown_tokio_runtime(py: Python<'_>) {
78    // Called from Python's atexit, which holds the GIL. Release it so tokio
79    // worker threads can acquire it to complete their Python work.
80    py.detach(|| {
81        if let Some(x) = INSTANCE.write().unwrap().take() {
82            x.shutdown_timeout(Duration::from_secs(1));
83        }
84    });
85}
86
87/// A global runtime proc used by this crate.
88pub(crate) fn get_proc_runtime() -> &'static Proc {
89    static RUNTIME_PROC: OnceLock<Proc> = OnceLock::new();
90    RUNTIME_PROC.get_or_init(|| {
91        let addr = ChannelAddr::any(ChannelTransport::Local);
92        let proc_id = reference::ProcId::unique(addr, "monarch_hyperactor_runtime");
93        Proc::configured(proc_id, BoxedMailboxSender::new(PanickingMailboxSender))
94    })
95}
96
97/// Stores the native thread ID of the main Python thread.
98/// This is lazily initialized on first call to `is_main_thread`.
99static MAIN_THREAD_NATIVE_ID: OnceLock<i64> = OnceLock::new();
100
101/// Returns the native thread ID of the main Python thread.
102/// On first call, looks it up via `threading.main_thread().native_id`.
103fn get_main_thread_native_id() -> i64 {
104    *MAIN_THREAD_NATIVE_ID.get_or_init(|| {
105        Python::attach(|py| {
106            let threading = py.import("threading").expect("failed to import threading");
107            let main_thread = threading
108                .call_method0("main_thread")
109                .expect("failed to get main_thread");
110            main_thread
111                .getattr("native_id")
112                .expect("failed to get native_id")
113                .extract::<i64>()
114                .expect("native_id is not an i64")
115        })
116    })
117}
118
119/// Returns the current thread's native ID in a cross-platform way.
120#[cfg(target_os = "linux")]
121fn get_current_thread_id() -> i64 {
122    nix::unistd::gettid().as_raw() as i64
123}
124
125/// Returns the current thread's native ID in a cross-platform way.
126#[cfg(target_os = "macos")]
127fn get_current_thread_id() -> i64 {
128    let mut tid: u64 = 0;
129    // pthread_threadid_np with thread=0 (null pthread_t) gets the current thread's ID.
130    unsafe {
131        let ret = libc::pthread_threadid_np(0, &mut tid);
132        debug_assert_eq!(
133            ret, 0,
134            "pthread_threadid_np failed with error code: {}",
135            ret
136        );
137    }
138    // macOS thread IDs are u64 so we need to convert to i64.
139    debug_assert!(tid <= i64::MAX as u64, "thread ID {} exceeds i64::MAX", tid);
140    tid as i64
141}
142
143/// Returns the current thread's native ID in a cross-platform way.
144#[cfg(not(any(target_os = "linux", target_os = "macos")))]
145compile_error!("get_current_thread_id is only implemented for Linux and macOS");
146
147/// Returns true if the current thread is the main Python thread.
148/// Compares the current thread's native ID against the main Python thread's native ID.
149pub fn is_main_thread() -> bool {
150    let current_tid = get_current_thread_id();
151    current_tid == get_main_thread_native_id()
152}
153
154pub fn initialize(py: Python) -> Result<()> {
155    // Eagerly initialize the main thread ID while we're on the main thread
156    // with the GIL held. If this were lazily initialized on a background
157    // tokio thread during shutdown, the `py.import("threading")` call inside
158    // get_main_thread_native_id() would trigger module_from_spec on a
159    // partially-finalized interpreter, causing a segfault.
160    let _ = get_main_thread_native_id();
161
162    let atexit = py.import("atexit")?;
163    let shutdown_fn = wrap_pyfunction!(shutdown_tokio_runtime, py)?;
164    atexit.call_method1("register", (shutdown_fn,))?;
165    Ok(())
166}
167
168/// Block the current thread on a future, but make sure to check for signals
169/// originating from the Python signal handler.
170///
171/// Python's signal handler just sets a flag that it expects the Python
172/// interpreter to handle later via a call to `PyErr_CheckSignals`. When we
173/// enter into potentially long-running native code, we need to make sure to be
174/// checking for signals frequently, otherwise we will ignore them. This will
175/// manifest as `ctrl-C` not doing anything.
176///
177/// One additional wrinkle is that `PyErr_CheckSignals` only works on the main
178/// Python thread; if it's called on any other thread it silently does nothing.
179/// So, we check if we're on the main thread by comparing native thread IDs.
180pub fn signal_safe_block_on<F>(py: Python, future: F) -> PyResult<F::Output>
181where
182    F: Future + Send + 'static,
183    F::Output: Send + 'static,
184{
185    let runtime = get_tokio_runtime();
186    // Release the GIL, otherwise the work in `future` that tries to acquire the
187    // GIL on another thread may deadlock.
188    py.detach(|| {
189        if is_main_thread() {
190            // Spawn the future onto the tokio runtime
191            let handle = runtime.spawn(future);
192            // Block the current thread on waiting for *either* the future to
193            // complete or a signal.
194            runtime.block_on(async {
195                tokio::select! {
196                    result = handle => result.map_err(|e| PyRuntimeError::new_err(format!("JoinErr: {:?}", e))),
197                    signal = async {
198                        let sleep_for = std::time::Duration::from_millis(100);
199                        loop {
200                            // Acquiring the GIL in a loop is sad, hopefully once
201                            // every 100ms is fine.
202                            Python::attach(|py| py.check_signals())?;
203                            tokio::time::sleep(sleep_for).await;
204                        }
205                    } => signal
206                }
207            })
208        } else {
209            // If we're not on the main thread, we can just block it. We've
210            // released the GIL, so the Python main thread will continue on, and
211            // `PyErr_CheckSignals` doesn't do anything anyway.
212            Ok(runtime.block_on(future))
213        }
214    })
215}
216
217/// A test function that sleeps indefinitely in a loop.
218/// This is used for testing signal handling in signal_safe_block_on.
219/// The function will sleep forever until interrupted by a signal.
220#[pyfunction]
221pub fn sleep_indefinitely_for_unit_tests(py: Python) -> PyResult<()> {
222    // Create a future that sleeps indefinitely
223    let future = async {
224        loop {
225            tracing::info!("idef sleeping for 100ms");
226            tokio::time::sleep(Duration::from_millis(100)).await;
227        }
228    };
229
230    // Use signal_safe_block_on to run the future, which should make it
231    // interruptible by signals like SIGINT
232    signal_safe_block_on(py, future)
233}
234
235/// Initialize the runtime module and expose Python functions
236pub fn register_python_bindings(runtime_mod: &Bound<'_, PyModule>) -> PyResult<()> {
237    let sleep_indefinitely_fn =
238        wrap_pyfunction!(sleep_indefinitely_for_unit_tests, runtime_mod.py())?;
239    sleep_indefinitely_fn.setattr(
240        "__module__",
241        "monarch._rust_bindings.monarch_hyperactor.runtime",
242    )?;
243    runtime_mod.add_function(sleep_indefinitely_fn)?;
244    Ok(())
245}
246
247struct SimpleRuntime;
248
249impl pyo3_async_runtimes::generic::Runtime for SimpleRuntime {
250    type JoinError = task::JoinError;
251    type JoinHandle = task::JoinHandle<()>;
252
253    fn spawn<F>(fut: F) -> Self::JoinHandle
254    where
255        F: Future<Output = ()> + Send + 'static,
256    {
257        get_tokio_runtime().spawn(async move {
258            fut.await;
259        })
260    }
261}
262
263tokio::task_local! {
264    static TASK_LOCALS: UnsyncOnceCell<TaskLocals>;
265}
266
267impl pyo3_async_runtimes::generic::ContextExt for SimpleRuntime {
268    fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
269    where
270        F: Future<Output = R> + Send + 'static,
271    {
272        let cell = UnsyncOnceCell::new();
273        cell.set(locals).unwrap();
274
275        Box::pin(TASK_LOCALS.scope(cell, fut))
276    }
277
278    fn get_task_locals() -> Option<TaskLocals> {
279        TASK_LOCALS
280            .try_with(|c| {
281                c.get()
282                    .map(|locals| monarch_with_gil_blocking(|py| locals.clone_ref(py)))
283            })
284            .unwrap_or_default()
285    }
286}
287
288pub fn future_into_py<F, T>(py: Python, fut: F) -> PyResult<Bound<PyAny>>
289where
290    F: Future<Output = PyResult<T>> + Send + 'static,
291    T: for<'py> IntoPyObject<'py>,
292{
293    pyo3_async_runtimes::generic::future_into_py::<SimpleRuntime, F, T>(py, fut)
294}
295
296/// Global lock to serialize GIL acquisition from Rust threads in async contexts.
297///
298/// Under high concurrency, many async tasks can simultaneously try to acquire the GIL.
299/// Each call blocks the current tokio worker thread, which can cause runtime starvation
300/// and apparent deadlocks (nothing else gets polled).
301///
302/// This wrapper serializes GIL acquisition among callers that opt in, so at most one
303/// tokio task is blocked in `Python::attach` at a time, improving fairness under
304/// contention.
305///
306/// Note: this does not globally prevent other sync code from calling `Python::attach`
307/// directly. Use `monarch_with_gil` or `monarch_with_gil_blocking` for Python interaction
308/// that occurs on async hot paths.
309static GIL_LOCK: Lazy<tokio::sync::Mutex<()>> = Lazy::new(|| tokio::sync::Mutex::new(()));
310
311// Thread-local depth counter for re-entrant GIL acquisition.
312//
313// This tracks when we're already inside a `monarch_with_gil` or `monarch_with_gil_blocking`
314// call. On re-entry (e.g., when Python calls back into Rust while we're already executing
315// under `Python::attach`), we bypass the `GIL_LOCK` to avoid deadlocks.
316//
317// Without this, the following scenario would deadlock:
318// 1. Rust async code calls `monarch_with_gil`, acquires `GIL_LOCK`
319// 2. Inside the closure, Python code is executed
320// 3. Python code calls back into Rust (e.g., via a PyO3 callback)
321// 4. The callback tries to call `monarch_with_gil` again
322// 5. DEADLOCK: waiting for `GIL_LOCK` which is held by the same logical call chain
323thread_local! {
324    static GIL_DEPTH: std::cell::Cell<u32> = const { std::cell::Cell::new(0) };
325}
326
327/// RAII guard that decrements the GIL depth counter when dropped.
328struct GilDepthGuard {
329    prev_depth: u32,
330}
331
332impl Drop for GilDepthGuard {
333    fn drop(&mut self) {
334        GIL_DEPTH.with(|d| d.set(self.prev_depth));
335    }
336}
337
338/// Increments the GIL depth counter and returns a guard that restores it on drop.
339fn increment_gil_depth() -> GilDepthGuard {
340    let prev_depth = GIL_DEPTH.with(|d| {
341        let current = d.get();
342        d.set(current + 1);
343        current
344    });
345    GilDepthGuard { prev_depth }
346}
347
348/// Returns true if we're already inside a `monarch_with_gil` call (re-entrant).
349fn is_reentrant() -> bool {
350    GIL_DEPTH.with(|d| d.get() > 0)
351}
352
353/// Async wrapper around `Python::attach` intended for async call sites.
354///
355/// Why: under high concurrency, many async tasks can simultaneously
356/// try to acquire the GIL. Each call blocks the current tokio worker
357/// thread, which can cause runtime starvation / apparent deadlocks
358/// (nothing else gets polled).
359///
360/// This wrapper serializes GIL acquisition among async callers so at most one tokio
361/// task is blocked in `Python::attach` at a time, preventing runtime starvation
362/// under GIL contention.
363///
364/// Note: this does not globally prevent other sync code from calling
365/// `Python::attach` directly. Use this wrapper for Python
366/// interaction that occurs on async hot paths.
367///
368/// # Re-entrancy Safety
369///
370/// This function is re-entrant safe. If called while already inside a `monarch_with_gil`
371/// or `monarch_with_gil_blocking` call (e.g., from a Python→Rust callback), it bypasses
372/// the `GIL_LOCK` to avoid deadlocks.
373///
374/// # Example
375/// ```ignore
376/// let result = monarch_with_gil(|py| {
377///     // Do work with Python GIL
378///     Ok(42)
379/// })
380/// .await?;
381/// ```
382pub async fn monarch_with_gil<F, R>(f: F) -> R
383where
384    F: for<'py> FnOnce(Python<'py>) -> R + Send,
385{
386    // If we're already inside a monarch_with_gil call (re-entrant), skip the lock
387    // to avoid deadlock from Python→Rust callbacks
388    if is_reentrant() {
389        let _depth_guard = increment_gil_depth();
390        return Python::attach(f);
391    }
392
393    // Not re-entrant: acquire the serialization lock
394    let _lock_guard = GIL_LOCK.lock().await;
395    let _depth_guard = increment_gil_depth();
396    Python::attach(f)
397}
398
399/// Blocking wrapper around `Python::with_gil` for use in synchronous contexts.
400///
401/// Unlike `monarch_with_gil`, this function does NOT use the `GIL_LOCK` async mutex.
402/// Since it is blocking call, it simply acquires the GIL and releases it when the
403/// closure returns.
404///
405/// # Example
406/// ```ignore
407/// let result = monarch_with_gil_blocking(|py| {
408///     // Do work with Python GIL
409///     Ok(42)
410/// })?;
411/// ```
412pub fn monarch_with_gil_blocking<F, R>(f: F) -> R
413where
414    F: for<'py> FnOnce(Python<'py>) -> R + Send,
415{
416    let _depth_guard = increment_gil_depth();
417    Python::attach(f)
418}