Skip to main content

monarch_hyperactor/
shape.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use monarch_types::MapPyErr;
10use ndslice::Extent;
11use ndslice::Point;
12use ndslice::Region;
13use ndslice::Shape;
14use ndslice::Slice;
15use ndslice::View;
16use pyo3::IntoPyObjectExt;
17use pyo3::exceptions::PyValueError;
18use pyo3::prelude::*;
19use pyo3::types::PyBytes;
20use pyo3::types::PyDict;
21use pyo3::types::PyMapping;
22use serde::Deserialize;
23use serde::Serialize;
24
25use crate::ndslice::PySlice;
26
27#[derive(Serialize, Deserialize, Clone)]
28#[pyclass(
29    name = "Extent",
30    module = "monarch._rust_bindings.monarch_hyperactor.shape",
31    frozen
32)]
33pub struct PyExtent {
34    inner: Extent,
35}
36
37#[pymethods]
38impl PyExtent {
39    #[new]
40    pub fn new(labels: Vec<String>, sizes: Vec<usize>) -> PyResult<PyExtent> {
41        Ok(PyExtent {
42            inner: Extent::new(labels, sizes).map_pyerr()?,
43        })
44    }
45    #[getter]
46    fn nelements(&self) -> usize {
47        self.inner.num_ranks()
48    }
49    fn __repr__(&self) -> String {
50        self.inner.to_string()
51    }
52    #[getter]
53    fn labels(&self) -> &[String] {
54        self.inner.labels()
55    }
56    #[getter]
57    fn sizes(&self) -> &[usize] {
58        self.inner.sizes()
59    }
60
61    #[staticmethod]
62    fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
63        let extent: PyExtent =
64            bincode::serde::decode_from_slice(bytes.as_bytes(), bincode::config::legacy())
65                .map(|(v, _)| v)
66                .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
67        Ok(extent)
68    }
69
70    fn __reduce__<'py>(
71        slf: &Bound<'py, Self>,
72    ) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
73        let bytes = bincode::serde::encode_to_vec(&*slf.borrow(), bincode::config::legacy())
74            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
75        let py_bytes = PyBytes::new(slf.py(), &bytes);
76        Ok((slf.getattr("from_bytes")?, (py_bytes,)))
77    }
78
79    fn __iter__(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
80        Ok(self
81            .labels()
82            .into_bound_py_any(py)?
83            .call_method0("__iter__")?
84            .into())
85    }
86
87    fn __getitem__(&self, label: &str) -> PyResult<usize> {
88        self.inner.size(label).ok_or_else(|| {
89            PyErr::new::<PyValueError, _>(format!("Dimension '{}' not found", label))
90        })
91    }
92
93    fn __len__(&self) -> usize {
94        self.inner.len()
95    }
96
97    fn keys(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
98        Ok(self.inner.labels().into_bound_py_any(py)?.into())
99    }
100
101    #[getter]
102    fn region(&self) -> PyRegion {
103        PyRegion {
104            inner: self.inner.region(),
105        }
106    }
107
108    fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
109        if let Ok(other) = other.extract::<PyExtent>() {
110            Ok(self.inner == other.inner)
111        } else {
112            Ok(false)
113        }
114    }
115}
116
117impl From<Extent> for PyExtent {
118    fn from(inner: Extent) -> Self {
119        PyExtent { inner }
120    }
121}
122
123impl From<PyExtent> for Extent {
124    fn from(py_extent: PyExtent) -> Self {
125        py_extent.inner
126    }
127}
128
129#[derive(Serialize, Deserialize, Clone)]
130#[pyclass(
131    name = "Region",
132    module = "monarch._rust_bindings.monarch_hyperactor.shape",
133    frozen
134)]
135pub struct PyRegion {
136    pub(crate) inner: Region,
137}
138
139impl PyRegion {
140    pub(crate) fn as_inner(&self) -> &Region {
141        &self.inner
142    }
143}
144
145#[pymethods]
146impl PyRegion {
147    #[new]
148    fn new(labels: Vec<String>, slice: PySlice) -> PyResult<Self> {
149        Ok(PyRegion {
150            inner: Region::new(labels, slice.into()),
151        })
152    }
153
154    fn as_shape(&self) -> PyShape {
155        PyShape {
156            inner: (&self.inner).into(),
157        }
158    }
159
160    #[getter]
161    fn labels(&self) -> Vec<String> {
162        self.inner.labels().to_vec()
163    }
164
165    fn slice(&self) -> PySlice {
166        self.inner.slice().clone().into()
167    }
168
169    fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
170        let bytes = bincode::serde::encode_to_vec(&self.inner, bincode::config::legacy())
171            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
172        let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
173        let from_bytes = py
174            .import("monarch._rust_bindings.monarch_hyperactor.shape")?
175            .getattr("Region")?
176            .getattr("from_bytes")?;
177        Ok((from_bytes, py_bytes))
178    }
179
180    #[staticmethod]
181    fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
182        Ok(bincode::serde::decode_from_slice::<Region, _>(
183            bytes.as_bytes(),
184            bincode::config::legacy(),
185        )
186        .map(|(v, _)| v)
187        .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?
188        .into())
189    }
190
191    fn point_of_base_rank(&self, rank: usize) -> PyResult<PyPoint> {
192        self.inner
193            .point_of_base_rank(rank)
194            .map_pyerr()
195            .map(PyPoint::from)
196    }
197
198    fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
199        if let Ok(other) = other.extract::<PyRegion>() {
200            Ok(self.inner == other.inner)
201        } else {
202            Ok(false)
203        }
204    }
205}
206
207impl From<Region> for PyRegion {
208    fn from(inner: Region) -> Self {
209        PyRegion { inner }
210    }
211}
212
213#[pyclass(
214    name = "Shape",
215    module = "monarch._rust_bindings.monarch_hyperactor.shape",
216    frozen
217)]
218#[derive(Clone)]
219pub struct PyShape {
220    pub(super) inner: Shape,
221}
222
223impl PyShape {
224    pub fn get_inner(&self) -> &Shape {
225        &self.inner
226    }
227}
228
229#[pymethods]
230impl PyShape {
231    #[new]
232    fn new(labels: Vec<String>, slice: PySlice) -> PyResult<Self> {
233        let shape = Shape::new(labels, Slice::from(slice))
234            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
235        Ok(PyShape { inner: shape })
236    }
237
238    #[getter]
239    fn ndslice(&self) -> PySlice {
240        self.inner.slice().clone().into()
241    }
242    #[getter]
243    fn labels(&self) -> Vec<String> {
244        self.inner.labels().to_vec()
245    }
246    fn __str__(&self) -> PyResult<String> {
247        Ok(self.inner.to_string())
248    }
249    fn __repr__(&self) -> PyResult<String> {
250        Ok(format!("{:?}", self.inner))
251    }
252    fn coordinates<'py>(
253        &self,
254        py: Python<'py>,
255        rank: usize,
256    ) -> PyResult<pyo3::Bound<'py, pyo3::types::PyDict>> {
257        self.inner
258            .coordinates(rank)
259            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))
260            .and_then(|x| PyDict::from_sequence(&x.into_bound_py_any(py)?))
261    }
262
263    fn at(&self, label: &str, index: usize) -> PyResult<PyShape> {
264        Ok(PyShape {
265            inner: self
266                .inner
267                .at(label, index)
268                .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?,
269        })
270    }
271
272    #[pyo3(signature = (**kwargs))]
273    fn index(&self, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<PyShape> {
274        if let Some(kwargs) = kwargs {
275            let mut indices: Vec<(String, usize)> = Vec::new();
276            // translate kwargs into indices
277            for (key, value) in kwargs.iter() {
278                let key_str = key.extract::<String>()?;
279                let idx = value.extract::<usize>()?;
280                indices.push((key_str, idx));
281            }
282            Ok(PyShape {
283                inner: self
284                    .inner
285                    .index(indices)
286                    .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?,
287            })
288        } else {
289            Ok(PyShape {
290                inner: self.inner.clone(),
291            })
292        }
293    }
294
295    fn select(&self, label: &str, slice: &Bound<'_, pyo3::types::PySlice>) -> PyResult<PyShape> {
296        let dim = self
297            .inner
298            .dim(label)
299            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
300        let size = self.inner.slice().sizes()[dim];
301
302        let indices = slice.indices(size as isize)?;
303        let start = indices.start as usize;
304        let stop = indices.stop as usize;
305        let step = indices.step as usize;
306
307        let range = ndslice::shape::Range(start, Some(stop), step);
308        Ok(PyShape {
309            inner: self
310                .inner
311                .select(label, range)
312                .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?,
313        })
314    }
315
316    #[staticmethod]
317    fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
318        let shape: Shape =
319            bincode::serde::decode_from_slice(bytes.as_bytes(), bincode::config::legacy())
320                .map(|(v, _)| v)
321                .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
322        Ok(PyShape::from(shape))
323    }
324
325    fn __reduce__<'py>(
326        slf: &Bound<'py, Self>,
327    ) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
328        let bytes = bincode::serde::encode_to_vec(&slf.borrow().inner, bincode::config::legacy())
329            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
330        let py_bytes = PyBytes::new(slf.py(), &bytes);
331        Ok((slf.getattr("from_bytes")?, (py_bytes,)))
332    }
333
334    fn ranks(&self) -> Vec<usize> {
335        self.inner.slice().iter().collect()
336    }
337
338    fn __len__(&self) -> usize {
339        self.inner.slice().len()
340    }
341
342    fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
343        if let Ok(other) = other.extract::<PyShape>() {
344            Ok(self.inner == other.inner)
345        } else {
346            Ok(false)
347        }
348    }
349
350    #[staticmethod]
351    fn unity() -> PyShape {
352        Shape::unity().into()
353    }
354
355    #[getter]
356    fn extent(&self) -> PyExtent {
357        self.inner.extent().into()
358    }
359
360    #[getter]
361    fn region(&self) -> PyRegion {
362        PyRegion {
363            inner: self.inner.region(),
364        }
365    }
366}
367
368impl From<Shape> for PyShape {
369    fn from(shape: Shape) -> Self {
370        PyShape { inner: shape }
371    }
372}
373
374#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)]
375#[pyclass(
376    name = "Point",
377    module = "monarch._rust_bindings.monarch_hyperactor.shape",
378    subclass,
379    frozen
380)]
381pub struct PyPoint {
382    rank: usize,
383    extent: Extent,
384}
385
386#[pymethods]
387impl PyPoint {
388    #[new]
389    pub fn new(rank: usize, extent: PyExtent) -> Self {
390        PyPoint {
391            rank,
392            extent: extent.inner,
393        }
394    }
395    fn __getitem__(&self, label: &str) -> PyResult<usize> {
396        let index = self.extent.position(label).ok_or_else(|| {
397            PyErr::new::<PyValueError, _>(format!("Dimension '{}' not found", label))
398        })?;
399        let point = self.extent.point_of_rank(self.rank).map_pyerr()?;
400        Ok(point.coords()[index])
401    }
402
403    fn size(&self, label: &str) -> PyResult<usize> {
404        self.extent.size(label).ok_or_else(|| {
405            PyErr::new::<PyValueError, _>(format!("Dimension '{}' not found", label))
406        })
407    }
408
409    fn __len__(&self) -> usize {
410        self.extent.len()
411    }
412    fn __iter__(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
413        Ok(self
414            .extent
415            .labels()
416            .into_bound_py_any(py)?
417            .call_method0("__iter__")?
418            .into())
419    }
420
421    #[staticmethod]
422    fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
423        let point: PyPoint =
424            bincode::serde::decode_from_slice(bytes.as_bytes(), bincode::config::legacy())
425                .map(|(v, _)| v)
426                .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
427        Ok(point)
428    }
429
430    fn __reduce__<'py>(
431        slf: &Bound<'py, Self>,
432    ) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
433        let bytes = bincode::serde::encode_to_vec(&*slf.borrow(), bincode::config::legacy())
434            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
435        let py_bytes = PyBytes::new(slf.py(), &bytes);
436        Ok((slf.getattr("from_bytes")?, (py_bytes,)))
437    }
438
439    #[getter]
440    fn extent(&self) -> PyExtent {
441        PyExtent {
442            inner: self.extent.clone(),
443        }
444    }
445    #[getter]
446    fn rank(&self) -> usize {
447        self.rank
448    }
449    fn __repr__(&self) -> PyResult<String> {
450        let point = self.extent.point_of_rank(self.rank).map_pyerr()?;
451        let coords = point.coords();
452        let labels = self.extent.labels();
453        let sizes = self.extent.sizes();
454        let mut parts = Vec::new();
455        for (i, label) in labels.iter().enumerate() {
456            parts.push(format!("'{}': {}/{}", label, coords[i], sizes[i]));
457        }
458
459        Ok(format!("{{{}}}", parts.join(", ")))
460    }
461
462    fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
463        if let Ok(other) = other.extract::<PyPoint>() {
464            Ok(*self == other)
465        } else {
466            Ok(false)
467        }
468    }
469
470    fn keys(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
471        Ok(self.extent.labels().into_bound_py_any(py)?.into())
472    }
473}
474impl From<Point> for PyPoint {
475    fn from(inner: Point) -> Self {
476        PyPoint {
477            rank: inner.rank(),
478            extent: inner.extent().clone(),
479        }
480    }
481}
482
483pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
484    let py = module.py();
485    module.add_class::<PyShape>()?;
486    module.add_class::<PySlice>()?;
487    module.add_class::<PyPoint>()?;
488    PyMapping::register::<PyPoint>(py)?;
489    module.add_class::<PyExtent>()?;
490    PyMapping::register::<PyExtent>(py)?;
491    module.add_class::<PyRegion>()?;
492    Ok(())
493}