monarch_types/
pytree.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use hyperactor::Named;
10use pyo3::Bound;
11use pyo3::FromPyObject;
12use pyo3::IntoPyObject;
13use pyo3::IntoPyObjectExt;
14use pyo3::PyAny;
15use pyo3::PyErr;
16use pyo3::Python;
17use pyo3::exceptions::PyRuntimeError;
18use pyo3::exceptions::PyTypeError;
19use pyo3::prelude::PyAnyMethods;
20use pyo3::prelude::PyResult;
21use pyo3::prelude::PyTupleMethods;
22use pyo3::types::PyBool;
23use pyo3::types::PyBoolMethods;
24use pyo3::types::PyList;
25use pyo3::types::PyModule;
26use pyo3::types::PyTuple;
27use serde::Deserialize;
28use serde::Serialize;
29
30use crate::PickledPyObject;
31use crate::python::TryIntoPyObjectUnsafe;
32
33#[derive(Clone, Debug, Serialize, Deserialize)]
34pub enum TreeSpec {
35    Leaf,
36    Tree(PickledPyObject),
37}
38
39/// A Rust wrapper around A PyTorch pytree, which can be serialized and sent
40/// across the wire.
41/// https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
42// NOTE: We have runtime deps on torch's pytree module and `pickle`, which the
43// user must ensure is available.
44#[derive(Clone, Debug, Serialize, Deserialize, Named)]
45pub struct PyTree<T> {
46    /// A wrapper around the tree spec.
47    // NOTE: This is currently just the pickled bytes of the tree spec.  We
48    // could also just deserialize to a `PyObject`, but this would mean
49    // acquiring the GIL when deserializing the message.
50    treespec: TreeSpec,
51    /// The deserialized leaves of the pytree.
52    leaves: Vec<T>,
53}
54
55impl<T> PyTree<T> {
56    pub fn is_leaf(&self) -> bool {
57        matches!(self.treespec, TreeSpec::Leaf)
58    }
59
60    pub fn into_leaf(mut self) -> Option<T> {
61        if self.is_leaf() {
62            self.leaves.pop()
63        } else {
64            None
65        }
66    }
67
68    pub fn leaves(&self) -> &[T] {
69        &self.leaves
70    }
71
72    pub fn into_leaves(self) -> Vec<T> {
73        self.leaves
74    }
75
76    pub fn for_each<F>(&self, func: F)
77    where
78        F: FnMut(&T),
79    {
80        self.leaves.iter().for_each(func)
81    }
82
83    pub fn iter<'a>(&'a self) -> std::slice::Iter<'a, T> {
84        self.leaves.iter()
85    }
86
87    pub fn iter_mut<'a>(&'a mut self) -> std::slice::IterMut<'a, T> {
88        self.leaves.iter_mut()
89    }
90
91    /// Map leaf values with the given func.
92    pub fn into_map<U, F>(self, func: F) -> PyTree<U>
93    where
94        F: FnMut(T) -> U,
95    {
96        PyTree {
97            treespec: self.treespec,
98            leaves: self.leaves.into_iter().map(func).collect(),
99        }
100    }
101
102    /// Map the leaves of the pytree with the given fallible callback.
103    // NOTE: This ends up copying the serialized treespec.
104    pub fn try_map<U, F, E>(&self, mut func: F) -> Result<PyTree<U>, E>
105    where
106        F: FnMut(&T) -> Result<U, E>,
107    {
108        let mut leaves = vec![];
109        for leaf in self.leaves.iter() {
110            leaves.push(func(leaf)?);
111        }
112        Ok(PyTree {
113            treespec: self.treespec.clone(),
114            leaves,
115        })
116    }
117
118    /// Map the leaves of the pytree with the given fallible callback.
119    pub fn try_into_map<U, F, E>(self, mut func: F) -> Result<PyTree<U>, E>
120    where
121        F: FnMut(T) -> Result<U, E>,
122    {
123        let mut leaves = vec![];
124        for leaf in self.leaves.into_iter() {
125            leaves.push(func(leaf)?);
126        }
127        Ok(PyTree {
128            treespec: self.treespec,
129            leaves,
130        })
131    }
132
133    fn unflatten_impl<'a>(
134        py: Python<'a>,
135        treespec: &TreeSpec,
136        mut leaves: impl Iterator<Item = PyResult<Bound<'a, PyAny>>>,
137    ) -> PyResult<Bound<'a, PyAny>> {
138        if let TreeSpec::Tree(tree) = treespec {
139            // Call into pytorch's unflatten.
140            let module = PyModule::import(py, "torch.utils._pytree")?;
141            let function = module.getattr("tree_unflatten")?;
142            let leaves = leaves.collect::<Result<Vec<_>, _>>()?;
143            let leaves = PyList::new(py, &leaves)?;
144            let args = PyTuple::new(py, vec![leaves.as_any(), &tree.unpickle(py)?])?;
145            let result = function.call(args, None)?;
146            Ok(result)
147        } else {
148            leaves.next().ok_or(PyRuntimeError::new_err(
149                "Pytree leaf unexpectedly had no value",
150            ))?
151        }
152    }
153}
154
155impl<T> From<T> for PyTree<T> {
156    fn from(value: T) -> Self {
157        PyTree {
158            treespec: TreeSpec::Leaf,
159            leaves: vec![value],
160        }
161    }
162}
163
164impl<'py, T> IntoPyObject<'py> for PyTree<T>
165where
166    T: IntoPyObject<'py>,
167{
168    type Target = PyAny;
169    type Output = Bound<'py, Self::Target>;
170    type Error = PyErr;
171
172    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
173        PyTree::<T>::unflatten_impl(
174            py,
175            &self.treespec,
176            self.leaves.into_iter().map(|l| l.into_bound_py_any(py)),
177        )
178    }
179}
180
181impl<'a, 'py, T> IntoPyObject<'py> for &'a PyTree<T>
182where
183    &'a T: IntoPyObject<'py>,
184    T: 'a,
185{
186    type Target = PyAny;
187    type Output = Bound<'py, Self::Target>;
188    type Error = PyErr;
189
190    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
191        PyTree::<T>::unflatten_impl(
192            py,
193            &self.treespec,
194            self.leaves.iter().map(|l| l.into_bound_py_any(py)),
195        )
196    }
197}
198
199/// Serialize into a `PyObject`.
200impl<'a, 'py, T> TryIntoPyObjectUnsafe<'py, PyAny> for &'a PyTree<T>
201where
202    &'a T: TryIntoPyObjectUnsafe<'py, PyAny>,
203    T: 'a,
204{
205    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
206        PyTree::<T>::unflatten_impl(
207            py,
208            &self.treespec,
209            self.leaves
210                .iter()
211                // SAFETY: Safety requirements are propagated via the `unsafe`
212                // tag on this method.
213                .map(|l| unsafe { l.try_to_object_unsafe(py) }),
214        )
215    }
216}
217
218impl<'a, T: FromPyObject<'a>> PyTree<T> {
219    pub fn flatten(tree: &Bound<'a, PyAny>) -> PyResult<Self> {
220        let py = tree.py();
221
222        // Call into pytorch's flatten.
223        let pytree_module = PyModule::import(py, "torch.utils._pytree")?;
224        let tree_flatten = pytree_module.getattr("tree_flatten")?;
225        let res = tree_flatten.call1((tree,))?;
226
227        // Convert leaves to Rust objects.
228        let (leaves, treespec) = match res.downcast::<PyTuple>()?.as_slice() {
229            [leaves, treespec] => {
230                let mut out = vec![];
231                for leaf in leaves.try_iter()? {
232                    out.push(T::extract_bound(&leaf?)?);
233                }
234                (out, treespec)
235            }
236            _ => return Err(PyTypeError::new_err("unexpected result from tree_flatten")),
237        };
238
239        if treespec
240            .call_method0("is_leaf")?
241            .downcast::<PyBool>()?
242            .is_true()
243        {
244            Ok(Self {
245                treespec: TreeSpec::Leaf,
246                leaves,
247            })
248        } else {
249            Ok(Self {
250                treespec: TreeSpec::Tree(PickledPyObject::pickle(treespec)?),
251                leaves,
252            })
253        }
254    }
255}
256
257/// Deserialize from a `PyObject`.
258impl<'a, T: FromPyObject<'a>> FromPyObject<'a> for PyTree<T> {
259    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
260        Self::flatten(ob)
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use anyhow::Result;
267    use pyo3::IntoPyObject;
268    use pyo3::Python;
269    use pyo3::ffi::c_str;
270    use pyo3::py_run;
271
272    use super::PyTree;
273
274    #[test]
275    fn flatten_unflatten() -> Result<()> {
276        pyo3::prepare_freethreaded_python();
277        Python::with_gil(|py| {
278            let tree = py.eval(c_str!("[1, 2]"), None, None)?;
279            let tree: PyTree<u64> = PyTree::flatten(&tree)?;
280            assert_eq!(tree.leaves, vec![1u64, 2u64]);
281            let list = tree.into_pyobject(py)?;
282            py_run!(py, list, "assert list == [1, 2]");
283            anyhow::Ok(())
284        })?;
285        Ok(())
286    }
287
288    #[test]
289    fn try_map() -> Result<()> {
290        pyo3::prepare_freethreaded_python();
291        Python::with_gil(|py| {
292            let tree = py.eval(c_str!("[1, 2]"), None, None)?;
293            let tree: PyTree<u64> = PyTree::flatten(&tree)?;
294            let tree = tree.try_map(|v| anyhow::Ok(v + 1))?;
295            assert_eq!(tree.leaves, vec![2u64, 3u64]);
296            anyhow::Ok(())
297        })?;
298        Ok(())
299    }
300}