monarch_hyperactor/
py_cell.rs1use std::fmt;
10use std::mem::take;
11use std::sync::Mutex;
12
13use pyo3::Py;
14use pyo3::PyClass;
15use pyo3::PyResult;
16use pyo3::Python;
17
18pub struct PyCell<T> {
21 inner: Mutex<PyCellState<T>>,
22}
23
24impl<T> fmt::Debug for PyCell<T> {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 f.debug_struct("PyCell").finish_non_exhaustive()
27 }
28}
29
30#[derive(Default)]
31enum PyCellState<T> {
32 #[default]
33 Invalid,
34
35 Rust(T),
36 Python(Py<T>),
37}
38
39impl<T> PyCell<T>
40where
41 T: PyClass,
42{
43 pub fn new(value: T) -> Self {
45 Self {
46 inner: Mutex::new(PyCellState::Rust(value)),
47 }
48 }
49
50 pub fn clone_ref(&self, py: Python<'_>) -> PyResult<Py<T>>
52 where
53 T: Into<pyo3::PyClassInitializer<T>>,
54 {
55 let mut inner = self.inner.lock().unwrap();
56
57 match take(&mut *inner) {
58 PyCellState::Rust(value) => {
59 let py_value = Py::new(py, value)?;
60 *inner = PyCellState::Python(py_value.clone_ref(py));
61 Ok(py_value)
62 }
63 PyCellState::Python(py_value) => {
64 *inner = PyCellState::Python(py_value.clone_ref(py));
65 Ok(py_value)
66 }
67 PyCellState::Invalid => panic!("invalid state"),
68 }
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use pyo3::prelude::*;
75
76 use super::*;
77
78 #[pyclass]
79 struct TestClass {
80 #[allow(dead_code)]
81 value: i32,
82 }
83
84 #[test]
85 fn test_clone_ref() {
86 Python::initialize();
87 Python::attach(|py| {
88 let cell = PyCell::new(TestClass { value: 42 });
89
90 let py_obj1 = cell.clone_ref(py).unwrap();
91 let py_obj2 = cell.clone_ref(py).unwrap();
92
93 assert!(py_obj1.is(&py_obj2));
95 });
96 }
97}