monarch_hyperactor/
ndslice.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::hash::DefaultHasher;
10use std::hash::Hash;
11use std::hash::Hasher;
12use std::sync::Arc;
13
14use pyo3::IntoPyObjectExt;
15use pyo3::exceptions::PyIndexError;
16use pyo3::exceptions::PyValueError;
17use pyo3::prelude::*;
18use pyo3::types::PyDict;
19use pyo3::types::PyList;
20use pyo3::types::PySliceMethods;
21use pyo3::types::PyTuple;
22
23/// A wrapper around [ndslice::Slice] to expose it to python.
24/// It is a compact representation of indices into the flat
25/// representation of an n-dimensional array. Given an offset, sizes of
26/// each dimension, and strides for each dimension, Slice can compute
27/// indices into the flat array.
28#[pyclass(
29    name = "Slice",
30    frozen,
31    module = "monarch._rust_bindings.monarch_hyperactor.shape"
32)]
33#[derive(Clone)]
34pub struct PySlice {
35    inner: Arc<ndslice::Slice>,
36}
37
38#[pymethods]
39impl PySlice {
40    #[new]
41    #[pyo3(signature = (*, offset, sizes, strides))]
42    fn new(offset: usize, sizes: Vec<usize>, strides: Vec<usize>) -> PyResult<Self> {
43        Ok(Self {
44            inner: Arc::new(
45                ndslice::Slice::new(offset, sizes, strides)
46                    .map_err(|err| PyValueError::new_err(err.to_string()))?,
47            ),
48        })
49    }
50
51    /// Returns the number of dimensions of the slice.
52    #[getter]
53    fn ndim(&self) -> usize {
54        self.inner.sizes().len()
55    }
56
57    /// Returns the offset of the slice.
58    #[getter]
59    fn offset(&self) -> usize {
60        self.inner.offset()
61    }
62
63    /// Returns the sizes of each of the dimensions of the slice.
64    #[getter]
65    fn sizes(&self) -> Vec<usize> {
66        self.inner.sizes().to_vec()
67    }
68
69    /// Returns the strides of each of the dimensions of the slice.
70    #[getter]
71    fn strides(&self) -> Vec<usize> {
72        self.inner.strides().to_vec()
73    }
74
75    /// Returns the index of the given value in the slice or raises a `ValueError`
76    /// if the value is not in the slice.
77    fn index(&self, value: usize) -> PyResult<usize> {
78        self.inner
79            .index(value)
80            .map_err(|err| PyValueError::new_err(err.to_string()))
81    }
82
83    /// Returns the coordinates of the given value in the slice or raises a `ValueError`
84    /// if the value is not in the slice.
85    fn coordinates(&self, value: usize) -> PyResult<Vec<usize>> {
86        self.inner
87            .coordinates(value)
88            .map_err(|err| PyValueError::new_err(err.to_string()))
89    }
90
91    /// Returns the value at the given coordinates or raises an `IndexError` if the coordinates
92    /// are out of bounds.
93    fn nditem(&self, coordinates: Vec<usize>) -> PyResult<usize> {
94        self.inner
95            .location(&coordinates)
96            .map_err(|err| PyIndexError::new_err(err.to_string()))
97    }
98
99    /// Returns the value at the given index or raises an `IndexError` if the index is out of bounds.
100    fn __getitem__(&self, py: Python<'_>, range: Range<'_>) -> PyResult<Py<PyAny>> {
101        match range {
102            Range::Single(index) => self
103                .inner
104                .get(index)
105                .map(|res| res.into_py_any(py))
106                .map_err(|err| PyIndexError::new_err(err.to_string()))?,
107            Range::Slice(slice) => {
108                let indices =
109                    slice.indices((self.inner.len() as std::os::raw::c_long).try_into()?)?;
110                let (start, stop, step) = (indices.start, indices.stop, indices.step);
111                if start < 0 || stop < 0 {
112                    return Err(PyIndexError::new_err("Only positive indices are support"));
113                }
114                let mut result = Vec::new();
115                let mut i = start;
116                while if step > 0 { i < stop } else { i > stop } {
117                    result.push(
118                        self.inner
119                            .get(i as usize)
120                            .map_err(|err| PyIndexError::new_err(err.to_string()))?,
121                    );
122                    i += step;
123                }
124                PyTuple::new(py, result)?.into_py_any(py)
125            }
126        }
127    }
128
129    fn __iter__(&self) -> PySliceIterator {
130        PySliceIterator::new(self.inner.clone())
131    }
132
133    fn __len__(&self) -> usize {
134        self.inner.len()
135    }
136
137    fn __getnewargs_ex__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
138        let kwargs = PyDict::new(py);
139        kwargs.set_item("offset", self.inner.offset()).unwrap();
140        kwargs.set_item("sizes", self.inner.sizes()).unwrap();
141        kwargs.set_item("strides", self.inner.strides()).unwrap();
142
143        PyTuple::new(
144            py,
145            vec![
146                PyTuple::empty(py).unbind().into_any(),
147                kwargs.unbind().into_any(),
148            ],
149        )
150    }
151
152    fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
153        if let Ok(other) = other.extract::<PySlice>() {
154            Ok(self.inner == other.inner)
155        } else {
156            Ok(false)
157        }
158    }
159
160    fn __hash__(&self) -> u64 {
161        let mut hasher = DefaultHasher::new();
162        self.inner.hash(&mut hasher);
163        hasher.finish()
164    }
165
166    /// Returns a list of slices that cover the given list of ranks.
167    #[staticmethod]
168    fn from_list(py: Python<'_>, ranks: Vec<usize>) -> PyResult<Py<PyAny>> {
169        if ranks.is_empty() {
170            return PyList::empty(py).into_py_any(py);
171        }
172        let mut ranks = ranks;
173        ranks.sort();
174
175        let mut result = Vec::new();
176        let mut offset = ranks[0];
177        let mut size = 1;
178        let mut stride = 1;
179        for &rank in &ranks[1..] {
180            if size == 1 && rank > offset {
181                stride = rank - offset;
182                size += 1;
183            } else if offset + size * stride == rank {
184                size += 1;
185            } else {
186                result.push(Self::new(offset, vec![size], vec![stride])?);
187                offset = rank;
188                size = 1;
189                stride = 1;
190            }
191        }
192        result.push(Self::new(offset, vec![size], vec![stride])?);
193        result.into_py_any(py)
194    }
195
196    fn __repr__(&self) -> PyResult<String> {
197        Ok(format!("{:?}", self.inner))
198    }
199
200    #[staticmethod]
201    fn new_row_major(sizes: Vec<usize>) -> PySlice {
202        ndslice::Slice::new_row_major(sizes).into()
203    }
204
205    fn get(&self, index: usize) -> PyResult<usize> {
206        self.inner
207            .get(index)
208            .map_err(|err| PyValueError::new_err(err.to_string()))
209    }
210}
211
212impl From<&PySlice> for ndslice::Slice {
213    fn from(slice: &PySlice) -> Self {
214        slice.inner.as_ref().clone()
215    }
216}
217
218impl From<PySlice> for ndslice::Slice {
219    fn from(slice: PySlice) -> Self {
220        slice.inner.as_ref().clone()
221    }
222}
223
224impl From<ndslice::Slice> for PySlice {
225    fn from(value: ndslice::Slice) -> Self {
226        Self {
227            inner: Arc::new(value),
228        }
229    }
230}
231
232#[derive(Debug, Clone, FromPyObject)]
233enum Range<'s> {
234    #[pyo3(transparent, annotation = "int")]
235    Single(usize),
236    #[pyo3(transparent, annotation = "slice")]
237    Slice(Bound<'s, pyo3::types::PySlice>),
238}
239
240#[pyclass]
241struct PySliceIterator {
242    data: Arc<ndslice::Slice>,
243    index: usize,
244}
245
246impl PySliceIterator {
247    fn new(data: Arc<ndslice::Slice>) -> Self {
248        Self { data, index: 0 }
249    }
250}
251
252#[pymethods]
253impl PySliceIterator {
254    fn __iter__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> {
255        self_
256    }
257
258    fn __next__(&mut self) -> Option<usize> {
259        let dims = self.data.sizes();
260        if self.index >= dims.iter().product::<usize>() {
261            return None;
262        }
263
264        let mut coords: Vec<usize> = vec![0; dims.len()];
265        let mut rest = self.index;
266        for (i, dim) in dims.iter().enumerate().rev() {
267            coords[i] = rest % dim;
268            rest /= dim;
269        }
270        self.index += 1;
271        Some(self.data.location(&coords).unwrap())
272    }
273}