monarch_hyperactor/
value_mesh.rs1use hyperactor_mesh::ValueMesh;
10use ndslice::Extent;
11use ndslice::Region;
12use ndslice::view::BuildFromRegion;
13use ndslice::view::Ranked;
14use ndslice::view::ViewExt;
15use pyo3::exceptions::PyValueError;
16use pyo3::prelude::*;
17use pyo3::types::PyAny;
18use pyo3::types::PyList;
19
20use crate::shape::PyShape;
21
22#[pyclass(name = "ValueMesh", module = "monarch._src.actor.actor_mesh")]
23pub struct PyValueMesh {
24 inner: ValueMesh<Py<PyAny>>,
25}
26
27#[pymethods]
28impl PyValueMesh {
29 #[new]
31 fn new(_py: Python<'_>, shape: &PyShape, values: Bound<'_, PyList>) -> PyResult<Self> {
32 let s = shape.get_inner();
36 let region = Region::new(s.labels().to_vec(), s.slice().clone());
37 let vals: Vec<Py<PyAny>> = values.extract()?;
38
39 let mut inner =
41 <ValueMesh<Py<PyAny>> as BuildFromRegion<Py<PyAny>>>::build_dense(region, vals)
42 .map_err(|e| PyValueError::new_err(e.to_string()))?;
43
44 inner.compress_adjacent_in_place_by(|a, b| a.as_ptr() == b.as_ptr());
51
52 Ok(Self { inner })
53 }
54
55 fn __len__(&self) -> usize {
57 self.inner.region().num_ranks()
58 }
59
60 #[getter]
62 fn _shape(&self) -> PyShape {
63 PyShape::from(ndslice::Shape::from(self.inner.region().clone()))
64 }
65
66 fn values(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
68 let vec: Vec<Py<PyAny>> = self.inner.values().collect();
71 Ok(PyList::new(py, vec)?.into())
72 }
73
74 fn get(&self, py: Python<'_>, rank: usize) -> PyResult<Py<PyAny>> {
76 let n = self.inner.region().num_ranks();
77 if rank >= n {
78 return Err(PyValueError::new_err(format!(
79 "index {} out of range (len={})",
80 rank, n
81 )));
82 }
83
84 let v: Py<PyAny> = self.inner.get(rank).unwrap().clone_ref(py);
89
90 Ok(v)
91 }
92
93 #[staticmethod]
95 fn from_indexed(
96 _py: Python<'_>,
97 shape: &PyShape,
98 pairs: Vec<(usize, Py<PyAny>)>,
99 ) -> PyResult<Self> {
100 let s = shape.get_inner();
102 let region = Region::new(s.labels().to_vec(), s.slice().clone());
103 let mut inner = <ValueMesh<Py<PyAny>> as ndslice::view::BuildFromRegionIndexed<
104 Py<PyAny>,
105 >>::build_indexed(region, pairs)
106 .map_err(|e| PyValueError::new_err(e.to_string()))?;
107
108 inner.compress_adjacent_in_place_by(|a, b| a.as_ptr() == b.as_ptr());
115
116 Ok(Self { inner })
117 }
118}
119
120impl PyValueMesh {
121 pub fn build_dense_from_extent(extent: &Extent, values: Vec<Py<PyAny>>) -> PyResult<Self> {
123 let mut inner = <ValueMesh<Py<PyAny>> as BuildFromRegion<Py<PyAny>>>::build_dense(
124 ndslice::View::region(extent),
125 values,
126 )
127 .map_err(|e| PyValueError::new_err(e.to_string()))?;
128 inner.compress_adjacent_in_place_by(|a, b| a.as_ptr() == b.as_ptr());
129
130 Ok(Self { inner })
131 }
132}
133
134#[pyfunction]
138fn _make_test_value_mesh(
139 labels: Vec<String>,
140 sizes: Vec<usize>,
141 values: Bound<'_, PyList>,
142) -> PyResult<PyValueMesh> {
143 let strides: Vec<usize> = {
144 let mut s = vec![1usize; sizes.len()];
145 for i in (0..sizes.len().saturating_sub(1)).rev() {
146 s[i] = s[i + 1] * sizes[i + 1];
147 }
148 s
149 };
150 let slice =
151 ndslice::Slice::new(0, sizes, strides).map_err(|e| PyValueError::new_err(e.to_string()))?;
152 let region = Region::new(labels, slice);
153 let vals: Vec<Py<PyAny>> = values.extract()?;
154 let mut inner = <ValueMesh<Py<PyAny>> as BuildFromRegion<Py<PyAny>>>::build_dense(region, vals)
155 .map_err(|e| PyValueError::new_err(e.to_string()))?;
156 inner.compress_adjacent_in_place_by(|a, b| a.as_ptr() == b.as_ptr());
157 Ok(PyValueMesh { inner })
158}
159
160pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
161 module.add_class::<PyValueMesh>()?;
162 module.add_function(wrap_pyfunction!(_make_test_value_mesh, module)?)?;
163 Ok(())
164}