monarch_hyperactor/
value_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use hyperactor_mesh::ValueMesh;
10use ndslice::Extent;
11use ndslice::Region;
12use ndslice::view::BuildFromRegion;
13use ndslice::view::Ranked;
14use ndslice::view::ViewExt;
15use pyo3::exceptions::PyValueError;
16use pyo3::prelude::*;
17use pyo3::types::PyAny;
18use pyo3::types::PyList;
19
20use crate::shape::PyShape;
21
22#[pyclass(name = "ValueMesh", module = "monarch._src.actor.actor_mesh")]
23pub struct PyValueMesh {
24    inner: ValueMesh<Py<PyAny>>,
25}
26
27#[pymethods]
28impl PyValueMesh {
29    /// __init__(self, shape: Shape, values: list)
30    #[new]
31    fn new(_py: Python<'_>, shape: &PyShape, values: Bound<'_, PyList>) -> PyResult<Self> {
32        // Convert shape to region, preserving the original Slice
33        // (offset/strides) so linear rank order matches the Python
34        // Shape.
35        let s = shape.get_inner();
36        let region = Region::new(s.labels().to_vec(), s.slice().clone());
37        let vals: Vec<Py<PyAny>> = values.extract()?;
38
39        // Build & validate cardinality against region.
40        let mut inner =
41            <ValueMesh<Py<PyAny>> as BuildFromRegion<Py<PyAny>>>::build_dense(region, vals)
42                .map_err(|e| PyValueError::new_err(e.to_string()))?;
43
44        // Coalesce adjacent identical Python objects (same pointer
45        // identity). For Py<PyAny>, we treat equality as object
46        // identity: consecutive references to the *same* object
47        // pointer are merged into RLE runs. This tends to compress
48        // sentinel/categorical/boolean data, but not freshly
49        // allocated numerics/strings.
50        inner.compress_adjacent_in_place_by(|a, b| a.as_ptr() == b.as_ptr());
51
52        Ok(Self { inner })
53    }
54
55    /// Return number of ranks (Python: len(vm))
56    fn __len__(&self) -> usize {
57        self.inner.region().num_ranks()
58    }
59
60    /// Expose the shape so Python MeshTrait methods can access labels/ndslice.
61    #[getter]
62    fn _shape(&self) -> PyShape {
63        PyShape::from(ndslice::Shape::from(self.inner.region().clone()))
64    }
65
66    /// Return the values in region/iteration order as a Python list.
67    fn values(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
68        // Clone the inner Py objects into a Python list (just bumps
69        // refcounts).
70        let vec: Vec<Py<PyAny>> = self.inner.values().collect();
71        Ok(PyList::new(py, vec)?.into())
72    }
73
74    /// Get value by linear rank (0..num_ranks-1).
75    fn get(&self, py: Python<'_>, rank: usize) -> PyResult<Py<PyAny>> {
76        let n = self.inner.region().num_ranks();
77        if rank >= n {
78            return Err(PyValueError::new_err(format!(
79                "index {} out of range (len={})",
80                rank, n
81            )));
82        }
83
84        // ValueMesh::get() returns &Py<PyAny>; we clone the smart
85        // pointer (incrementing the Python refcount) to return an
86        // owned Py<PyAny>. `unwrap` is safe because the bounds have
87        // been checked.
88        let v: Py<PyAny> = self.inner.get(rank).unwrap().clone_ref(py);
89
90        Ok(v)
91    }
92
93    /// Build from (rank, value) pairs with last-write-wins semantics.
94    #[staticmethod]
95    fn from_indexed(
96        _py: Python<'_>,
97        shape: &PyShape,
98        pairs: Vec<(usize, Py<PyAny>)>,
99    ) -> PyResult<Self> {
100        // Preserve the shape's original Slice (offset/strides).
101        let s = shape.get_inner();
102        let region = Region::new(s.labels().to_vec(), s.slice().clone());
103        let mut inner = <ValueMesh<Py<PyAny>> as ndslice::view::BuildFromRegionIndexed<
104            Py<PyAny>,
105        >>::build_indexed(region, pairs)
106        .map_err(|e| PyValueError::new_err(e.to_string()))?;
107
108        // Coalesce adjacent identical Python objects (same pointer
109        // identity). For Py<PyAny>, we treat equality as object
110        // identity: consecutive references to the *same* object
111        // pointer are merged into RLE runs. This tends to compress
112        // sentinel/categorical/boolean data, but not freshly
113        // allocated numerics/strings.
114        inner.compress_adjacent_in_place_by(|a, b| a.as_ptr() == b.as_ptr());
115
116        Ok(Self { inner })
117    }
118}
119
120impl PyValueMesh {
121    /// Create a ValueMesh from an extent and a pre-populated Vec of values.
122    pub fn build_dense_from_extent(extent: &Extent, values: Vec<Py<PyAny>>) -> PyResult<Self> {
123        let mut inner = <ValueMesh<Py<PyAny>> as BuildFromRegion<Py<PyAny>>>::build_dense(
124            ndslice::View::region(extent),
125            values,
126        )
127        .map_err(|e| PyValueError::new_err(e.to_string()))?;
128        inner.compress_adjacent_in_place_by(|a, b| a.as_ptr() == b.as_ptr());
129
130        Ok(Self { inner })
131    }
132}
133
134/// Test helper: create a ValueMesh entirely from Rust and return it to Python.
135/// This lets us verify that Python extension methods (patched via @rust_struct)
136/// are available on objects returned from Rust functions.
137#[pyfunction]
138fn _make_test_value_mesh(
139    labels: Vec<String>,
140    sizes: Vec<usize>,
141    values: Bound<'_, PyList>,
142) -> PyResult<PyValueMesh> {
143    let strides: Vec<usize> = {
144        let mut s = vec![1usize; sizes.len()];
145        for i in (0..sizes.len().saturating_sub(1)).rev() {
146            s[i] = s[i + 1] * sizes[i + 1];
147        }
148        s
149    };
150    let slice =
151        ndslice::Slice::new(0, sizes, strides).map_err(|e| PyValueError::new_err(e.to_string()))?;
152    let region = Region::new(labels, slice);
153    let vals: Vec<Py<PyAny>> = values.extract()?;
154    let mut inner = <ValueMesh<Py<PyAny>> as BuildFromRegion<Py<PyAny>>>::build_dense(region, vals)
155        .map_err(|e| PyValueError::new_err(e.to_string()))?;
156    inner.compress_adjacent_in_place_by(|a, b| a.as_ptr() == b.as_ptr());
157    Ok(PyValueMesh { inner })
158}
159
160pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
161    module.add_class::<PyValueMesh>()?;
162    module.add_function(wrap_pyfunction!(_make_test_value_mesh, module)?)?;
163    Ok(())
164}