monarch_hyperactor/
ndslice.rs1use std::hash::DefaultHasher;
10use std::hash::Hash;
11use std::hash::Hasher;
12use std::sync::Arc;
13
14use pyo3::IntoPyObjectExt;
15use pyo3::exceptions::PyIndexError;
16use pyo3::exceptions::PyValueError;
17use pyo3::prelude::*;
18use pyo3::types::PyDict;
19use pyo3::types::PyList;
20use pyo3::types::PySliceMethods;
21use pyo3::types::PyTuple;
22
23#[pyclass(
29 name = "Slice",
30 frozen,
31 module = "monarch._rust_bindings.monarch_hyperactor.shape"
32)]
33#[derive(Clone)]
34pub struct PySlice {
35 inner: Arc<ndslice::Slice>,
36}
37
38#[pymethods]
39impl PySlice {
40 #[new]
41 #[pyo3(signature = (*, offset, sizes, strides))]
42 fn new(offset: usize, sizes: Vec<usize>, strides: Vec<usize>) -> PyResult<Self> {
43 Ok(Self {
44 inner: Arc::new(
45 ndslice::Slice::new(offset, sizes, strides)
46 .map_err(|err| PyValueError::new_err(err.to_string()))?,
47 ),
48 })
49 }
50
51 #[getter]
53 fn ndim(&self) -> usize {
54 self.inner.sizes().len()
55 }
56
57 #[getter]
59 fn offset(&self) -> usize {
60 self.inner.offset()
61 }
62
63 #[getter]
65 fn sizes(&self) -> Vec<usize> {
66 self.inner.sizes().to_vec()
67 }
68
69 #[getter]
71 fn strides(&self) -> Vec<usize> {
72 self.inner.strides().to_vec()
73 }
74
75 fn index(&self, value: usize) -> PyResult<usize> {
78 self.inner
79 .index(value)
80 .map_err(|err| PyValueError::new_err(err.to_string()))
81 }
82
83 fn coordinates(&self, value: usize) -> PyResult<Vec<usize>> {
86 self.inner
87 .coordinates(value)
88 .map_err(|err| PyValueError::new_err(err.to_string()))
89 }
90
91 fn nditem(&self, coordinates: Vec<usize>) -> PyResult<usize> {
94 self.inner
95 .location(&coordinates)
96 .map_err(|err| PyIndexError::new_err(err.to_string()))
97 }
98
99 fn __getitem__(&self, py: Python<'_>, range: Range<'_>) -> PyResult<Py<PyAny>> {
101 match range {
102 Range::Single(index) => self
103 .inner
104 .get(index)
105 .map(|res| res.into_py_any(py))
106 .map_err(|err| PyIndexError::new_err(err.to_string()))?,
107 Range::Slice(slice) => {
108 let indices =
109 slice.indices((self.inner.len() as std::os::raw::c_long).try_into()?)?;
110 let (start, stop, step) = (indices.start, indices.stop, indices.step);
111 if start < 0 || stop < 0 {
112 return Err(PyIndexError::new_err("Only positive indices are support"));
113 }
114 let mut result = Vec::new();
115 let mut i = start;
116 while if step > 0 { i < stop } else { i > stop } {
117 result.push(
118 self.inner
119 .get(i as usize)
120 .map_err(|err| PyIndexError::new_err(err.to_string()))?,
121 );
122 i += step;
123 }
124 PyTuple::new(py, result)?.into_py_any(py)
125 }
126 }
127 }
128
129 fn __iter__(&self) -> PySliceIterator {
130 PySliceIterator::new(self.inner.clone())
131 }
132
133 fn __len__(&self) -> usize {
134 self.inner.len()
135 }
136
137 fn __getnewargs_ex__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
138 let kwargs = PyDict::new(py);
139 kwargs.set_item("offset", self.inner.offset()).unwrap();
140 kwargs.set_item("sizes", self.inner.sizes()).unwrap();
141 kwargs.set_item("strides", self.inner.strides()).unwrap();
142
143 PyTuple::new(
144 py,
145 vec![
146 PyTuple::empty(py).unbind().into_any(),
147 kwargs.unbind().into_any(),
148 ],
149 )
150 }
151
152 fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
153 if let Ok(other) = other.extract::<PySlice>() {
154 Ok(self.inner == other.inner)
155 } else {
156 Ok(false)
157 }
158 }
159
160 fn __hash__(&self) -> u64 {
161 let mut hasher = DefaultHasher::new();
162 self.inner.hash(&mut hasher);
163 hasher.finish()
164 }
165
166 #[staticmethod]
168 fn from_list(py: Python<'_>, ranks: Vec<usize>) -> PyResult<Py<PyAny>> {
169 if ranks.is_empty() {
170 return PyList::empty(py).into_py_any(py);
171 }
172 let mut ranks = ranks;
173 ranks.sort();
174
175 let mut result = Vec::new();
176 let mut offset = ranks[0];
177 let mut size = 1;
178 let mut stride = 1;
179 for &rank in &ranks[1..] {
180 if size == 1 && rank > offset {
181 stride = rank - offset;
182 size += 1;
183 } else if offset + size * stride == rank {
184 size += 1;
185 } else {
186 result.push(Self::new(offset, vec![size], vec![stride])?);
187 offset = rank;
188 size = 1;
189 stride = 1;
190 }
191 }
192 result.push(Self::new(offset, vec![size], vec![stride])?);
193 result.into_py_any(py)
194 }
195
196 fn __repr__(&self) -> PyResult<String> {
197 Ok(format!("{:?}", self.inner))
198 }
199
200 #[staticmethod]
201 fn new_row_major(sizes: Vec<usize>) -> PySlice {
202 ndslice::Slice::new_row_major(sizes).into()
203 }
204
205 fn get(&self, index: usize) -> PyResult<usize> {
206 self.inner
207 .get(index)
208 .map_err(|err| PyValueError::new_err(err.to_string()))
209 }
210}
211
212impl From<&PySlice> for ndslice::Slice {
213 fn from(slice: &PySlice) -> Self {
214 slice.inner.as_ref().clone()
215 }
216}
217
218impl From<PySlice> for ndslice::Slice {
219 fn from(slice: PySlice) -> Self {
220 slice.inner.as_ref().clone()
221 }
222}
223
224impl From<ndslice::Slice> for PySlice {
225 fn from(value: ndslice::Slice) -> Self {
226 Self {
227 inner: Arc::new(value),
228 }
229 }
230}
231
232#[derive(Debug, Clone, FromPyObject)]
233enum Range<'s> {
234 #[pyo3(transparent, annotation = "int")]
235 Single(usize),
236 #[pyo3(transparent, annotation = "slice")]
237 Slice(Bound<'s, pyo3::types::PySlice>),
238}
239
240#[pyclass]
241struct PySliceIterator {
242 data: Arc<ndslice::Slice>,
243 index: usize,
244}
245
246impl PySliceIterator {
247 fn new(data: Arc<ndslice::Slice>) -> Self {
248 Self { data, index: 0 }
249 }
250}
251
252#[pymethods]
253impl PySliceIterator {
254 fn __iter__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> {
255 self_
256 }
257
258 fn __next__(&mut self) -> Option<usize> {
259 let dims = self.data.sizes();
260 if self.index >= dims.iter().product::<usize>() {
261 return None;
262 }
263
264 let mut coords: Vec<usize> = vec![0; dims.len()];
265 let mut rest = self.index;
266 for (i, dim) in dims.iter().enumerate().rev() {
267 coords[i] = rest % dim;
268 rest /= dim;
269 }
270 self.index += 1;
271 Some(self.data.location(&coords).unwrap())
272 }
273}