monarch_hyperactor/
pympsc.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! A mpsc channel that can is used to send messages from Rust to Python without acquiring
10//! the GIL on the sender side.
11
12use std::sync::Arc;
13use std::sync::Mutex;
14use std::sync::mpsc;
15
16use monarch_types::MapPyErr;
17use pyo3::Bound;
18use pyo3::IntoPyObject;
19use pyo3::IntoPyObjectExt;
20use pyo3::Py;
21use pyo3::PyAny;
22use pyo3::PyResult;
23use pyo3::Python;
24use pyo3::pyclass;
25use pyo3::pymethods;
26use pyo3::types::PyModule;
27use pyo3::types::PyModuleMethods;
28use pyo3::wrap_pyfunction;
29
30use crate::py_cell::PyCell;
31use crate::pywaker::PyEvent;
32use crate::pywaker::{self};
33
34/// Create a new channel with a Rust sender and a Python receiver.
35pub fn channel() -> Result<(Sender, PyReceiver), nix::Error> {
36    let (tx, rx) = mpsc::channel();
37    let rx = Arc::new(Mutex::new(rx));
38
39    let (waker, event) = pywaker::event()?;
40
41    Ok((
42        Sender {
43            tx,
44            waker: Arc::new(waker),
45        },
46        PyReceiver {
47            rx,
48            event: PyCell::new(event),
49        },
50    ))
51}
52
53/// A blanket trait used to convert boxed objects into python objects.
54pub trait IntoPyObjectBox: Send {
55    fn into_py_object(self: Box<Self>, py: Python<'_>) -> PyResult<Py<PyAny>>;
56}
57
58impl<T> IntoPyObjectBox for T
59where
60    T: for<'py> IntoPyObject<'py> + Send,
61{
62    fn into_py_object(self: Box<Self>, py: Python<'_>) -> PyResult<Py<PyAny>> {
63        (*self).into_py_any(py)
64    }
65}
66
67/// Error type for send operations
68#[derive(Debug)]
69pub struct SendError;
70
71impl std::fmt::Display for SendError {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        write!(f, "SendError")
74    }
75}
76
77impl std::error::Error for SendError {}
78
79/// A channel that can be used to send messages from Rust to Python without acquiring
80/// the GIL on the sender side.
81#[derive(Clone)]
82pub struct Sender {
83    tx: mpsc::Sender<Box<dyn IntoPyObjectBox>>,
84    waker: Arc<pywaker::Waker>,
85}
86
87impl std::fmt::Debug for Sender {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("Sender").finish_non_exhaustive()
90    }
91}
92
93impl Sender {
94    /// Send a message to the channel. The object must be convertible to a Python object;
95    /// conversion is deferred until the message is received in a Python context.
96    pub fn send<T>(&self, msg: T) -> Result<(), SendError>
97    where
98        T: IntoPyObjectBox + Send + 'static,
99    {
100        self.tx.send(Box::new(msg)).map_err(|_| SendError)?;
101        let _ = self.waker.wake();
102        Ok(())
103    }
104}
105
106/// The receiver side of a channel. Objects are converted to Python heap objects when
107/// they are received.
108#[pyclass(name = "Receiver", module = "monarch._src.actor.mpsc")]
109pub struct PyReceiver {
110    rx: Arc<Mutex<mpsc::Receiver<Box<dyn IntoPyObjectBox>>>>,
111    event: PyCell<PyEvent>,
112}
113
114impl std::fmt::Debug for PyReceiver {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        f.debug_struct("PyReceiver").finish_non_exhaustive()
117    }
118}
119
120#[pymethods]
121impl PyReceiver {
122    fn try_recv(&self, py: Python<'_>) -> PyResult<Option<Py<PyAny>>> {
123        match self.rx.lock().unwrap().try_recv() {
124            Ok(boxed_msg) => Ok(Some(boxed_msg.into_py_object(py)?)),
125            Err(mpsc::TryRecvError::Empty) => Ok(None),
126            Err(mpsc::TryRecvError::Disconnected) => {
127                Err(pyo3::exceptions::PyEOFError::new_err("Channel closed"))
128            }
129        }
130    }
131
132    fn _event<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyEvent>> {
133        let py_event = self.event.clone_ref(py)?;
134        Ok(py_event.bind(py).clone())
135    }
136}
137
138mod testing {
139    use pyo3::pyfunction;
140    use pyo3::types::PyAnyMethods;
141
142    use super::*;
143
144    // NOTE: We can't use a Python calss name that starts with "Test" since
145    // during Python testing, Pytest will inspect anything that starts with
146    // "Test" and check if its callable which in pyo3 >= 0.26 will raise
147    // a TypeError.
148    #[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.pympsc")]
149    struct PyTestSender {
150        sender: Arc<Mutex<Sender>>,
151    }
152
153    #[pymethods]
154    impl PyTestSender {
155        fn send(&self, _py: Python<'_>, obj: Py<PyAny>) -> PyResult<()> {
156            self.sender.lock().unwrap().send(obj).map_pyerr()?;
157            Ok(())
158        }
159    }
160
161    #[pyfunction]
162    fn channel_for_test(_py: Python<'_>) -> PyResult<(PyTestSender, PyReceiver)> {
163        let (tx, rx) = channel().map_pyerr()?;
164        let tx = PyTestSender {
165            sender: Arc::new(Mutex::new(tx)),
166        };
167        Ok((tx, rx))
168    }
169
170    pub(super) fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
171        hyperactor_mod.add_class::<PyTestSender>()?;
172        let channel_for_test = wrap_pyfunction!(channel_for_test, hyperactor_mod)?;
173        channel_for_test.setattr(
174            "__module__",
175            "monarch._rust_bindings.monarch_hyperactor.pympsc",
176        )?;
177
178        hyperactor_mod.add_function(channel_for_test)?;
179        Ok(())
180    }
181}
182
183pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
184    hyperactor_mod.add_class::<PyReceiver>()?;
185    testing::register_python_bindings(hyperactor_mod)?;
186    Ok(())
187}