monarch_hyperactor/
host_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::ops::Deref;
11use std::path::PathBuf;
12use std::sync::OnceLock;
13use std::time::Duration;
14
15use hyperactor::ActorHandle;
16use hyperactor::Instance;
17use hyperactor::Proc;
18use hyperactor_mesh::ProcMeshRef;
19use hyperactor_mesh::bootstrap::BootstrapCommand;
20use hyperactor_mesh::bootstrap::host;
21use hyperactor_mesh::host_mesh::HostMesh;
22use hyperactor_mesh::host_mesh::HostMeshRef;
23use hyperactor_mesh::host_mesh::host_agent::GetLocalProcClient;
24use hyperactor_mesh::host_mesh::host_agent::HostAgent;
25use hyperactor_mesh::host_mesh::host_agent::ShutdownHost;
26use hyperactor_mesh::proc_agent::GetProcClient;
27use hyperactor_mesh::proc_mesh::ProcRef;
28use hyperactor_mesh::shared_cell::SharedCell;
29use hyperactor_mesh::transport::default_bind_spec;
30use ndslice::View;
31use ndslice::view::RankedSliceable;
32use pyo3::IntoPyObjectExt;
33use pyo3::exceptions::PyException;
34use pyo3::exceptions::PyRuntimeError;
35use pyo3::exceptions::PyValueError;
36use pyo3::prelude::*;
37use pyo3::types::PyBytes;
38use pyo3::types::PyType;
39
40use crate::actor::PythonActor;
41use crate::actor::to_py_error;
42use crate::alloc::PyAlloc;
43use crate::context::PyInstance;
44use crate::proc_mesh::PyProcMesh;
45use crate::pytokio::PyPythonTask;
46use crate::runtime::monarch_with_gil;
47use crate::shape::PyExtent;
48use crate::shape::PyRegion;
49
50#[pyclass(
51    name = "BootstrapCommand",
52    module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
53)]
54#[derive(Clone)]
55pub struct PyBootstrapCommand {
56    #[pyo3(get, set)]
57    pub program: String,
58    #[pyo3(get, set)]
59    pub arg0: Option<String>,
60    #[pyo3(get, set)]
61    pub args: Vec<String>,
62    #[pyo3(get, set)]
63    pub env: HashMap<String, String>,
64}
65
66#[pymethods]
67impl PyBootstrapCommand {
68    #[new]
69    fn new(
70        program: String,
71        arg0: Option<String>,
72        args: Vec<String>,
73        env: HashMap<String, String>,
74    ) -> Self {
75        Self {
76            program,
77            arg0,
78            args,
79            env,
80        }
81    }
82
83    fn __repr__(&self) -> String {
84        format!(
85            "BootstrapCommand(program='{}', args={:?}, env={:?})",
86            self.program, self.args, self.env
87        )
88    }
89}
90
91impl PyBootstrapCommand {
92    pub fn to_rust(&self) -> BootstrapCommand {
93        BootstrapCommand {
94            program: PathBuf::from(&self.program),
95            arg0: self.arg0.clone(),
96            args: self.args.clone(),
97            env: self.env.clone(),
98        }
99    }
100
101    pub fn default<'py>(py: Python<'py>) -> PyResult<Bound<'py, Self>> {
102        py.import("monarch._src.actor.host_mesh")?
103            .getattr("_bootstrap_cmd")?
104            .call0()?
105            .downcast::<PyBootstrapCommand>()
106            .cloned()
107            .map_err(to_py_error)
108    }
109}
110
111#[pyclass(
112    name = "HostMesh",
113    module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
114)]
115pub(crate) enum PyHostMesh {
116    Owned(PyHostMeshImpl),
117    Ref(PyHostMeshRefImpl),
118}
119
120impl PyHostMesh {
121    pub(crate) fn new_owned(inner: HostMesh) -> Self {
122        Self::Owned(PyHostMeshImpl(SharedCell::from(inner)))
123    }
124
125    pub(crate) fn new_ref(inner: HostMeshRef) -> Self {
126        Self::Ref(PyHostMeshRefImpl(inner))
127    }
128
129    fn mesh_ref(&self) -> Result<HostMeshRef, anyhow::Error> {
130        match self {
131            PyHostMesh::Owned(inner) => Ok(inner.0.borrow()?.clone()),
132            PyHostMesh::Ref(inner) => Ok(inner.0.clone()),
133        }
134    }
135}
136
137#[pymethods]
138impl PyHostMesh {
139    #[classmethod]
140    fn allocate_nonblocking(
141        _cls: &Bound<'_, PyType>,
142        instance: &PyInstance,
143        alloc: &mut PyAlloc,
144        name: String,
145        bootstrap_params: Option<PyBootstrapCommand>,
146    ) -> PyResult<PyPythonTask> {
147        let bootstrap_params =
148            bootstrap_params.map_or_else(|| alloc.bootstrap_command.clone(), |b| Some(b.to_rust()));
149        let alloc = match alloc.take() {
150            Some(alloc) => alloc,
151            None => {
152                return Err(PyException::new_err(
153                    "Alloc object already used".to_string(),
154                ));
155            }
156        };
157        let instance = instance.clone();
158        PyPythonTask::new(async move {
159            let mesh = HostMesh::allocate(instance.deref(), alloc, &name, bootstrap_params)
160                .await
161                .map_err(|err| PyException::new_err(err.to_string()))?;
162            Ok(Self::new_owned(mesh))
163        })
164    }
165
166    fn spawn_nonblocking(
167        &self,
168        instance: &PyInstance,
169        name: String,
170        per_host: &PyExtent,
171    ) -> PyResult<PyPythonTask> {
172        let host_mesh = self.mesh_ref()?.clone();
173        let instance = instance.clone();
174        let per_host = per_host.clone().into();
175        let mesh_impl = async move {
176            let proc_mesh = host_mesh
177                .spawn(instance.deref(), &name, per_host)
178                .await
179                .map_err(to_py_error)?;
180            Ok(PyProcMesh::new_owned(proc_mesh))
181        };
182        PyPythonTask::new(mesh_impl)
183    }
184
185    /// Spawn a MeshAdminAgent on the head host's system proc and
186    /// return its HTTP address as a string.
187    ///
188    /// When `admin_addr` is provided (as a `"host:port"` string), the
189    /// HTTP server binds to that address; otherwise it reads
190    /// `MESH_ADMIN_ADDR` from config.
191    fn _spawn_admin(
192        &self,
193        instance: &PyInstance,
194        admin_addr: Option<String>,
195    ) -> PyResult<PyPythonTask> {
196        let admin_addr = admin_addr
197            .map(|s| {
198                s.parse::<std::net::SocketAddr>()
199                    .map_err(|e| PyException::new_err(format!("invalid admin_addr '{}': {}", s, e)))
200            })
201            .transpose()?;
202        let host_mesh = self.mesh_ref()?.clone();
203        let instance = instance.clone();
204        PyPythonTask::new(async move {
205            let addr = host_mesh
206                .spawn_admin(instance.deref(), admin_addr)
207                .await
208                .map_err(|e| PyException::new_err(e.to_string()))?;
209            Ok(addr)
210        })
211    }
212
213    fn sliced(&self, region: &PyRegion) -> PyResult<Self> {
214        Ok(Self::new_ref(
215            self.mesh_ref()?.sliced(region.as_inner().clone()),
216        ))
217    }
218
219    #[getter]
220    fn region(&self) -> PyResult<PyRegion> {
221        Ok(PyRegion::from(self.mesh_ref()?.region()))
222    }
223
224    fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
225        let bytes = bincode::serialize(&self.mesh_ref()?)
226            .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
227        let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
228        let from_bytes =
229            PyModule::import(py, "monarch._rust_bindings.monarch_hyperactor.host_mesh")?
230                .getattr("py_host_mesh_from_bytes")?;
231        Ok((from_bytes, py_bytes))
232    }
233
234    fn __eq__(&self, other: &PyHostMesh) -> PyResult<bool> {
235        Ok(self.mesh_ref()? == other.mesh_ref()?)
236    }
237
238    fn shutdown(&self, instance: &PyInstance) -> PyResult<PyPythonTask> {
239        match self {
240            PyHostMesh::Owned(inner) => {
241                let instance = instance.clone();
242                let mesh_borrow = inner.0.clone();
243                let fut = async move {
244                    match mesh_borrow.take().await {
245                        Ok(mut mesh) => {
246                            mesh.shutdown(instance.deref()).await?;
247                            Ok(())
248                        }
249                        Err(_) => {
250                            // Don't return an exception, silently ignore the stop request
251                            // because it was already done.
252                            tracing::info!("shutdown was already called on host mesh");
253                            Ok(())
254                        }
255                    }
256                };
257                PyPythonTask::new(fut)
258            }
259            PyHostMesh::Ref(_) => Err(PyRuntimeError::new_err(
260                "cannot shut down `HostMesh` that is a reference instead of owned",
261            )),
262        }
263    }
264}
265
266#[derive(Clone)]
267#[pyclass(
268    name = "HostMeshImpl",
269    module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
270)]
271pub(crate) struct PyHostMeshImpl(SharedCell<HostMesh>);
272
273#[derive(Debug, Clone)]
274#[pyclass(
275    name = "HostMeshRefImpl",
276    module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
277)]
278pub(crate) struct PyHostMeshRefImpl(HostMeshRef);
279
280impl PyHostMeshRefImpl {
281    fn __repr__(&self) -> PyResult<String> {
282        Ok(format!("<HostMeshRefImpl {:?}>", self.0))
283    }
284}
285
286/// Static storage for the root client instance when using host-based bootstrap.
287static ROOT_CLIENT_INSTANCE_FOR_HOST: OnceLock<Instance<PythonActor>> = OnceLock::new();
288
289/// Static storage for the host mesh agent created by bootstrap_host().
290static HOST_MESH_AGENT_FOR_HOST: OnceLock<ActorHandle<HostAgent>> = OnceLock::new();
291
292/// Bootstrap the client host and root client actor.
293///
294/// This creates a proper Host with BootstrapProcManager, spawns the root client
295/// actor on the Host's local_proc.
296///
297/// Returns a tuple of (HostMesh, ProcMesh, PyInstance) where:
298/// - PyHostMesh: the bootstrapped (local) host mesh; and
299/// - PyProcMesh: the local ProcMesh on this HostMesh; and
300/// - PyInstance: the root client actor instance, on the ProcMesh.
301///
302/// The HostMesh is served on the default transport.
303///
304/// This should be called only once, at process initialization
305#[pyfunction]
306fn bootstrap_host(bootstrap_cmd: Option<PyBootstrapCommand>) -> PyResult<PyPythonTask> {
307    let bootstrap_cmd = match bootstrap_cmd {
308        Some(cmd) => cmd.to_rust(),
309        None => BootstrapCommand::current().map_err(|e| PyException::new_err(e.to_string()))?,
310    };
311
312    PyPythonTask::new(async move {
313        let (host_mesh_agent, _shutdown) = host(
314            default_bind_spec().binding_addr(),
315            Some(bootstrap_cmd),
316            None,
317            false,
318        )
319        .await
320        .map_err(|e| PyException::new_err(e.to_string()))?;
321
322        // Store the agent for later shutdown
323        HOST_MESH_AGENT_FOR_HOST.set(host_mesh_agent.clone()).ok(); // Ignore error if already set
324
325        let host_mesh_name = hyperactor_mesh::Name::new_reserved("local").unwrap();
326        let host_mesh = HostMeshRef::from_host_agent(host_mesh_name, host_mesh_agent.bind())
327            .map_err(|e| PyException::new_err(e.to_string()))?;
328
329        // Register C so MeshAdminAgent can discover it ("A/C
330        // invariant" - hyperactor_mesh/src/mesh_admin.rs).
331        hyperactor_mesh::global_context::register_client_host(host_mesh.clone());
332
333        // We require a temporary instance to make a call to the host/proc agent.
334        let temp_proc = Proc::local();
335        let (temp_instance, _) = temp_proc
336            .instance("temp")
337            .map_err(|e| PyException::new_err(e.to_string()))?;
338
339        let local_proc_agent: hyperactor::ActorHandle<hyperactor_mesh::proc_agent::ProcAgent> =
340            host_mesh_agent
341                .get_local_proc(&temp_instance)
342                .await
343                .map_err(|e| PyException::new_err(e.to_string()))?;
344
345        let proc_mesh_name = hyperactor_mesh::Name::new_reserved("local").unwrap();
346        let proc_mesh = ProcMeshRef::new_singleton(
347            proc_mesh_name,
348            ProcRef::new(
349                local_proc_agent.actor_id().proc_id().clone(),
350                0,
351                local_proc_agent.bind(),
352            ),
353        );
354
355        let local_proc = local_proc_agent
356            .get_proc(&temp_instance)
357            .await
358            .map_err(|e| PyException::new_err(e.to_string()))?;
359
360        let (instance, _handle) = monarch_with_gil(|py| {
361            PythonActor::bootstrap_client_inner(py, local_proc, &ROOT_CLIENT_INSTANCE_FOR_HOST)
362        })
363        .await;
364
365        // Notify telemetry of the bootstrap host mesh, proc mesh, and client actor.
366        {
367            let now = std::time::SystemTime::now();
368
369            let host_name_str = host_mesh.name().to_string();
370            let host_mesh_id = hyperactor_telemetry::hash_to_u64(&host_name_str);
371            hyperactor_telemetry::notify_mesh_created(hyperactor_telemetry::MeshEvent {
372                id: host_mesh_id,
373                timestamp: now,
374                class: "Host".to_string(),
375                given_name: host_mesh.name().name().to_string(),
376                full_name: host_name_str,
377                shape_json: host_mesh.region().extent().to_string(),
378                parent_mesh_id: None,
379                parent_view_json: None,
380            });
381
382            let host_agent_id = host_mesh_agent.actor_id();
383            hyperactor_telemetry::notify_actor_created(hyperactor_telemetry::ActorEvent {
384                id: hyperactor_telemetry::hash_to_u64(host_agent_id),
385                timestamp: now,
386                mesh_id: host_mesh_id,
387                rank: 0,
388                full_name: host_agent_id.to_string(),
389                display_name: None,
390            });
391
392            let proc_name_str = proc_mesh.name().to_string();
393            let proc_mesh_id = hyperactor_telemetry::hash_to_u64(&proc_name_str);
394            hyperactor_telemetry::notify_mesh_created(hyperactor_telemetry::MeshEvent {
395                id: proc_mesh_id,
396                timestamp: now,
397                class: "Proc".to_string(),
398                given_name: proc_mesh.name().name().to_string(),
399                full_name: proc_name_str,
400                shape_json: proc_mesh.region().extent().to_string(),
401                parent_mesh_id: Some(host_mesh_id),
402                parent_view_json: None,
403            });
404
405            let proc_agent_id = local_proc_agent.actor_id();
406            hyperactor_telemetry::notify_actor_created(hyperactor_telemetry::ActorEvent {
407                id: hyperactor_telemetry::hash_to_u64(proc_agent_id),
408                timestamp: now,
409                mesh_id: proc_mesh_id,
410                rank: 0,
411                full_name: proc_agent_id.to_string(),
412                display_name: None,
413            });
414
415            let client_mesh_name = format!("{}/client", proc_mesh.name());
416            let client_mesh_id = hyperactor_telemetry::hash_to_u64(&client_mesh_name);
417            hyperactor_telemetry::notify_mesh_created(hyperactor_telemetry::MeshEvent {
418                id: client_mesh_id,
419                timestamp: now,
420                class: <PythonActor as typeuri::Named>::typename().to_string(),
421                given_name: "client".to_string(),
422                full_name: client_mesh_name,
423                shape_json: proc_mesh.region().extent().to_string(),
424                parent_mesh_id: Some(proc_mesh_id),
425                parent_view_json: None,
426            });
427
428            hyperactor_telemetry::notify_actor_created(hyperactor_telemetry::ActorEvent {
429                id: hyperactor_telemetry::hash_to_u64(instance.self_id()),
430                timestamp: now,
431                mesh_id: client_mesh_id,
432                rank: 0,
433                full_name: instance.self_id().to_string(),
434                display_name: Some("<root>".to_string()),
435            });
436        }
437
438        Ok((
439            PyHostMesh::new_ref(host_mesh),
440            PyProcMesh::new_ref(proc_mesh),
441            PyInstance::from(instance),
442        ))
443    })
444}
445
446#[pyfunction]
447fn py_host_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<PyHostMesh> {
448    let r: PyResult<HostMeshRef> = bincode::deserialize(bytes.as_bytes())
449        .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()));
450    r.map(PyHostMesh::new_ref)
451}
452
453#[pyfunction]
454fn shutdown_local_host_mesh() -> PyResult<PyPythonTask> {
455    let agent = HOST_MESH_AGENT_FOR_HOST
456        .get()
457        .ok_or_else(|| PyException::new_err("No local host mesh to shutdown"))?
458        .clone();
459
460    PyPythonTask::new(async move {
461        // Create a temporary instance to send the shutdown message
462        let temp_proc = hyperactor::Proc::local();
463        let (instance, _) = temp_proc
464            .instance("shutdown_requester")
465            .map_err(|e| PyException::new_err(e.to_string()))?;
466
467        tracing::info!(
468            "sending shutdown_host request to agent {}",
469            agent.actor_id()
470        );
471        // Use same defaults as HostMesh::shutdown():
472        // - MESH_TERMINATE_TIMEOUT = 10 seconds
473        // - MESH_TERMINATE_CONCURRENCY = 16
474
475        let (port, _) = instance.open_port();
476        let mut port = port.bind();
477        // We don't need the ack, and this temporary proc doesn't have a mailbox
478        // receiver set up anyways. Just ignore the message.
479        port.return_undeliverable(false);
480        agent
481            .send(
482                &instance,
483                ShutdownHost {
484                    timeout: Duration::from_secs(10),
485                    max_in_flight: 16,
486                    ack: port,
487                },
488            )
489            .map_err(|e| PyException::new_err(e.to_string()))?;
490
491        Ok(())
492    })
493}
494
495pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
496    let f = wrap_pyfunction!(py_host_mesh_from_bytes, hyperactor_mod)?;
497    f.setattr(
498        "__module__",
499        "monarch._rust_bindings.monarch_hyperactor.host_mesh",
500    )?;
501    hyperactor_mod.add_function(f)?;
502
503    let f2 = wrap_pyfunction!(bootstrap_host, hyperactor_mod)?;
504    f2.setattr(
505        "__module__",
506        "monarch._rust_bindings.monarch_hyperactor.host_mesh",
507    )?;
508    hyperactor_mod.add_function(f2)?;
509
510    let f3 = wrap_pyfunction!(shutdown_local_host_mesh, hyperactor_mod)?;
511    f3.setattr(
512        "__module__",
513        "monarch._rust_bindings.monarch_hyperactor.host_mesh",
514    )?;
515    hyperactor_mod.add_function(f3)?;
516
517    hyperactor_mod.add_class::<PyHostMesh>()?;
518    hyperactor_mod.add_class::<PyBootstrapCommand>()?;
519    Ok(())
520}