monarch_hyperactor/
proc_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::fmt::Debug;
10use std::ops::Deref;
11
12use hyperactor_mesh::ProcMesh;
13use hyperactor_mesh::ProcMeshRef;
14use hyperactor_mesh::shared_cell::SharedCell;
15use monarch_types::PickledPyObject;
16use monarch_types::py_module_add_function;
17use ndslice::View;
18use ndslice::view::RankedSliceable;
19use pyo3::IntoPyObjectExt;
20use pyo3::exceptions::PyException;
21use pyo3::exceptions::PyRuntimeError;
22use pyo3::exceptions::PyValueError;
23use pyo3::prelude::*;
24use pyo3::types::PyBytes;
25use pyo3::types::PyType;
26
27use crate::actor::PythonActorParams;
28use crate::actor_mesh::PythonActorMesh;
29use crate::actor_mesh::PythonActorMeshImpl;
30use crate::actor_mesh::SupervisableActorMesh;
31use crate::alloc::PyAlloc;
32use crate::context::PyInstance;
33use crate::pickle::PendingMessage;
34use crate::pytokio::PyPythonTask;
35use crate::pytokio::PyShared;
36use crate::runtime::get_tokio_runtime;
37use crate::runtime::monarch_with_gil;
38use crate::runtime::monarch_with_gil_blocking;
39use crate::shape::PyRegion;
40
41#[pyclass(
42    name = "ProcMesh",
43    module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
44)]
45pub enum PyProcMesh {
46    Owned(PyProcMeshImpl),
47    Ref(PyProcMeshRefImpl),
48}
49
50impl PyProcMesh {
51    pub fn new_owned(inner: ProcMesh) -> Self {
52        Self::Owned(PyProcMeshImpl(inner.into()))
53    }
54
55    pub(crate) fn new_ref(inner: ProcMeshRef) -> Self {
56        Self::Ref(PyProcMeshRefImpl(inner))
57    }
58
59    pub fn mesh_ref(&self) -> PyResult<ProcMeshRef> {
60        match self {
61            PyProcMesh::Owned(inner) => Ok(inner
62                .0
63                .borrow()
64                .map_err(|_| PyRuntimeError::new_err("`ProcMesh` has already been stopped"))?
65                .clone()),
66            PyProcMesh::Ref(inner) => Ok(inner.0.clone()),
67        }
68    }
69}
70
71#[pymethods]
72impl PyProcMesh {
73    #[classmethod]
74    fn allocate_nonblocking<'py>(
75        _cls: &Bound<'_, PyType>,
76        _py: Python<'py>,
77        instance: &PyInstance,
78        alloc: &mut PyAlloc,
79        name: String,
80    ) -> PyResult<PyPythonTask> {
81        let alloc = match alloc.take() {
82            Some(alloc) => alloc,
83            None => {
84                return Err(PyException::new_err(
85                    "Alloc object already used".to_string(),
86                ));
87            }
88        };
89        let instance = instance.clone();
90        PyPythonTask::new(async move {
91            let mesh = ProcMesh::allocate(instance.deref(), alloc, &name)
92                .await
93                .map_err(|err| PyException::new_err(err.to_string()))?;
94            Ok(Self::new_owned(mesh))
95        })
96    }
97
98    #[staticmethod]
99    #[pyo3(signature = (proc_mesh, instance, name, actor, init_message, emulated, supervision_display_name = None))]
100    fn spawn_async(
101        proc_mesh: &mut PyShared,
102        instance: &PyInstance,
103        name: String,
104        actor: Py<PyType>,
105        init_message: &mut PendingMessage,
106        emulated: bool,
107        supervision_display_name: Option<String>,
108    ) -> PyResult<Py<PyAny>> {
109        let init_message = init_message.take()?;
110        let task = proc_mesh.task()?.take_task()?;
111        let instance = instance.clone();
112        let mesh_impl = async move {
113            let proc_mesh = task.await?;
114
115            let init_message = init_message.resolve().await?;
116
117            let (proc_mesh, params) = monarch_with_gil(|py| -> PyResult<_> {
118                let slf: Bound<PyProcMesh> = proc_mesh.extract(py)?;
119                let slf = slf.borrow();
120                let pickled_type = PickledPyObject::pickle(actor.bind(py).as_any())?;
121                Ok((
122                    slf.mesh_ref()?.clone(),
123                    PythonActorParams::new(pickled_type, Some(init_message)),
124                ))
125            })
126            .await?;
127
128            let full_name = hyperactor_mesh::Name::new(name).unwrap();
129            let actor_mesh = proc_mesh
130                .spawn_with_name(
131                    instance.deref(),
132                    full_name,
133                    &params,
134                    supervision_display_name,
135                    false,
136                )
137                .await
138                .map_err(anyhow::Error::from)?;
139            Ok::<_, PyErr>(Box::new(PythonActorMeshImpl::new_owned(actor_mesh)))
140        };
141        if emulated {
142            // we give up on doing mesh spawn async for the emulated old version
143            // it is too complicated to make both work.
144            let r = get_tokio_runtime().block_on(mesh_impl)?;
145            monarch_with_gil_blocking(|py| r.into_py_any(py))
146        } else {
147            let r = PythonActorMesh::new(
148                async move {
149                    let mesh_impl: Box<dyn SupervisableActorMesh> = mesh_impl.await?;
150                    Ok(mesh_impl)
151                },
152                true,
153            );
154            monarch_with_gil_blocking(|py| r.into_py_any(py))
155        }
156    }
157
158    fn __repr__(&self) -> PyResult<String> {
159        match self {
160            PyProcMesh::Owned(inner) => Ok(format!("<ProcMesh: {:?}>", inner.__repr__()?)),
161            PyProcMesh::Ref(inner) => Ok(format!("<ProcMesh: {:?}>", inner.__repr__()?)),
162        }
163    }
164
165    fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
166        let bytes = bincode::serialize(&self.mesh_ref()?)
167            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
168        let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
169        let from_bytes =
170            PyModule::import(py, "monarch._rust_bindings.monarch_hyperactor.proc_mesh")?
171                .getattr("py_proc_mesh_from_bytes")?;
172        Ok((from_bytes, py_bytes))
173    }
174
175    #[getter]
176    fn region(&self) -> PyResult<PyRegion> {
177        Ok(self.mesh_ref()?.region().into())
178    }
179
180    fn stop_nonblocking(&self, instance: &PyInstance, reason: String) -> PyResult<PyPythonTask> {
181        // Clone the necessary fields from self to avoid capturing self in the async block
182        let (owned_inner, instance) = monarch_with_gil_blocking(|_py| {
183            let owned_inner = match self {
184                PyProcMesh::Owned(inner) => inner.clone(),
185                PyProcMesh::Ref(_) => {
186                    return Err(PyValueError::new_err(
187                        "ProcMesh is not owned; must be stopped by an owner",
188                    ));
189                }
190            };
191
192            let instance = instance.clone();
193            Ok((owned_inner, instance))
194        })?;
195        PyPythonTask::new(async move {
196            let mesh = owned_inner.0.take().await;
197            match mesh {
198                Ok(mut mesh) => mesh
199                    .stop(instance.deref(), reason)
200                    .await
201                    .map_err(|e| PyValueError::new_err(format!("error stopping mesh: {}", e))),
202                Err(e) => {
203                    // Don't return an exception, silently ignore the stop request
204                    // because it was already done.
205                    tracing::info!("proc mesh already stopped: {}", e);
206                    Ok(())
207                }
208            }
209        })
210    }
211
212    fn sliced(&self, region: &PyRegion) -> PyResult<Self> {
213        Ok(Self::new_ref(
214            self.mesh_ref()?.sliced(region.as_inner().clone()),
215        ))
216    }
217}
218
219#[derive(Clone)]
220#[pyclass(
221    name = "ProcMeshImpl",
222    module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
223)]
224pub struct PyProcMeshImpl(SharedCell<ProcMesh>);
225
226impl PyProcMeshImpl {
227    fn __repr__(&self) -> PyResult<String> {
228        Ok(format!(
229            "<ProcMeshImpl {:?}>",
230            *self.0.borrow().map_err(anyhow::Error::from)?
231        ))
232    }
233}
234
235#[derive(Debug, Clone)]
236#[pyclass(
237    name = "ProcMeshRefImpl",
238    module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
239)]
240pub struct PyProcMeshRefImpl(ProcMeshRef);
241
242impl PyProcMeshRefImpl {
243    fn __repr__(&self) -> PyResult<String> {
244        Ok(format!("<ProcMeshRefImpl {:?}>", self.0))
245    }
246}
247
248#[pyfunction]
249fn py_proc_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<PyProcMesh> {
250    let r: PyResult<ProcMeshRef> = bincode::deserialize(bytes.as_bytes())
251        .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()));
252    r.map(PyProcMesh::new_ref)
253}
254
255pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
256    hyperactor_mod.add_class::<PyProcMesh>()?;
257    py_module_add_function!(
258        hyperactor_mod,
259        "monarch._rust_bindings.monarch_hyperactor.proc_mesh",
260        py_proc_mesh_from_bytes
261    );
262    Ok(())
263}