Skip to main content

monarch_hyperactor/
proc_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::fmt::Debug;
10use std::ops::Deref;
11
12use hyperactor::id::Label;
13use hyperactor_mesh::ProcMesh;
14use hyperactor_mesh::ProcMeshRef;
15use hyperactor_mesh::mesh_id::ActorMeshId;
16use hyperactor_mesh::shared_cell::SharedCell;
17use monarch_types::PickledPyObject;
18use monarch_types::py_module_add_function;
19use ndslice::View;
20use ndslice::view::RankedSliceable;
21use pyo3::IntoPyObjectExt;
22use pyo3::exceptions::PyRuntimeError;
23use pyo3::exceptions::PyValueError;
24use pyo3::prelude::*;
25use pyo3::types::PyBytes;
26use pyo3::types::PyType;
27
28use crate::actor::PythonActorParams;
29use crate::actor_mesh::PythonActorMesh;
30use crate::actor_mesh::PythonActorMeshImpl;
31use crate::actor_mesh::SupervisableActorMesh;
32use crate::context::PyInstance;
33use crate::pickle::PendingMessage;
34use crate::pytokio::PyPythonTask;
35use crate::pytokio::PyShared;
36use crate::runtime::get_tokio_runtime;
37use crate::runtime::monarch_with_gil;
38use crate::runtime::monarch_with_gil_blocking;
39use crate::shape::PyRegion;
40
41#[pyclass(
42    name = "ProcMesh",
43    module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
44)]
45#[expect(
46    clippy::large_enum_variant,
47    reason = "PyO3 #[pyclass] enum; Box wrapping interacts with PyO3 codegen and Python interop — separate diff"
48)]
49pub enum PyProcMesh {
50    Owned(PyProcMeshImpl),
51    Ref(PyProcMeshRefImpl),
52}
53
54impl PyProcMesh {
55    pub fn new_owned(inner: ProcMesh) -> Self {
56        Self::Owned(PyProcMeshImpl(inner.into()))
57    }
58
59    pub(crate) fn new_ref(inner: ProcMeshRef) -> Self {
60        Self::Ref(PyProcMeshRefImpl(inner))
61    }
62
63    pub fn mesh_ref(&self) -> PyResult<ProcMeshRef> {
64        match self {
65            PyProcMesh::Owned(inner) => Ok(inner
66                .0
67                .borrow()
68                .map_err(|_| PyRuntimeError::new_err("`ProcMesh` has already been stopped"))?
69                .clone()),
70            PyProcMesh::Ref(inner) => Ok(inner.0.clone()),
71        }
72    }
73}
74
75#[pymethods]
76impl PyProcMesh {
77    #[staticmethod]
78    #[pyo3(signature = (proc_mesh, instance, mesh_base_name, actor, init_message, emulated, supervision_display_name = None))]
79    fn spawn_async(
80        proc_mesh: &mut PyShared,
81        instance: &PyInstance,
82        mesh_base_name: String,
83        actor: Py<PyType>,
84        init_message: &mut PendingMessage,
85        emulated: bool,
86        supervision_display_name: Option<String>,
87    ) -> PyResult<Py<PyAny>> {
88        let init_message = init_message.take()?;
89        let task = proc_mesh.task()?.take_task()?;
90        let instance = instance.clone();
91        let mesh_impl = async move {
92            let proc_mesh = task.await?;
93
94            let init_message = init_message.resolve().await?;
95
96            let (proc_mesh, params) = monarch_with_gil(|py| -> PyResult<_> {
97                let slf: Bound<PyProcMesh> = proc_mesh.extract(py)?;
98                let slf = slf.borrow();
99                let pickled_type = PickledPyObject::pickle(actor.bind(py).as_any())?;
100                Ok((
101                    slf.mesh_ref()?.clone(),
102                    // Plumb `mesh_base_name` into the actor so the
103                    // direct actor-handled supervision path can
104                    // populate `MeshFailure.actor_mesh_name` without
105                    // a lookup. Kept separate from
106                    // `supervision_display_name` below (rendered
107                    // supervision display string).
108                    PythonActorParams::new(
109                        pickled_type,
110                        Some(init_message),
111                        Some(mesh_base_name.clone()),
112                    ),
113                ))
114            })
115            .await?;
116
117            let mesh_name = ActorMeshId::instance(Label::strip(&mesh_base_name));
118            let actor_mesh = proc_mesh
119                .spawn_with_name(
120                    instance.deref(),
121                    mesh_name,
122                    &params,
123                    supervision_display_name,
124                    false,
125                )
126                .await
127                .map_err(anyhow::Error::from)?;
128            Ok::<_, PyErr>(Box::new(PythonActorMeshImpl::new_owned(actor_mesh)))
129        };
130        if emulated {
131            // we give up on doing mesh spawn async for the emulated old version
132            // it is too complicated to make both work.
133            let r = get_tokio_runtime().block_on(mesh_impl)?;
134            monarch_with_gil_blocking(|py| r.into_py_any(py))
135        } else {
136            let r = PythonActorMesh::new(
137                async move {
138                    let mesh_impl: Box<dyn SupervisableActorMesh> = mesh_impl.await?;
139                    Ok(mesh_impl)
140                },
141                true,
142            );
143            monarch_with_gil_blocking(|py| r.into_py_any(py))
144        }
145    }
146
147    fn __repr__(&self) -> PyResult<String> {
148        match self {
149            PyProcMesh::Owned(inner) => Ok(format!("<ProcMesh: {:?}>", inner.__repr__()?)),
150            PyProcMesh::Ref(inner) => Ok(format!("<ProcMesh: {:?}>", inner.__repr__()?)),
151        }
152    }
153
154    fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
155        let bytes = bincode::serde::encode_to_vec(&self.mesh_ref()?, bincode::config::legacy())
156            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
157        let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
158        let from_bytes =
159            PyModule::import(py, "monarch._rust_bindings.monarch_hyperactor.proc_mesh")?
160                .getattr("py_proc_mesh_from_bytes")?;
161        Ok((from_bytes, py_bytes))
162    }
163
164    #[getter]
165    fn region(&self) -> PyResult<PyRegion> {
166        Ok(self.mesh_ref()?.region().into())
167    }
168
169    fn stop_nonblocking(&self, instance: &PyInstance, reason: String) -> PyResult<PyPythonTask> {
170        // Clone the necessary fields from self to avoid capturing self in the async block
171        let (owned_inner, instance) = monarch_with_gil_blocking(|_py| {
172            let owned_inner = match self {
173                PyProcMesh::Owned(inner) => inner.clone(),
174                PyProcMesh::Ref(_) => {
175                    return Err(PyValueError::new_err(
176                        "ProcMesh is not owned; must be stopped by an owner",
177                    ));
178                }
179            };
180
181            let instance = instance.clone();
182            Ok((owned_inner, instance))
183        })?;
184        PyPythonTask::new(async move {
185            let mesh = owned_inner.0.take().await;
186            match mesh {
187                Ok(mut mesh) => mesh
188                    .stop(instance.deref(), reason)
189                    .await
190                    .map_err(|e| PyValueError::new_err(format!("error stopping mesh: {}", e))),
191                Err(e) => {
192                    // Don't return an exception, silently ignore the stop request
193                    // because it was already done.
194                    tracing::info!("proc mesh already stopped: {}", e);
195                    Ok(())
196                }
197            }
198        })
199    }
200
201    fn sliced(&self, region: &PyRegion) -> PyResult<Self> {
202        Ok(Self::new_ref(
203            self.mesh_ref()?.sliced(region.as_inner().clone()),
204        ))
205    }
206}
207
208#[derive(Clone)]
209#[pyclass(
210    name = "ProcMeshImpl",
211    module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
212)]
213pub struct PyProcMeshImpl(SharedCell<ProcMesh>);
214
215impl PyProcMeshImpl {
216    fn __repr__(&self) -> PyResult<String> {
217        Ok(format!(
218            "<ProcMeshImpl {:?}>",
219            *self.0.borrow().map_err(anyhow::Error::from)?
220        ))
221    }
222}
223
224#[derive(Debug, Clone)]
225#[pyclass(
226    name = "ProcMeshRefImpl",
227    module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
228)]
229pub struct PyProcMeshRefImpl(ProcMeshRef);
230
231impl PyProcMeshRefImpl {
232    fn __repr__(&self) -> PyResult<String> {
233        Ok(format!("<ProcMeshRefImpl {:?}>", self.0))
234    }
235}
236
237#[pyfunction]
238fn py_proc_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<PyProcMesh> {
239    let r: PyResult<ProcMeshRef> =
240        bincode::serde::decode_from_slice(bytes.as_bytes(), bincode::config::legacy())
241            .map(|(v, _)| v)
242            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()));
243    r.map(PyProcMesh::new_ref)
244}
245
246pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
247    hyperactor_mod.add_class::<PyProcMesh>()?;
248    py_module_add_function!(
249        hyperactor_mod,
250        "monarch._rust_bindings.monarch_hyperactor.proc_mesh",
251        py_proc_mesh_from_bytes
252    );
253    Ok(())
254}