1use hyperactor::channel::ChannelTransport;
10pub mod mesh_agent;
11
12use std::collections::HashSet;
13use std::ops::Deref;
14use std::str::FromStr;
15use std::sync::Arc;
16
17use hyperactor::ActorRef;
18use hyperactor::Named;
19use hyperactor::ProcId;
20use hyperactor::channel::ChannelAddr;
21use hyperactor::context;
22use ndslice::Extent;
23use ndslice::Region;
24use ndslice::ViewExt;
25use ndslice::extent;
26use ndslice::view;
27use ndslice::view::Ranked;
28use ndslice::view::RegionParseError;
29use serde::Deserialize;
30use serde::Serialize;
31
32use crate::alloc::Alloc;
33use crate::bootstrap::BootstrapCommand;
34use crate::resource;
35use crate::resource::CreateOrUpdateClient;
36use crate::resource::GetRankStatusClient;
37use crate::resource::RankedValues;
38use crate::v1;
39use crate::v1::Name;
40use crate::v1::ProcMesh;
41use crate::v1::ProcMeshRef;
42pub use crate::v1::host_mesh::mesh_agent::HostMeshAgent;
43use crate::v1::host_mesh::mesh_agent::HostMeshAgentProcMeshTrampoline;
44use crate::v1::host_mesh::mesh_agent::ShutdownHostClient;
45use crate::v1::proc_mesh::ProcRef;
46
47#[derive(Debug, Clone, PartialEq, Eq, Hash, Named, Serialize, Deserialize)]
49pub struct HostRef(ChannelAddr);
50
51impl HostRef {
52 fn mesh_agent(&self) -> ActorRef<HostMeshAgent> {
54 ActorRef::attest(self.service_proc().actor_id("agent", 0))
55 }
56
57 fn named_proc(&self, name: &Name) -> ProcId {
59 ProcId::Direct(self.0.clone(), name.to_string())
60 }
61
62 fn service_proc(&self) -> ProcId {
64 ProcId::Direct(self.0.clone(), "service".to_string())
65 }
66
67 async fn shutdown(&self, cx: &impl hyperactor::context::Actor) -> anyhow::Result<()> {
86 let agent = self.mesh_agent();
87 let terminate_timeout =
88 hyperactor::config::global::get(crate::bootstrap::MESH_TERMINATE_TIMEOUT);
89 let max_in_flight =
90 hyperactor::config::global::get(crate::bootstrap::MESH_TERMINATE_CONCURRENCY);
91 agent
92 .shutdown_host(cx, terminate_timeout, max_in_flight.clamp(1, 256))
93 .await?;
94 Ok(())
95 }
96}
97
98impl std::fmt::Display for HostRef {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 self.0.fmt(f)
101 }
102}
103
104impl FromStr for HostRef {
105 type Err = <ChannelAddr as FromStr>::Err;
106
107 fn from_str(s: &str) -> Result<Self, Self::Err> {
108 Ok(HostRef(ChannelAddr::from_str(s)?))
109 }
110}
111
112#[allow(dead_code)]
125pub struct HostMesh {
126 name: Name,
127 extent: Extent,
128 allocation: HostMeshAllocation,
129 current_ref: HostMeshRef,
130}
131
132#[allow(dead_code)]
150enum HostMeshAllocation {
151 ProcMesh {
158 proc_mesh: ProcMesh,
159 proc_mesh_ref: ProcMeshRef,
160 hosts: Vec<HostRef>,
161 },
162 Owned { hosts: Vec<HostRef> },
170}
171
172impl HostMesh {
173 pub async fn allocate(
217 cx: &impl context::Actor,
218 alloc: Box<dyn Alloc + Send + Sync>,
219 name: &str,
220 bootstrap_params: Option<BootstrapCommand>,
221 ) -> v1::Result<Self> {
222 let transport = alloc.transport();
223 let extent = alloc.extent().clone();
224 let is_local = alloc.is_local();
225 let proc_mesh = ProcMesh::allocate(cx, alloc, name).await?;
226 let name = Name::new(name);
227
228 let (mesh_agents, mut mesh_agents_rx) = cx.mailbox().open_port();
233 let _trampoline_actor_mesh = proc_mesh
234 .spawn::<HostMeshAgentProcMeshTrampoline>(
235 cx,
236 "host_mesh_trampoline",
237 &(transport, mesh_agents.bind(), bootstrap_params, is_local),
238 )
239 .await?;
240
241 let mut hosts = Vec::new();
243 for _rank in 0..extent.num_ranks() {
244 let mesh_agent = mesh_agents_rx.recv().await?;
245
246 let Some((addr, _)) = mesh_agent.actor_id().proc_id().as_direct() else {
247 return Err(v1::Error::HostMeshAgentConfigurationError(
248 mesh_agent.actor_id().clone(),
249 "host mesh agent must be a direct actor".to_string(),
250 ));
251 };
252
253 let host_ref = HostRef(addr.clone());
254 if host_ref.mesh_agent() != mesh_agent {
255 return Err(v1::Error::HostMeshAgentConfigurationError(
256 mesh_agent.actor_id().clone(),
257 format!(
258 "expected mesh agent actor id to be {}",
259 host_ref.mesh_agent().actor_id()
260 ),
261 ));
262 }
263 hosts.push(host_ref);
264 }
265
266 let proc_mesh_ref = proc_mesh.clone();
267 Ok(Self {
268 name,
269 extent: extent.clone(),
270 allocation: HostMeshAllocation::ProcMesh {
271 proc_mesh,
272 proc_mesh_ref,
273 hosts: hosts.clone(),
274 },
275 current_ref: HostMeshRef::new(extent.into(), hosts).unwrap(),
276 })
277 }
278
279 pub fn take(name: impl Into<Name>, mesh: HostMeshRef) -> Self {
286 let name = name.into();
287 let region = mesh.region().clone();
288 let hosts: Vec<HostRef> = mesh.values().collect();
289
290 let current_ref = HostMeshRef::new(region.clone(), hosts.clone())
291 .expect("region/hosts cardinality must match");
292
293 Self {
294 name,
295 extent: region.extent().clone(),
296 allocation: HostMeshAllocation::Owned { hosts },
297 current_ref,
298 }
299 }
300
301 pub async fn shutdown(&self, cx: &impl hyperactor::context::Actor) -> anyhow::Result<()> {
311 let mut attempted = 0;
312 let mut ok = 0;
313 for host in self.current_ref.values() {
314 attempted += 1;
315 if let Err(e) = host.shutdown(cx).await {
316 tracing::warn!(host = %host, error = %e, "host shutdown failed");
317 } else {
318 ok += 1;
319 }
320 }
321 tracing::info!(attempted, ok, "hostmesh shutdown summary");
322 Ok(())
323 }
324}
325
326impl Deref for HostMesh {
327 type Target = HostMeshRef;
328
329 fn deref(&self) -> &Self::Target {
330 &self.current_ref
331 }
332}
333
334impl Drop for HostMesh {
335 fn drop(&mut self) {
353 let hosts: Vec<HostRef> = match &self.allocation {
355 HostMeshAllocation::ProcMesh { hosts, .. } | HostMeshAllocation::Owned { hosts } => {
356 hosts.clone()
357 }
358 };
359
360 if let Ok(handle) = tokio::runtime::Handle::try_current() {
362 let mesh_name = self.name.clone();
363 let allocation_label = match &self.allocation {
364 HostMeshAllocation::ProcMesh { .. } => "proc_mesh",
365 HostMeshAllocation::Owned { .. } => "owned",
366 }
367 .to_string();
368
369 handle.spawn(async move {
370 let span = tracing::info_span!(
371 "hostmesh_drop_cleanup",
372 %mesh_name,
373 allocation = %allocation_label,
374 hosts = hosts.len(),
375 );
376 let _g = span.enter();
377
378 match hyperactor::Proc::direct(
381 ChannelTransport::Unix.any(),
382 "hostmesh-drop".to_string(),
383 )
384 .await
385 {
386 Err(e) => {
387 tracing::warn!(
388 error = %e,
389 "failed to construct ephemeral Proc for drop-cleanup; \
390 relying on PDEATHSIG/manager Drop"
391 );
392 }
393 Ok(proc) => {
394 match proc.instance("drop") {
395 Err(e) => {
396 tracing::warn!(
397 error = %e,
398 "failed to create ephemeral instance for drop-cleanup; \
399 relying on PDEATHSIG/manager Drop"
400 );
401 }
402 Ok((instance, _guard)) => {
403 let mut attempted = 0usize;
404 let mut ok = 0usize;
405 let mut err = 0usize;
406
407 for host in hosts {
408 attempted += 1;
409 tracing::debug!(host = %host, "drop-cleanup: shutdown start");
410 match host.shutdown(&instance).await {
411 Ok(()) => {
412 ok += 1;
413 tracing::debug!(host = %host, "drop-cleanup: shutdown ok");
414 }
415 Err(e) => {
416 err += 1;
417 tracing::warn!(host = %host, error = %e, "drop-cleanup: shutdown failed");
418 }
419 }
420 }
421
422 tracing::info!(
423 attempted, ok, err,
424 "hostmesh drop-cleanup summary"
425 );
426 }
427 }
428 }
429 }
430 });
431 } else {
432 tracing::warn!(
435 hosts = hosts.len(),
436 "HostMesh dropped without a tokio runtime; skipping best-effort shutdown"
437 );
438 }
439 }
440}
441
442#[derive(Debug, Clone, PartialEq, Eq, Hash, Named, Serialize, Deserialize)]
461pub struct HostMeshRef {
462 region: Region,
463 ranks: Arc<Vec<HostRef>>,
464}
465
466impl HostMeshRef {
467 fn new(region: Region, ranks: Vec<HostRef>) -> v1::Result<Self> {
470 if region.num_ranks() != ranks.len() {
471 return Err(v1::Error::InvalidRankCardinality {
472 expected: region.num_ranks(),
473 actual: ranks.len(),
474 });
475 }
476 Ok(Self {
477 region,
478 ranks: Arc::new(ranks),
479 })
480 }
481
482 pub fn from_hosts(hosts: Vec<ChannelAddr>) -> Self {
485 Self {
486 region: extent!(hosts = hosts.len()).into(),
487 ranks: Arc::new(hosts.into_iter().map(HostRef).collect()),
488 }
489 }
490
491 pub async fn spawn(
497 &self,
498 cx: &impl context::Actor,
499 name: &str,
500 per_host: Extent,
501 ) -> v1::Result<ProcMesh> {
502 let per_host_labels = per_host.labels().iter().collect::<HashSet<_>>();
503 let host_labels = self.region.labels().iter().collect::<HashSet<_>>();
504 if !per_host_labels
505 .intersection(&host_labels)
506 .collect::<Vec<_>>()
507 .is_empty()
508 {
509 return Err(v1::Error::ConfigurationError(anyhow::anyhow!(
510 "per_host dims overlap with existing dims when spawning proc mesh"
511 )));
512 }
513
514 let extent = self
515 .region
516 .extent()
517 .concat(&per_host)
518 .map_err(|err| v1::Error::ConfigurationError(err.into()))?;
519
520 let mesh_name = Name::new(name);
521 let mut procs = Vec::new();
522 let num_ranks = self.region().num_ranks() * per_host.num_ranks();
523 let (port, mut rx) = cx.mailbox().open_accum_port(RankedValues::default());
524 for (host_rank, host) in self.ranks.iter().enumerate() {
530 for per_host_rank in 0..per_host.num_ranks() {
531 let create_rank = per_host.num_ranks() * host_rank + per_host_rank;
532 let proc_name = Name::new(format!("{}-{}", name, per_host_rank));
533 host.mesh_agent()
534 .create_or_update(cx, proc_name.clone(), resource::Rank::new(create_rank), ())
535 .await
536 .map_err(|e| {
537 v1::Error::HostMeshAgentConfigurationError(
538 host.mesh_agent().actor_id().clone(),
539 format!("failed while creating proc: {}", e),
540 )
541 })?;
542 host.mesh_agent()
543 .get_rank_status(cx, proc_name.clone(), port.bind())
544 .await
545 .map_err(|e| {
546 v1::Error::HostMeshAgentConfigurationError(
547 host.mesh_agent().actor_id().clone(),
548 format!("failed while querying proc status: {}", e),
549 )
550 })?;
551 procs.push(ProcRef::new(
552 host.named_proc(&proc_name),
553 create_rank,
554 ActorRef::attest(host.named_proc(&proc_name).actor_id("agent", 0)),
556 ));
557 }
558 }
559
560 loop {
562 let statuses = rx.recv().await?;
563 if let Some((ranks, status)) =
564 statuses.iter().find(|(_, status)| status.is_terminating())
565 {
566 let rank = ranks.start;
567 let proc_name = Name::new(format!("{}-{}", name, rank % per_host.num_ranks()));
568 return Err(v1::Error::ProcCreationError {
569 proc_name,
570 mesh_agent: self.ranks[rank].mesh_agent(),
571 host_rank: rank / per_host.num_ranks(),
572 status: status.clone(),
573 });
574 }
575
576 if statuses.rank(num_ranks) == num_ranks {
577 break;
578 }
579 }
580
581 ProcMesh::create_owned_unchecked(cx, mesh_name, extent, self.clone(), procs).await
582 }
583}
584
585impl view::Ranked for HostMeshRef {
586 type Item = HostRef;
587
588 fn region(&self) -> &Region {
589 &self.region
590 }
591
592 fn get(&self, rank: usize) -> Option<&Self::Item> {
593 self.ranks.get(rank)
594 }
595}
596
597impl view::RankedSliceable for HostMeshRef {
598 fn sliced(&self, region: Region) -> Self {
599 let ranks = self
600 .region()
601 .remap(®ion)
602 .unwrap()
603 .map(|index| self.get(index).unwrap().clone());
604 Self::new(region, ranks.collect()).unwrap()
605 }
606}
607
608impl std::fmt::Display for HostMeshRef {
609 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
610 for (rank, host) in self.ranks.iter().enumerate() {
611 if rank > 0 {
612 write!(f, ",")?;
613 }
614 write!(f, "{}", host)?;
615 }
616 write!(f, "@{}", self.region)
617 }
618}
619
620#[derive(thiserror::Error, Debug)]
622pub enum HostMeshRefParseError {
623 #[error(transparent)]
624 RegionParseError(#[from] RegionParseError),
625
626 #[error("invalid host mesh ref: missing region")]
627 MissingRegion,
628
629 #[error(transparent)]
630 InvalidHostMeshRef(#[from] Box<v1::Error>),
631
632 #[error(transparent)]
633 Other(#[from] anyhow::Error),
634}
635
636impl From<v1::Error> for HostMeshRefParseError {
637 fn from(err: v1::Error) -> Self {
638 Self::InvalidHostMeshRef(Box::new(err))
639 }
640}
641
642impl FromStr for HostMeshRef {
643 type Err = HostMeshRefParseError;
644
645 fn from_str(s: &str) -> Result<Self, Self::Err> {
646 let (hosts, region) = s
647 .split_once('@')
648 .ok_or(HostMeshRefParseError::MissingRegion)?;
649 let hosts = hosts
650 .split(',')
651 .map(|host| host.trim())
652 .map(|host| host.parse::<HostRef>())
653 .collect::<Result<Vec<_>, _>>()?;
654 let region = region.parse()?;
655 Ok(HostMeshRef::new(region, hosts)?)
656 }
657}
658
659#[cfg(test)]
660mod tests {
661 use std::assert_matches::assert_matches;
662 use std::collections::HashSet;
663 use std::collections::VecDeque;
664
665 use hyperactor::context::Mailbox as _;
666 use itertools::Itertools;
667 use ndslice::ViewExt;
668 use ndslice::extent;
669 use tokio::process::Command;
670
671 use super::*;
672 use crate::Bootstrap;
673 use crate::v1::ActorMesh;
674 use crate::v1::testactor;
675 use crate::v1::testing;
676
677 #[test]
678 fn test_host_mesh_subset() {
679 let hosts: HostMeshRef = "local:1,local:2,local:3,local:4@replica=2/2,host=2/1"
680 .parse()
681 .unwrap();
682 assert_eq!(
683 hosts.range("replica", 1).unwrap().to_string(),
684 "local:3,local:4@2+replica=1/2,host=2/1"
685 );
686 }
687
688 #[test]
689 fn test_host_mesh_ref_parse_roundtrip() {
690 let host_mesh_ref = HostMeshRef::new(
691 extent!(replica = 2, host = 2).into(),
692 vec![
693 "tcp:127.0.0.1:123".parse().unwrap(),
694 "tcp:127.0.0.1:123".parse().unwrap(),
695 "tcp:127.0.0.1:123".parse().unwrap(),
696 "tcp:127.0.0.1:123".parse().unwrap(),
697 ],
698 )
699 .unwrap();
700
701 assert_eq!(
702 host_mesh_ref.to_string().parse::<HostMeshRef>().unwrap(),
703 host_mesh_ref
704 );
705 }
706
707 #[tokio::test]
708 async fn test_allocate() {
709 let config = hyperactor::config::global::lock();
710 let _guard = config.override_key(crate::bootstrap::MESH_BOOTSTRAP_ENABLE_PDEATHSIG, false);
711
712 let instance = testing::instance().await;
713
714 for alloc in testing::allocs(extent!(replicas = 4)).await {
715 let host_mesh = HostMesh::allocate(instance, alloc, "test", None)
716 .await
717 .unwrap();
718
719 let proc_mesh1 = host_mesh
720 .spawn(instance, "test_1", Extent::unity())
721 .await
722 .unwrap();
723
724 let actor_mesh1: ActorMesh<testactor::TestActor> =
725 proc_mesh1.spawn(instance, "test", &()).await.unwrap();
726
727 let proc_mesh2 = host_mesh
728 .spawn(instance, "test_2", extent!(gpus = 3, extra = 2))
729 .await
730 .unwrap();
731 assert_eq!(
732 proc_mesh2.extent(),
733 extent!(replicas = 4, gpus = 3, extra = 2)
734 );
735 assert_eq!(proc_mesh2.values().count(), 24);
736
737 let actor_mesh2: ActorMesh<testactor::TestActor> =
738 proc_mesh2.spawn(instance, "test", &()).await.unwrap();
739 assert_eq!(
740 actor_mesh2.extent(),
741 extent!(replicas = 4, gpus = 3, extra = 2)
742 );
743 assert_eq!(actor_mesh2.values().count(), 24);
744
745 let host_mesh_ref: HostMeshRef = host_mesh.clone();
747 assert_eq!(
749 host_mesh_ref.iter().collect::<Vec<_>>(),
750 host_mesh.iter().collect::<Vec<_>>(),
751 );
752
753 for actor_mesh in [&actor_mesh1, &actor_mesh2] {
755 let (port, mut rx) = instance.mailbox().open_port();
756 actor_mesh
757 .cast(instance, testactor::GetActorId(port.bind()))
758 .unwrap();
759
760 let mut expected_actor_ids: HashSet<_> = actor_mesh
761 .values()
762 .map(|actor_ref| actor_ref.actor_id().clone())
763 .collect();
764
765 while !expected_actor_ids.is_empty() {
766 let actor_id = rx.recv().await.unwrap();
767 assert!(
768 expected_actor_ids.remove(&actor_id),
769 "got {actor_id}, expect {expected_actor_ids:?}"
770 );
771 }
772 }
773
774 let mut to_visit: VecDeque<_> = actor_mesh1
778 .values()
779 .chain(actor_mesh2.values())
780 .map(|actor_ref| actor_ref.port())
781 .permutations(2)
783 .flatten()
785 .collect();
786
787 let expect_visited: Vec<_> = to_visit.clone().into();
788
789 let (last, mut last_rx) = instance.mailbox().open_port();
791 to_visit.push_back(last.bind());
792
793 let forward = testactor::Forward {
794 to_visit,
795 visited: Vec::new(),
796 };
797 let first = forward.to_visit.front().unwrap().clone();
798 first.send(instance, forward).unwrap();
799
800 let forward = last_rx.recv().await.unwrap();
801 assert_eq!(forward.visited, expect_visited);
802
803 let _ = host_mesh.shutdown(&instance).await;
804 }
805 }
806
807 fn free_localhost_addr() -> ChannelAddr {
813 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
814 ChannelAddr::Tcp(listener.local_addr().unwrap())
815 }
816
817 #[tokio::test]
818 async fn test_extrinsic_allocation() {
819 let config = hyperactor::config::global::lock();
820 let _guard = config.override_key(crate::bootstrap::MESH_BOOTSTRAP_ENABLE_PDEATHSIG, false);
821
822 let program = crate::testresource::get("monarch/hyperactor_mesh/bootstrap");
823
824 let hosts = vec![free_localhost_addr(), free_localhost_addr()];
825
826 let mut children = Vec::new();
827 for host in hosts.iter() {
828 let mut cmd = Command::new(program.clone());
829 let boot = Bootstrap::Host {
830 addr: host.clone(),
831 command: None, config: None,
833 };
834 boot.to_env(&mut cmd);
835 cmd.kill_on_drop(true);
836 children.push(cmd.spawn().unwrap());
837 }
838
839 let instance = testing::instance().await;
840 let host_mesh = HostMeshRef::from_hosts(hosts);
841
842 let proc_mesh = host_mesh
843 .spawn(&testing::instance().await, "test", Extent::unity())
844 .await
845 .unwrap();
846
847 let actor_mesh: ActorMesh<testactor::TestActor> = proc_mesh
848 .spawn(&testing::instance().await, "test", &())
849 .await
850 .unwrap();
851
852 testactor::assert_mesh_shape(actor_mesh).await;
853
854 HostMesh::take(Name::new("extrinsic"), host_mesh)
855 .shutdown(&instance)
856 .await
857 .expect("hosts shutdown");
858 }
859
860 #[tokio::test]
861 async fn test_failing_proc_allocation() {
862 let program = buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap();
863
864 let hosts = vec![free_localhost_addr(), free_localhost_addr()];
865
866 let mut children = Vec::new();
867 for host in hosts.iter() {
868 let mut cmd = Command::new(program.clone());
869 let boot = Bootstrap::Host {
870 addr: host.clone(),
871 config: None,
872 command: Some(BootstrapCommand::from("/bin/false")),
874 };
875 boot.to_env(&mut cmd);
876 cmd.kill_on_drop(true);
877 children.push(cmd.spawn().unwrap());
878 }
879 let host_mesh = HostMeshRef::from_hosts(hosts);
880
881 let instance = testing::instance().await;
882
883 let err = host_mesh
884 .spawn(&instance, "test", Extent::unity())
885 .await
886 .unwrap_err();
887 assert_matches!(
888 err, v1::Error::ProcCreationError { status: resource::Status::Failed(msg), .. }
889 if msg.contains("failed to configure process: Terminal(Stopped { exit_code: 1")
890 );
891 }
892}