monarch_hyperactor/
shape.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use monarch_types::MapPyErr;
10use ndslice::Extent;
11use ndslice::Point;
12use ndslice::Region;
13use ndslice::Shape;
14use ndslice::Slice;
15use ndslice::View;
16use pyo3::IntoPyObjectExt;
17use pyo3::exceptions::PyValueError;
18use pyo3::prelude::*;
19use pyo3::types::PyBytes;
20use pyo3::types::PyDict;
21use pyo3::types::PyMapping;
22use serde::Deserialize;
23use serde::Serialize;
24
25use crate::ndslice::PySlice;
26
27#[derive(Serialize, Deserialize, Clone)]
28#[pyclass(
29    name = "Extent",
30    module = "monarch._rust_bindings.monarch_hyperactor.shape",
31    frozen
32)]
33pub struct PyExtent {
34    inner: Extent,
35}
36
37#[pymethods]
38impl PyExtent {
39    #[new]
40    pub fn new(labels: Vec<String>, sizes: Vec<usize>) -> PyResult<PyExtent> {
41        Ok(PyExtent {
42            inner: Extent::new(labels, sizes).map_pyerr()?,
43        })
44    }
45    #[getter]
46    fn nelements(&self) -> usize {
47        self.inner.num_ranks()
48    }
49    fn __repr__(&self) -> String {
50        self.inner.to_string()
51    }
52    #[getter]
53    fn labels(&self) -> &[String] {
54        self.inner.labels()
55    }
56    #[getter]
57    fn sizes(&self) -> &[usize] {
58        self.inner.sizes()
59    }
60
61    #[staticmethod]
62    fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
63        let extent: PyExtent = bincode::deserialize(bytes.as_bytes())
64            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
65        Ok(extent)
66    }
67
68    fn __reduce__<'py>(
69        slf: &Bound<'py, Self>,
70    ) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
71        let bytes = bincode::serialize(&*slf.borrow())
72            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
73        let py_bytes = PyBytes::new(slf.py(), &bytes);
74        Ok((slf.getattr("from_bytes")?, (py_bytes,)))
75    }
76
77    fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Py<PyAny>> {
78        Ok(self
79            .labels()
80            .into_bound_py_any(py)?
81            .call_method0("__iter__")?
82            .into())
83    }
84
85    fn __getitem__(&self, label: &str) -> PyResult<usize> {
86        self.inner.size(label).ok_or_else(|| {
87            PyErr::new::<PyValueError, _>(format!("Dimension '{}' not found", label))
88        })
89    }
90
91    fn __len__(&self) -> usize {
92        self.inner.len()
93    }
94
95    fn keys<'py>(&self, py: Python<'py>) -> PyResult<Py<PyAny>> {
96        Ok(self.inner.labels().into_bound_py_any(py)?.into())
97    }
98
99    #[getter]
100    fn region(&self) -> PyRegion {
101        PyRegion {
102            inner: self.inner.region(),
103        }
104    }
105
106    fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
107        if let Ok(other) = other.extract::<PyExtent>() {
108            Ok(self.inner == other.inner)
109        } else {
110            Ok(false)
111        }
112    }
113}
114
115impl From<Extent> for PyExtent {
116    fn from(inner: Extent) -> Self {
117        PyExtent { inner }
118    }
119}
120
121impl From<PyExtent> for Extent {
122    fn from(py_extent: PyExtent) -> Self {
123        py_extent.inner
124    }
125}
126
127#[derive(Serialize, Deserialize, Clone)]
128#[pyclass(
129    name = "Region",
130    module = "monarch._rust_bindings.monarch_hyperactor.shape",
131    frozen
132)]
133pub struct PyRegion {
134    pub(crate) inner: Region,
135}
136
137impl PyRegion {
138    pub(crate) fn as_inner(&self) -> &Region {
139        &self.inner
140    }
141}
142
143#[pymethods]
144impl PyRegion {
145    #[new]
146    fn new(labels: Vec<String>, slice: PySlice) -> PyResult<Self> {
147        Ok(PyRegion {
148            inner: Region::new(labels, slice.into()),
149        })
150    }
151
152    fn as_shape(&self) -> PyShape {
153        PyShape {
154            inner: (&self.inner).into(),
155        }
156    }
157
158    #[getter]
159    fn labels(&self) -> Vec<String> {
160        self.inner.labels().to_vec()
161    }
162
163    fn slice(&self) -> PySlice {
164        self.inner.slice().clone().into()
165    }
166
167    fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
168        let bytes = bincode::serialize(&self.inner)
169            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
170        let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
171        let from_bytes = py
172            .import("monarch._rust_bindings.monarch_hyperactor.shape")?
173            .getattr("Region")?
174            .getattr("from_bytes")?;
175        Ok((from_bytes, py_bytes))
176    }
177
178    #[staticmethod]
179    fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
180        Ok(bincode::deserialize::<Region>(bytes.as_bytes())
181            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?
182            .into())
183    }
184
185    fn point_of_base_rank(&self, rank: usize) -> PyResult<PyPoint> {
186        self.inner
187            .point_of_base_rank(rank)
188            .map_pyerr()
189            .map(PyPoint::from)
190    }
191
192    fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
193        if let Ok(other) = other.extract::<PyRegion>() {
194            Ok(self.inner == other.inner)
195        } else {
196            Ok(false)
197        }
198    }
199}
200
201impl From<Region> for PyRegion {
202    fn from(inner: Region) -> Self {
203        PyRegion { inner }
204    }
205}
206
207#[pyclass(
208    name = "Shape",
209    module = "monarch._rust_bindings.monarch_hyperactor.shape",
210    frozen
211)]
212#[derive(Clone)]
213pub struct PyShape {
214    pub(super) inner: Shape,
215}
216
217impl PyShape {
218    pub fn get_inner(&self) -> &Shape {
219        &self.inner
220    }
221}
222
223#[pymethods]
224impl PyShape {
225    #[new]
226    fn new(labels: Vec<String>, slice: PySlice) -> PyResult<Self> {
227        let shape = Shape::new(labels, Slice::from(slice))
228            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
229        Ok(PyShape { inner: shape })
230    }
231
232    #[getter]
233    fn ndslice(&self) -> PySlice {
234        self.inner.slice().clone().into()
235    }
236    #[getter]
237    fn labels(&self) -> Vec<String> {
238        self.inner.labels().to_vec()
239    }
240    fn __str__(&self) -> PyResult<String> {
241        Ok(self.inner.to_string())
242    }
243    fn __repr__(&self) -> PyResult<String> {
244        Ok(format!("{:?}", self.inner))
245    }
246    fn coordinates<'py>(
247        &self,
248        py: Python<'py>,
249        rank: usize,
250    ) -> PyResult<pyo3::Bound<'py, pyo3::types::PyDict>> {
251        self.inner
252            .coordinates(rank)
253            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))
254            .and_then(|x| PyDict::from_sequence(&x.into_bound_py_any(py)?))
255    }
256
257    fn at(&self, label: &str, index: usize) -> PyResult<PyShape> {
258        Ok(PyShape {
259            inner: self
260                .inner
261                .at(label, index)
262                .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?,
263        })
264    }
265
266    #[pyo3(signature = (**kwargs))]
267    fn index(&self, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<PyShape> {
268        if let Some(kwargs) = kwargs {
269            let mut indices: Vec<(String, usize)> = Vec::new();
270            // translate kwargs into indices
271            for (key, value) in kwargs.iter() {
272                let key_str = key.extract::<String>()?;
273                let idx = value.extract::<usize>()?;
274                indices.push((key_str, idx));
275            }
276            Ok(PyShape {
277                inner: self
278                    .inner
279                    .index(indices)
280                    .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?,
281            })
282        } else {
283            Ok(PyShape {
284                inner: self.inner.clone(),
285            })
286        }
287    }
288
289    fn select(&self, label: &str, slice: &Bound<'_, pyo3::types::PySlice>) -> PyResult<PyShape> {
290        let dim = self
291            .inner
292            .dim(label)
293            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
294        let size = self.inner.slice().sizes()[dim];
295
296        let indices = slice.indices(size as isize)?;
297        let start = indices.start as usize;
298        let stop = indices.stop as usize;
299        let step = indices.step as usize;
300
301        let range = ndslice::shape::Range(start, Some(stop), step);
302        Ok(PyShape {
303            inner: self
304                .inner
305                .select(label, range)
306                .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?,
307        })
308    }
309
310    #[staticmethod]
311    fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
312        let shape: Shape = bincode::deserialize(bytes.as_bytes())
313            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
314        Ok(PyShape::from(shape))
315    }
316
317    fn __reduce__<'py>(
318        slf: &Bound<'py, Self>,
319    ) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
320        let bytes = bincode::serialize(&slf.borrow().inner)
321            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
322        let py_bytes = PyBytes::new(slf.py(), &bytes);
323        Ok((slf.getattr("from_bytes")?, (py_bytes,)))
324    }
325
326    fn ranks(&self) -> Vec<usize> {
327        self.inner.slice().iter().collect()
328    }
329
330    fn __len__(&self) -> usize {
331        self.inner.slice().len()
332    }
333
334    fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
335        if let Ok(other) = other.extract::<PyShape>() {
336            Ok(self.inner == other.inner)
337        } else {
338            Ok(false)
339        }
340    }
341
342    #[staticmethod]
343    fn unity() -> PyShape {
344        Shape::unity().into()
345    }
346
347    #[getter]
348    fn extent(&self) -> PyExtent {
349        self.inner.extent().into()
350    }
351
352    #[getter]
353    fn region(&self) -> PyRegion {
354        PyRegion {
355            inner: self.inner.region(),
356        }
357    }
358}
359
360impl From<Shape> for PyShape {
361    fn from(shape: Shape) -> Self {
362        PyShape { inner: shape }
363    }
364}
365
366#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)]
367#[pyclass(
368    name = "Point",
369    module = "monarch._rust_bindings.monarch_hyperactor.shape",
370    subclass,
371    frozen
372)]
373pub struct PyPoint {
374    rank: usize,
375    extent: Extent,
376}
377
378#[pymethods]
379impl PyPoint {
380    #[new]
381    pub fn new(rank: usize, extent: PyExtent) -> Self {
382        PyPoint {
383            rank,
384            extent: extent.inner,
385        }
386    }
387    fn __getitem__(&self, label: &str) -> PyResult<usize> {
388        let index = self.extent.position(label).ok_or_else(|| {
389            PyErr::new::<PyValueError, _>(format!("Dimension '{}' not found", label))
390        })?;
391        let point = self.extent.point_of_rank(self.rank).map_pyerr()?;
392        Ok(point.coords()[index])
393    }
394
395    fn size(&self, label: &str) -> PyResult<usize> {
396        self.extent.size(label).ok_or_else(|| {
397            PyErr::new::<PyValueError, _>(format!("Dimension '{}' not found", label))
398        })
399    }
400
401    fn __len__(&self) -> usize {
402        self.extent.len()
403    }
404    fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Py<PyAny>> {
405        Ok(self
406            .extent
407            .labels()
408            .into_bound_py_any(py)?
409            .call_method0("__iter__")?
410            .into())
411    }
412
413    #[staticmethod]
414    fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
415        let point: PyPoint = bincode::deserialize(bytes.as_bytes())
416            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
417        Ok(point)
418    }
419
420    fn __reduce__<'py>(
421        slf: &Bound<'py, Self>,
422    ) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
423        let bytes = bincode::serialize(&*slf.borrow())
424            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
425        let py_bytes = PyBytes::new(slf.py(), &bytes);
426        Ok((slf.getattr("from_bytes")?, (py_bytes,)))
427    }
428
429    #[getter]
430    fn extent(&self) -> PyExtent {
431        PyExtent {
432            inner: self.extent.clone(),
433        }
434    }
435    #[getter]
436    fn rank(&self) -> usize {
437        self.rank
438    }
439    fn __repr__(&self) -> PyResult<String> {
440        let point = self.extent.point_of_rank(self.rank).map_pyerr()?;
441        let coords = point.coords();
442        let labels = self.extent.labels();
443        let sizes = self.extent.sizes();
444        let mut parts = Vec::new();
445        for (i, label) in labels.iter().enumerate() {
446            parts.push(format!("'{}': {}/{}", label, coords[i], sizes[i]));
447        }
448
449        Ok(format!("{{{}}}", parts.join(", ")))
450    }
451
452    fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
453        if let Ok(other) = other.extract::<PyPoint>() {
454            Ok(*self == other)
455        } else {
456            Ok(false)
457        }
458    }
459
460    fn keys<'py>(&self, py: Python<'py>) -> PyResult<Py<PyAny>> {
461        Ok(self.extent.labels().into_bound_py_any(py)?.into())
462    }
463}
464impl From<Point> for PyPoint {
465    fn from(inner: Point) -> Self {
466        PyPoint {
467            rank: inner.rank(),
468            extent: inner.extent().clone(),
469        }
470    }
471}
472
473pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
474    let py = module.py();
475    module.add_class::<PyShape>()?;
476    module.add_class::<PySlice>()?;
477    module.add_class::<PyPoint>()?;
478    PyMapping::register::<PyPoint>(py)?;
479    module.add_class::<PyExtent>()?;
480    PyMapping::register::<PyExtent>(py)?;
481    module.add_class::<PyRegion>()?;
482    Ok(())
483}