1use monarch_types::MapPyErr;
10use ndslice::Extent;
11use ndslice::Point;
12use ndslice::Region;
13use ndslice::Shape;
14use ndslice::Slice;
15use ndslice::View;
16use pyo3::IntoPyObjectExt;
17use pyo3::exceptions::PyValueError;
18use pyo3::prelude::*;
19use pyo3::types::PyBytes;
20use pyo3::types::PyDict;
21use pyo3::types::PyMapping;
22use serde::Deserialize;
23use serde::Serialize;
24
25use crate::ndslice::PySlice;
26
27#[derive(Serialize, Deserialize, Clone)]
28#[pyclass(
29 name = "Extent",
30 module = "monarch._rust_bindings.monarch_hyperactor.shape",
31 frozen
32)]
33pub struct PyExtent {
34 inner: Extent,
35}
36
37#[pymethods]
38impl PyExtent {
39 #[new]
40 pub fn new(labels: Vec<String>, sizes: Vec<usize>) -> PyResult<PyExtent> {
41 Ok(PyExtent {
42 inner: Extent::new(labels, sizes).map_pyerr()?,
43 })
44 }
45 #[getter]
46 fn nelements(&self) -> usize {
47 self.inner.num_ranks()
48 }
49 fn __repr__(&self) -> String {
50 self.inner.to_string()
51 }
52 #[getter]
53 fn labels(&self) -> &[String] {
54 self.inner.labels()
55 }
56 #[getter]
57 fn sizes(&self) -> &[usize] {
58 self.inner.sizes()
59 }
60
61 #[staticmethod]
62 fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
63 let extent: PyExtent =
64 bincode::serde::decode_from_slice(bytes.as_bytes(), bincode::config::legacy())
65 .map(|(v, _)| v)
66 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
67 Ok(extent)
68 }
69
70 fn __reduce__<'py>(
71 slf: &Bound<'py, Self>,
72 ) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
73 let bytes = bincode::serde::encode_to_vec(&*slf.borrow(), bincode::config::legacy())
74 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
75 let py_bytes = PyBytes::new(slf.py(), &bytes);
76 Ok((slf.getattr("from_bytes")?, (py_bytes,)))
77 }
78
79 fn __iter__(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
80 Ok(self
81 .labels()
82 .into_bound_py_any(py)?
83 .call_method0("__iter__")?
84 .into())
85 }
86
87 fn __getitem__(&self, label: &str) -> PyResult<usize> {
88 self.inner.size(label).ok_or_else(|| {
89 PyErr::new::<PyValueError, _>(format!("Dimension '{}' not found", label))
90 })
91 }
92
93 fn __len__(&self) -> usize {
94 self.inner.len()
95 }
96
97 fn keys(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
98 Ok(self.inner.labels().into_bound_py_any(py)?.into())
99 }
100
101 #[getter]
102 fn region(&self) -> PyRegion {
103 PyRegion {
104 inner: self.inner.region(),
105 }
106 }
107
108 fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
109 if let Ok(other) = other.extract::<PyExtent>() {
110 Ok(self.inner == other.inner)
111 } else {
112 Ok(false)
113 }
114 }
115}
116
117impl From<Extent> for PyExtent {
118 fn from(inner: Extent) -> Self {
119 PyExtent { inner }
120 }
121}
122
123impl From<PyExtent> for Extent {
124 fn from(py_extent: PyExtent) -> Self {
125 py_extent.inner
126 }
127}
128
129#[derive(Serialize, Deserialize, Clone)]
130#[pyclass(
131 name = "Region",
132 module = "monarch._rust_bindings.monarch_hyperactor.shape",
133 frozen
134)]
135pub struct PyRegion {
136 pub(crate) inner: Region,
137}
138
139impl PyRegion {
140 pub(crate) fn as_inner(&self) -> &Region {
141 &self.inner
142 }
143}
144
145#[pymethods]
146impl PyRegion {
147 #[new]
148 fn new(labels: Vec<String>, slice: PySlice) -> PyResult<Self> {
149 Ok(PyRegion {
150 inner: Region::new(labels, slice.into()),
151 })
152 }
153
154 fn as_shape(&self) -> PyShape {
155 PyShape {
156 inner: (&self.inner).into(),
157 }
158 }
159
160 #[getter]
161 fn labels(&self) -> Vec<String> {
162 self.inner.labels().to_vec()
163 }
164
165 fn slice(&self) -> PySlice {
166 self.inner.slice().clone().into()
167 }
168
169 fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
170 let bytes = bincode::serde::encode_to_vec(&self.inner, bincode::config::legacy())
171 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
172 let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
173 let from_bytes = py
174 .import("monarch._rust_bindings.monarch_hyperactor.shape")?
175 .getattr("Region")?
176 .getattr("from_bytes")?;
177 Ok((from_bytes, py_bytes))
178 }
179
180 #[staticmethod]
181 fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
182 Ok(bincode::serde::decode_from_slice::<Region, _>(
183 bytes.as_bytes(),
184 bincode::config::legacy(),
185 )
186 .map(|(v, _)| v)
187 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?
188 .into())
189 }
190
191 fn point_of_base_rank(&self, rank: usize) -> PyResult<PyPoint> {
192 self.inner
193 .point_of_base_rank(rank)
194 .map_pyerr()
195 .map(PyPoint::from)
196 }
197
198 fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
199 if let Ok(other) = other.extract::<PyRegion>() {
200 Ok(self.inner == other.inner)
201 } else {
202 Ok(false)
203 }
204 }
205}
206
207impl From<Region> for PyRegion {
208 fn from(inner: Region) -> Self {
209 PyRegion { inner }
210 }
211}
212
213#[pyclass(
214 name = "Shape",
215 module = "monarch._rust_bindings.monarch_hyperactor.shape",
216 frozen
217)]
218#[derive(Clone)]
219pub struct PyShape {
220 pub(super) inner: Shape,
221}
222
223impl PyShape {
224 pub fn get_inner(&self) -> &Shape {
225 &self.inner
226 }
227}
228
229#[pymethods]
230impl PyShape {
231 #[new]
232 fn new(labels: Vec<String>, slice: PySlice) -> PyResult<Self> {
233 let shape = Shape::new(labels, Slice::from(slice))
234 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
235 Ok(PyShape { inner: shape })
236 }
237
238 #[getter]
239 fn ndslice(&self) -> PySlice {
240 self.inner.slice().clone().into()
241 }
242 #[getter]
243 fn labels(&self) -> Vec<String> {
244 self.inner.labels().to_vec()
245 }
246 fn __str__(&self) -> PyResult<String> {
247 Ok(self.inner.to_string())
248 }
249 fn __repr__(&self) -> PyResult<String> {
250 Ok(format!("{:?}", self.inner))
251 }
252 fn coordinates<'py>(
253 &self,
254 py: Python<'py>,
255 rank: usize,
256 ) -> PyResult<pyo3::Bound<'py, pyo3::types::PyDict>> {
257 self.inner
258 .coordinates(rank)
259 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))
260 .and_then(|x| PyDict::from_sequence(&x.into_bound_py_any(py)?))
261 }
262
263 fn at(&self, label: &str, index: usize) -> PyResult<PyShape> {
264 Ok(PyShape {
265 inner: self
266 .inner
267 .at(label, index)
268 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?,
269 })
270 }
271
272 #[pyo3(signature = (**kwargs))]
273 fn index(&self, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<PyShape> {
274 if let Some(kwargs) = kwargs {
275 let mut indices: Vec<(String, usize)> = Vec::new();
276 for (key, value) in kwargs.iter() {
278 let key_str = key.extract::<String>()?;
279 let idx = value.extract::<usize>()?;
280 indices.push((key_str, idx));
281 }
282 Ok(PyShape {
283 inner: self
284 .inner
285 .index(indices)
286 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?,
287 })
288 } else {
289 Ok(PyShape {
290 inner: self.inner.clone(),
291 })
292 }
293 }
294
295 fn select(&self, label: &str, slice: &Bound<'_, pyo3::types::PySlice>) -> PyResult<PyShape> {
296 let dim = self
297 .inner
298 .dim(label)
299 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
300 let size = self.inner.slice().sizes()[dim];
301
302 let indices = slice.indices(size as isize)?;
303 let start = indices.start as usize;
304 let stop = indices.stop as usize;
305 let step = indices.step as usize;
306
307 let range = ndslice::shape::Range(start, Some(stop), step);
308 Ok(PyShape {
309 inner: self
310 .inner
311 .select(label, range)
312 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?,
313 })
314 }
315
316 #[staticmethod]
317 fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
318 let shape: Shape =
319 bincode::serde::decode_from_slice(bytes.as_bytes(), bincode::config::legacy())
320 .map(|(v, _)| v)
321 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
322 Ok(PyShape::from(shape))
323 }
324
325 fn __reduce__<'py>(
326 slf: &Bound<'py, Self>,
327 ) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
328 let bytes = bincode::serde::encode_to_vec(&slf.borrow().inner, bincode::config::legacy())
329 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
330 let py_bytes = PyBytes::new(slf.py(), &bytes);
331 Ok((slf.getattr("from_bytes")?, (py_bytes,)))
332 }
333
334 fn ranks(&self) -> Vec<usize> {
335 self.inner.slice().iter().collect()
336 }
337
338 fn __len__(&self) -> usize {
339 self.inner.slice().len()
340 }
341
342 fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
343 if let Ok(other) = other.extract::<PyShape>() {
344 Ok(self.inner == other.inner)
345 } else {
346 Ok(false)
347 }
348 }
349
350 #[staticmethod]
351 fn unity() -> PyShape {
352 Shape::unity().into()
353 }
354
355 #[getter]
356 fn extent(&self) -> PyExtent {
357 self.inner.extent().into()
358 }
359
360 #[getter]
361 fn region(&self) -> PyRegion {
362 PyRegion {
363 inner: self.inner.region(),
364 }
365 }
366}
367
368impl From<Shape> for PyShape {
369 fn from(shape: Shape) -> Self {
370 PyShape { inner: shape }
371 }
372}
373
374#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)]
375#[pyclass(
376 name = "Point",
377 module = "monarch._rust_bindings.monarch_hyperactor.shape",
378 subclass,
379 frozen
380)]
381pub struct PyPoint {
382 rank: usize,
383 extent: Extent,
384}
385
386#[pymethods]
387impl PyPoint {
388 #[new]
389 pub fn new(rank: usize, extent: PyExtent) -> Self {
390 PyPoint {
391 rank,
392 extent: extent.inner,
393 }
394 }
395 fn __getitem__(&self, label: &str) -> PyResult<usize> {
396 let index = self.extent.position(label).ok_or_else(|| {
397 PyErr::new::<PyValueError, _>(format!("Dimension '{}' not found", label))
398 })?;
399 let point = self.extent.point_of_rank(self.rank).map_pyerr()?;
400 Ok(point.coords()[index])
401 }
402
403 fn size(&self, label: &str) -> PyResult<usize> {
404 self.extent.size(label).ok_or_else(|| {
405 PyErr::new::<PyValueError, _>(format!("Dimension '{}' not found", label))
406 })
407 }
408
409 fn __len__(&self) -> usize {
410 self.extent.len()
411 }
412 fn __iter__(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
413 Ok(self
414 .extent
415 .labels()
416 .into_bound_py_any(py)?
417 .call_method0("__iter__")?
418 .into())
419 }
420
421 #[staticmethod]
422 fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
423 let point: PyPoint =
424 bincode::serde::decode_from_slice(bytes.as_bytes(), bincode::config::legacy())
425 .map(|(v, _)| v)
426 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
427 Ok(point)
428 }
429
430 fn __reduce__<'py>(
431 slf: &Bound<'py, Self>,
432 ) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
433 let bytes = bincode::serde::encode_to_vec(&*slf.borrow(), bincode::config::legacy())
434 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
435 let py_bytes = PyBytes::new(slf.py(), &bytes);
436 Ok((slf.getattr("from_bytes")?, (py_bytes,)))
437 }
438
439 #[getter]
440 fn extent(&self) -> PyExtent {
441 PyExtent {
442 inner: self.extent.clone(),
443 }
444 }
445 #[getter]
446 fn rank(&self) -> usize {
447 self.rank
448 }
449 fn __repr__(&self) -> PyResult<String> {
450 let point = self.extent.point_of_rank(self.rank).map_pyerr()?;
451 let coords = point.coords();
452 let labels = self.extent.labels();
453 let sizes = self.extent.sizes();
454 let mut parts = Vec::new();
455 for (i, label) in labels.iter().enumerate() {
456 parts.push(format!("'{}': {}/{}", label, coords[i], sizes[i]));
457 }
458
459 Ok(format!("{{{}}}", parts.join(", ")))
460 }
461
462 fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
463 if let Ok(other) = other.extract::<PyPoint>() {
464 Ok(*self == other)
465 } else {
466 Ok(false)
467 }
468 }
469
470 fn keys(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
471 Ok(self.extent.labels().into_bound_py_any(py)?.into())
472 }
473}
474impl From<Point> for PyPoint {
475 fn from(inner: Point) -> Self {
476 PyPoint {
477 rank: inner.rank(),
478 extent: inner.extent().clone(),
479 }
480 }
481}
482
483pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
484 let py = module.py();
485 module.add_class::<PyShape>()?;
486 module.add_class::<PySlice>()?;
487 module.add_class::<PyPoint>()?;
488 PyMapping::register::<PyPoint>(py)?;
489 module.add_class::<PyExtent>()?;
490 PyMapping::register::<PyExtent>(py)?;
491 module.add_class::<PyRegion>()?;
492 Ok(())
493}