monarch_hyperactor/
value_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::sync::Mutex;
10
11use hyperactor_mesh::ValueMesh;
12use ndslice::Extent;
13use ndslice::Region;
14use ndslice::view::BuildFromRegion;
15use ndslice::view::Ranked;
16use pyo3::exceptions::PyValueError;
17use pyo3::prelude::*;
18use pyo3::types::PyAny;
19use pyo3::types::PyList;
20use serde_multipart::Part;
21
22use crate::buffers::FrozenBuffer;
23use crate::pickle::unpickle;
24use crate::shape::PyShape;
25
26/// A value that is either raw pickled bytes or an already-unpickled
27/// Python object. On first access, [`Pickled`] is unpickled and
28/// replaced with [`Unpickled`] so subsequent accesses skip
29/// deserialization.
30#[derive(Clone)]
31enum LazyPyObject {
32    Pickled(Part),
33    Unpickled(Py<PyAny>),
34}
35
36type LazyCell = Mutex<LazyPyObject>;
37
38impl LazyPyObject {
39    /// Resolve to a Python object, caching the result in place.
40    /// After this call the cell will contain [`Unpickled`].
41    fn resolve(cell: &LazyCell, py: Python<'_>) -> PyResult<Py<PyAny>> {
42        let mut guard = cell.lock().unwrap();
43
44        match &*guard {
45            LazyPyObject::Unpickled(obj) => Ok(obj.clone_ref(py)),
46            LazyPyObject::Pickled(part) => {
47                let py_obj = unpickle(
48                    py,
49                    FrozenBuffer {
50                        inner: part.clone().into_bytes(),
51                    },
52                )?
53                .unbind();
54
55                *guard = LazyPyObject::Unpickled(py_obj.clone_ref(py));
56
57                Ok(py_obj)
58            }
59        }
60    }
61}
62
63fn compress(inner: &mut ValueMesh<LazyCell>) {
64    inner.compress_adjacent_in_place_by(|a, b| match (&*a.lock().unwrap(), &*b.lock().unwrap()) {
65        (LazyPyObject::Unpickled(a), LazyPyObject::Unpickled(b)) => a.as_ptr() == b.as_ptr(),
66        (LazyPyObject::Pickled(a), LazyPyObject::Pickled(b)) => a == b,
67        _ => false,
68    });
69}
70
71#[pyclass(name = "ValueMesh", module = "monarch._src.actor.actor_mesh")]
72pub struct PyValueMesh {
73    inner: ValueMesh<LazyCell>,
74}
75
76#[pymethods]
77impl PyValueMesh {
78    /// __init__(self, shape: Shape, values: list)
79    #[new]
80    fn new(_py: Python<'_>, shape: &PyShape, values: Bound<'_, PyList>) -> PyResult<Self> {
81        // Convert shape to region, preserving the original Slice
82        // (offset/strides) so linear rank order matches the Python
83        // Shape.
84        let s = shape.get_inner();
85        let region = Region::new(s.labels().to_vec(), s.slice().clone());
86        let vals: Vec<LazyCell> = values
87            .extract::<Vec<Py<PyAny>>>()?
88            .into_iter()
89            .map(|v| Mutex::new(LazyPyObject::Unpickled(v)))
90            .collect();
91
92        let mut inner =
93            <ValueMesh<LazyCell> as BuildFromRegion<LazyCell>>::build_dense(region, vals)
94                .map_err(|e| PyValueError::new_err(e.to_string()))?;
95
96        // Coalesce adjacent identical Python objects (same pointer
97        // identity). For Py<PyAny>, we treat equality as object
98        // identity: consecutive references to the *same* object
99        // pointer are merged into RLE runs. This tends to compress
100        // sentinel/categorical/boolean data, but not freshly
101        // allocated numerics/strings.
102        compress(&mut inner);
103
104        Ok(Self { inner })
105    }
106
107    /// Return number of ranks (Python: len(vm))
108    fn __len__(&self) -> usize {
109        self.inner.region().num_ranks()
110    }
111
112    /// Expose the shape so Python MeshTrait methods can access labels/ndslice.
113    #[getter]
114    fn _shape(&self) -> PyShape {
115        PyShape::from(ndslice::Shape::from(self.inner.region().clone()))
116    }
117
118    /// Return the values in region/iteration order as a Python list.
119    fn values(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
120        let n = self.inner.region().num_ranks();
121        let mut vec: Vec<Py<PyAny>> = Vec::with_capacity(n);
122        for rank in 0..n {
123            vec.push(LazyPyObject::resolve(self.inner.get(rank).unwrap(), py)?);
124        }
125        Ok(PyList::new(py, vec)?.into())
126    }
127
128    /// Get value by linear rank (0..num_ranks-1).
129    fn get(&self, py: Python<'_>, rank: usize) -> PyResult<Py<PyAny>> {
130        let n = self.inner.region().num_ranks();
131        if rank >= n {
132            return Err(PyValueError::new_err(format!(
133                "index {} out of range (len={})",
134                rank, n
135            )));
136        }
137
138        LazyPyObject::resolve(self.inner.get(rank).unwrap(), py)
139    }
140
141    /// Build from (rank, value) pairs with last-write-wins semantics.
142    #[staticmethod]
143    fn from_indexed(
144        _py: Python<'_>,
145        shape: &PyShape,
146        pairs: Vec<(usize, Py<PyAny>)>,
147    ) -> PyResult<Self> {
148        // Preserve the shape's original Slice (offset/strides).
149        let s = shape.get_inner();
150        let region = Region::new(s.labels().to_vec(), s.slice().clone());
151        let lazy_pairs: Vec<(usize, LazyCell)> = pairs
152            .into_iter()
153            .map(|(rank, obj)| (rank, Mutex::new(LazyPyObject::Unpickled(obj))))
154            .collect();
155        let mut inner = <ValueMesh<LazyCell> as ndslice::view::BuildFromRegionIndexed<
156            LazyCell,
157        >>::build_indexed(region, lazy_pairs)
158        .map_err(|e| PyValueError::new_err(e.to_string()))?;
159
160        // Coalesce adjacent identical Python objects (same pointer
161        // identity). For Py<PyAny>, we treat equality as object
162        // identity: consecutive references to the *same* object
163        // pointer are merged into RLE runs. This tends to compress
164        // sentinel/categorical/boolean data, but not freshly
165        // allocated numerics/strings.
166        compress(&mut inner);
167
168        Ok(Self { inner })
169    }
170}
171
172impl PyValueMesh {
173    /// Create a lazy ValueMesh from an extent and raw pickled parts.
174    /// Values are unpickled on demand when accessed via `get()` or `values()`.
175    pub fn build_from_parts(extent: &Extent, parts: Vec<Part>) -> PyResult<Self> {
176        let lazy_values: Vec<LazyCell> = parts
177            .into_iter()
178            .map(|p| Mutex::new(LazyPyObject::Pickled(p)))
179            .collect();
180        let mut inner = <ValueMesh<LazyCell> as BuildFromRegion<LazyCell>>::build_dense(
181            ndslice::View::region(extent),
182            lazy_values,
183        )
184        .map_err(|e| PyValueError::new_err(e.to_string()))?;
185        compress(&mut inner);
186
187        Ok(Self { inner })
188    }
189}
190
191/// Test helper: create a ValueMesh entirely from Rust and return it to Python.
192/// This lets us verify that Python extension methods (patched via @rust_struct)
193/// are available on objects returned from Rust functions.
194#[pyfunction]
195fn _make_test_value_mesh(
196    labels: Vec<String>,
197    sizes: Vec<usize>,
198    values: Bound<'_, PyList>,
199) -> PyResult<PyValueMesh> {
200    let strides: Vec<usize> = {
201        let mut s = vec![1usize; sizes.len()];
202        for i in (0..sizes.len().saturating_sub(1)).rev() {
203            s[i] = s[i + 1] * sizes[i + 1];
204        }
205        s
206    };
207    let slice =
208        ndslice::Slice::new(0, sizes, strides).map_err(|e| PyValueError::new_err(e.to_string()))?;
209    let region = Region::new(labels, slice);
210    let vals: Vec<LazyCell> = values
211        .extract::<Vec<Py<PyAny>>>()?
212        .into_iter()
213        .map(|v| Mutex::new(LazyPyObject::Unpickled(v)))
214        .collect();
215    let mut inner = <ValueMesh<LazyCell> as BuildFromRegion<LazyCell>>::build_dense(region, vals)
216        .map_err(|e| PyValueError::new_err(e.to_string()))?;
217    compress(&mut inner);
218    Ok(PyValueMesh { inner })
219}
220
221pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
222    module.add_class::<PyValueMesh>()?;
223    module.add_function(wrap_pyfunction!(_make_test_value_mesh, module)?)?;
224    Ok(())
225}