monarch_hyperactor/
proc_mesh.rs1use std::fmt::Debug;
10use std::ops::Deref;
11
12use hyperactor::id::Label;
13use hyperactor_mesh::ProcMesh;
14use hyperactor_mesh::ProcMeshRef;
15use hyperactor_mesh::mesh_id::ActorMeshId;
16use hyperactor_mesh::shared_cell::SharedCell;
17use monarch_types::PickledPyObject;
18use monarch_types::py_module_add_function;
19use ndslice::View;
20use ndslice::view::RankedSliceable;
21use pyo3::IntoPyObjectExt;
22use pyo3::exceptions::PyRuntimeError;
23use pyo3::exceptions::PyValueError;
24use pyo3::prelude::*;
25use pyo3::types::PyBytes;
26use pyo3::types::PyType;
27
28use crate::actor::PythonActorParams;
29use crate::actor_mesh::PythonActorMesh;
30use crate::actor_mesh::PythonActorMeshImpl;
31use crate::actor_mesh::SupervisableActorMesh;
32use crate::context::PyInstance;
33use crate::pickle::PendingMessage;
34use crate::pytokio::PyPythonTask;
35use crate::pytokio::PyShared;
36use crate::runtime::get_tokio_runtime;
37use crate::runtime::monarch_with_gil;
38use crate::runtime::monarch_with_gil_blocking;
39use crate::shape::PyRegion;
40
41#[pyclass(
42 name = "ProcMesh",
43 module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
44)]
45#[expect(
46 clippy::large_enum_variant,
47 reason = "PyO3 #[pyclass] enum; Box wrapping interacts with PyO3 codegen and Python interop — separate diff"
48)]
49pub enum PyProcMesh {
50 Owned(PyProcMeshImpl),
51 Ref(PyProcMeshRefImpl),
52}
53
54impl PyProcMesh {
55 pub fn new_owned(inner: ProcMesh) -> Self {
56 Self::Owned(PyProcMeshImpl(inner.into()))
57 }
58
59 pub(crate) fn new_ref(inner: ProcMeshRef) -> Self {
60 Self::Ref(PyProcMeshRefImpl(inner))
61 }
62
63 pub fn mesh_ref(&self) -> PyResult<ProcMeshRef> {
64 match self {
65 PyProcMesh::Owned(inner) => Ok(inner
66 .0
67 .borrow()
68 .map_err(|_| PyRuntimeError::new_err("`ProcMesh` has already been stopped"))?
69 .clone()),
70 PyProcMesh::Ref(inner) => Ok(inner.0.clone()),
71 }
72 }
73}
74
75#[pymethods]
76impl PyProcMesh {
77 #[staticmethod]
78 #[pyo3(signature = (proc_mesh, instance, mesh_base_name, actor, init_message, emulated, supervision_display_name = None))]
79 fn spawn_async(
80 proc_mesh: &mut PyShared,
81 instance: &PyInstance,
82 mesh_base_name: String,
83 actor: Py<PyType>,
84 init_message: &mut PendingMessage,
85 emulated: bool,
86 supervision_display_name: Option<String>,
87 ) -> PyResult<Py<PyAny>> {
88 let init_message = init_message.take()?;
89 let task = proc_mesh.task()?.take_task()?;
90 let instance = instance.clone();
91 let mesh_impl = async move {
92 let proc_mesh = task.await?;
93
94 let init_message = init_message.resolve().await?;
95
96 let (proc_mesh, params) = monarch_with_gil(|py| -> PyResult<_> {
97 let slf: Bound<PyProcMesh> = proc_mesh.extract(py)?;
98 let slf = slf.borrow();
99 let pickled_type = PickledPyObject::pickle(actor.bind(py).as_any())?;
100 Ok((
101 slf.mesh_ref()?.clone(),
102 PythonActorParams::new(
109 pickled_type,
110 Some(init_message),
111 Some(mesh_base_name.clone()),
112 ),
113 ))
114 })
115 .await?;
116
117 let mesh_name = ActorMeshId::instance(Label::strip(&mesh_base_name));
118 let actor_mesh = proc_mesh
119 .spawn_with_name(
120 instance.deref(),
121 mesh_name,
122 ¶ms,
123 supervision_display_name,
124 false,
125 )
126 .await
127 .map_err(anyhow::Error::from)?;
128 Ok::<_, PyErr>(Box::new(PythonActorMeshImpl::new_owned(actor_mesh)))
129 };
130 if emulated {
131 let r = get_tokio_runtime().block_on(mesh_impl)?;
134 monarch_with_gil_blocking(|py| r.into_py_any(py))
135 } else {
136 let r = PythonActorMesh::new(
137 async move {
138 let mesh_impl: Box<dyn SupervisableActorMesh> = mesh_impl.await?;
139 Ok(mesh_impl)
140 },
141 true,
142 );
143 monarch_with_gil_blocking(|py| r.into_py_any(py))
144 }
145 }
146
147 fn __repr__(&self) -> PyResult<String> {
148 match self {
149 PyProcMesh::Owned(inner) => Ok(format!("<ProcMesh: {:?}>", inner.__repr__()?)),
150 PyProcMesh::Ref(inner) => Ok(format!("<ProcMesh: {:?}>", inner.__repr__()?)),
151 }
152 }
153
154 fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
155 let bytes = bincode::serde::encode_to_vec(&self.mesh_ref()?, bincode::config::legacy())
156 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
157 let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
158 let from_bytes =
159 PyModule::import(py, "monarch._rust_bindings.monarch_hyperactor.proc_mesh")?
160 .getattr("py_proc_mesh_from_bytes")?;
161 Ok((from_bytes, py_bytes))
162 }
163
164 #[getter]
165 fn region(&self) -> PyResult<PyRegion> {
166 Ok(self.mesh_ref()?.region().into())
167 }
168
169 fn stop_nonblocking(&self, instance: &PyInstance, reason: String) -> PyResult<PyPythonTask> {
170 let (owned_inner, instance) = monarch_with_gil_blocking(|_py| {
172 let owned_inner = match self {
173 PyProcMesh::Owned(inner) => inner.clone(),
174 PyProcMesh::Ref(_) => {
175 return Err(PyValueError::new_err(
176 "ProcMesh is not owned; must be stopped by an owner",
177 ));
178 }
179 };
180
181 let instance = instance.clone();
182 Ok((owned_inner, instance))
183 })?;
184 PyPythonTask::new(async move {
185 let mesh = owned_inner.0.take().await;
186 match mesh {
187 Ok(mut mesh) => mesh
188 .stop(instance.deref(), reason)
189 .await
190 .map_err(|e| PyValueError::new_err(format!("error stopping mesh: {}", e))),
191 Err(e) => {
192 tracing::info!("proc mesh already stopped: {}", e);
195 Ok(())
196 }
197 }
198 })
199 }
200
201 fn sliced(&self, region: &PyRegion) -> PyResult<Self> {
202 Ok(Self::new_ref(
203 self.mesh_ref()?.sliced(region.as_inner().clone()),
204 ))
205 }
206}
207
208#[derive(Clone)]
209#[pyclass(
210 name = "ProcMeshImpl",
211 module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
212)]
213pub struct PyProcMeshImpl(SharedCell<ProcMesh>);
214
215impl PyProcMeshImpl {
216 fn __repr__(&self) -> PyResult<String> {
217 Ok(format!(
218 "<ProcMeshImpl {:?}>",
219 *self.0.borrow().map_err(anyhow::Error::from)?
220 ))
221 }
222}
223
224#[derive(Debug, Clone)]
225#[pyclass(
226 name = "ProcMeshRefImpl",
227 module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
228)]
229pub struct PyProcMeshRefImpl(ProcMeshRef);
230
231impl PyProcMeshRefImpl {
232 fn __repr__(&self) -> PyResult<String> {
233 Ok(format!("<ProcMeshRefImpl {:?}>", self.0))
234 }
235}
236
237#[pyfunction]
238fn py_proc_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<PyProcMesh> {
239 let r: PyResult<ProcMeshRef> =
240 bincode::serde::decode_from_slice(bytes.as_bytes(), bincode::config::legacy())
241 .map(|(v, _)| v)
242 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()));
243 r.map(PyProcMesh::new_ref)
244}
245
246pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
247 hyperactor_mod.add_class::<PyProcMesh>()?;
248 py_module_add_function!(
249 hyperactor_mod,
250 "monarch._rust_bindings.monarch_hyperactor.proc_mesh",
251 py_proc_mesh_from_bytes
252 );
253 Ok(())
254}