1use std::collections::HashMap;
10use std::ops::Deref;
11use std::path::PathBuf;
12use std::sync::OnceLock;
13use std::time::Duration;
14
15use hyperactor::ActorHandle;
16use hyperactor::Instance;
17use hyperactor::Proc;
18use hyperactor_mesh::ProcMeshRef;
19use hyperactor_mesh::bootstrap::BootstrapCommand;
20use hyperactor_mesh::bootstrap::ProcBind;
21use hyperactor_mesh::bootstrap::host;
22use hyperactor_mesh::host_mesh;
23use hyperactor_mesh::host_mesh::HostMesh;
24use hyperactor_mesh::host_mesh::HostMeshRef;
25use hyperactor_mesh::host_mesh::host_agent::GetLocalProcClient;
26use hyperactor_mesh::host_mesh::host_agent::HostAgent;
27use hyperactor_mesh::host_mesh::host_agent::ShutdownHost;
28use hyperactor_mesh::proc_agent::GetProcClient;
29use hyperactor_mesh::proc_mesh::ProcRef;
30use hyperactor_mesh::shared_cell::SharedCell;
31use hyperactor_mesh::transport::default_bind_spec;
32use ndslice::View;
33use ndslice::view::RankedSliceable;
34use pyo3::IntoPyObjectExt;
35use pyo3::exceptions::PyException;
36use pyo3::exceptions::PyRuntimeError;
37use pyo3::exceptions::PyValueError;
38use pyo3::prelude::*;
39use pyo3::types::PyBytes;
40use pyo3::types::PyType;
41
42use crate::actor::PythonActor;
43use crate::actor::to_py_error;
44use crate::alloc::PyAlloc;
45use crate::context::PyInstance;
46use crate::proc_mesh::PyProcMesh;
47use crate::pytokio::PyPythonTask;
48use crate::runtime::monarch_with_gil;
49use crate::shape::PyExtent;
50use crate::shape::PyRegion;
51
52#[pyclass(
53 name = "BootstrapCommand",
54 module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
55)]
56#[derive(Clone)]
57pub struct PyBootstrapCommand {
58 #[pyo3(get, set)]
59 pub program: String,
60 #[pyo3(get, set)]
61 pub arg0: Option<String>,
62 #[pyo3(get, set)]
63 pub args: Vec<String>,
64 #[pyo3(get, set)]
65 pub env: HashMap<String, String>,
66}
67
68#[pymethods]
69impl PyBootstrapCommand {
70 #[new]
71 fn new(
72 program: String,
73 arg0: Option<String>,
74 args: Vec<String>,
75 env: HashMap<String, String>,
76 ) -> Self {
77 Self {
78 program,
79 arg0,
80 args,
81 env,
82 }
83 }
84
85 fn __repr__(&self) -> String {
86 format!(
87 "BootstrapCommand(program='{}', args={:?}, env={:?})",
88 self.program, self.args, self.env
89 )
90 }
91}
92
93impl PyBootstrapCommand {
94 pub fn to_rust(&self) -> BootstrapCommand {
95 BootstrapCommand {
96 program: PathBuf::from(&self.program),
97 arg0: self.arg0.clone(),
98 args: self.args.clone(),
99 env: self.env.clone(),
100 }
101 }
102
103 pub fn default<'py>(py: Python<'py>) -> PyResult<Bound<'py, Self>> {
104 py.import("monarch._src.actor.host_mesh")?
105 .getattr("_bootstrap_cmd")?
106 .call0()?
107 .downcast::<PyBootstrapCommand>()
108 .cloned()
109 .map_err(to_py_error)
110 }
111}
112
113#[pyclass(
114 name = "HostMesh",
115 module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
116)]
117pub(crate) enum PyHostMesh {
118 Owned(PyHostMeshImpl),
119 Ref(PyHostMeshRefImpl),
120}
121
122impl PyHostMesh {
123 pub(crate) fn new_owned(inner: HostMesh) -> Self {
124 Self::Owned(PyHostMeshImpl(SharedCell::from(inner)))
125 }
126
127 pub(crate) fn new_ref(inner: HostMeshRef) -> Self {
128 Self::Ref(PyHostMeshRefImpl(inner))
129 }
130
131 fn mesh_ref(&self) -> Result<HostMeshRef, anyhow::Error> {
132 match self {
133 PyHostMesh::Owned(inner) => Ok(inner.0.borrow()?.clone()),
134 PyHostMesh::Ref(inner) => Ok(inner.0.clone()),
135 }
136 }
137}
138
139#[pymethods]
140impl PyHostMesh {
141 #[classmethod]
142 fn allocate_nonblocking(
143 _cls: &Bound<'_, PyType>,
144 instance: &PyInstance,
145 alloc: &mut PyAlloc,
146 name: String,
147 bootstrap_params: Option<PyBootstrapCommand>,
148 ) -> PyResult<PyPythonTask> {
149 let bootstrap_params =
150 bootstrap_params.map_or_else(|| alloc.bootstrap_command.clone(), |b| Some(b.to_rust()));
151 let alloc = match alloc.take() {
152 Some(alloc) => alloc,
153 None => {
154 return Err(PyException::new_err(
155 "Alloc object already used".to_string(),
156 ));
157 }
158 };
159 let instance = instance.clone();
160 PyPythonTask::new(async move {
161 let mesh = HostMesh::allocate(instance.deref(), alloc, &name, bootstrap_params)
162 .await
163 .map_err(|err| PyException::new_err(err.to_string()))?;
164 Ok(Self::new_owned(mesh))
165 })
166 }
167
168 #[pyo3(signature = (instance, name, per_host, proc_bind = None))]
169 fn spawn_nonblocking(
170 &self,
171 instance: &PyInstance,
172 name: String,
173 per_host: &PyExtent,
174 proc_bind: Option<Vec<HashMap<String, String>>>,
175 ) -> PyResult<PyPythonTask> {
176 let host_mesh = self.mesh_ref()?.clone();
177 let instance = instance.clone();
178 let per_host = per_host.clone().into();
179 let proc_bind = proc_bind.map(|v| v.into_iter().map(ProcBind::from).collect());
180 let mesh_impl = async move {
181 let proc_mesh = host_mesh
182 .spawn(instance.deref(), &name, per_host, proc_bind)
183 .await
184 .map_err(to_py_error)?;
185 Ok(PyProcMesh::new_owned(proc_mesh))
186 };
187 PyPythonTask::new(mesh_impl)
188 }
189
190 fn with_bootstrap(&self, bootstrap_command: &PyBootstrapCommand) -> PyResult<Self> {
191 match self {
192 PyHostMesh::Owned(inner) => {
193 let cmd = bootstrap_command.to_rust();
194 inner
195 .0
196 .try_with_mut(|mesh| mesh.set_bootstrap(cmd))
197 .map_err(|e| PyException::new_err(e.to_string()))?;
198 Ok(Self::Owned(inner.clone()))
199 }
200 PyHostMesh::Ref(_) => Ok(Self::new_ref(
201 self.mesh_ref()?.with_bootstrap(bootstrap_command.to_rust()),
202 )),
203 }
204 }
205
206 fn sliced(&self, region: &PyRegion) -> PyResult<Self> {
207 Ok(Self::new_ref(
208 self.mesh_ref()?.sliced(region.as_inner().clone()),
209 ))
210 }
211
212 #[getter]
213 fn region(&self) -> PyResult<PyRegion> {
214 Ok(PyRegion::from(self.mesh_ref()?.region()))
215 }
216
217 fn __reduce__<'py>(&self, py: Python<'py>) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
218 let bytes = bincode::serialize(&self.mesh_ref()?)
219 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))?;
220 let py_bytes = (PyBytes::new(py, &bytes),).into_bound_py_any(py).unwrap();
221 let from_bytes =
222 PyModule::import(py, "monarch._rust_bindings.monarch_hyperactor.host_mesh")?
223 .getattr("py_host_mesh_from_bytes")?;
224 Ok((from_bytes, py_bytes))
225 }
226
227 fn __eq__(&self, other: &PyHostMesh) -> PyResult<bool> {
228 Ok(self.mesh_ref()? == other.mesh_ref()?)
229 }
230
231 fn shutdown(&self, instance: &PyInstance) -> PyResult<PyPythonTask> {
232 match self {
233 PyHostMesh::Owned(inner) => {
234 let instance = instance.clone();
235 let mesh_borrow = inner.0.clone();
236 let fut = async move {
237 match mesh_borrow.take().await {
238 Ok(mut mesh) => {
239 mesh.shutdown(instance.deref()).await?;
240 Ok(())
241 }
242 Err(_) => {
243 tracing::info!("shutdown was already called on host mesh");
246 Ok(())
247 }
248 }
249 };
250 PyPythonTask::new(fut)
251 }
252 PyHostMesh::Ref(_) => Err(PyRuntimeError::new_err(
253 "cannot shut down `HostMesh` that is a reference instead of owned",
254 )),
255 }
256 }
257
258 fn stop(&self, instance: &PyInstance) -> PyResult<PyPythonTask> {
259 match self {
260 PyHostMesh::Owned(inner) => {
261 let instance = instance.clone();
262 let mesh_borrow = inner.0.clone();
263 let fut = async move {
264 match mesh_borrow.take().await {
265 Ok(mut mesh) => {
266 mesh.stop(instance.deref()).await?;
267 Ok(())
268 }
269 Err(_) => {
270 tracing::info!("stop was already called on host mesh");
271 Ok(())
272 }
273 }
274 };
275 PyPythonTask::new(fut)
276 }
277 PyHostMesh::Ref(_) => Err(PyRuntimeError::new_err(
278 "cannot stop `HostMesh` that is a reference instead of owned",
279 )),
280 }
281 }
282}
283
284#[derive(Clone)]
285#[pyclass(
286 name = "HostMeshImpl",
287 module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
288)]
289pub(crate) struct PyHostMeshImpl(SharedCell<HostMesh>);
290
291#[derive(Debug, Clone)]
292#[pyclass(
293 name = "HostMeshRefImpl",
294 module = "monarch._rust_bindings.monarch_hyperactor.host_mesh"
295)]
296pub(crate) struct PyHostMeshRefImpl(HostMeshRef);
297
298impl PyHostMeshRefImpl {
299 fn __repr__(&self) -> PyResult<String> {
300 Ok(format!("<HostMeshRefImpl {:?}>", self.0))
301 }
302}
303
304static ROOT_CLIENT_INSTANCE_FOR_HOST: OnceLock<Instance<PythonActor>> = OnceLock::new();
306
307static HOST_MESH_AGENT_FOR_HOST: OnceLock<ActorHandle<HostAgent>> = OnceLock::new();
309
310static HOST_SHUTDOWN_HANDLE: OnceLock<
314 tokio::sync::Mutex<Option<hyperactor_mesh::bootstrap::HostShutdownHandle>>,
315> = OnceLock::new();
316
317#[pyfunction]
331fn bootstrap_host(bootstrap_cmd: Option<PyBootstrapCommand>) -> PyResult<PyPythonTask> {
332 let bootstrap_cmd = match bootstrap_cmd {
333 Some(cmd) => cmd.to_rust(),
334 None => BootstrapCommand::current().map_err(|e| PyException::new_err(e.to_string()))?,
335 };
336
337 PyPythonTask::new(async move {
338 let (host_mesh_agent, shutdown_handle) = host(
339 default_bind_spec().binding_addr(),
340 Some(bootstrap_cmd),
341 None,
342 false,
343 )
344 .await
345 .map_err(|e| PyException::new_err(e.to_string()))?;
346
347 HOST_MESH_AGENT_FOR_HOST.set(host_mesh_agent.clone()).ok();
349 HOST_SHUTDOWN_HANDLE.get_or_init(|| tokio::sync::Mutex::new(Some(shutdown_handle)));
350
351 let host_mesh_name = hyperactor_mesh::Name::new_reserved("local").unwrap();
352 let host_mesh = HostMeshRef::from_host_agent(host_mesh_name, host_mesh_agent.bind())
353 .map_err(|e| PyException::new_err(e.to_string()))?;
354
355 hyperactor_mesh::global_context::register_client_host(host_mesh.clone());
358
359 let temp_proc = Proc::local();
361 let (temp_instance, _) = temp_proc
362 .instance("temp")
363 .map_err(|e| PyException::new_err(e.to_string()))?;
364
365 let local_proc_agent: hyperactor::ActorHandle<hyperactor_mesh::proc_agent::ProcAgent> =
366 host_mesh_agent
367 .get_local_proc(&temp_instance)
368 .await
369 .map_err(|e| PyException::new_err(e.to_string()))?;
370
371 let proc_mesh_name = hyperactor_mesh::Name::new_reserved("local").unwrap();
372 let proc_mesh = ProcMeshRef::new_singleton(
373 proc_mesh_name,
374 ProcRef::new(
375 local_proc_agent.actor_id().proc_id().clone(),
376 0,
377 local_proc_agent.bind(),
378 ),
379 );
380
381 let local_proc = local_proc_agent
382 .get_proc(&temp_instance)
383 .await
384 .map_err(|e| PyException::new_err(e.to_string()))?;
385
386 let (instance, _handle) = monarch_with_gil(|py| {
387 PythonActor::bootstrap_client_inner(py, local_proc, &ROOT_CLIENT_INSTANCE_FOR_HOST)
388 })
389 .await;
390
391 {
393 let now = std::time::SystemTime::now();
394
395 let host_name_str = host_mesh.name().to_string();
396 let host_mesh_id = hyperactor_telemetry::hash_to_u64(&host_name_str);
397 hyperactor_telemetry::notify_mesh_created(hyperactor_telemetry::MeshEvent {
398 id: host_mesh_id,
399 timestamp: now,
400 class: "Host".to_string(),
401 given_name: host_mesh.name().name().to_string(),
402 full_name: host_name_str,
403 shape_json: serde_json::to_string(&host_mesh.region().extent()).unwrap_or_default(),
404 parent_mesh_id: None,
405 parent_view_json: None,
406 });
407
408 let host_agent_id = host_mesh_agent.actor_id();
409 hyperactor_telemetry::notify_actor_created(hyperactor_telemetry::ActorEvent {
410 id: hyperactor_telemetry::hash_to_u64(host_agent_id),
411 timestamp: now,
412 mesh_id: host_mesh_id,
413 rank: 0,
414 full_name: host_agent_id.to_string(),
415 display_name: None,
416 });
417
418 let proc_name_str = proc_mesh.name().to_string();
419 let proc_mesh_id = hyperactor_telemetry::hash_to_u64(&proc_name_str);
420 hyperactor_telemetry::notify_mesh_created(hyperactor_telemetry::MeshEvent {
421 id: proc_mesh_id,
422 timestamp: now,
423 class: "Proc".to_string(),
424 given_name: proc_mesh.name().name().to_string(),
425 full_name: proc_name_str,
426 shape_json: serde_json::to_string(&proc_mesh.region().extent()).unwrap_or_default(),
427 parent_mesh_id: Some(host_mesh_id),
428 parent_view_json: None,
429 });
430
431 let proc_agent_id = local_proc_agent.actor_id();
432 hyperactor_telemetry::notify_actor_created(hyperactor_telemetry::ActorEvent {
433 id: hyperactor_telemetry::hash_to_u64(proc_agent_id),
434 timestamp: now,
435 mesh_id: proc_mesh_id,
436 rank: 0,
437 full_name: proc_agent_id.to_string(),
438 display_name: None,
439 });
440
441 let client_mesh_name = format!("{}/client", proc_mesh.name());
442 let client_mesh_id = hyperactor_telemetry::hash_to_u64(&client_mesh_name);
443 hyperactor_telemetry::notify_mesh_created(hyperactor_telemetry::MeshEvent {
444 id: client_mesh_id,
445 timestamp: now,
446 class: <PythonActor as typeuri::Named>::typename().to_string(),
447 given_name: "client".to_string(),
448 full_name: client_mesh_name,
449 shape_json: serde_json::to_string(&proc_mesh.region().extent()).unwrap_or_default(),
450 parent_mesh_id: Some(proc_mesh_id),
451 parent_view_json: None,
452 });
453
454 hyperactor_telemetry::notify_actor_created(hyperactor_telemetry::ActorEvent {
455 id: hyperactor_telemetry::hash_to_u64(instance.self_id()),
456 timestamp: now,
457 mesh_id: client_mesh_id,
458 rank: 0,
459 full_name: instance.self_id().to_string(),
460 display_name: Some("<root>".to_string()),
461 });
462 }
463
464 Ok((
465 PyHostMesh::new_ref(host_mesh),
466 PyProcMesh::new_ref(proc_mesh),
467 PyInstance::from(instance),
468 ))
469 })
470}
471
472#[pyfunction]
473fn py_host_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<PyHostMesh> {
474 let r: PyResult<HostMeshRef> = bincode::deserialize(bytes.as_bytes())
475 .map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()));
476 r.map(PyHostMesh::new_ref)
477}
478
479#[pyfunction]
480fn shutdown_local_host_mesh() -> PyResult<PyPythonTask> {
481 let agent = HOST_MESH_AGENT_FOR_HOST
482 .get()
483 .ok_or_else(|| PyException::new_err("No local host mesh to shutdown"))?
484 .clone();
485
486 PyPythonTask::new(async move {
487 let temp_proc = hyperactor::Proc::local();
489 let (instance, _) = temp_proc
490 .instance("shutdown_requester")
491 .map_err(|e| PyException::new_err(e.to_string()))?;
492
493 tracing::info!(
494 "sending shutdown_host request to agent {}",
495 agent.actor_id()
496 );
497 let (port, _) = instance.open_port();
502 let mut port = port.bind();
503 port.return_undeliverable(false);
506 agent
507 .send(
508 &instance,
509 ShutdownHost {
510 timeout: Duration::from_secs(10),
511 max_in_flight: 16,
512 ack: port,
513 },
514 )
515 .map_err(|e| PyException::new_err(e.to_string()))?;
516
517 if let Some(lock) = HOST_SHUTDOWN_HANDLE.get() {
520 if let Some(handle) = lock.lock().await.take() {
521 handle.join().await;
522 }
523 }
524
525 Ok(())
526 })
527}
528
529#[pyfunction]
539fn _spawn_admin(
540 host_meshes: Vec<PyRef<'_, PyHostMesh>>,
541 instance: &PyInstance,
542 admin_addr: Option<String>,
543 telemetry_url: Option<String>,
544) -> PyResult<PyPythonTask> {
545 if host_meshes.is_empty() {
546 return Err(PyException::new_err("at least one mesh is required"));
547 }
548
549 let admin_addr = admin_addr
550 .map(|s| {
551 s.parse::<std::net::SocketAddr>()
552 .map_err(|e| PyException::new_err(format!("invalid admin_addr '{}': {}", s, e)))
553 })
554 .transpose()?;
555
556 let mesh_refs = host_meshes
557 .iter()
558 .map(|m| -> PyResult<HostMeshRef> { Ok(m.mesh_ref()?.clone()) })
559 .collect::<PyResult<Vec<HostMeshRef>>>()?;
560
561 let instance = instance.clone();
562 PyPythonTask::new(async move {
563 let addr = host_mesh::spawn_admin(&mesh_refs, instance.deref(), admin_addr, telemetry_url)
564 .await
565 .map_err(|e| PyException::new_err(e.to_string()))?;
566 Ok(addr)
567 })
568}
569
570pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
571 let f = wrap_pyfunction!(py_host_mesh_from_bytes, hyperactor_mod)?;
572 f.setattr(
573 "__module__",
574 "monarch._rust_bindings.monarch_hyperactor.host_mesh",
575 )?;
576 hyperactor_mod.add_function(f)?;
577
578 let f2 = wrap_pyfunction!(bootstrap_host, hyperactor_mod)?;
579 f2.setattr(
580 "__module__",
581 "monarch._rust_bindings.monarch_hyperactor.host_mesh",
582 )?;
583 hyperactor_mod.add_function(f2)?;
584
585 let f3 = wrap_pyfunction!(shutdown_local_host_mesh, hyperactor_mod)?;
586 f3.setattr(
587 "__module__",
588 "monarch._rust_bindings.monarch_hyperactor.host_mesh",
589 )?;
590 hyperactor_mod.add_function(f3)?;
591
592 let f4 = wrap_pyfunction!(_spawn_admin, hyperactor_mod)?;
593 f4.setattr(
594 "__module__",
595 "monarch._rust_bindings.monarch_hyperactor.host_mesh",
596 )?;
597 hyperactor_mod.add_function(f4)?;
598
599 hyperactor_mod.add_class::<PyHostMesh>()?;
600 hyperactor_mod.add_class::<PyBootstrapCommand>()?;
601 Ok(())
602}