monarch_hyperactor/
pympsc.rs1use std::sync::Arc;
13use std::sync::Mutex;
14use std::sync::mpsc;
15
16use monarch_types::MapPyErr;
17use pyo3::Bound;
18use pyo3::IntoPyObject;
19use pyo3::IntoPyObjectExt;
20use pyo3::Py;
21use pyo3::PyAny;
22use pyo3::PyResult;
23use pyo3::Python;
24use pyo3::pyclass;
25use pyo3::pymethods;
26use pyo3::types::PyModule;
27use pyo3::types::PyModuleMethods;
28use pyo3::wrap_pyfunction;
29
30use crate::py_cell::PyCell;
31use crate::pywaker::PyEvent;
32use crate::pywaker::{self};
33
34pub fn channel() -> Result<(Sender, PyReceiver), nix::Error> {
36 let (tx, rx) = mpsc::channel();
37 let rx = Arc::new(Mutex::new(rx));
38
39 let (waker, event) = pywaker::event()?;
40
41 Ok((
42 Sender {
43 tx,
44 waker: Arc::new(waker),
45 },
46 PyReceiver {
47 rx,
48 event: PyCell::new(event),
49 },
50 ))
51}
52
53pub trait IntoPyObjectBox: Send {
55 fn into_py_object(self: Box<Self>, py: Python<'_>) -> PyResult<Py<PyAny>>;
56}
57
58impl<T> IntoPyObjectBox for T
59where
60 T: for<'py> IntoPyObject<'py> + Send,
61{
62 fn into_py_object(self: Box<Self>, py: Python<'_>) -> PyResult<Py<PyAny>> {
63 (*self).into_py_any(py)
64 }
65}
66
67#[derive(Debug)]
69pub struct SendError;
70
71impl std::fmt::Display for SendError {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 write!(f, "SendError")
74 }
75}
76
77impl std::error::Error for SendError {}
78
79#[derive(Clone)]
82pub struct Sender {
83 tx: mpsc::Sender<Box<dyn IntoPyObjectBox>>,
84 waker: Arc<pywaker::Waker>,
85}
86
87impl std::fmt::Debug for Sender {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("Sender").finish_non_exhaustive()
90 }
91}
92
93impl Sender {
94 pub fn send<T>(&self, msg: T) -> Result<(), SendError>
97 where
98 T: IntoPyObjectBox + Send + 'static,
99 {
100 self.tx.send(Box::new(msg)).map_err(|_| SendError)?;
101 let _ = self.waker.wake();
102 Ok(())
103 }
104}
105
106#[pyclass(name = "Receiver", module = "monarch._src.actor.mpsc")]
109pub struct PyReceiver {
110 rx: Arc<Mutex<mpsc::Receiver<Box<dyn IntoPyObjectBox>>>>,
111 event: PyCell<PyEvent>,
112}
113
114impl std::fmt::Debug for PyReceiver {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 f.debug_struct("PyReceiver").finish_non_exhaustive()
117 }
118}
119
120#[pymethods]
121impl PyReceiver {
122 fn try_recv(&self, py: Python<'_>) -> PyResult<Option<Py<PyAny>>> {
123 match self.rx.lock().unwrap().try_recv() {
124 Ok(boxed_msg) => Ok(Some(boxed_msg.into_py_object(py)?)),
125 Err(mpsc::TryRecvError::Empty) => Ok(None),
126 Err(mpsc::TryRecvError::Disconnected) => {
127 Err(pyo3::exceptions::PyEOFError::new_err("Channel closed"))
128 }
129 }
130 }
131
132 fn _event<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyEvent>> {
133 let py_event = self.event.clone_ref(py)?;
134 Ok(py_event.bind(py).clone())
135 }
136}
137
138mod testing {
139 use pyo3::pyfunction;
140 use pyo3::types::PyAnyMethods;
141
142 use super::*;
143
144 #[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.pympsc")]
149 struct PyTestSender {
150 sender: Arc<Mutex<Sender>>,
151 }
152
153 #[pymethods]
154 impl PyTestSender {
155 fn send(&self, _py: Python<'_>, obj: Py<PyAny>) -> PyResult<()> {
156 self.sender.lock().unwrap().send(obj).map_pyerr()?;
157 Ok(())
158 }
159 }
160
161 #[pyfunction]
162 fn channel_for_test(_py: Python<'_>) -> PyResult<(PyTestSender, PyReceiver)> {
163 let (tx, rx) = channel().map_pyerr()?;
164 let tx = PyTestSender {
165 sender: Arc::new(Mutex::new(tx)),
166 };
167 Ok((tx, rx))
168 }
169
170 pub(super) fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
171 hyperactor_mod.add_class::<PyTestSender>()?;
172 let channel_for_test = wrap_pyfunction!(channel_for_test, hyperactor_mod)?;
173 channel_for_test.setattr(
174 "__module__",
175 "monarch._rust_bindings.monarch_hyperactor.pympsc",
176 )?;
177
178 hyperactor_mod.add_function(channel_for_test)?;
179 Ok(())
180 }
181}
182
183pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
184 hyperactor_mod.add_class::<PyReceiver>()?;
185 testing::register_python_bindings(hyperactor_mod)?;
186 Ok(())
187}