1use monarch_types::MapPyErr;
10use ndslice::Extent;
11use ndslice::Point;
12use ndslice::Region;
13use ndslice::Shape;
14use ndslice::Slice;
15use ndslice::View;
16use pyo3::IntoPyObjectExt;
17use pyo3::exceptions::PyValueError;
18use pyo3::prelude::*;
19use pyo3::types::PyBytes;
20use pyo3::types::PyDict;
21use pyo3::types::PyMapping;
22use serde::Deserialize;
23use serde::Serialize;
24
25use crate::ndslice::PySlice;
26
27#[derive(Serialize, Deserialize, Clone)]
28#[pyclass(
29 name = "Extent",
30 module = "monarch._rust_bindings.monarch_hyperactor.shape",
31 frozen
32)]
33pub struct PyExtent {
34 inner: Extent,
35}
36
37#[pymethods]
38impl PyExtent {
39 #[new]
40 pub fn new(labels: Vec<String>, sizes: Vec<usize>) -> PyResult<PyExtent> {
41 Ok(PyExtent {
42 inner: Extent::new(labels, sizes).map_pyerr()?,
43 })
44 }
45 #[getter]
46 fn nelements(&self) -> usize {
47 self.inner.num_ranks()
48 }
49 fn __repr__(&self) -> String {
50 self.inner.to_string()
51 }
52 #[getter]
53 fn labels(&self) -> &[String] {
54 self.inner.labels()
55 }
56 #[getter]
57 fn sizes(&self) -> &[usize] {
58 self.inner.sizes()
59 }
60
61 #[staticmethod]
62 fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
63 let extent: PyExtent = bincode::deserialize(bytes.as_bytes())
64 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
65 Ok(extent)
66 }
67
68 fn __reduce__<'py>(
69 slf: &Bound<'py, Self>,
70 ) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
71 let bytes = bincode::serialize(&*slf.borrow())
72 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
73 let py_bytes = PyBytes::new(slf.py(), &bytes);
74 Ok((slf.getattr("from_bytes")?, (py_bytes,)))
75 }
76
77 fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Py<PyAny>> {
78 Ok(self
79 .labels()
80 .into_bound_py_any(py)?
81 .call_method0("__iter__")?
82 .into())
83 }
84
85 fn __getitem__(&self, label: &str) -> PyResult<usize> {
86 self.inner.size(label).ok_or_else(|| {
87 PyErr::new::<PyValueError, _>(format!("Dimension '{}' not found", label))
88 })
89 }
90
91 fn __len__(&self) -> usize {
92 self.inner.len()
93 }
94
95 fn keys<'py>(&self, py: Python<'py>) -> PyResult<Py<PyAny>> {
96 Ok(self.inner.labels().into_bound_py_any(py)?.into())
97 }
98
99 #[getter]
100 fn region(&self) -> PyRegion {
101 PyRegion {
102 inner: self.inner.region(),
103 }
104 }
105
106 fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
107 if let Ok(other) = other.extract::<PyExtent>() {
108 Ok(self.inner == other.inner)
109 } else {
110 Ok(false)
111 }
112 }
113}
114
115impl From<Extent> for PyExtent {
116 fn from(inner: Extent) -> Self {
117 PyExtent { inner }
118 }
119}
120
121impl From<PyExtent> for Extent {
122 fn from(py_extent: PyExtent) -> Self {
123 py_extent.inner
124 }
125}
126
127#[derive(Serialize, Deserialize, Clone)]
128#[pyclass(
129 name = "Region",
130 module = "monarch._rust_bindings.monarch_hyperactor.shape",
131 frozen
132)]
133pub struct PyRegion {
134 pub(crate) inner: Region,
135}
136
137impl PyRegion {
138 pub(crate) fn as_inner(&self) -> &Region {
139 &self.inner
140 }
141}
142
143#[pymethods]
144impl PyRegion {
145 #[new]
146 fn new(labels: Vec<String>, slice: PySlice) -> PyResult<Self> {
147 Ok(PyRegion {
148 inner: Region::new(labels, slice.into()),
149 })
150 }
151
152 fn as_shape(&self) -> PyShape {
153 PyShape {
154 inner: (&self.inner).into(),
155 }
156 }
157
158 #[getter]
159 fn labels(&self) -> Vec<String> {
160 self.inner.labels().to_vec()
161 }
162
163 fn slice(&self) -> PySlice {
164 self.inner.slice().clone().into()
165 }
166
167 fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
168 let bytes = bincode::serialize(&self.inner)
169 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
170 let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
171 let from_bytes = py
172 .import("monarch._rust_bindings.monarch_hyperactor.shape")?
173 .getattr("Region")?
174 .getattr("from_bytes")?;
175 Ok((from_bytes, py_bytes))
176 }
177
178 #[staticmethod]
179 fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
180 Ok(bincode::deserialize::<Region>(bytes.as_bytes())
181 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?
182 .into())
183 }
184
185 fn point_of_base_rank(&self, rank: usize) -> PyResult<PyPoint> {
186 self.inner
187 .point_of_base_rank(rank)
188 .map_pyerr()
189 .map(PyPoint::from)
190 }
191
192 fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
193 if let Ok(other) = other.extract::<PyRegion>() {
194 Ok(self.inner == other.inner)
195 } else {
196 Ok(false)
197 }
198 }
199}
200
201impl From<Region> for PyRegion {
202 fn from(inner: Region) -> Self {
203 PyRegion { inner }
204 }
205}
206
207#[pyclass(
208 name = "Shape",
209 module = "monarch._rust_bindings.monarch_hyperactor.shape",
210 frozen
211)]
212#[derive(Clone)]
213pub struct PyShape {
214 pub(super) inner: Shape,
215}
216
217impl PyShape {
218 pub fn get_inner(&self) -> &Shape {
219 &self.inner
220 }
221}
222
223#[pymethods]
224impl PyShape {
225 #[new]
226 fn new(labels: Vec<String>, slice: PySlice) -> PyResult<Self> {
227 let shape = Shape::new(labels, Slice::from(slice))
228 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
229 Ok(PyShape { inner: shape })
230 }
231
232 #[getter]
233 fn ndslice(&self) -> PySlice {
234 self.inner.slice().clone().into()
235 }
236 #[getter]
237 fn labels(&self) -> Vec<String> {
238 self.inner.labels().to_vec()
239 }
240 fn __str__(&self) -> PyResult<String> {
241 Ok(self.inner.to_string())
242 }
243 fn __repr__(&self) -> PyResult<String> {
244 Ok(format!("{:?}", self.inner))
245 }
246 fn coordinates<'py>(
247 &self,
248 py: Python<'py>,
249 rank: usize,
250 ) -> PyResult<pyo3::Bound<'py, pyo3::types::PyDict>> {
251 self.inner
252 .coordinates(rank)
253 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))
254 .and_then(|x| PyDict::from_sequence(&x.into_bound_py_any(py)?))
255 }
256
257 fn at(&self, label: &str, index: usize) -> PyResult<PyShape> {
258 Ok(PyShape {
259 inner: self
260 .inner
261 .at(label, index)
262 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?,
263 })
264 }
265
266 #[pyo3(signature = (**kwargs))]
267 fn index(&self, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<PyShape> {
268 if let Some(kwargs) = kwargs {
269 let mut indices: Vec<(String, usize)> = Vec::new();
270 for (key, value) in kwargs.iter() {
272 let key_str = key.extract::<String>()?;
273 let idx = value.extract::<usize>()?;
274 indices.push((key_str, idx));
275 }
276 Ok(PyShape {
277 inner: self
278 .inner
279 .index(indices)
280 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?,
281 })
282 } else {
283 Ok(PyShape {
284 inner: self.inner.clone(),
285 })
286 }
287 }
288
289 fn select(&self, label: &str, slice: &Bound<'_, pyo3::types::PySlice>) -> PyResult<PyShape> {
290 let dim = self
291 .inner
292 .dim(label)
293 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
294 let size = self.inner.slice().sizes()[dim];
295
296 let indices = slice.indices(size as isize)?;
297 let start = indices.start as usize;
298 let stop = indices.stop as usize;
299 let step = indices.step as usize;
300
301 let range = ndslice::shape::Range(start, Some(stop), step);
302 Ok(PyShape {
303 inner: self
304 .inner
305 .select(label, range)
306 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?,
307 })
308 }
309
310 #[staticmethod]
311 fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
312 let shape: Shape = bincode::deserialize(bytes.as_bytes())
313 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
314 Ok(PyShape::from(shape))
315 }
316
317 fn __reduce__<'py>(
318 slf: &Bound<'py, Self>,
319 ) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
320 let bytes = bincode::serialize(&slf.borrow().inner)
321 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
322 let py_bytes = PyBytes::new(slf.py(), &bytes);
323 Ok((slf.getattr("from_bytes")?, (py_bytes,)))
324 }
325
326 fn ranks(&self) -> Vec<usize> {
327 self.inner.slice().iter().collect()
328 }
329
330 fn __len__(&self) -> usize {
331 self.inner.slice().len()
332 }
333
334 fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
335 if let Ok(other) = other.extract::<PyShape>() {
336 Ok(self.inner == other.inner)
337 } else {
338 Ok(false)
339 }
340 }
341
342 #[staticmethod]
343 fn unity() -> PyShape {
344 Shape::unity().into()
345 }
346
347 #[getter]
348 fn extent(&self) -> PyExtent {
349 self.inner.extent().into()
350 }
351
352 #[getter]
353 fn region(&self) -> PyRegion {
354 PyRegion {
355 inner: self.inner.region(),
356 }
357 }
358}
359
360impl From<Shape> for PyShape {
361 fn from(shape: Shape) -> Self {
362 PyShape { inner: shape }
363 }
364}
365
366#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)]
367#[pyclass(
368 name = "Point",
369 module = "monarch._rust_bindings.monarch_hyperactor.shape",
370 subclass,
371 frozen
372)]
373pub struct PyPoint {
374 rank: usize,
375 extent: Extent,
376}
377
378#[pymethods]
379impl PyPoint {
380 #[new]
381 pub fn new(rank: usize, extent: PyExtent) -> Self {
382 PyPoint {
383 rank,
384 extent: extent.inner,
385 }
386 }
387 fn __getitem__(&self, label: &str) -> PyResult<usize> {
388 let index = self.extent.position(label).ok_or_else(|| {
389 PyErr::new::<PyValueError, _>(format!("Dimension '{}' not found", label))
390 })?;
391 let point = self.extent.point_of_rank(self.rank).map_pyerr()?;
392 Ok(point.coords()[index])
393 }
394
395 fn size(&self, label: &str) -> PyResult<usize> {
396 self.extent.size(label).ok_or_else(|| {
397 PyErr::new::<PyValueError, _>(format!("Dimension '{}' not found", label))
398 })
399 }
400
401 fn __len__(&self) -> usize {
402 self.extent.len()
403 }
404 fn __iter__<'py>(&self, py: Python<'py>) -> PyResult<Py<PyAny>> {
405 Ok(self
406 .extent
407 .labels()
408 .into_bound_py_any(py)?
409 .call_method0("__iter__")?
410 .into())
411 }
412
413 #[staticmethod]
414 fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<Self> {
415 let point: PyPoint = bincode::deserialize(bytes.as_bytes())
416 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
417 Ok(point)
418 }
419
420 fn __reduce__<'py>(
421 slf: &Bound<'py, Self>,
422 ) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> {
423 let bytes = bincode::serialize(&*slf.borrow())
424 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
425 let py_bytes = PyBytes::new(slf.py(), &bytes);
426 Ok((slf.getattr("from_bytes")?, (py_bytes,)))
427 }
428
429 #[getter]
430 fn extent(&self) -> PyExtent {
431 PyExtent {
432 inner: self.extent.clone(),
433 }
434 }
435 #[getter]
436 fn rank(&self) -> usize {
437 self.rank
438 }
439 fn __repr__(&self) -> PyResult<String> {
440 let point = self.extent.point_of_rank(self.rank).map_pyerr()?;
441 let coords = point.coords();
442 let labels = self.extent.labels();
443 let sizes = self.extent.sizes();
444 let mut parts = Vec::new();
445 for (i, label) in labels.iter().enumerate() {
446 parts.push(format!("'{}': {}/{}", label, coords[i], sizes[i]));
447 }
448
449 Ok(format!("{{{}}}", parts.join(", ")))
450 }
451
452 fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult<bool> {
453 if let Ok(other) = other.extract::<PyPoint>() {
454 Ok(*self == other)
455 } else {
456 Ok(false)
457 }
458 }
459
460 fn keys<'py>(&self, py: Python<'py>) -> PyResult<Py<PyAny>> {
461 Ok(self.extent.labels().into_bound_py_any(py)?.into())
462 }
463}
464impl From<Point> for PyPoint {
465 fn from(inner: Point) -> Self {
466 PyPoint {
467 rank: inner.rank(),
468 extent: inner.extent().clone(),
469 }
470 }
471}
472
473pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
474 let py = module.py();
475 module.add_class::<PyShape>()?;
476 module.add_class::<PySlice>()?;
477 module.add_class::<PyPoint>()?;
478 PyMapping::register::<PyPoint>(py)?;
479 module.add_class::<PyExtent>()?;
480 PyMapping::register::<PyExtent>(py)?;
481 module.add_class::<PyRegion>()?;
482 Ok(())
483}