1use hyperactor::Named;
10use pyo3::Bound;
11use pyo3::FromPyObject;
12use pyo3::IntoPyObject;
13use pyo3::IntoPyObjectExt;
14use pyo3::PyAny;
15use pyo3::PyErr;
16use pyo3::Python;
17use pyo3::exceptions::PyRuntimeError;
18use pyo3::exceptions::PyTypeError;
19use pyo3::prelude::PyAnyMethods;
20use pyo3::prelude::PyResult;
21use pyo3::prelude::PyTupleMethods;
22use pyo3::types::PyBool;
23use pyo3::types::PyBoolMethods;
24use pyo3::types::PyList;
25use pyo3::types::PyModule;
26use pyo3::types::PyTuple;
27use serde::Deserialize;
28use serde::Serialize;
29
30use crate::PickledPyObject;
31use crate::python::TryIntoPyObjectUnsafe;
32
33#[derive(Clone, Debug, Serialize, Deserialize)]
34pub enum TreeSpec {
35 Leaf,
36 Tree(PickledPyObject),
37}
38
39#[derive(Clone, Debug, Serialize, Deserialize, Named)]
45pub struct PyTree<T> {
46 treespec: TreeSpec,
51 leaves: Vec<T>,
53}
54
55impl<T> PyTree<T> {
56 pub fn is_leaf(&self) -> bool {
57 matches!(self.treespec, TreeSpec::Leaf)
58 }
59
60 pub fn into_leaf(mut self) -> Option<T> {
61 if self.is_leaf() {
62 self.leaves.pop()
63 } else {
64 None
65 }
66 }
67
68 pub fn leaves(&self) -> &[T] {
69 &self.leaves
70 }
71
72 pub fn into_leaves(self) -> Vec<T> {
73 self.leaves
74 }
75
76 pub fn for_each<F>(&self, func: F)
77 where
78 F: FnMut(&T),
79 {
80 self.leaves.iter().for_each(func)
81 }
82
83 pub fn iter<'a>(&'a self) -> std::slice::Iter<'a, T> {
84 self.leaves.iter()
85 }
86
87 pub fn iter_mut<'a>(&'a mut self) -> std::slice::IterMut<'a, T> {
88 self.leaves.iter_mut()
89 }
90
91 pub fn into_map<U, F>(self, func: F) -> PyTree<U>
93 where
94 F: FnMut(T) -> U,
95 {
96 PyTree {
97 treespec: self.treespec,
98 leaves: self.leaves.into_iter().map(func).collect(),
99 }
100 }
101
102 pub fn try_map<U, F, E>(&self, mut func: F) -> Result<PyTree<U>, E>
105 where
106 F: FnMut(&T) -> Result<U, E>,
107 {
108 let mut leaves = vec![];
109 for leaf in self.leaves.iter() {
110 leaves.push(func(leaf)?);
111 }
112 Ok(PyTree {
113 treespec: self.treespec.clone(),
114 leaves,
115 })
116 }
117
118 pub fn try_into_map<U, F, E>(self, mut func: F) -> Result<PyTree<U>, E>
120 where
121 F: FnMut(T) -> Result<U, E>,
122 {
123 let mut leaves = vec![];
124 for leaf in self.leaves.into_iter() {
125 leaves.push(func(leaf)?);
126 }
127 Ok(PyTree {
128 treespec: self.treespec,
129 leaves,
130 })
131 }
132
133 fn unflatten_impl<'a>(
134 py: Python<'a>,
135 treespec: &TreeSpec,
136 mut leaves: impl Iterator<Item = PyResult<Bound<'a, PyAny>>>,
137 ) -> PyResult<Bound<'a, PyAny>> {
138 if let TreeSpec::Tree(tree) = treespec {
139 let module = PyModule::import(py, "torch.utils._pytree")?;
141 let function = module.getattr("tree_unflatten")?;
142 let leaves = leaves.collect::<Result<Vec<_>, _>>()?;
143 let leaves = PyList::new(py, &leaves)?;
144 let args = PyTuple::new(py, vec![leaves.as_any(), &tree.unpickle(py)?])?;
145 let result = function.call(args, None)?;
146 Ok(result)
147 } else {
148 leaves.next().ok_or(PyRuntimeError::new_err(
149 "Pytree leaf unexpectedly had no value",
150 ))?
151 }
152 }
153}
154
155impl<T> From<T> for PyTree<T> {
156 fn from(value: T) -> Self {
157 PyTree {
158 treespec: TreeSpec::Leaf,
159 leaves: vec![value],
160 }
161 }
162}
163
164impl<'py, T> IntoPyObject<'py> for PyTree<T>
165where
166 T: IntoPyObject<'py>,
167{
168 type Target = PyAny;
169 type Output = Bound<'py, Self::Target>;
170 type Error = PyErr;
171
172 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
173 PyTree::<T>::unflatten_impl(
174 py,
175 &self.treespec,
176 self.leaves.into_iter().map(|l| l.into_bound_py_any(py)),
177 )
178 }
179}
180
181impl<'a, 'py, T> IntoPyObject<'py> for &'a PyTree<T>
182where
183 &'a T: IntoPyObject<'py>,
184 T: 'a,
185{
186 type Target = PyAny;
187 type Output = Bound<'py, Self::Target>;
188 type Error = PyErr;
189
190 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
191 PyTree::<T>::unflatten_impl(
192 py,
193 &self.treespec,
194 self.leaves.iter().map(|l| l.into_bound_py_any(py)),
195 )
196 }
197}
198
199impl<'a, 'py, T> TryIntoPyObjectUnsafe<'py, PyAny> for &'a PyTree<T>
201where
202 &'a T: TryIntoPyObjectUnsafe<'py, PyAny>,
203 T: 'a,
204{
205 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
206 PyTree::<T>::unflatten_impl(
207 py,
208 &self.treespec,
209 self.leaves
210 .iter()
211 .map(|l| unsafe { l.try_to_object_unsafe(py) }),
214 )
215 }
216}
217
218impl<'a, T: FromPyObject<'a>> PyTree<T> {
219 pub fn flatten(tree: &Bound<'a, PyAny>) -> PyResult<Self> {
220 let py = tree.py();
221
222 let pytree_module = PyModule::import(py, "torch.utils._pytree")?;
224 let tree_flatten = pytree_module.getattr("tree_flatten")?;
225 let res = tree_flatten.call1((tree,))?;
226
227 let (leaves, treespec) = match res.downcast::<PyTuple>()?.as_slice() {
229 [leaves, treespec] => {
230 let mut out = vec![];
231 for leaf in leaves.try_iter()? {
232 out.push(T::extract_bound(&leaf?)?);
233 }
234 (out, treespec)
235 }
236 _ => return Err(PyTypeError::new_err("unexpected result from tree_flatten")),
237 };
238
239 if treespec
240 .call_method0("is_leaf")?
241 .downcast::<PyBool>()?
242 .is_true()
243 {
244 Ok(Self {
245 treespec: TreeSpec::Leaf,
246 leaves,
247 })
248 } else {
249 Ok(Self {
250 treespec: TreeSpec::Tree(PickledPyObject::pickle(treespec)?),
251 leaves,
252 })
253 }
254 }
255}
256
257impl<'a, T: FromPyObject<'a>> FromPyObject<'a> for PyTree<T> {
259 fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
260 Self::flatten(ob)
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use anyhow::Result;
267 use pyo3::IntoPyObject;
268 use pyo3::Python;
269 use pyo3::ffi::c_str;
270 use pyo3::py_run;
271
272 use super::PyTree;
273
274 #[test]
275 fn flatten_unflatten() -> Result<()> {
276 pyo3::prepare_freethreaded_python();
277 Python::with_gil(|py| {
278 let tree = py.eval(c_str!("[1, 2]"), None, None)?;
279 let tree: PyTree<u64> = PyTree::flatten(&tree)?;
280 assert_eq!(tree.leaves, vec![1u64, 2u64]);
281 let list = tree.into_pyobject(py)?;
282 py_run!(py, list, "assert list == [1, 2]");
283 anyhow::Ok(())
284 })?;
285 Ok(())
286 }
287
288 #[test]
289 fn try_map() -> Result<()> {
290 pyo3::prepare_freethreaded_python();
291 Python::with_gil(|py| {
292 let tree = py.eval(c_str!("[1, 2]"), None, None)?;
293 let tree: PyTree<u64> = PyTree::flatten(&tree)?;
294 let tree = tree.try_map(|v| anyhow::Ok(v + 1))?;
295 assert_eq!(tree.leaves, vec![2u64, 3u64]);
296 anyhow::Ok(())
297 })?;
298 Ok(())
299 }
300}