monarch_types/
pytree.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use pyo3::Bound;
10use pyo3::FromPyObject;
11use pyo3::IntoPyObject;
12use pyo3::IntoPyObjectExt;
13use pyo3::PyAny;
14use pyo3::PyErr;
15use pyo3::Python;
16use pyo3::exceptions::PyRuntimeError;
17use pyo3::exceptions::PyTypeError;
18use pyo3::prelude::PyAnyMethods;
19use pyo3::prelude::PyResult;
20use pyo3::prelude::PyTupleMethods;
21use pyo3::types::PyBool;
22use pyo3::types::PyBoolMethods;
23use pyo3::types::PyList;
24use pyo3::types::PyModule;
25use pyo3::types::PyTuple;
26use serde::Deserialize;
27use serde::Serialize;
28
29use crate::PickledPyObject;
30use crate::python::TryIntoPyObjectUnsafe;
31
32#[derive(Clone, Debug, Serialize, Deserialize)]
33pub enum TreeSpec {
34    Leaf,
35    Tree(PickledPyObject),
36}
37
38/// A Rust wrapper around A PyTorch pytree, which can be serialized and sent
39/// across the wire.
40/// https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
41// NOTE: We have runtime deps on torch's pytree module and `pickle`, which the
42// user must ensure is available.
43#[derive(Clone, Debug, Serialize, Deserialize)]
44pub struct PyTree<T> {
45    /// A wrapper around the tree spec.
46    // NOTE: This is currently just the pickled bytes of the tree spec.  We
47    // could also just deserialize to a `PyObject`, but this would mean
48    // acquiring the GIL when deserializing the message.
49    treespec: TreeSpec,
50    /// The deserialized leaves of the pytree.
51    leaves: Vec<T>,
52}
53
54impl<T> PyTree<T> {
55    pub fn is_leaf(&self) -> bool {
56        matches!(self.treespec, TreeSpec::Leaf)
57    }
58
59    pub fn into_leaf(mut self) -> Option<T> {
60        if self.is_leaf() {
61            self.leaves.pop()
62        } else {
63            None
64        }
65    }
66
67    pub fn leaves(&self) -> &[T] {
68        &self.leaves
69    }
70
71    pub fn into_leaves(self) -> Vec<T> {
72        self.leaves
73    }
74
75    pub fn for_each<F>(&self, func: F)
76    where
77        F: FnMut(&T),
78    {
79        self.leaves.iter().for_each(func)
80    }
81
82    pub fn iter<'a>(&'a self) -> std::slice::Iter<'a, T> {
83        self.leaves.iter()
84    }
85
86    pub fn iter_mut<'a>(&'a mut self) -> std::slice::IterMut<'a, T> {
87        self.leaves.iter_mut()
88    }
89
90    /// Map leaf values with the given func.
91    pub fn into_map<U, F>(self, func: F) -> PyTree<U>
92    where
93        F: FnMut(T) -> U,
94    {
95        PyTree {
96            treespec: self.treespec,
97            leaves: self.leaves.into_iter().map(func).collect(),
98        }
99    }
100
101    /// Map the leaves of the pytree with the given fallible callback.
102    // NOTE: This ends up copying the serialized treespec.
103    pub fn try_map<U, F, E>(&self, mut func: F) -> Result<PyTree<U>, E>
104    where
105        F: FnMut(&T) -> Result<U, E>,
106    {
107        let mut leaves = vec![];
108        for leaf in self.leaves.iter() {
109            leaves.push(func(leaf)?);
110        }
111        Ok(PyTree {
112            treespec: self.treespec.clone(),
113            leaves,
114        })
115    }
116
117    /// Map the leaves of the pytree with the given fallible callback.
118    pub fn try_into_map<U, F, E>(self, mut func: F) -> Result<PyTree<U>, E>
119    where
120        F: FnMut(T) -> Result<U, E>,
121    {
122        let mut leaves = vec![];
123        for leaf in self.leaves.into_iter() {
124            leaves.push(func(leaf)?);
125        }
126        Ok(PyTree {
127            treespec: self.treespec,
128            leaves,
129        })
130    }
131
132    fn unflatten_impl<'a>(
133        py: Python<'a>,
134        treespec: &TreeSpec,
135        mut leaves: impl Iterator<Item = PyResult<Bound<'a, PyAny>>>,
136    ) -> PyResult<Bound<'a, PyAny>> {
137        if let TreeSpec::Tree(tree) = treespec {
138            // Call into pytorch's unflatten.
139            let module = PyModule::import(py, "torch.utils._pytree")?;
140            let function = module.getattr("tree_unflatten")?;
141            let leaves = leaves.collect::<Result<Vec<_>, _>>()?;
142            let leaves = PyList::new(py, &leaves)?;
143            let args = PyTuple::new(py, vec![leaves.as_any(), &tree.unpickle(py)?])?;
144            let result = function.call(args, None)?;
145            Ok(result)
146        } else {
147            leaves.next().ok_or(PyRuntimeError::new_err(
148                "Pytree leaf unexpectedly had no value",
149            ))?
150        }
151    }
152}
153
154impl<T> From<T> for PyTree<T> {
155    fn from(value: T) -> Self {
156        PyTree {
157            treespec: TreeSpec::Leaf,
158            leaves: vec![value],
159        }
160    }
161}
162
163impl<'py, T> IntoPyObject<'py> for PyTree<T>
164where
165    T: IntoPyObject<'py>,
166{
167    type Target = PyAny;
168    type Output = Bound<'py, Self::Target>;
169    type Error = PyErr;
170
171    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
172        PyTree::<T>::unflatten_impl(
173            py,
174            &self.treespec,
175            self.leaves.into_iter().map(|l| l.into_bound_py_any(py)),
176        )
177    }
178}
179
180impl<'a, 'py, T> IntoPyObject<'py> for &'a PyTree<T>
181where
182    &'a T: IntoPyObject<'py>,
183    T: 'a,
184{
185    type Target = PyAny;
186    type Output = Bound<'py, Self::Target>;
187    type Error = PyErr;
188
189    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
190        PyTree::<T>::unflatten_impl(
191            py,
192            &self.treespec,
193            self.leaves.iter().map(|l| l.into_bound_py_any(py)),
194        )
195    }
196}
197
198/// Serialize into a `PyObject`.
199impl<'a, 'py, T> TryIntoPyObjectUnsafe<'py, PyAny> for &'a PyTree<T>
200where
201    &'a T: TryIntoPyObjectUnsafe<'py, PyAny>,
202    T: 'a,
203{
204    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
205        PyTree::<T>::unflatten_impl(
206            py,
207            &self.treespec,
208            self.leaves
209                .iter()
210                // SAFETY: Safety requirements are propagated via the `unsafe`
211                // tag on this method.
212                .map(|l| unsafe { l.try_to_object_unsafe(py) }),
213        )
214    }
215}
216
217impl<'a, T: FromPyObject<'a>> PyTree<T> {
218    pub fn flatten(tree: &Bound<'a, PyAny>) -> PyResult<Self> {
219        let py = tree.py();
220
221        // Call into pytorch's flatten.
222        let pytree_module = PyModule::import(py, "torch.utils._pytree")?;
223        let tree_flatten = pytree_module.getattr("tree_flatten")?;
224        let res = tree_flatten.call1((tree,))?;
225
226        // Convert leaves to Rust objects.
227        let (leaves, treespec) = match res.downcast::<PyTuple>()?.as_slice() {
228            [leaves, treespec] => {
229                let mut out = vec![];
230                for leaf in leaves.try_iter()? {
231                    out.push(T::extract_bound(&leaf?)?);
232                }
233                (out, treespec)
234            }
235            _ => return Err(PyTypeError::new_err("unexpected result from tree_flatten")),
236        };
237
238        if treespec
239            .call_method0("is_leaf")?
240            .downcast::<PyBool>()?
241            .is_true()
242        {
243            Ok(Self {
244                treespec: TreeSpec::Leaf,
245                leaves,
246            })
247        } else {
248            Ok(Self {
249                treespec: TreeSpec::Tree(PickledPyObject::pickle(treespec)?),
250                leaves,
251            })
252        }
253    }
254}
255
256/// Deserialize from a `PyObject`.
257impl<'a, T: FromPyObject<'a>> FromPyObject<'a> for PyTree<T> {
258    fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
259        Self::flatten(ob)
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use anyhow::Result;
266    use pyo3::IntoPyObject;
267    use pyo3::Python;
268    use pyo3::ffi::c_str;
269    use pyo3::py_run;
270
271    use super::PyTree;
272
273    #[test]
274    fn flatten_unflatten() -> Result<()> {
275        pyo3::prepare_freethreaded_python();
276        Python::with_gil(|py| {
277            let tree = py.eval(c_str!("[1, 2]"), None, None)?;
278            let tree: PyTree<u64> = PyTree::flatten(&tree)?;
279            assert_eq!(tree.leaves, vec![1u64, 2u64]);
280            let list = tree.into_pyobject(py)?;
281            py_run!(py, list, "assert list == [1, 2]");
282            anyhow::Ok(())
283        })?;
284        Ok(())
285    }
286
287    #[test]
288    fn try_map() -> Result<()> {
289        pyo3::prepare_freethreaded_python();
290        Python::with_gil(|py| {
291            let tree = py.eval(c_str!("[1, 2]"), None, None)?;
292            let tree: PyTree<u64> = PyTree::flatten(&tree)?;
293            let tree = tree.try_map(|v| anyhow::Ok(v + 1))?;
294            assert_eq!(tree.leaves, vec![2u64, 3u64]);
295            anyhow::Ok(())
296        })?;
297        Ok(())
298    }
299}