monarch_hyperactor/
host_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::ops::Deref;
11use std::path::PathBuf;
12use std::sync::OnceLock;
13use std::time::Duration;
14
15use hyperactor::ActorHandle;
16use hyperactor::Instance;
17use hyperactor::Proc;
18use hyperactor_mesh::ProcMeshRef;
19use hyperactor_mesh::bootstrap::BootstrapCommand;
20use hyperactor_mesh::bootstrap::ProcBind;
21use hyperactor_mesh::bootstrap::host;
22use hyperactor_mesh::host_mesh;
23use hyperactor_mesh::host_mesh::HostMesh;
24use hyperactor_mesh::host_mesh::HostMeshRef;
25use hyperactor_mesh::host_mesh::host_agent::GetLocalProcClient;
26use hyperactor_mesh::host_mesh::host_agent::HostAgent;
27use hyperactor_mesh::host_mesh::host_agent::ShutdownHost;
28use hyperactor_mesh::proc_agent::GetProcClient;
29use hyperactor_mesh::proc_mesh::ProcRef;
30use hyperactor_mesh::shared_cell::SharedCell;
31use hyperactor_mesh::transport::default_bind_spec;
32use ndslice::View;
33use ndslice::view::RankedSliceable;
34use pyo3::IntoPyObjectExt;
35use pyo3::exceptions::PyException;
36use pyo3::exceptions::PyRuntimeError;
37use pyo3::exceptions::PyValueError;
38use pyo3::prelude::*;
39use pyo3::types::PyBytes;
40use pyo3::types::PyType;
41
42use crate::actor::PythonActor;
43use crate::actor::to_py_error;
44use crate::alloc::PyAlloc;
45use crate::context::PyInstance;
46use crate::proc_mesh::PyProcMesh;
47use crate::pytokio::PyPythonTask;
48use crate::runtime::monarch_with_gil;
49use crate::shape::PyExtent;
50use crate::shape::PyRegion;
51
52#[pyclass(
53    name = "BootstrapCommand",
54    module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
55)]
56#[derive(Clone)]
57pub struct PyBootstrapCommand {
58    #[pyo3(get, set)]
59    pub program: String,
60    #[pyo3(get, set)]
61    pub arg0: Option<String>,
62    #[pyo3(get, set)]
63    pub args: Vec<String>,
64    #[pyo3(get, set)]
65    pub env: HashMap<String, String>,
66}
67
68#[pymethods]
69impl PyBootstrapCommand {
70    #[new]
71    fn new(
72        program: String,
73        arg0: Option<String>,
74        args: Vec<String>,
75        env: HashMap<String, String>,
76    ) -> Self {
77        Self {
78            program,
79            arg0,
80            args,
81            env,
82        }
83    }
84
85    fn __repr__(&self) -> String {
86        format!(
87            "BootstrapCommand(program='{}', args={:?}, env={:?})",
88            self.program, self.args, self.env
89        )
90    }
91}
92
93impl PyBootstrapCommand {
94    pub fn to_rust(&self) -> BootstrapCommand {
95        BootstrapCommand {
96            program: PathBuf::from(&self.program),
97            arg0: self.arg0.clone(),
98            args: self.args.clone(),
99            env: self.env.clone(),
100        }
101    }
102
103    pub fn default<'py>(py: Python<'py>) -> PyResult<Bound<'py, Self>> {
104        py.import("monarch._src.actor.host_mesh")?
105            .getattr("_bootstrap_cmd")?
106            .call0()?
107            .downcast::<PyBootstrapCommand>()
108            .cloned()
109            .map_err(to_py_error)
110    }
111}
112
113#[pyclass(
114    name = "HostMesh",
115    module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
116)]
117pub(crate) enum PyHostMesh {
118    Owned(PyHostMeshImpl),
119    Ref(PyHostMeshRefImpl),
120}
121
122impl PyHostMesh {
123    pub(crate) fn new_owned(inner: HostMesh) -> Self {
124        Self::Owned(PyHostMeshImpl(SharedCell::from(inner)))
125    }
126
127    pub(crate) fn new_ref(inner: HostMeshRef) -> Self {
128        Self::Ref(PyHostMeshRefImpl(inner))
129    }
130
131    fn mesh_ref(&self) -> Result<HostMeshRef, anyhow::Error> {
132        match self {
133            PyHostMesh::Owned(inner) => Ok(inner.0.borrow()?.clone()),
134            PyHostMesh::Ref(inner) => Ok(inner.0.clone()),
135        }
136    }
137}
138
139#[pymethods]
140impl PyHostMesh {
141    #[classmethod]
142    fn allocate_nonblocking(
143        _cls: &Bound<'_, PyType>,
144        instance: &PyInstance,
145        alloc: &mut PyAlloc,
146        name: String,
147        bootstrap_params: Option<PyBootstrapCommand>,
148    ) -> PyResult<PyPythonTask> {
149        let bootstrap_params =
150            bootstrap_params.map_or_else(|| alloc.bootstrap_command.clone(), |b| Some(b.to_rust()));
151        let alloc = match alloc.take() {
152            Some(alloc) => alloc,
153            None => {
154                return Err(PyException::new_err(
155                    "Alloc object already used".to_string(),
156                ));
157            }
158        };
159        let instance = instance.clone();
160        PyPythonTask::new(async move {
161            let mesh = HostMesh::allocate(instance.deref(), alloc, &name, bootstrap_params)
162                .await
163                .map_err(|err| PyException::new_err(err.to_string()))?;
164            Ok(Self::new_owned(mesh))
165        })
166    }
167
168    #[pyo3(signature = (instance, name, per_host, proc_bind = None))]
169    fn spawn_nonblocking(
170        &self,
171        instance: &PyInstance,
172        name: String,
173        per_host: &PyExtent,
174        proc_bind: Option<Vec<HashMap<String, String>>>,
175    ) -> PyResult<PyPythonTask> {
176        let host_mesh = self.mesh_ref()?.clone();
177        let instance = instance.clone();
178        let per_host = per_host.clone().into();
179        let proc_bind = proc_bind.map(|v| v.into_iter().map(ProcBind::from).collect());
180        let mesh_impl = async move {
181            let proc_mesh = host_mesh
182                .spawn(instance.deref(), &name, per_host, proc_bind)
183                .await
184                .map_err(to_py_error)?;
185            Ok(PyProcMesh::new_owned(proc_mesh))
186        };
187        PyPythonTask::new(mesh_impl)
188    }
189
190    fn with_bootstrap(&self, bootstrap_command: &PyBootstrapCommand) -> PyResult<Self> {
191        match self {
192            PyHostMesh::Owned(inner) => {
193                let cmd = bootstrap_command.to_rust();
194                inner
195                    .0
196                    .try_with_mut(|mesh| mesh.set_bootstrap(cmd))
197                    .map_err(|e| PyException::new_err(e.to_string()))?;
198                Ok(Self::Owned(inner.clone()))
199            }
200            PyHostMesh::Ref(_) => Ok(Self::new_ref(
201                self.mesh_ref()?.with_bootstrap(bootstrap_command.to_rust()),
202            )),
203        }
204    }
205
206    fn sliced(&self, region: &PyRegion) -> PyResult<Self> {
207        Ok(Self::new_ref(
208            self.mesh_ref()?.sliced(region.as_inner().clone()),
209        ))
210    }
211
212    #[getter]
213    fn region(&self) -> PyResult<PyRegion> {
214        Ok(PyRegion::from(self.mesh_ref()?.region()))
215    }
216
217    fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
218        let bytes = bincode::serialize(&self.mesh_ref()?)
219            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
220        let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
221        let from_bytes =
222            PyModule::import(py, "monarch._rust_bindings.monarch_hyperactor.host_mesh")?
223                .getattr("py_host_mesh_from_bytes")?;
224        Ok((from_bytes, py_bytes))
225    }
226
227    fn __eq__(&self, other: &PyHostMesh) -> PyResult<bool> {
228        Ok(self.mesh_ref()? == other.mesh_ref()?)
229    }
230
231    fn shutdown(&self, instance: &PyInstance) -> PyResult<PyPythonTask> {
232        match self {
233            PyHostMesh::Owned(inner) => {
234                let instance = instance.clone();
235                let mesh_borrow = inner.0.clone();
236                let fut = async move {
237                    match mesh_borrow.take().await {
238                        Ok(mut mesh) => {
239                            mesh.shutdown(instance.deref()).await?;
240                            Ok(())
241                        }
242                        Err(_) => {
243                            // Don't return an exception, silently ignore the stop request
244                            // because it was already done.
245                            tracing::info!("shutdown was already called on host mesh");
246                            Ok(())
247                        }
248                    }
249                };
250                PyPythonTask::new(fut)
251            }
252            PyHostMesh::Ref(_) => Err(PyRuntimeError::new_err(
253                "cannot shut down `HostMesh` that is a reference instead of owned",
254            )),
255        }
256    }
257
258    fn stop(&self, instance: &PyInstance) -> PyResult<PyPythonTask> {
259        match self {
260            PyHostMesh::Owned(inner) => {
261                let instance = instance.clone();
262                let mesh_borrow = inner.0.clone();
263                let fut = async move {
264                    match mesh_borrow.take().await {
265                        Ok(mut mesh) => {
266                            mesh.stop(instance.deref()).await?;
267                            Ok(())
268                        }
269                        Err(_) => {
270                            tracing::info!("stop was already called on host mesh");
271                            Ok(())
272                        }
273                    }
274                };
275                PyPythonTask::new(fut)
276            }
277            PyHostMesh::Ref(_) => Err(PyRuntimeError::new_err(
278                "cannot stop `HostMesh` that is a reference instead of owned",
279            )),
280        }
281    }
282}
283
284#[derive(Clone)]
285#[pyclass(
286    name = "HostMeshImpl",
287    module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
288)]
289pub(crate) struct PyHostMeshImpl(SharedCell<HostMesh>);
290
291#[derive(Debug, Clone)]
292#[pyclass(
293    name = "HostMeshRefImpl",
294    module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
295)]
296pub(crate) struct PyHostMeshRefImpl(HostMeshRef);
297
298impl PyHostMeshRefImpl {
299    fn __repr__(&self) -> PyResult<String> {
300        Ok(format!("<HostMeshRefImpl {:?}>", self.0))
301    }
302}
303
304/// Static storage for the root client instance when using host-based bootstrap.
305static ROOT_CLIENT_INSTANCE_FOR_HOST: OnceLock<Instance<PythonActor>> = OnceLock::new();
306
307/// Static storage for the host mesh agent created by bootstrap_host().
308static HOST_MESH_AGENT_FOR_HOST: OnceLock<ActorHandle<HostAgent>> = OnceLock::new();
309
310/// Static storage for the host shutdown handle created by bootstrap_host().
311/// Used during shutdown_context to join the mailbox server and flush
312/// receive-side acks.
313static HOST_SHUTDOWN_HANDLE: OnceLock<
314    tokio::sync::Mutex<Option<hyperactor_mesh::bootstrap::HostShutdownHandle>>,
315> = OnceLock::new();
316
317/// Bootstrap the client host and root client actor.
318///
319/// This creates a proper Host with BootstrapProcManager, spawns the root client
320/// actor on the Host's local_proc.
321///
322/// Returns a tuple of (HostMesh, ProcMesh, PyInstance) where:
323/// - PyHostMesh: the bootstrapped (local) host mesh; and
324/// - PyProcMesh: the local ProcMesh on this HostMesh; and
325/// - PyInstance: the root client actor instance, on the ProcMesh.
326///
327/// The HostMesh is served on the default transport.
328///
329/// This should be called only once, at process initialization
330#[pyfunction]
331fn bootstrap_host(bootstrap_cmd: Option<PyBootstrapCommand>) -> PyResult<PyPythonTask> {
332    let bootstrap_cmd = match bootstrap_cmd {
333        Some(cmd) => cmd.to_rust(),
334        None => BootstrapCommand::current().map_err(|e| PyException::new_err(e.to_string()))?,
335    };
336
337    PyPythonTask::new(async move {
338        let (host_mesh_agent, shutdown_handle) = host(
339            default_bind_spec().binding_addr(),
340            Some(bootstrap_cmd),
341            None,
342            false,
343        )
344        .await
345        .map_err(|e| PyException::new_err(e.to_string()))?;
346
347        // Store the agent and shutdown handle for later shutdown
348        HOST_MESH_AGENT_FOR_HOST.set(host_mesh_agent.clone()).ok();
349        HOST_SHUTDOWN_HANDLE.get_or_init(|| tokio::sync::Mutex::new(Some(shutdown_handle)));
350
351        let host_mesh_name = hyperactor_mesh::Name::new_reserved("local").unwrap();
352        let host_mesh = HostMeshRef::from_host_agent(host_mesh_name, host_mesh_agent.bind())
353            .map_err(|e| PyException::new_err(e.to_string()))?;
354
355        // Register C so MeshAdminAgent can discover it ("A/C
356        // invariant" - hyperactor_mesh/src/mesh_admin.rs).
357        hyperactor_mesh::global_context::register_client_host(host_mesh.clone());
358
359        // We require a temporary instance to make a call to the host/proc agent.
360        let temp_proc = Proc::local();
361        let (temp_instance, _) = temp_proc
362            .instance("temp")
363            .map_err(|e| PyException::new_err(e.to_string()))?;
364
365        let local_proc_agent: hyperactor::ActorHandle<hyperactor_mesh::proc_agent::ProcAgent> =
366            host_mesh_agent
367                .get_local_proc(&temp_instance)
368                .await
369                .map_err(|e| PyException::new_err(e.to_string()))?;
370
371        let proc_mesh_name = hyperactor_mesh::Name::new_reserved("local").unwrap();
372        let proc_mesh = ProcMeshRef::new_singleton(
373            proc_mesh_name,
374            ProcRef::new(
375                local_proc_agent.actor_id().proc_id().clone(),
376                0,
377                local_proc_agent.bind(),
378            ),
379        );
380
381        let local_proc = local_proc_agent
382            .get_proc(&temp_instance)
383            .await
384            .map_err(|e| PyException::new_err(e.to_string()))?;
385
386        let (instance, _handle) = monarch_with_gil(|py| {
387            PythonActor::bootstrap_client_inner(py, local_proc, &ROOT_CLIENT_INSTANCE_FOR_HOST)
388        })
389        .await;
390
391        // Notify telemetry of the bootstrap host mesh, proc mesh, and client actor.
392        {
393            let now = std::time::SystemTime::now();
394
395            let host_name_str = host_mesh.name().to_string();
396            let host_mesh_id = hyperactor_telemetry::hash_to_u64(&host_name_str);
397            hyperactor_telemetry::notify_mesh_created(hyperactor_telemetry::MeshEvent {
398                id: host_mesh_id,
399                timestamp: now,
400                class: "Host".to_string(),
401                given_name: host_mesh.name().name().to_string(),
402                full_name: host_name_str,
403                shape_json: serde_json::to_string(&host_mesh.region().extent()).unwrap_or_default(),
404                parent_mesh_id: None,
405                parent_view_json: None,
406            });
407
408            let host_agent_id = host_mesh_agent.actor_id();
409            hyperactor_telemetry::notify_actor_created(hyperactor_telemetry::ActorEvent {
410                id: hyperactor_telemetry::hash_to_u64(host_agent_id),
411                timestamp: now,
412                mesh_id: host_mesh_id,
413                rank: 0,
414                full_name: host_agent_id.to_string(),
415                display_name: None,
416            });
417
418            let proc_name_str = proc_mesh.name().to_string();
419            let proc_mesh_id = hyperactor_telemetry::hash_to_u64(&proc_name_str);
420            hyperactor_telemetry::notify_mesh_created(hyperactor_telemetry::MeshEvent {
421                id: proc_mesh_id,
422                timestamp: now,
423                class: "Proc".to_string(),
424                given_name: proc_mesh.name().name().to_string(),
425                full_name: proc_name_str,
426                shape_json: serde_json::to_string(&proc_mesh.region().extent()).unwrap_or_default(),
427                parent_mesh_id: Some(host_mesh_id),
428                parent_view_json: None,
429            });
430
431            let proc_agent_id = local_proc_agent.actor_id();
432            hyperactor_telemetry::notify_actor_created(hyperactor_telemetry::ActorEvent {
433                id: hyperactor_telemetry::hash_to_u64(proc_agent_id),
434                timestamp: now,
435                mesh_id: proc_mesh_id,
436                rank: 0,
437                full_name: proc_agent_id.to_string(),
438                display_name: None,
439            });
440
441            let client_mesh_name = format!("{}/client", proc_mesh.name());
442            let client_mesh_id = hyperactor_telemetry::hash_to_u64(&client_mesh_name);
443            hyperactor_telemetry::notify_mesh_created(hyperactor_telemetry::MeshEvent {
444                id: client_mesh_id,
445                timestamp: now,
446                class: <PythonActor as typeuri::Named>::typename().to_string(),
447                given_name: "client".to_string(),
448                full_name: client_mesh_name,
449                shape_json: serde_json::to_string(&proc_mesh.region().extent()).unwrap_or_default(),
450                parent_mesh_id: Some(proc_mesh_id),
451                parent_view_json: None,
452            });
453
454            hyperactor_telemetry::notify_actor_created(hyperactor_telemetry::ActorEvent {
455                id: hyperactor_telemetry::hash_to_u64(instance.self_id()),
456                timestamp: now,
457                mesh_id: client_mesh_id,
458                rank: 0,
459                full_name: instance.self_id().to_string(),
460                display_name: Some("<root>".to_string()),
461            });
462        }
463
464        Ok((
465            PyHostMesh::new_ref(host_mesh),
466            PyProcMesh::new_ref(proc_mesh),
467            PyInstance::from(instance),
468        ))
469    })
470}
471
472#[pyfunction]
473fn py_host_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<PyHostMesh> {
474    let r: PyResult<HostMeshRef> = bincode::deserialize(bytes.as_bytes())
475        .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()));
476    r.map(PyHostMesh::new_ref)
477}
478
479#[pyfunction]
480fn shutdown_local_host_mesh() -> PyResult<PyPythonTask> {
481    let agent = HOST_MESH_AGENT_FOR_HOST
482        .get()
483        .ok_or_else(|| PyException::new_err("No local host mesh to shutdown"))?
484        .clone();
485
486    PyPythonTask::new(async move {
487        // Create a temporary instance to send the shutdown message
488        let temp_proc = hyperactor::Proc::local();
489        let (instance, _) = temp_proc
490            .instance("shutdown_requester")
491            .map_err(|e| PyException::new_err(e.to_string()))?;
492
493        tracing::info!(
494            "sending shutdown_host request to agent {}",
495            agent.actor_id()
496        );
497        // Use same defaults as HostMesh::shutdown():
498        // - MESH_TERMINATE_TIMEOUT = 10 seconds
499        // - MESH_TERMINATE_CONCURRENCY = 16
500
501        let (port, _) = instance.open_port();
502        let mut port = port.bind();
503        // We don't need the ack, and this temporary proc doesn't have a mailbox
504        // receiver set up anyways. Just ignore the message.
505        port.return_undeliverable(false);
506        agent
507            .send(
508                &instance,
509                ShutdownHost {
510                    timeout: Duration::from_secs(10),
511                    max_in_flight: 16,
512                    ack: port,
513                },
514            )
515            .map_err(|e| PyException::new_err(e.to_string()))?;
516
517        // Join the host's mailbox server to flush receive-side acks
518        // before the process exits.
519        if let Some(lock) = HOST_SHUTDOWN_HANDLE.get() {
520            if let Some(handle) = lock.lock().await.take() {
521                handle.join().await;
522            }
523        }
524
525        Ok(())
526    })
527}
528
529/// Spawn a MeshAdminAgent aggregating topology across one or more meshes.
530///
531/// The admin runs on the caller's local proc and serves the
532/// mesh-admin HTTP API. Returns the admin HTTP URL. When
533/// `admin_addr` is `None`, the bind address is read from
534/// `MESH_ADMIN_ADDR` config.
535///
536/// Python-facing wrapper around
537/// [`hyperactor_mesh::host_mesh::spawn_admin`].
538#[pyfunction]
539fn _spawn_admin(
540    host_meshes: Vec<PyRef<'_, PyHostMesh>>,
541    instance: &PyInstance,
542    admin_addr: Option<String>,
543    telemetry_url: Option<String>,
544) -> PyResult<PyPythonTask> {
545    if host_meshes.is_empty() {
546        return Err(PyException::new_err("at least one mesh is required"));
547    }
548
549    let admin_addr = admin_addr
550        .map(|s| {
551            s.parse::<std::net::SocketAddr>()
552                .map_err(|e| PyException::new_err(format!("invalid admin_addr '{}': {}", s, e)))
553        })
554        .transpose()?;
555
556    let mesh_refs = host_meshes
557        .iter()
558        .map(|m| -> PyResult<HostMeshRef> { Ok(m.mesh_ref()?.clone()) })
559        .collect::<PyResult<Vec<HostMeshRef>>>()?;
560
561    let instance = instance.clone();
562    PyPythonTask::new(async move {
563        let addr = host_mesh::spawn_admin(&mesh_refs, instance.deref(), admin_addr, telemetry_url)
564            .await
565            .map_err(|e| PyException::new_err(e.to_string()))?;
566        Ok(addr)
567    })
568}
569
570pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
571    let f = wrap_pyfunction!(py_host_mesh_from_bytes, hyperactor_mod)?;
572    f.setattr(
573        "__module__",
574        "monarch._rust_bindings.monarch_hyperactor.host_mesh",
575    )?;
576    hyperactor_mod.add_function(f)?;
577
578    let f2 = wrap_pyfunction!(bootstrap_host, hyperactor_mod)?;
579    f2.setattr(
580        "__module__",
581        "monarch._rust_bindings.monarch_hyperactor.host_mesh",
582    )?;
583    hyperactor_mod.add_function(f2)?;
584
585    let f3 = wrap_pyfunction!(shutdown_local_host_mesh, hyperactor_mod)?;
586    f3.setattr(
587        "__module__",
588        "monarch._rust_bindings.monarch_hyperactor.host_mesh",
589    )?;
590    hyperactor_mod.add_function(f3)?;
591
592    let f4 = wrap_pyfunction!(_spawn_admin, hyperactor_mod)?;
593    f4.setattr(
594        "__module__",
595        "monarch._rust_bindings.monarch_hyperactor.host_mesh",
596    )?;
597    hyperactor_mod.add_function(f4)?;
598
599    hyperactor_mod.add_class::<PyHostMesh>()?;
600    hyperactor_mod.add_class::<PyBootstrapCommand>()?;
601    Ok(())
602}