monarch_hyperactor/
proc_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::fmt::Debug;
10use std::ops::Deref;
11
12use hyperactor_mesh::ProcMesh;
13use hyperactor_mesh::ProcMeshRef;
14use hyperactor_mesh::shared_cell::SharedCell;
15use monarch_types::PickledPyObject;
16use monarch_types::py_module_add_function;
17use ndslice::View;
18use ndslice::view::RankedSliceable;
19use pyo3::IntoPyObjectExt;
20use pyo3::exceptions::PyRuntimeError;
21use pyo3::exceptions::PyValueError;
22use pyo3::prelude::*;
23use pyo3::types::PyBytes;
24use pyo3::types::PyType;
25
26use crate::actor::PythonActorParams;
27use crate::actor_mesh::PythonActorMesh;
28use crate::actor_mesh::PythonActorMeshImpl;
29use crate::actor_mesh::SupervisableActorMesh;
30use crate::context::PyInstance;
31use crate::pickle::PendingMessage;
32use crate::pytokio::PyPythonTask;
33use crate::pytokio::PyShared;
34use crate::runtime::get_tokio_runtime;
35use crate::runtime::monarch_with_gil;
36use crate::runtime::monarch_with_gil_blocking;
37use crate::shape::PyRegion;
38
39#[pyclass(
40    name = "ProcMesh",
41    module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
42)]
43pub enum PyProcMesh {
44    Owned(PyProcMeshImpl),
45    Ref(PyProcMeshRefImpl),
46}
47
48impl PyProcMesh {
49    pub fn new_owned(inner: ProcMesh) -> Self {
50        Self::Owned(PyProcMeshImpl(inner.into()))
51    }
52
53    pub(crate) fn new_ref(inner: ProcMeshRef) -> Self {
54        Self::Ref(PyProcMeshRefImpl(inner))
55    }
56
57    pub fn mesh_ref(&self) -> PyResult<ProcMeshRef> {
58        match self {
59            PyProcMesh::Owned(inner) => Ok(inner
60                .0
61                .borrow()
62                .map_err(|_| PyRuntimeError::new_err("`ProcMesh` has already been stopped"))?
63                .clone()),
64            PyProcMesh::Ref(inner) => Ok(inner.0.clone()),
65        }
66    }
67}
68
69#[pymethods]
70impl PyProcMesh {
71    #[staticmethod]
72    #[pyo3(signature = (proc_mesh, instance, name, actor, init_message, emulated, supervision_display_name = None))]
73    fn spawn_async(
74        proc_mesh: &mut PyShared,
75        instance: &PyInstance,
76        name: String,
77        actor: Py<PyType>,
78        init_message: &mut PendingMessage,
79        emulated: bool,
80        supervision_display_name: Option<String>,
81    ) -> PyResult<Py<PyAny>> {
82        let init_message = init_message.take()?;
83        let task = proc_mesh.task()?.take_task()?;
84        let instance = instance.clone();
85        let mesh_impl = async move {
86            let proc_mesh = task.await?;
87
88            let init_message = init_message.resolve().await?;
89
90            let (proc_mesh, params) = monarch_with_gil(|py| -> PyResult<_> {
91                let slf: Bound<PyProcMesh> = proc_mesh.extract(py)?;
92                let slf = slf.borrow();
93                let pickled_type = PickledPyObject::pickle(actor.bind(py).as_any())?;
94                Ok((
95                    slf.mesh_ref()?.clone(),
96                    PythonActorParams::new(pickled_type, Some(init_message)),
97                ))
98            })
99            .await?;
100
101            let full_name = hyperactor_mesh::Name::new(name).unwrap();
102            let actor_mesh = proc_mesh
103                .spawn_with_name(
104                    instance.deref(),
105                    full_name,
106                    &params,
107                    supervision_display_name,
108                    false,
109                )
110                .await
111                .map_err(anyhow::Error::from)?;
112            Ok::<_, PyErr>(Box::new(PythonActorMeshImpl::new_owned(actor_mesh)))
113        };
114        if emulated {
115            // we give up on doing mesh spawn async for the emulated old version
116            // it is too complicated to make both work.
117            let r = get_tokio_runtime().block_on(mesh_impl)?;
118            monarch_with_gil_blocking(|py| r.into_py_any(py))
119        } else {
120            let r = PythonActorMesh::new(
121                async move {
122                    let mesh_impl: Box<dyn SupervisableActorMesh> = mesh_impl.await?;
123                    Ok(mesh_impl)
124                },
125                true,
126            );
127            monarch_with_gil_blocking(|py| r.into_py_any(py))
128        }
129    }
130
131    fn __repr__(&self) -> PyResult<String> {
132        match self {
133            PyProcMesh::Owned(inner) => Ok(format!("<ProcMesh: {:?}>", inner.__repr__()?)),
134            PyProcMesh::Ref(inner) => Ok(format!("<ProcMesh: {:?}>", inner.__repr__()?)),
135        }
136    }
137
138    fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
139        let bytes = bincode::serialize(&self.mesh_ref()?)
140            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
141        let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
142        let from_bytes =
143            PyModule::import(py, "monarch._rust_bindings.monarch_hyperactor.proc_mesh")?
144                .getattr("py_proc_mesh_from_bytes")?;
145        Ok((from_bytes, py_bytes))
146    }
147
148    #[getter]
149    fn region(&self) -> PyResult<PyRegion> {
150        Ok(self.mesh_ref()?.region().into())
151    }
152
153    fn stop_nonblocking(&self, instance: &PyInstance, reason: String) -> PyResult<PyPythonTask> {
154        // Clone the necessary fields from self to avoid capturing self in the async block
155        let (owned_inner, instance) = monarch_with_gil_blocking(|_py| {
156            let owned_inner = match self {
157                PyProcMesh::Owned(inner) => inner.clone(),
158                PyProcMesh::Ref(_) => {
159                    return Err(PyValueError::new_err(
160                        "ProcMesh is not owned; must be stopped by an owner",
161                    ));
162                }
163            };
164
165            let instance = instance.clone();
166            Ok((owned_inner, instance))
167        })?;
168        PyPythonTask::new(async move {
169            let mesh = owned_inner.0.take().await;
170            match mesh {
171                Ok(mut mesh) => mesh
172                    .stop(instance.deref(), reason)
173                    .await
174                    .map_err(|e| PyValueError::new_err(format!("error stopping mesh: {}", e))),
175                Err(e) => {
176                    // Don't return an exception, silently ignore the stop request
177                    // because it was already done.
178                    tracing::info!("proc mesh already stopped: {}", e);
179                    Ok(())
180                }
181            }
182        })
183    }
184
185    fn sliced(&self, region: &PyRegion) -> PyResult<Self> {
186        Ok(Self::new_ref(
187            self.mesh_ref()?.sliced(region.as_inner().clone()),
188        ))
189    }
190}
191
192#[derive(Clone)]
193#[pyclass(
194    name = "ProcMeshImpl",
195    module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
196)]
197pub struct PyProcMeshImpl(SharedCell<ProcMesh>);
198
199impl PyProcMeshImpl {
200    fn __repr__(&self) -> PyResult<String> {
201        Ok(format!(
202            "<ProcMeshImpl {:?}>",
203            *self.0.borrow().map_err(anyhow::Error::from)?
204        ))
205    }
206}
207
208#[derive(Debug, Clone)]
209#[pyclass(
210    name = "ProcMeshRefImpl",
211    module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
212)]
213pub struct PyProcMeshRefImpl(ProcMeshRef);
214
215impl PyProcMeshRefImpl {
216    fn __repr__(&self) -> PyResult<String> {
217        Ok(format!("<ProcMeshRefImpl {:?}>", self.0))
218    }
219}
220
221#[pyfunction]
222fn py_proc_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<PyProcMesh> {
223    let r: PyResult<ProcMeshRef> = bincode::deserialize(bytes.as_bytes())
224        .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()));
225    r.map(PyProcMesh::new_ref)
226}
227
228pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
229    hyperactor_mod.add_class::<PyProcMesh>()?;
230    py_module_add_function!(
231        hyperactor_mod,
232        "monarch._rust_bindings.monarch_hyperactor.proc_mesh",
233        py_proc_mesh_from_bytes
234    );
235    Ok(())
236}