1use std::collections::HashMap;
10use std::ops::Deref;
11use std::path::PathBuf;
12use std::sync::OnceLock;
13use std::time::Duration;
14
15use hyperactor::ActorHandle;
16use hyperactor::Endpoint as _;
17use hyperactor::Instance;
18use hyperactor::Proc;
19use hyperactor::id::Label;
20use hyperactor_mesh::ProcMeshRef;
21use hyperactor_mesh::bootstrap::BootstrapCommand;
22use hyperactor_mesh::bootstrap::ProcBind;
23use hyperactor_mesh::bootstrap::host;
24use hyperactor_mesh::host_mesh;
25use hyperactor_mesh::host_mesh::HostMesh;
26use hyperactor_mesh::host_mesh::HostMeshRef;
27use hyperactor_mesh::host_mesh::PerRankBootstrapFn;
28use hyperactor_mesh::host_mesh::host_agent::GetLocalProcClient;
29use hyperactor_mesh::host_mesh::host_agent::HostAgent;
30use hyperactor_mesh::host_mesh::host_agent::ShutdownHost;
31use hyperactor_mesh::mesh_admin::MeshAdminMessageClient;
32use hyperactor_mesh::mesh_id::HostMeshId;
33use hyperactor_mesh::mesh_id::ProcMeshId;
34use hyperactor_mesh::proc_agent::GetProcClient;
35use hyperactor_mesh::proc_mesh::ProcRef;
36use hyperactor_mesh::shared_cell::SharedCell;
37use hyperactor_mesh::transport::default_bind_spec;
38use ndslice::View;
39use ndslice::view::RankedSliceable;
40use pyo3::IntoPyObjectExt;
41use pyo3::exceptions::PyException;
42use pyo3::exceptions::PyRuntimeError;
43use pyo3::exceptions::PyValueError;
44use pyo3::prelude::*;
45use pyo3::types::PyBytes;
46
47use crate::actor::PythonActor;
48use crate::actor::to_py_error;
49use crate::context::PyInstance;
50use crate::proc_mesh::PyProcMesh;
51use crate::pytokio::PyPythonTask;
52use crate::runtime::monarch_with_gil;
53use crate::shape::PyExtent;
54use crate::shape::PyPoint;
55use crate::shape::PyRegion;
56
57#[pyclass(
58 name = "BootstrapCommand",
59 module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
60)]
61#[derive(Clone)]
62pub struct PyBootstrapCommand {
63 #[pyo3(get, set)]
64 pub program: String,
65 #[pyo3(get, set)]
66 pub arg0: Option<String>,
67 #[pyo3(get, set)]
68 pub args: Vec<String>,
69 #[pyo3(get, set)]
70 pub env: HashMap<String, String>,
71}
72
73#[pymethods]
74impl PyBootstrapCommand {
75 #[new]
76 fn new(
77 program: String,
78 arg0: Option<String>,
79 args: Vec<String>,
80 env: HashMap<String, String>,
81 ) -> Self {
82 Self {
83 program,
84 arg0,
85 args,
86 env,
87 }
88 }
89
90 fn __repr__(&self) -> String {
91 format!(
92 "BootstrapCommand(program='{}', args={:?}, env={:?})",
93 self.program, self.args, self.env
94 )
95 }
96
97 fn with_env(&self, env: HashMap<String, String>) -> Self {
101 let mut new_env = self.env.clone();
102 new_env.extend(env);
103 Self {
104 program: self.program.clone(),
105 arg0: self.arg0.clone(),
106 args: self.args.clone(),
107 env: new_env,
108 }
109 }
110}
111
112impl PyBootstrapCommand {
113 pub fn to_rust(&self) -> BootstrapCommand {
114 BootstrapCommand {
115 program: PathBuf::from(&self.program),
116 arg0: self.arg0.clone(),
117 args: self.args.clone(),
118 env: self.env.clone(),
119 }
120 }
121}
122
123#[pyclass(
124 name = "HostMesh",
125 module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
126)]
127#[expect(
128 clippy::large_enum_variant,
129 reason = "PyO3 #[pyclass] enum; Box wrapping interacts with PyO3 codegen and Python interop — separate diff"
130)]
131pub(crate) enum PyHostMesh {
132 Owned(PyHostMeshImpl),
133 Ref(PyHostMeshRefImpl),
134}
135
136impl PyHostMesh {
137 pub(crate) fn new_owned(inner: HostMesh) -> Self {
138 Self::Owned(PyHostMeshImpl(SharedCell::from(inner)))
139 }
140
141 pub(crate) fn new_ref(inner: HostMeshRef) -> Self {
142 Self::Ref(PyHostMeshRefImpl(inner))
143 }
144
145 fn mesh_ref(&self) -> Result<HostMeshRef, anyhow::Error> {
146 match self {
147 PyHostMesh::Owned(inner) => Ok(inner.0.borrow()?.clone()),
148 PyHostMesh::Ref(inner) => Ok(inner.0.clone()),
149 }
150 }
151}
152
153#[pymethods]
154impl PyHostMesh {
155 #[pyo3(signature = (instance, name, per_host, proc_bind = None, per_rank_bootstrap = None))]
156 fn spawn_nonblocking(
157 &self,
158 _py: Python<'_>,
159 instance: &PyInstance,
160 name: String,
161 per_host: &PyExtent,
162 proc_bind: Option<Vec<HashMap<String, String>>>,
163 per_rank_bootstrap: Option<Py<PyAny>>,
164 ) -> PyResult<PyPythonTask> {
165 let host_mesh = self.mesh_ref()?.clone();
166 let per_rank_bootstrap: Option<Box<PerRankBootstrapFn>> = per_rank_bootstrap
167 .map(|callable| -> PyResult<Box<PerRankBootstrapFn>> {
168 Ok(Box::new(move |point| {
169 Python::attach(|py| {
170 let result =
171 callable
172 .bind(py)
173 .call1((PyPoint::from(point),))
174 .map_err(|e| {
175 anyhow::anyhow!("per-rank bootstrap callable raised: {}", e)
176 })?;
177 let cmd: PyBootstrapCommand = result.extract().map_err(|e| {
178 anyhow::anyhow!(
179 "per-rank bootstrap callable did not return BootstrapCommand: {}",
180 e
181 )
182 })?;
183 Ok(cmd.to_rust())
184 })
185 }))
186 })
187 .transpose()?;
188 let instance = instance.clone();
189 let per_host = per_host.clone().into();
190 let proc_bind = proc_bind.map(|v| v.into_iter().map(ProcBind::from).collect());
191 let mesh_impl = async move {
192 let proc_mesh = host_mesh
193 .spawn(
194 instance.deref(),
195 &name,
196 per_host,
197 proc_bind,
198 per_rank_bootstrap,
199 )
200 .await
201 .map_err(to_py_error)?;
202 Ok(PyProcMesh::new_owned(proc_mesh))
203 };
204 PyPythonTask::new(mesh_impl)
205 }
206
207 fn with_bootstrap(&self, bootstrap_command: &PyBootstrapCommand) -> PyResult<Self> {
208 match self {
209 PyHostMesh::Owned(inner) => {
210 let cmd = bootstrap_command.to_rust();
211 inner
212 .0
213 .try_with_mut(|mesh| mesh.set_bootstrap(cmd))
214 .map_err(|e| PyException::new_err(e.to_string()))?;
215 Ok(Self::Owned(inner.clone()))
216 }
217 PyHostMesh::Ref(_) => Ok(Self::new_ref(
218 self.mesh_ref()?.with_bootstrap(bootstrap_command.to_rust()),
219 )),
220 }
221 }
222
223 fn sliced(&self, region: &PyRegion) -> PyResult<Self> {
224 Ok(Self::new_ref(
225 self.mesh_ref()?.sliced(region.as_inner().clone()),
226 ))
227 }
228
229 #[getter]
230 fn region(&self) -> PyResult<PyRegion> {
231 Ok(PyRegion::from(self.mesh_ref()?.region()))
232 }
233
234 fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
235 let bytes = bincode::serde::encode_to_vec(&self.mesh_ref()?, bincode::config::legacy())
236 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
237 let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
238 let from_bytes =
239 PyModule::import(py, "monarch._rust_bindings.monarch_hyperactor.host_mesh")?
240 .getattr("py_host_mesh_from_bytes")?;
241 Ok((from_bytes, py_bytes))
242 }
243
244 fn __eq__(&self, other: &PyHostMesh) -> PyResult<bool> {
245 Ok(self.mesh_ref()? == other.mesh_ref()?)
246 }
247
248 fn shutdown(&self, instance: &PyInstance) -> PyResult<PyPythonTask> {
249 match self {
250 PyHostMesh::Owned(inner) => {
251 let instance = instance.clone();
252 let mesh_borrow = inner.0.clone();
253 let fut = async move {
254 match mesh_borrow.take().await {
255 Ok(mut mesh) => {
256 mesh.shutdown(instance.deref()).await?;
257 Ok(())
258 }
259 Err(_) => {
260 tracing::info!("shutdown was already called on host mesh");
263 Ok(())
264 }
265 }
266 };
267 PyPythonTask::new(fut)
268 }
269 PyHostMesh::Ref(_) => Err(PyRuntimeError::new_err(
270 "cannot shut down `HostMesh` that is a reference instead of owned",
271 )),
272 }
273 }
274
275 fn stop(&self, instance: &PyInstance) -> PyResult<PyPythonTask> {
276 match self {
277 PyHostMesh::Owned(inner) => {
278 let instance = instance.clone();
279 let mesh_borrow = inner.0.clone();
280 let fut = async move {
281 match mesh_borrow.take().await {
282 Ok(mut mesh) => {
283 mesh.stop(instance.deref()).await?;
284 Ok(())
285 }
286 Err(_) => {
287 tracing::info!("stop was already called on host mesh");
288 Ok(())
289 }
290 }
291 };
292 PyPythonTask::new(fut)
293 }
294 PyHostMesh::Ref(_) => Err(PyRuntimeError::new_err(
295 "cannot stop `HostMesh` that is a reference instead of owned",
296 )),
297 }
298 }
299}
300
301#[derive(Clone)]
302#[pyclass(
303 name = "HostMeshImpl",
304 module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
305)]
306pub(crate) struct PyHostMeshImpl(SharedCell<HostMesh>);
307
308#[derive(Debug, Clone)]
309#[pyclass(
310 name = "HostMeshRefImpl",
311 module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
312)]
313pub(crate) struct PyHostMeshRefImpl(HostMeshRef);
314
315impl PyHostMeshRefImpl {
316 fn __repr__(&self) -> PyResult<String> {
317 Ok(format!("<HostMeshRefImpl {:?}>", self.0))
318 }
319}
320
321static ROOT_CLIENT_INSTANCE_FOR_HOST: OnceLock<Instance<PythonActor>> = OnceLock::new();
323
324static HOST_MESH_AGENT_FOR_HOST: OnceLock<ActorHandle<HostAgent>> = OnceLock::new();
326
327static HOST_SHUTDOWN_HANDLE: OnceLock<
331 tokio::sync::Mutex<Option<hyperactor_mesh::bootstrap::HostShutdownHandle>>,
332> = OnceLock::new();
333
334#[pyfunction]
348fn bootstrap_host(bootstrap_cmd: Option<PyBootstrapCommand>) -> PyResult<PyPythonTask> {
349 let bootstrap_cmd = match bootstrap_cmd {
350 Some(cmd) => cmd.to_rust(),
351 None => BootstrapCommand::current().map_err(|e| PyException::new_err(e.to_string()))?,
352 };
353
354 PyPythonTask::new(async move {
355 let (host_mesh_agent, shutdown_handle) = host(
356 default_bind_spec().binding_addr(),
357 Some(bootstrap_cmd),
358 None,
359 false,
360 None,
361 )
362 .await
363 .map_err(|e| PyException::new_err(e.to_string()))?;
364
365 HOST_MESH_AGENT_FOR_HOST.set(host_mesh_agent.clone()).ok();
367 HOST_SHUTDOWN_HANDLE.get_or_init(|| tokio::sync::Mutex::new(Some(shutdown_handle)));
368
369 let host_mesh_id = HostMeshId::singleton(Label::new("local").unwrap());
370 let host_mesh = HostMeshRef::from_host_agent(host_mesh_id, host_mesh_agent.bind())
371 .map_err(|e| PyException::new_err(e.to_string()))?;
372
373 hyperactor_mesh::global_context::register_client_host(host_mesh.clone());
376
377 let temp_proc = Proc::isolated();
379 let (temp_instance, _) = temp_proc
380 .client("temp")
381 .map_err(|e| PyException::new_err(e.to_string()))?;
382
383 let local_proc_agent: hyperactor::ActorHandle<hyperactor_mesh::proc_agent::ProcAgent> =
384 host_mesh_agent
385 .get_local_proc(&temp_instance)
386 .await
387 .map_err(|e| PyException::new_err(e.to_string()))?;
388
389 let proc_mesh = ProcMeshRef::new_singleton(
390 ProcMeshId::singleton(Label::new("local").unwrap()),
391 ProcRef::new(
392 local_proc_agent.actor_addr().proc_addr(),
393 0,
394 local_proc_agent.bind(),
395 ),
396 );
397
398 let local_proc = local_proc_agent
399 .get_proc(&temp_instance)
400 .await
401 .map_err(|e| PyException::new_err(e.to_string()))?;
402
403 let (instance, _handle) = monarch_with_gil(|py| {
404 PythonActor::bootstrap_client_inner(py, local_proc, &ROOT_CLIENT_INSTANCE_FOR_HOST)
405 })
406 .await;
407
408 {
410 let now = std::time::SystemTime::now();
411
412 let host_name_str = host_mesh.id().to_string();
413 let host_mesh_id = hyperactor_telemetry::hash_to_u64(&host_name_str);
414 hyperactor_telemetry::notify_mesh_created(hyperactor_telemetry::MeshEvent {
415 id: host_mesh_id,
416 timestamp: now,
417 class: "Host".to_string(),
418 given_name: host_mesh
419 .id()
420 .display_label()
421 .map(|l| l.as_str())
422 .unwrap_or("unnamed")
423 .to_string(),
424 full_name: host_name_str,
425 shape_json: serde_json::to_string(&host_mesh.region().extent()).unwrap_or_default(),
426 parent_mesh_id: None,
427 parent_view_json: None,
428 });
429
430 let host_agent_id = host_mesh_agent.actor_addr();
431 hyperactor_telemetry::notify_actor_created(hyperactor_telemetry::ActorEvent {
432 id: hyperactor_telemetry::hash_to_u64(host_agent_id),
433 timestamp: now,
434 mesh_id: host_mesh_id,
435 rank: 0,
436 full_name: host_agent_id.to_string(),
437 display_name: None,
438 });
439
440 let proc_id_str = proc_mesh.id().to_string();
441 let proc_mesh_id = hyperactor_telemetry::hash_to_u64(&proc_id_str);
442 hyperactor_telemetry::notify_mesh_created(hyperactor_telemetry::MeshEvent {
443 id: proc_mesh_id,
444 timestamp: now,
445 class: "Proc".to_string(),
446 given_name: proc_mesh
447 .id()
448 .display_label()
449 .map(|l| l.as_str())
450 .unwrap_or("unnamed")
451 .to_string(),
452 full_name: proc_id_str,
453 shape_json: serde_json::to_string(&proc_mesh.region().extent()).unwrap_or_default(),
454 parent_mesh_id: Some(host_mesh_id),
455 parent_view_json: None,
456 });
457
458 let proc_agent_id = local_proc_agent.actor_addr();
459 hyperactor_telemetry::notify_actor_created(hyperactor_telemetry::ActorEvent {
460 id: hyperactor_telemetry::hash_to_u64(proc_agent_id),
461 timestamp: now,
462 mesh_id: proc_mesh_id,
463 rank: 0,
464 full_name: proc_agent_id.to_string(),
465 display_name: None,
466 });
467
468 let client_mesh_name = format!("{}/client", proc_mesh.id());
469 let client_mesh_id = hyperactor_telemetry::hash_to_u64(&client_mesh_name);
470 hyperactor_telemetry::notify_mesh_created(hyperactor_telemetry::MeshEvent {
471 id: client_mesh_id,
472 timestamp: now,
473 class: <PythonActor as typeuri::Named>::typename().to_string(),
474 given_name: "client".to_string(),
475 full_name: client_mesh_name,
476 shape_json: serde_json::to_string(&proc_mesh.region().extent()).unwrap_or_default(),
477 parent_mesh_id: Some(proc_mesh_id),
478 parent_view_json: None,
479 });
480
481 hyperactor_telemetry::notify_actor_created(hyperactor_telemetry::ActorEvent {
482 id: hyperactor_telemetry::hash_to_u64(instance.self_addr()),
483 timestamp: now,
484 mesh_id: client_mesh_id,
485 rank: 0,
486 full_name: instance.self_addr().to_string(),
487 display_name: Some("<root>".to_string()),
488 });
489 }
490
491 Ok((
492 PyHostMesh::new_ref(host_mesh),
493 PyProcMesh::new_ref(proc_mesh),
494 PyInstance::from(instance),
495 ))
496 })
497}
498
499#[pyfunction]
500fn py_host_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<PyHostMesh> {
501 let r: PyResult<HostMeshRef> =
502 bincode::serde::decode_from_slice(bytes.as_bytes(), bincode::config::legacy())
503 .map(|(v, _)| v)
504 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()));
505 r.map(PyHostMesh::new_ref)
506}
507
508#[pyfunction]
509fn shutdown_local_host_mesh() -> PyResult<PyPythonTask> {
510 let agent = HOST_MESH_AGENT_FOR_HOST
511 .get()
512 .ok_or_else(|| PyException::new_err("No local host mesh to shutdown"))?
513 .clone();
514
515 PyPythonTask::new(async move {
516 let temp_proc = hyperactor::Proc::isolated();
518 let (instance, _) = temp_proc
519 .client("shutdown_requester")
520 .map_err(|e| PyException::new_err(e.to_string()))?;
521
522 tracing::info!(
523 "sending shutdown_host request to agent {}",
524 agent.actor_addr()
525 );
526 let (port, _) = instance.open_port();
531 let mut port = port.bind();
532 port.return_undeliverable(false);
535 agent.post(
536 &instance,
537 ShutdownHost {
538 timeout: Duration::from_secs(10),
539 max_in_flight: 16,
540 ack: port,
541 },
542 );
543
544 if let Some(lock) = HOST_SHUTDOWN_HANDLE.get()
547 && let Some(handle) = lock.lock().await.take()
548 {
549 handle.join().await;
550 }
551
552 Ok(())
553 })
554}
555
556#[pyclass(
561 name = "PyMeshAdminRef",
562 module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
563)]
564#[derive(Clone)]
565pub struct PyMeshAdminRef(hyperactor::ActorRef<hyperactor_mesh::mesh_admin::MeshAdminAgent>);
566
567impl PyMeshAdminRef {
568 pub fn actor_ref(&self) -> hyperactor::ActorRef<hyperactor_mesh::mesh_admin::MeshAdminAgent> {
569 self.0.clone()
570 }
571}
572
573#[pyfunction]
578fn _spawn_admin(
579 host_meshes: Vec<PyRef<'_, PyHostMesh>>,
580 instance: &PyInstance,
581 admin_addr: Option<String>,
582 telemetry_url: Option<String>,
583) -> PyResult<PyPythonTask> {
584 if host_meshes.is_empty() {
585 return Err(PyException::new_err("at least one mesh is required"));
586 }
587
588 let admin_addr = admin_addr
589 .map(|s| {
590 s.parse::<std::net::SocketAddr>()
591 .map_err(|e| PyException::new_err(format!("invalid admin_addr '{}': {}", s, e)))
592 })
593 .transpose()?;
594
595 let mesh_refs = host_meshes
596 .iter()
597 .map(|m| -> PyResult<HostMeshRef> { Ok(m.mesh_ref()?.clone()) })
598 .collect::<PyResult<Vec<HostMeshRef>>>()?;
599
600 let instance = instance.clone();
601 PyPythonTask::new(async move {
602 let admin_ref =
603 host_mesh::spawn_admin(&mesh_refs, instance.deref(), admin_addr, telemetry_url)
604 .await
605 .map_err(|e| PyException::new_err(e.to_string()))?;
606 let admin_url = admin_ref
607 .get_admin_addr(instance.deref())
608 .await
609 .map_err(|e| PyException::new_err(e.to_string()))?
610 .addr
611 .ok_or_else(|| PyException::new_err("mesh admin agent did not report an address"))?;
612 Ok((admin_url, PyMeshAdminRef(admin_ref)))
613 })
614}
615
616pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
617 let f = wrap_pyfunction!(py_host_mesh_from_bytes, hyperactor_mod)?;
618 f.setattr(
619 "__module__",
620 "monarch._rust_bindings.monarch_hyperactor.host_mesh",
621 )?;
622 hyperactor_mod.add_function(f)?;
623
624 let f2 = wrap_pyfunction!(bootstrap_host, hyperactor_mod)?;
625 f2.setattr(
626 "__module__",
627 "monarch._rust_bindings.monarch_hyperactor.host_mesh",
628 )?;
629 hyperactor_mod.add_function(f2)?;
630
631 let f3 = wrap_pyfunction!(shutdown_local_host_mesh, hyperactor_mod)?;
632 f3.setattr(
633 "__module__",
634 "monarch._rust_bindings.monarch_hyperactor.host_mesh",
635 )?;
636 hyperactor_mod.add_function(f3)?;
637
638 let f4 = wrap_pyfunction!(_spawn_admin, hyperactor_mod)?;
639 f4.setattr(
640 "__module__",
641 "monarch._rust_bindings.monarch_hyperactor.host_mesh",
642 )?;
643 hyperactor_mod.add_function(f4)?;
644
645 hyperactor_mod.add_class::<PyHostMesh>()?;
646 hyperactor_mod.add_class::<PyBootstrapCommand>()?;
647 hyperactor_mod.add_class::<PyMeshAdminRef>()?;
648 Ok(())
649}