monarch_tensor_worker/
py_pipe.rs1use std::collections::HashMap;
10
11use monarch_messages::worker::ResolvableFunction;
12use monarch_types::PyTree;
13use monarch_types::TryIntoPyObjectUnsafe;
14use pyo3::prelude::*;
15use pyo3::types::PyTuple;
16use torch_sys::RValue;
17
18use crate::pipe::Pipe;
19
20#[pyclass]
22pub struct PyPipe {
23 pipe: Box<dyn Pipe<PyTree<RValue>> + Send + Sync>,
24 #[pyo3(get)]
25 ranks: HashMap<String, usize>,
26 #[pyo3(get)]
27 sizes: HashMap<String, usize>,
28 allow_unsafe_obj_conversion: bool,
29}
30
31impl PyPipe {
32 pub fn new(
33 pipe: Box<dyn Pipe<PyTree<RValue>> + Send + Sync>,
34 ranks: HashMap<String, usize>,
35 sizes: HashMap<String, usize>,
36 allow_unsafe_obj_conversion: bool,
37 ) -> Self {
38 Self {
39 pipe,
40 ranks,
41 sizes,
42 allow_unsafe_obj_conversion,
43 }
44 }
45}
46
47#[pymethods]
48impl PyPipe {
49 fn send(&mut self, py: Python<'_>, value: &Bound<'_, PyAny>) -> PyResult<()> {
50 let val = value.extract::<PyTree<RValue>>()?;
51 py.allow_threads(move || self.pipe.send(val))?;
52 Ok(())
53 }
54
55 fn recv<'a>(&mut self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
56 let val = py.allow_threads(|| self.pipe.recv())?;
57 if self.allow_unsafe_obj_conversion {
58 unsafe { val.try_to_object_unsafe(py) }
61 } else {
62 val.into_pyobject(py)
63 }
64 }
65}
66
67pub fn run_py_pipe(
70 pipe: PyPipe,
71 func: ResolvableFunction,
72 args: Vec<PyTree<RValue>>,
73 kwargs: HashMap<String, PyTree<RValue>>,
74) -> PyResult<()> {
75 Python::with_gil(|py| {
76 let pipe_obj: Py<PyPipe> = Py::new(py, pipe)?;
77 let func = func.resolve(py)?;
78 let mut py_args = vec![pipe_obj.into_bound(py).into_any()];
79 py_args.extend(
80 args.into_iter()
81 .map(|a| a.into_pyobject(py))
82 .collect::<Result<Vec<_>, _>>()?,
83 );
84 func.call(PyTuple::new(py, py_args)?, Some(&kwargs.into_pyobject(py)?))?;
85 Ok(())
86 })
87}
88
89#[cfg(test)]
90mod tests {
91 use std::assert_matches::assert_matches;
92 use std::collections::HashMap;
93
94 use anyhow::Result;
95 use futures::try_join;
96 use indoc::indoc;
97 use pyo3::Python;
98 use pyo3::ffi::c_str;
99 use pyo3::types::PyModule;
100 use timed_test::async_timed_test;
101 use torch_sys::RValue;
102
103 use super::PyPipe;
104 use super::run_py_pipe;
105 use crate::pipe::AsyncPipe;
106 use crate::pipe::create_local_pipe;
107
108 #[async_timed_test(timeout_secs = 60)]
109 async fn test_py_pipe() -> Result<()> {
110 pyo3::prepare_freethreaded_python();
111 Python::with_gil(|py| py.run(c_str!("import torch"), None, None))?;
114
115 Python::with_gil(|py| {
117 let _mod = PyModule::from_code(
118 py,
119 c_str!(indoc! {r#"
120 def func(pipe):
121 val = pipe.recv()
122 pipe.send(val)
123 "#}),
124 c_str!("test_helpers.py"),
125 c_str!("test_helpers"),
126 )?;
127 anyhow::Ok(())
128 })?;
129
130 let (mut client, server) = create_local_pipe();
131 let ((), ()) = try_join!(
132 async move {
134 tokio::task::spawn_blocking(move || {
135 run_py_pipe(
136 PyPipe::new(
137 Box::new(server),
138 HashMap::new(),
139 HashMap::new(),
140 false, ),
142 "test_helpers.func".into(),
143 vec![],
144 HashMap::new(),
145 )
146 })
147 .await??;
148 anyhow::Ok(())
149 },
150 async move {
152 client.send(RValue::Int(3).into()).await?;
153 let val = client.recv().await?;
154 assert_matches!(val.into_leaf().unwrap(), RValue::Int(3));
155 anyhow::Ok(())
156 },
157 )?;
158
159 Ok(())
160 }
161}