monarch_hyperactor/
value_mesh.rs1use std::sync::Mutex;
10
11use hyperactor_mesh::ValueMesh;
12use ndslice::Extent;
13use ndslice::Region;
14use ndslice::view::BuildFromRegion;
15use ndslice::view::Ranked;
16use pyo3::exceptions::PyValueError;
17use pyo3::prelude::*;
18use pyo3::types::PyAny;
19use pyo3::types::PyList;
20use serde_multipart::Part;
21
22use crate::buffers::FrozenBuffer;
23use crate::pickle::unpickle;
24use crate::shape::PyShape;
25
26#[derive(Clone)]
31enum LazyPyObject {
32 Pickled(Part),
33 Unpickled(Py<PyAny>),
34}
35
36type LazyCell = Mutex<LazyPyObject>;
37
38impl LazyPyObject {
39 fn resolve(cell: &LazyCell, py: Python<'_>) -> PyResult<Py<PyAny>> {
42 let mut guard = cell.lock().unwrap();
43
44 match &*guard {
45 LazyPyObject::Unpickled(obj) => Ok(obj.clone_ref(py)),
46 LazyPyObject::Pickled(part) => {
47 let py_obj = unpickle(
48 py,
49 FrozenBuffer {
50 inner: part.clone().into_bytes(),
51 },
52 )?
53 .unbind();
54
55 *guard = LazyPyObject::Unpickled(py_obj.clone_ref(py));
56
57 Ok(py_obj)
58 }
59 }
60 }
61}
62
63fn compress(inner: &mut ValueMesh<LazyCell>) {
64 inner.compress_adjacent_in_place_by(|a, b| match (&*a.lock().unwrap(), &*b.lock().unwrap()) {
65 (LazyPyObject::Unpickled(a), LazyPyObject::Unpickled(b)) => a.as_ptr() == b.as_ptr(),
66 (LazyPyObject::Pickled(a), LazyPyObject::Pickled(b)) => a == b,
67 _ => false,
68 });
69}
70
71#[pyclass(name = "ValueMesh", module = "monarch._src.actor.actor_mesh")]
72pub struct PyValueMesh {
73 inner: ValueMesh<LazyCell>,
74}
75
76#[pymethods]
77impl PyValueMesh {
78 #[new]
80 fn new(_py: Python<'_>, shape: &PyShape, values: Bound<'_, PyList>) -> PyResult<Self> {
81 let s = shape.get_inner();
85 let region = Region::new(s.labels().to_vec(), s.slice().clone());
86 let vals: Vec<LazyCell> = values
87 .extract::<Vec<Py<PyAny>>>()?
88 .into_iter()
89 .map(|v| Mutex::new(LazyPyObject::Unpickled(v)))
90 .collect();
91
92 let mut inner =
93 <ValueMesh<LazyCell> as BuildFromRegion<LazyCell>>::build_dense(region, vals)
94 .map_err(|e| PyValueError::new_err(e.to_string()))?;
95
96 compress(&mut inner);
103
104 Ok(Self { inner })
105 }
106
107 fn __len__(&self) -> usize {
109 self.inner.region().num_ranks()
110 }
111
112 #[getter]
114 fn _shape(&self) -> PyShape {
115 PyShape::from(ndslice::Shape::from(self.inner.region().clone()))
116 }
117
118 fn values(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
120 let n = self.inner.region().num_ranks();
121 let mut vec: Vec<Py<PyAny>> = Vec::with_capacity(n);
122 for rank in 0..n {
123 vec.push(LazyPyObject::resolve(self.inner.get(rank).unwrap(), py)?);
124 }
125 Ok(PyList::new(py, vec)?.into())
126 }
127
128 fn get(&self, py: Python<'_>, rank: usize) -> PyResult<Py<PyAny>> {
130 let n = self.inner.region().num_ranks();
131 if rank >= n {
132 return Err(PyValueError::new_err(format!(
133 "index {} out of range (len={})",
134 rank, n
135 )));
136 }
137
138 LazyPyObject::resolve(self.inner.get(rank).unwrap(), py)
139 }
140
141 #[staticmethod]
143 fn from_indexed(
144 _py: Python<'_>,
145 shape: &PyShape,
146 pairs: Vec<(usize, Py<PyAny>)>,
147 ) -> PyResult<Self> {
148 let s = shape.get_inner();
150 let region = Region::new(s.labels().to_vec(), s.slice().clone());
151 let lazy_pairs: Vec<(usize, LazyCell)> = pairs
152 .into_iter()
153 .map(|(rank, obj)| (rank, Mutex::new(LazyPyObject::Unpickled(obj))))
154 .collect();
155 let mut inner = <ValueMesh<LazyCell> as ndslice::view::BuildFromRegionIndexed<
156 LazyCell,
157 >>::build_indexed(region, lazy_pairs)
158 .map_err(|e| PyValueError::new_err(e.to_string()))?;
159
160 compress(&mut inner);
167
168 Ok(Self { inner })
169 }
170}
171
172impl PyValueMesh {
173 pub fn build_from_parts(extent: &Extent, parts: Vec<Part>) -> PyResult<Self> {
176 let lazy_values: Vec<LazyCell> = parts
177 .into_iter()
178 .map(|p| Mutex::new(LazyPyObject::Pickled(p)))
179 .collect();
180 let mut inner = <ValueMesh<LazyCell> as BuildFromRegion<LazyCell>>::build_dense(
181 ndslice::View::region(extent),
182 lazy_values,
183 )
184 .map_err(|e| PyValueError::new_err(e.to_string()))?;
185 compress(&mut inner);
186
187 Ok(Self { inner })
188 }
189}
190
191#[pyfunction]
195fn _make_test_value_mesh(
196 labels: Vec<String>,
197 sizes: Vec<usize>,
198 values: Bound<'_, PyList>,
199) -> PyResult<PyValueMesh> {
200 let strides: Vec<usize> = {
201 let mut s = vec![1usize; sizes.len()];
202 for i in (0..sizes.len().saturating_sub(1)).rev() {
203 s[i] = s[i + 1] * sizes[i + 1];
204 }
205 s
206 };
207 let slice =
208 ndslice::Slice::new(0, sizes, strides).map_err(|e| PyValueError::new_err(e.to_string()))?;
209 let region = Region::new(labels, slice);
210 let vals: Vec<LazyCell> = values
211 .extract::<Vec<Py<PyAny>>>()?
212 .into_iter()
213 .map(|v| Mutex::new(LazyPyObject::Unpickled(v)))
214 .collect();
215 let mut inner = <ValueMesh<LazyCell> as BuildFromRegion<LazyCell>>::build_dense(region, vals)
216 .map_err(|e| PyValueError::new_err(e.to_string()))?;
217 compress(&mut inner);
218 Ok(PyValueMesh { inner })
219}
220
221pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
222 module.add_class::<PyValueMesh>()?;
223 module.add_function(wrap_pyfunction!(_make_test_value_mesh, module)?)?;
224 Ok(())
225}