monarch_hyperactor/
proc_mesh.rs1use std::fmt::Debug;
10use std::ops::Deref;
11
12use hyperactor_mesh::ProcMesh;
13use hyperactor_mesh::ProcMeshRef;
14use hyperactor_mesh::shared_cell::SharedCell;
15use monarch_types::PickledPyObject;
16use monarch_types::py_module_add_function;
17use ndslice::View;
18use ndslice::view::RankedSliceable;
19use pyo3::IntoPyObjectExt;
20use pyo3::exceptions::PyException;
21use pyo3::exceptions::PyRuntimeError;
22use pyo3::exceptions::PyValueError;
23use pyo3::prelude::*;
24use pyo3::types::PyBytes;
25use pyo3::types::PyType;
26
27use crate::actor::PythonActorParams;
28use crate::actor_mesh::PythonActorMesh;
29use crate::actor_mesh::PythonActorMeshImpl;
30use crate::actor_mesh::SupervisableActorMesh;
31use crate::alloc::PyAlloc;
32use crate::context::PyInstance;
33use crate::pickle::PendingMessage;
34use crate::pytokio::PyPythonTask;
35use crate::pytokio::PyShared;
36use crate::runtime::get_tokio_runtime;
37use crate::runtime::monarch_with_gil;
38use crate::runtime::monarch_with_gil_blocking;
39use crate::shape::PyRegion;
40
41#[pyclass(
42 name = "ProcMesh",
43 module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
44)]
45pub enum PyProcMesh {
46 Owned(PyProcMeshImpl),
47 Ref(PyProcMeshRefImpl),
48}
49
50impl PyProcMesh {
51 pub fn new_owned(inner: ProcMesh) -> Self {
52 Self::Owned(PyProcMeshImpl(inner.into()))
53 }
54
55 pub(crate) fn new_ref(inner: ProcMeshRef) -> Self {
56 Self::Ref(PyProcMeshRefImpl(inner))
57 }
58
59 pub fn mesh_ref(&self) -> PyResult<ProcMeshRef> {
60 match self {
61 PyProcMesh::Owned(inner) => Ok(inner
62 .0
63 .borrow()
64 .map_err(|_| PyRuntimeError::new_err("`ProcMesh` has already been stopped"))?
65 .clone()),
66 PyProcMesh::Ref(inner) => Ok(inner.0.clone()),
67 }
68 }
69}
70
71#[pymethods]
72impl PyProcMesh {
73 #[classmethod]
74 fn allocate_nonblocking<'py>(
75 _cls: &Bound<'_, PyType>,
76 _py: Python<'py>,
77 instance: &PyInstance,
78 alloc: &mut PyAlloc,
79 name: String,
80 ) -> PyResult<PyPythonTask> {
81 let alloc = match alloc.take() {
82 Some(alloc) => alloc,
83 None => {
84 return Err(PyException::new_err(
85 "Alloc object already used".to_string(),
86 ));
87 }
88 };
89 let instance = instance.clone();
90 PyPythonTask::new(async move {
91 let mesh = ProcMesh::allocate(instance.deref(), alloc, &name)
92 .await
93 .map_err(|err| PyException::new_err(err.to_string()))?;
94 Ok(Self::new_owned(mesh))
95 })
96 }
97
98 #[staticmethod]
99 #[pyo3(signature = (proc_mesh, instance, name, actor, init_message, emulated, supervision_display_name = None))]
100 fn spawn_async(
101 proc_mesh: &mut PyShared,
102 instance: &PyInstance,
103 name: String,
104 actor: Py<PyType>,
105 init_message: &mut PendingMessage,
106 emulated: bool,
107 supervision_display_name: Option<String>,
108 ) -> PyResult<Py<PyAny>> {
109 let init_message = init_message.take()?;
110 let task = proc_mesh.task()?.take_task()?;
111 let instance = instance.clone();
112 let mesh_impl = async move {
113 let proc_mesh = task.await?;
114
115 let init_message = init_message.resolve().await?;
116
117 let (proc_mesh, params) = monarch_with_gil(|py| -> PyResult<_> {
118 let slf: Bound<PyProcMesh> = proc_mesh.extract(py)?;
119 let slf = slf.borrow();
120 let pickled_type = PickledPyObject::pickle(actor.bind(py).as_any())?;
121 Ok((
122 slf.mesh_ref()?.clone(),
123 PythonActorParams::new(pickled_type, Some(init_message)),
124 ))
125 })
126 .await?;
127
128 let full_name = hyperactor_mesh::Name::new(name).unwrap();
129 let actor_mesh = proc_mesh
130 .spawn_with_name(
131 instance.deref(),
132 full_name,
133 ¶ms,
134 supervision_display_name,
135 false,
136 )
137 .await
138 .map_err(anyhow::Error::from)?;
139 Ok::<_, PyErr>(Box::new(PythonActorMeshImpl::new_owned(actor_mesh)))
140 };
141 if emulated {
142 let r = get_tokio_runtime().block_on(mesh_impl)?;
145 monarch_with_gil_blocking(|py| r.into_py_any(py))
146 } else {
147 let r = PythonActorMesh::new(
148 async move {
149 let mesh_impl: Box<dyn SupervisableActorMesh> = mesh_impl.await?;
150 Ok(mesh_impl)
151 },
152 true,
153 );
154 monarch_with_gil_blocking(|py| r.into_py_any(py))
155 }
156 }
157
158 fn __repr__(&self) -> PyResult<String> {
159 match self {
160 PyProcMesh::Owned(inner) => Ok(format!("<ProcMesh: {:?}>", inner.__repr__()?)),
161 PyProcMesh::Ref(inner) => Ok(format!("<ProcMesh: {:?}>", inner.__repr__()?)),
162 }
163 }
164
165 fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
166 let bytes = bincode::serialize(&self.mesh_ref()?)
167 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
168 let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
169 let from_bytes =
170 PyModule::import(py, "monarch._rust_bindings.monarch_hyperactor.proc_mesh")?
171 .getattr("py_proc_mesh_from_bytes")?;
172 Ok((from_bytes, py_bytes))
173 }
174
175 #[getter]
176 fn region(&self) -> PyResult<PyRegion> {
177 Ok(self.mesh_ref()?.region().into())
178 }
179
180 fn stop_nonblocking(&self, instance: &PyInstance, reason: String) -> PyResult<PyPythonTask> {
181 let (owned_inner, instance) = monarch_with_gil_blocking(|_py| {
183 let owned_inner = match self {
184 PyProcMesh::Owned(inner) => inner.clone(),
185 PyProcMesh::Ref(_) => {
186 return Err(PyValueError::new_err(
187 "ProcMesh is not owned; must be stopped by an owner",
188 ));
189 }
190 };
191
192 let instance = instance.clone();
193 Ok((owned_inner, instance))
194 })?;
195 PyPythonTask::new(async move {
196 let mesh = owned_inner.0.take().await;
197 match mesh {
198 Ok(mut mesh) => mesh
199 .stop(instance.deref(), reason)
200 .await
201 .map_err(|e| PyValueError::new_err(format!("error stopping mesh: {}", e))),
202 Err(e) => {
203 tracing::info!("proc mesh already stopped: {}", e);
206 Ok(())
207 }
208 }
209 })
210 }
211
212 fn sliced(&self, region: &PyRegion) -> PyResult<Self> {
213 Ok(Self::new_ref(
214 self.mesh_ref()?.sliced(region.as_inner().clone()),
215 ))
216 }
217}
218
219#[derive(Clone)]
220#[pyclass(
221 name = "ProcMeshImpl",
222 module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
223)]
224pub struct PyProcMeshImpl(SharedCell<ProcMesh>);
225
226impl PyProcMeshImpl {
227 fn __repr__(&self) -> PyResult<String> {
228 Ok(format!(
229 "<ProcMeshImpl {:?}>",
230 *self.0.borrow().map_err(anyhow::Error::from)?
231 ))
232 }
233}
234
235#[derive(Debug, Clone)]
236#[pyclass(
237 name = "ProcMeshRefImpl",
238 module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
239)]
240pub struct PyProcMeshRefImpl(ProcMeshRef);
241
242impl PyProcMeshRefImpl {
243 fn __repr__(&self) -> PyResult<String> {
244 Ok(format!("<ProcMeshRefImpl {:?}>", self.0))
245 }
246}
247
248#[pyfunction]
249fn py_proc_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<PyProcMesh> {
250 let r: PyResult<ProcMeshRef> = bincode::deserialize(bytes.as_bytes())
251 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()));
252 r.map(PyProcMesh::new_ref)
253}
254
255pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
256 hyperactor_mod.add_class::<PyProcMesh>()?;
257 py_module_add_function!(
258 hyperactor_mod,
259 "monarch._rust_bindings.monarch_hyperactor.proc_mesh",
260 py_proc_mesh_from_bytes
261 );
262 Ok(())
263}