1use pyo3::Bound;
10use pyo3::FromPyObject;
11use pyo3::IntoPyObject;
12use pyo3::IntoPyObjectExt;
13use pyo3::PyAny;
14use pyo3::PyErr;
15use pyo3::Python;
16use pyo3::exceptions::PyRuntimeError;
17use pyo3::exceptions::PyTypeError;
18use pyo3::prelude::PyAnyMethods;
19use pyo3::prelude::PyResult;
20use pyo3::prelude::PyTupleMethods;
21use pyo3::types::PyBool;
22use pyo3::types::PyBoolMethods;
23use pyo3::types::PyList;
24use pyo3::types::PyModule;
25use pyo3::types::PyTuple;
26use serde::Deserialize;
27use serde::Serialize;
28
29use crate::PickledPyObject;
30use crate::python::TryIntoPyObjectUnsafe;
31
32#[derive(Clone, Debug, Serialize, Deserialize)]
33pub enum TreeSpec {
34 Leaf,
35 Tree(PickledPyObject),
36}
37
38#[derive(Clone, Debug, Serialize, Deserialize)]
44pub struct PyTree<T> {
45 treespec: TreeSpec,
50 leaves: Vec<T>,
52}
53
54impl<T> PyTree<T> {
55 pub fn is_leaf(&self) -> bool {
56 matches!(self.treespec, TreeSpec::Leaf)
57 }
58
59 pub fn into_leaf(mut self) -> Option<T> {
60 if self.is_leaf() {
61 self.leaves.pop()
62 } else {
63 None
64 }
65 }
66
67 pub fn leaves(&self) -> &[T] {
68 &self.leaves
69 }
70
71 pub fn into_leaves(self) -> Vec<T> {
72 self.leaves
73 }
74
75 pub fn for_each<F>(&self, func: F)
76 where
77 F: FnMut(&T),
78 {
79 self.leaves.iter().for_each(func)
80 }
81
82 pub fn iter<'a>(&'a self) -> std::slice::Iter<'a, T> {
83 self.leaves.iter()
84 }
85
86 pub fn iter_mut<'a>(&'a mut self) -> std::slice::IterMut<'a, T> {
87 self.leaves.iter_mut()
88 }
89
90 pub fn into_map<U, F>(self, func: F) -> PyTree<U>
92 where
93 F: FnMut(T) -> U,
94 {
95 PyTree {
96 treespec: self.treespec,
97 leaves: self.leaves.into_iter().map(func).collect(),
98 }
99 }
100
101 pub fn try_map<U, F, E>(&self, mut func: F) -> Result<PyTree<U>, E>
104 where
105 F: FnMut(&T) -> Result<U, E>,
106 {
107 let mut leaves = vec![];
108 for leaf in self.leaves.iter() {
109 leaves.push(func(leaf)?);
110 }
111 Ok(PyTree {
112 treespec: self.treespec.clone(),
113 leaves,
114 })
115 }
116
117 pub fn try_into_map<U, F, E>(self, mut func: F) -> Result<PyTree<U>, E>
119 where
120 F: FnMut(T) -> Result<U, E>,
121 {
122 let mut leaves = vec![];
123 for leaf in self.leaves.into_iter() {
124 leaves.push(func(leaf)?);
125 }
126 Ok(PyTree {
127 treespec: self.treespec,
128 leaves,
129 })
130 }
131
132 fn unflatten_impl<'a>(
133 py: Python<'a>,
134 treespec: &TreeSpec,
135 mut leaves: impl Iterator<Item = PyResult<Bound<'a, PyAny>>>,
136 ) -> PyResult<Bound<'a, PyAny>> {
137 if let TreeSpec::Tree(tree) = treespec {
138 let module = PyModule::import(py, "torch.utils._pytree")?;
140 let function = module.getattr("tree_unflatten")?;
141 let leaves = leaves.collect::<Result<Vec<_>, _>>()?;
142 let leaves = PyList::new(py, &leaves)?;
143 let args = PyTuple::new(py, vec![leaves.as_any(), &tree.unpickle(py)?])?;
144 let result = function.call(args, None)?;
145 Ok(result)
146 } else {
147 leaves.next().ok_or(PyRuntimeError::new_err(
148 "Pytree leaf unexpectedly had no value",
149 ))?
150 }
151 }
152}
153
154impl<T> From<T> for PyTree<T> {
155 fn from(value: T) -> Self {
156 PyTree {
157 treespec: TreeSpec::Leaf,
158 leaves: vec![value],
159 }
160 }
161}
162
163impl<'py, T> IntoPyObject<'py> for PyTree<T>
164where
165 T: IntoPyObject<'py>,
166{
167 type Target = PyAny;
168 type Output = Bound<'py, Self::Target>;
169 type Error = PyErr;
170
171 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
172 PyTree::<T>::unflatten_impl(
173 py,
174 &self.treespec,
175 self.leaves.into_iter().map(|l| l.into_bound_py_any(py)),
176 )
177 }
178}
179
180impl<'a, 'py, T> IntoPyObject<'py> for &'a PyTree<T>
181where
182 &'a T: IntoPyObject<'py>,
183 T: 'a,
184{
185 type Target = PyAny;
186 type Output = Bound<'py, Self::Target>;
187 type Error = PyErr;
188
189 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
190 PyTree::<T>::unflatten_impl(
191 py,
192 &self.treespec,
193 self.leaves.iter().map(|l| l.into_bound_py_any(py)),
194 )
195 }
196}
197
198impl<'a, 'py, T> TryIntoPyObjectUnsafe<'py, PyAny> for &'a PyTree<T>
200where
201 &'a T: TryIntoPyObjectUnsafe<'py, PyAny>,
202 T: 'a,
203{
204 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
205 PyTree::<T>::unflatten_impl(
206 py,
207 &self.treespec,
208 self.leaves
209 .iter()
210 .map(|l| unsafe { l.try_to_object_unsafe(py) }),
213 )
214 }
215}
216
217impl<'a, T: FromPyObject<'a>> PyTree<T> {
218 pub fn flatten(tree: &Bound<'a, PyAny>) -> PyResult<Self> {
219 let py = tree.py();
220
221 let pytree_module = PyModule::import(py, "torch.utils._pytree")?;
223 let tree_flatten = pytree_module.getattr("tree_flatten")?;
224 let res = tree_flatten.call1((tree,))?;
225
226 let (leaves, treespec) = match res.downcast::<PyTuple>()?.as_slice() {
228 [leaves, treespec] => {
229 let mut out = vec![];
230 for leaf in leaves.try_iter()? {
231 out.push(T::extract_bound(&leaf?)?);
232 }
233 (out, treespec)
234 }
235 _ => return Err(PyTypeError::new_err("unexpected result from tree_flatten")),
236 };
237
238 if treespec
239 .call_method0("is_leaf")?
240 .downcast::<PyBool>()?
241 .is_true()
242 {
243 Ok(Self {
244 treespec: TreeSpec::Leaf,
245 leaves,
246 })
247 } else {
248 Ok(Self {
249 treespec: TreeSpec::Tree(PickledPyObject::pickle(treespec)?),
250 leaves,
251 })
252 }
253 }
254}
255
256impl<'a, T: FromPyObject<'a>> FromPyObject<'a> for PyTree<T> {
258 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
259 Self::flatten(ob)
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use anyhow::Result;
266 use pyo3::IntoPyObject;
267 use pyo3::Python;
268 use pyo3::ffi::c_str;
269 use pyo3::py_run;
270
271 use super::PyTree;
272
273 #[test]
274 fn flatten_unflatten() -> Result<()> {
275 pyo3::prepare_freethreaded_python();
276 Python::with_gil(|py| {
277 let tree = py.eval(c_str!("[1, 2]"), None, None)?;
278 let tree: PyTree<u64> = PyTree::flatten(&tree)?;
279 assert_eq!(tree.leaves, vec![1u64, 2u64]);
280 let list = tree.into_pyobject(py)?;
281 py_run!(py, list, "assert list == [1, 2]");
282 anyhow::Ok(())
283 })?;
284 Ok(())
285 }
286
287 #[test]
288 fn try_map() -> Result<()> {
289 pyo3::prepare_freethreaded_python();
290 Python::with_gil(|py| {
291 let tree = py.eval(c_str!("[1, 2]"), None, None)?;
292 let tree: PyTree<u64> = PyTree::flatten(&tree)?;
293 let tree = tree.try_map(|v| anyhow::Ok(v + 1))?;
294 assert_eq!(tree.leaves, vec![2u64, 3u64]);
295 anyhow::Ok(())
296 })?;
297 Ok(())
298 }
299}