monarch_hyperactor/
proc_mesh.rs1use std::fmt::Debug;
10use std::ops::Deref;
11
12use hyperactor_mesh::ProcMesh;
13use hyperactor_mesh::ProcMeshRef;
14use hyperactor_mesh::shared_cell::SharedCell;
15use monarch_types::PickledPyObject;
16use monarch_types::py_module_add_function;
17use ndslice::View;
18use ndslice::view::RankedSliceable;
19use pyo3::IntoPyObjectExt;
20use pyo3::exceptions::PyRuntimeError;
21use pyo3::exceptions::PyValueError;
22use pyo3::prelude::*;
23use pyo3::types::PyBytes;
24use pyo3::types::PyType;
25
26use crate::actor::PythonActorParams;
27use crate::actor_mesh::PythonActorMesh;
28use crate::actor_mesh::PythonActorMeshImpl;
29use crate::actor_mesh::SupervisableActorMesh;
30use crate::context::PyInstance;
31use crate::pickle::PendingMessage;
32use crate::pytokio::PyPythonTask;
33use crate::pytokio::PyShared;
34use crate::runtime::get_tokio_runtime;
35use crate::runtime::monarch_with_gil;
36use crate::runtime::monarch_with_gil_blocking;
37use crate::shape::PyRegion;
38
39#[pyclass(
40 name = "ProcMesh",
41 module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
42)]
43pub enum PyProcMesh {
44 Owned(PyProcMeshImpl),
45 Ref(PyProcMeshRefImpl),
46}
47
48impl PyProcMesh {
49 pub fn new_owned(inner: ProcMesh) -> Self {
50 Self::Owned(PyProcMeshImpl(inner.into()))
51 }
52
53 pub(crate) fn new_ref(inner: ProcMeshRef) -> Self {
54 Self::Ref(PyProcMeshRefImpl(inner))
55 }
56
57 pub fn mesh_ref(&self) -> PyResult<ProcMeshRef> {
58 match self {
59 PyProcMesh::Owned(inner) => Ok(inner
60 .0
61 .borrow()
62 .map_err(|_| PyRuntimeError::new_err("`ProcMesh` has already been stopped"))?
63 .clone()),
64 PyProcMesh::Ref(inner) => Ok(inner.0.clone()),
65 }
66 }
67}
68
69#[pymethods]
70impl PyProcMesh {
71 #[staticmethod]
72 #[pyo3(signature = (proc_mesh, instance, name, actor, init_message, emulated, supervision_display_name = None))]
73 fn spawn_async(
74 proc_mesh: &mut PyShared,
75 instance: &PyInstance,
76 name: String,
77 actor: Py<PyType>,
78 init_message: &mut PendingMessage,
79 emulated: bool,
80 supervision_display_name: Option<String>,
81 ) -> PyResult<Py<PyAny>> {
82 let init_message = init_message.take()?;
83 let task = proc_mesh.task()?.take_task()?;
84 let instance = instance.clone();
85 let mesh_impl = async move {
86 let proc_mesh = task.await?;
87
88 let init_message = init_message.resolve().await?;
89
90 let (proc_mesh, params) = monarch_with_gil(|py| -> PyResult<_> {
91 let slf: Bound<PyProcMesh> = proc_mesh.extract(py)?;
92 let slf = slf.borrow();
93 let pickled_type = PickledPyObject::pickle(actor.bind(py).as_any())?;
94 Ok((
95 slf.mesh_ref()?.clone(),
96 PythonActorParams::new(pickled_type, Some(init_message)),
97 ))
98 })
99 .await?;
100
101 let full_name = hyperactor_mesh::Name::new(name).unwrap();
102 let actor_mesh = proc_mesh
103 .spawn_with_name(
104 instance.deref(),
105 full_name,
106 ¶ms,
107 supervision_display_name,
108 false,
109 )
110 .await
111 .map_err(anyhow::Error::from)?;
112 Ok::<_, PyErr>(Box::new(PythonActorMeshImpl::new_owned(actor_mesh)))
113 };
114 if emulated {
115 let r = get_tokio_runtime().block_on(mesh_impl)?;
118 monarch_with_gil_blocking(|py| r.into_py_any(py))
119 } else {
120 let r = PythonActorMesh::new(
121 async move {
122 let mesh_impl: Box<dyn SupervisableActorMesh> = mesh_impl.await?;
123 Ok(mesh_impl)
124 },
125 true,
126 );
127 monarch_with_gil_blocking(|py| r.into_py_any(py))
128 }
129 }
130
131 fn __repr__(&self) -> PyResult<String> {
132 match self {
133 PyProcMesh::Owned(inner) => Ok(format!("<ProcMesh: {:?}>", inner.__repr__()?)),
134 PyProcMesh::Ref(inner) => Ok(format!("<ProcMesh: {:?}>", inner.__repr__()?)),
135 }
136 }
137
138 fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
139 let bytes = bincode::serialize(&self.mesh_ref()?)
140 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
141 let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
142 let from_bytes =
143 PyModule::import(py, "monarch._rust_bindings.monarch_hyperactor.proc_mesh")?
144 .getattr("py_proc_mesh_from_bytes")?;
145 Ok((from_bytes, py_bytes))
146 }
147
148 #[getter]
149 fn region(&self) -> PyResult<PyRegion> {
150 Ok(self.mesh_ref()?.region().into())
151 }
152
153 fn stop_nonblocking(&self, instance: &PyInstance, reason: String) -> PyResult<PyPythonTask> {
154 let (owned_inner, instance) = monarch_with_gil_blocking(|_py| {
156 let owned_inner = match self {
157 PyProcMesh::Owned(inner) => inner.clone(),
158 PyProcMesh::Ref(_) => {
159 return Err(PyValueError::new_err(
160 "ProcMesh is not owned; must be stopped by an owner",
161 ));
162 }
163 };
164
165 let instance = instance.clone();
166 Ok((owned_inner, instance))
167 })?;
168 PyPythonTask::new(async move {
169 let mesh = owned_inner.0.take().await;
170 match mesh {
171 Ok(mut mesh) => mesh
172 .stop(instance.deref(), reason)
173 .await
174 .map_err(|e| PyValueError::new_err(format!("error stopping mesh: {}", e))),
175 Err(e) => {
176 tracing::info!("proc mesh already stopped: {}", e);
179 Ok(())
180 }
181 }
182 })
183 }
184
185 fn sliced(&self, region: &PyRegion) -> PyResult<Self> {
186 Ok(Self::new_ref(
187 self.mesh_ref()?.sliced(region.as_inner().clone()),
188 ))
189 }
190}
191
192#[derive(Clone)]
193#[pyclass(
194 name = "ProcMeshImpl",
195 module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
196)]
197pub struct PyProcMeshImpl(SharedCell<ProcMesh>);
198
199impl PyProcMeshImpl {
200 fn __repr__(&self) -> PyResult<String> {
201 Ok(format!(
202 "<ProcMeshImpl {:?}>",
203 *self.0.borrow().map_err(anyhow::Error::from)?
204 ))
205 }
206}
207
208#[derive(Debug, Clone)]
209#[pyclass(
210 name = "ProcMeshRefImpl",
211 module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
212)]
213pub struct PyProcMeshRefImpl(ProcMeshRef);
214
215impl PyProcMeshRefImpl {
216 fn __repr__(&self) -> PyResult<String> {
217 Ok(format!("<ProcMeshRefImpl {:?}>", self.0))
218 }
219}
220
221#[pyfunction]
222fn py_proc_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<PyProcMesh> {
223 let r: PyResult<ProcMeshRef> = bincode::deserialize(bytes.as_bytes())
224 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()));
225 r.map(PyProcMesh::new_ref)
226}
227
228pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
229 hyperactor_mod.add_class::<PyProcMesh>()?;
230 py_module_add_function!(
231 hyperactor_mod,
232 "monarch._rust_bindings.monarch_hyperactor.proc_mesh",
233 py_proc_mesh_from_bytes
234 );
235 Ok(())
236}