monarch_tensor_worker/
py_pipe.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10
11use monarch_messages::worker::ResolvableFunction;
12use monarch_types::PyTree;
13use monarch_types::TryIntoPyObjectUnsafe;
14use pyo3::prelude::*;
15use pyo3::types::PyTuple;
16use torch_sys::RValue;
17
18use crate::pipe::Pipe;
19
20/// Wrapper around `Pipe` to make it usable in Python.
21#[pyclass]
22pub struct PyPipe {
23    pipe: Box<dyn Pipe<PyTree<RValue>> + Send + Sync>,
24    #[pyo3(get)]
25    ranks: HashMap<String, usize>,
26    #[pyo3(get)]
27    sizes: HashMap<String, usize>,
28    allow_unsafe_obj_conversion: bool,
29}
30
31impl PyPipe {
32    pub fn new(
33        pipe: Box<dyn Pipe<PyTree<RValue>> + Send + Sync>,
34        ranks: HashMap<String, usize>,
35        sizes: HashMap<String, usize>,
36        allow_unsafe_obj_conversion: bool,
37    ) -> Self {
38        Self {
39            pipe,
40            ranks,
41            sizes,
42            allow_unsafe_obj_conversion,
43        }
44    }
45}
46
47#[pymethods]
48impl PyPipe {
49    fn send(&mut self, py: Python<'_>, value: &Bound<'_, PyAny>) -> PyResult<()> {
50        let val = value.extract::<PyTree<RValue>>()?;
51        py.allow_threads(move || self.pipe.send(val))?;
52        Ok(())
53    }
54
55    fn recv<'a>(&mut self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
56        let val = py.allow_threads(|| self.pipe.recv())?;
57        if self.allow_unsafe_obj_conversion {
58            // SAFETY: A caller who initialized this PyPipe with allow_unsafe_obj_conversion=True
59            // asserts that it is safe to use this unsafe method.
60            unsafe { val.try_to_object_unsafe(py) }
61        } else {
62            val.into_pyobject(py)
63        }
64    }
65}
66
67/// Run a Python pipe server, which loads a remote function sent over the pipe
68/// then delegates to it.
69pub fn run_py_pipe(
70    pipe: PyPipe,
71    func: ResolvableFunction,
72    args: Vec<PyTree<RValue>>,
73    kwargs: HashMap<String, PyTree<RValue>>,
74) -> PyResult<()> {
75    Python::with_gil(|py| {
76        let pipe_obj: Py<PyPipe> = Py::new(py, pipe)?;
77        let func = func.resolve(py)?;
78        let mut py_args = vec![pipe_obj.into_bound(py).into_any()];
79        py_args.extend(
80            args.into_iter()
81                .map(|a| a.into_pyobject(py))
82                .collect::<Result<Vec<_>, _>>()?,
83        );
84        func.call(PyTuple::new(py, py_args)?, Some(&kwargs.into_pyobject(py)?))?;
85        Ok(())
86    })
87}
88
89#[cfg(test)]
90mod tests {
91    use std::assert_matches::assert_matches;
92    use std::collections::HashMap;
93
94    use anyhow::Result;
95    use futures::try_join;
96    use indoc::indoc;
97    use pyo3::Python;
98    use pyo3::ffi::c_str;
99    use pyo3::types::PyModule;
100    use timed_test::async_timed_test;
101    use torch_sys::RValue;
102
103    use super::PyPipe;
104    use super::run_py_pipe;
105    use crate::pipe::AsyncPipe;
106    use crate::pipe::create_local_pipe;
107
108    #[async_timed_test(timeout_secs = 60)]
109    async fn test_py_pipe() -> Result<()> {
110        pyo3::prepare_freethreaded_python();
111        // We need to load torch to initialize some internal structures used by
112        // the FFI funcs we use to convert ivalues to/from py objects.
113        Python::with_gil(|py| py.run(c_str!("import torch"), None, None))?;
114
115        // Create the Python function that runs as the pipe handler.
116        Python::with_gil(|py| {
117            let _mod = PyModule::from_code(
118                py,
119                c_str!(indoc! {r#"
120                    def func(pipe):
121                        val = pipe.recv()
122                        pipe.send(val)
123                "#}),
124                c_str!("test_helpers.py"),
125                c_str!("test_helpers"),
126            )?;
127            anyhow::Ok(())
128        })?;
129
130        let (mut client, server) = create_local_pipe();
131        let ((), ()) = try_join!(
132            // Startup the pipe server side.
133            async move {
134                tokio::task::spawn_blocking(move || {
135                    run_py_pipe(
136                        PyPipe::new(
137                            Box::new(server),
138                            HashMap::new(),
139                            HashMap::new(),
140                            false, // allow_unsafe_obj_conversion
141                        ),
142                        "test_helpers.func".into(),
143                        vec![],
144                        HashMap::new(),
145                    )
146                })
147                .await??;
148                anyhow::Ok(())
149            },
150            // Run the pipe client side.
151            async move {
152                client.send(RValue::Int(3).into()).await?;
153                let val = client.recv().await?;
154                assert_matches!(val.into_leaf().unwrap(), RValue::Int(3));
155                anyhow::Ok(())
156            },
157        )?;
158
159        Ok(())
160    }
161}