monarch_hyperactor/
py_cell.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::fmt;
10use std::mem::take;
11use std::sync::Mutex;
12
13use pyo3::Py;
14use pyo3::PyClass;
15use pyo3::PyResult;
16use pyo3::Python;
17
18/// A PyCell holds a `#[pyclass]` value constructed on the Rust heap;
19/// when it is first used, it is moved to the Python heap.
20pub struct PyCell<T> {
21    inner: Mutex<PyCellState<T>>,
22}
23
24impl<T> fmt::Debug for PyCell<T> {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        f.debug_struct("PyCell").finish_non_exhaustive()
27    }
28}
29
30#[derive(Default)]
31enum PyCellState<T> {
32    #[default]
33    Invalid,
34
35    Rust(T),
36    Python(Py<T>),
37}
38
39impl<T> PyCell<T>
40where
41    T: PyClass,
42{
43    /// Create a new PyCell with a Rust-owned value.
44    pub fn new(value: T) -> Self {
45        Self {
46            inner: Mutex::new(PyCellState::Rust(value)),
47        }
48    }
49
50    /// Clone the PyCell, returning a reference to the Python-owned value.
51    pub fn clone_ref(&self, py: Python<'_>) -> PyResult<Py<T>>
52    where
53        T: Into<pyo3::PyClassInitializer<T>>,
54    {
55        let mut inner = self.inner.lock().unwrap();
56
57        match take(&mut *inner) {
58            PyCellState::Rust(value) => {
59                let py_value = Py::new(py, value)?;
60                *inner = PyCellState::Python(py_value.clone_ref(py));
61                Ok(py_value)
62            }
63            PyCellState::Python(py_value) => {
64                *inner = PyCellState::Python(py_value.clone_ref(py));
65                Ok(py_value)
66            }
67            PyCellState::Invalid => panic!("invalid state"),
68        }
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use pyo3::prelude::*;
75
76    use super::*;
77
78    #[pyclass]
79    struct TestClass {
80        #[allow(dead_code)]
81        value: i32,
82    }
83
84    #[test]
85    fn test_clone_ref() {
86        Python::initialize();
87        Python::attach(|py| {
88            let cell = PyCell::new(TestClass { value: 42 });
89
90            let py_obj1 = cell.clone_ref(py).unwrap();
91            let py_obj2 = cell.clone_ref(py).unwrap();
92
93            // These are the same:
94            assert!(py_obj1.is(&py_obj2));
95        });
96    }
97}