1use hyperactor::Actor;
10use hyperactor::ActorHandle;
11use hyperactor::accum::ReducerOpts;
12use hyperactor::channel::ChannelTransport;
13use hyperactor::clock::Clock;
14use hyperactor::clock::RealClock;
15use hyperactor::config;
16use hyperactor::config::CONFIG;
17use hyperactor::config::ConfigAttr;
18use hyperactor::declare_attrs;
19use hyperactor::host::Host;
20use ndslice::view::CollectMeshExt;
21
22pub mod mesh_agent;
23
24use std::collections::HashSet;
25use std::ops::Deref;
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::Duration;
29
30use hyperactor::ActorRef;
31use hyperactor::Named;
32use hyperactor::ProcId;
33use hyperactor::channel::ChannelAddr;
34use hyperactor::context;
35use ndslice::Extent;
36use ndslice::Region;
37use ndslice::ViewExt;
38use ndslice::extent;
39use ndslice::view;
40use ndslice::view::Ranked;
41use ndslice::view::RegionParseError;
42use serde::Deserialize;
43use serde::Serialize;
44
45use crate::Bootstrap;
46use crate::alloc::Alloc;
47use crate::bootstrap::BootstrapCommand;
48use crate::bootstrap::BootstrapProcManager;
49use crate::proc_mesh::DEFAULT_TRANSPORT;
50use crate::resource;
51use crate::resource::CreateOrUpdateClient;
52use crate::resource::GetRankStatus;
53use crate::resource::GetRankStatusClient;
54use crate::resource::ProcSpec;
55use crate::resource::RankedValues;
56use crate::resource::Status;
57use crate::v1;
58use crate::v1::Name;
59use crate::v1::ProcMesh;
60use crate::v1::ProcMeshRef;
61use crate::v1::ValueMesh;
62use crate::v1::host_mesh::mesh_agent::HostAgentMode;
63pub use crate::v1::host_mesh::mesh_agent::HostMeshAgent;
64use crate::v1::host_mesh::mesh_agent::HostMeshAgentProcMeshTrampoline;
65use crate::v1::host_mesh::mesh_agent::ProcState;
66use crate::v1::host_mesh::mesh_agent::ShutdownHostClient;
67use crate::v1::mesh_controller::HostMeshController;
68use crate::v1::mesh_controller::ProcMeshController;
69use crate::v1::proc_mesh::ProcRef;
70
71declare_attrs! {
72 @meta(CONFIG = ConfigAttr {
75 env_name: Some("HYPERACTOR_MESH_PROC_SPAWN_MAX_IDLE".to_string()),
76 py_name: None,
77 })
78 pub attr PROC_SPAWN_MAX_IDLE: Duration = Duration::from_secs(30);
79
80 @meta(CONFIG = ConfigAttr {
83 env_name: Some("HYPERACTOR_MESH_PROC_STOP_MAX_IDLE".to_string()),
84 py_name: None,
85 })
86 pub attr PROC_STOP_MAX_IDLE: Duration = Duration::from_secs(30);
87
88 @meta(CONFIG = ConfigAttr {
89 env_name: Some("HYPERACTOR_MESH_GET_PROC_STATE_MAX_IDLE".to_string()),
90 py_name: None,
91 })
92 pub attr GET_PROC_STATE_MAX_IDLE: Duration = Duration::from_secs(60);
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, Hash, Named, Serialize, Deserialize)]
97pub struct HostRef(ChannelAddr);
98
99impl HostRef {
100 fn mesh_agent(&self) -> ActorRef<HostMeshAgent> {
102 ActorRef::attest(self.service_proc().actor_id("agent", 0))
103 }
104
105 fn named_proc(&self, name: &Name) -> ProcId {
107 ProcId::Direct(self.0.clone(), name.to_string())
108 }
109
110 fn service_proc(&self) -> ProcId {
112 ProcId::Direct(self.0.clone(), "service".to_string())
113 }
114
115 pub(crate) async fn shutdown(
134 &self,
135 cx: &impl hyperactor::context::Actor,
136 ) -> anyhow::Result<()> {
137 let agent = self.mesh_agent();
138 let terminate_timeout =
139 hyperactor::config::global::get(crate::bootstrap::MESH_TERMINATE_TIMEOUT);
140 let max_in_flight =
141 hyperactor::config::global::get(crate::bootstrap::MESH_TERMINATE_CONCURRENCY);
142 agent
143 .shutdown_host(cx, terminate_timeout, max_in_flight.clamp(1, 256))
144 .await?;
145 Ok(())
146 }
147}
148
149impl std::fmt::Display for HostRef {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 self.0.fmt(f)
152 }
153}
154
155impl FromStr for HostRef {
156 type Err = <ChannelAddr as FromStr>::Err;
157
158 fn from_str(s: &str) -> Result<Self, Self::Err> {
159 Ok(HostRef(ChannelAddr::from_str(s)?))
160 }
161}
162
163#[allow(dead_code)]
176pub struct HostMesh {
177 name: Name,
178 extent: Extent,
179 allocation: HostMeshAllocation,
180 current_ref: HostMeshRef,
181}
182
183#[allow(dead_code)]
201enum HostMeshAllocation {
202 ProcMesh {
209 proc_mesh: ProcMesh,
210 proc_mesh_ref: ProcMeshRef,
211 hosts: Vec<HostRef>,
212 },
213 Owned { hosts: Vec<HostRef> },
221}
222
223impl HostMesh {
224 pub async fn local() -> v1::Result<HostMesh> {
247 Self::local_with_bootstrap(BootstrapCommand::current()?).await
248 }
249
250 pub async fn local_with_bootstrap(bootstrap_cmd: BootstrapCommand) -> v1::Result<HostMesh> {
258 if let Ok(Some(boot)) = Bootstrap::get_from_env() {
259 let err = boot.bootstrap().await;
260 tracing::error!("failed to bootstrap local host mesh process: {}", err);
261 std::process::exit(1);
262 }
263
264 let addr = config::global::get_cloned(DEFAULT_TRANSPORT).any();
265
266 let manager = BootstrapProcManager::new(bootstrap_cmd)?;
267 let (host, _handle) = Host::serve(manager, addr).await?;
268 let addr = host.addr().clone();
269 let host_mesh_agent = host
270 .system_proc()
271 .clone()
272 .spawn::<HostMeshAgent>("agent", HostAgentMode::Process(host))
273 .await
274 .map_err(v1::Error::SingletonActorSpawnError)?;
275 host_mesh_agent.bind::<HostMeshAgent>();
276
277 let host = HostRef(addr);
278 let host_mesh_ref =
279 HostMeshRef::new(Name::new("local"), extent!(hosts = 1).into(), vec![host])?;
280 Ok(HostMesh::take(host_mesh_ref))
281 }
282
283 pub async fn process(extent: Extent, command: BootstrapCommand) -> v1::Result<HostMesh> {
294 if let Ok(Some(boot)) = Bootstrap::get_from_env() {
295 let err = boot.bootstrap().await;
296 tracing::error!("failed to bootstrap process host mesh process: {}", err);
297 std::process::exit(1);
298 }
299
300 let transport = config::global::get_cloned(DEFAULT_TRANSPORT);
301 let mut hosts = Vec::with_capacity(extent.num_ranks());
302 for _ in 0..extent.num_ranks() {
303 let addr = transport.any();
305 let bootstrap = Bootstrap::Host {
306 addr: addr.clone(),
307 command: Some(command.clone()),
308 config: Some(config::global::attrs()),
309 };
310
311 let mut cmd = command.new();
312 bootstrap.to_env(&mut cmd);
313 cmd.spawn()?;
314 hosts.push(HostRef(addr));
315 }
316
317 let host_mesh_ref = HostMeshRef::new(Name::new("process"), extent.into(), hosts)?;
318 Ok(HostMesh::take(host_mesh_ref))
319 }
320
321 pub async fn allocate(
365 cx: &impl context::Actor,
366 alloc: Box<dyn Alloc + Send + Sync>,
367 name: &str,
368 bootstrap_params: Option<BootstrapCommand>,
369 ) -> v1::Result<Self> {
370 Self::allocate_inner(cx, alloc, Name::new(name), bootstrap_params).await
371 }
372
373 #[hyperactor::instrument(fields(host_mesh=name.to_string()))]
375 async fn allocate_inner(
376 cx: &impl context::Actor,
377 alloc: Box<dyn Alloc + Send + Sync>,
378 name: Name,
379 bootstrap_params: Option<BootstrapCommand>,
380 ) -> v1::Result<Self> {
381 tracing::info!(name = "HostMeshStatus", status = "Allocate::Attempt");
382 let transport = alloc.transport();
383 let extent = alloc.extent().clone();
384 let is_local = alloc.is_local();
385 let proc_mesh = ProcMesh::allocate(cx, alloc, name.name()).await?;
386
387 let (mesh_agents, mut mesh_agents_rx) = cx.mailbox().open_port();
392 let _trampoline_actor_mesh = proc_mesh
393 .spawn::<HostMeshAgentProcMeshTrampoline>(
394 cx,
395 "host_mesh_trampoline",
396 &(transport, mesh_agents.bind(), bootstrap_params, is_local),
397 )
398 .await?;
399
400 let mut hosts = Vec::new();
402 for _rank in 0..extent.num_ranks() {
403 let mesh_agent = mesh_agents_rx.recv().await?;
404
405 let Some((addr, _)) = mesh_agent.actor_id().proc_id().as_direct() else {
406 return Err(v1::Error::HostMeshAgentConfigurationError(
407 mesh_agent.actor_id().clone(),
408 "host mesh agent must be a direct actor".to_string(),
409 ));
410 };
411
412 let host_ref = HostRef(addr.clone());
413 if host_ref.mesh_agent() != mesh_agent {
414 return Err(v1::Error::HostMeshAgentConfigurationError(
415 mesh_agent.actor_id().clone(),
416 format!(
417 "expected mesh agent actor id to be {}",
418 host_ref.mesh_agent().actor_id()
419 ),
420 ));
421 }
422 hosts.push(host_ref);
423 }
424
425 let proc_mesh_ref = proc_mesh.clone();
426 let mesh = Self {
427 name: name.clone(),
428 extent: extent.clone(),
429 allocation: HostMeshAllocation::ProcMesh {
430 proc_mesh,
431 proc_mesh_ref,
432 hosts: hosts.clone(),
433 },
434 current_ref: HostMeshRef::new(name, extent.into(), hosts).unwrap(),
435 };
436
437 let _controller: ActorHandle<HostMeshController> =
440 HostMeshController::spawn(cx, mesh.deref().clone())
441 .await
442 .map_err(|e| v1::Error::ControllerActorSpawnError(mesh.name().clone(), e))?;
443
444 tracing::info!(name = "HostMeshStatus", status = "Allocate::Created");
445 Ok(mesh)
446 }
447
448 pub fn take(mesh: HostMeshRef) -> Self {
455 let region = mesh.region().clone();
456 let hosts: Vec<HostRef> = mesh.values().collect();
457
458 let current_ref = HostMeshRef::new(mesh.name.clone(), region.clone(), hosts.clone())
459 .expect("region/hosts cardinality must match");
460
461 Self {
462 name: mesh.name,
463 extent: region.extent().clone(),
464 allocation: HostMeshAllocation::Owned { hosts },
465 current_ref,
466 }
467 }
468
469 #[hyperactor::instrument(fields(host_mesh=self.name.to_string()))]
479 pub async fn shutdown(&self, cx: &impl hyperactor::context::Actor) -> anyhow::Result<()> {
480 tracing::info!(name = "HostMeshStatus", status = "Shutdown::Attempt");
481 let mut failed_hosts = vec![];
482 for host in self.current_ref.values() {
483 if let Err(e) = host.shutdown(cx).await {
484 tracing::warn!(
485 name = "HostMeshStatus",
486 status = "Shutdown::Host::Failed",
487 host = %host,
488 error = %e,
489 "host shutdown failed"
490 );
491 failed_hosts.push(host);
492 }
493 }
494 if failed_hosts.is_empty() {
495 tracing::info!(name = "HostMeshStatus", status = "Shutdown::Success");
496 } else {
497 tracing::error!(
498 name = "HostMeshStatus",
499 status = "Shutdown::Failed",
500 "host mesh shutdown failed; check the logs of the failed hosts for details: {:?}",
501 failed_hosts
502 );
503 }
504 Ok(())
505 }
506}
507
508impl Deref for HostMesh {
509 type Target = HostMeshRef;
510
511 fn deref(&self) -> &Self::Target {
512 &self.current_ref
513 }
514}
515
516impl Drop for HostMesh {
517 fn drop(&mut self) {
535 tracing::info!(
536 name = "HostMeshStatus",
537 host_mesh = %self.name,
538 status = "Dropping",
539 );
540 let hosts: Vec<HostRef> = match &self.allocation {
542 HostMeshAllocation::ProcMesh { hosts, .. } | HostMeshAllocation::Owned { hosts } => {
543 hosts.clone()
544 }
545 };
546
547 if let Ok(handle) = tokio::runtime::Handle::try_current() {
549 let mesh_name = self.name.clone();
550 let allocation_label = match &self.allocation {
551 HostMeshAllocation::ProcMesh { .. } => "proc_mesh",
552 HostMeshAllocation::Owned { .. } => "owned",
553 }
554 .to_string();
555
556 handle.spawn(async move {
557 let span = tracing::info_span!(
558 "hostmesh_drop_cleanup",
559 host_mesh = %mesh_name,
560 allocation = %allocation_label,
561 hosts = hosts.len(),
562 );
563 let _g = span.enter();
564
565 match hyperactor::Proc::direct(
568 ChannelTransport::Unix.any(),
569 "hostmesh-drop".to_string(),
570 )
571 .await
572 {
573 Err(e) => {
574 tracing::warn!(
575 error = %e,
576 "failed to construct ephemeral Proc for drop-cleanup; \
577 relying on PDEATHSIG/manager Drop"
578 );
579 }
580 Ok(proc) => {
581 match proc.instance("drop") {
582 Err(e) => {
583 tracing::warn!(
584 error = %e,
585 "failed to create ephemeral instance for drop-cleanup; \
586 relying on PDEATHSIG/manager Drop"
587 );
588 }
589 Ok((instance, _guard)) => {
590 let mut attempted = 0usize;
591 let mut ok = 0usize;
592 let mut err = 0usize;
593
594 for host in hosts {
595 attempted += 1;
596 tracing::debug!(host = %host, "drop-cleanup: shutdown start");
597 match host.shutdown(&instance).await {
598 Ok(()) => {
599 ok += 1;
600 tracing::debug!(host = %host, "drop-cleanup: shutdown ok");
601 }
602 Err(e) => {
603 err += 1;
604 tracing::warn!(host = %host, error = %e, "drop-cleanup: shutdown failed");
605 }
606 }
607 }
608
609 tracing::info!(
610 attempted, ok, err,
611 "hostmesh drop-cleanup summary"
612 );
613 }
614 }
615 }
616 }
617 });
618 } else {
619 tracing::warn!(
622 host_mesh = %self.name,
623 hosts = hosts.len(),
624 "HostMesh dropped without a tokio runtime; skipping best-effort shutdown"
625 );
626 }
627
628 tracing::info!(
629 name = "HostMeshStatus",
630 host_mesh = %self.name,
631 status = "Dropped",
632 );
633 }
634}
635
636pub(crate) fn mesh_to_rankedvalues_with_default<T, F>(
645 mesh: &ValueMesh<T>,
646 default: T,
647 is_sentinel: F,
648 len: usize,
649) -> RankedValues<T>
650where
651 T: Eq + Clone + 'static,
652 F: Fn(&T) -> bool,
653{
654 let mut out = RankedValues::from((0..len, default));
655 for (i, s) in mesh.values().enumerate() {
656 if !is_sentinel(&s) {
657 out.merge_from(RankedValues::from((i..i + 1, s)));
658 }
659 }
660 out
661}
662
663#[derive(Debug, Clone, PartialEq, Eq, Hash, Named, Serialize, Deserialize)]
682pub struct HostMeshRef {
683 name: Name,
684 region: Region,
685 ranks: Arc<Vec<HostRef>>,
686}
687
688impl HostMeshRef {
689 #[allow(clippy::result_large_err)]
692 fn new(name: Name, region: Region, ranks: Vec<HostRef>) -> v1::Result<Self> {
693 if region.num_ranks() != ranks.len() {
694 return Err(v1::Error::InvalidRankCardinality {
695 expected: region.num_ranks(),
696 actual: ranks.len(),
697 });
698 }
699 Ok(Self {
700 name,
701 region,
702 ranks: Arc::new(ranks),
703 })
704 }
705
706 pub fn from_hosts(name: Name, hosts: Vec<ChannelAddr>) -> Self {
709 Self {
710 name,
711 region: extent!(hosts = hosts.len()).into(),
712 ranks: Arc::new(hosts.into_iter().map(HostRef).collect()),
713 }
714 }
715
716 #[allow(clippy::result_large_err)]
722 pub async fn spawn(
723 &self,
724 cx: &impl context::Actor,
725 name: &str,
726 per_host: Extent,
727 ) -> v1::Result<ProcMesh> {
728 self.spawn_inner(cx, Name::new(name), per_host).await
729 }
730
731 #[hyperactor::instrument(fields(host_mesh=self.name.to_string(), proc_mesh=proc_mesh_name.to_string()))]
732 async fn spawn_inner(
733 &self,
734 cx: &impl context::Actor,
735 proc_mesh_name: Name,
736 per_host: Extent,
737 ) -> v1::Result<ProcMesh> {
738 tracing::info!(name = "HostMeshStatus", status = "ProcMesh::Spawn::Attempt");
739 tracing::info!(name = "ProcMeshStatus", status = "Spawn::Attempt",);
740 let result = self.spawn_inner_inner(cx, proc_mesh_name, per_host).await;
741 match &result {
742 Ok(_) => {
743 tracing::info!(name = "HostMeshStatus", status = "ProcMesh::Spawn::Success");
744 tracing::info!(name = "ProcMeshStatus", status = "Spawn::Success");
745 }
746 Err(error) => {
747 tracing::error!(name = "HostMeshStatus", status = "ProcMesh::Spawn::Failed", %error);
748 tracing::error!(name = "ProcMeshStatus", status = "Spawn::Failed", %error);
749 }
750 }
751 result
752 }
753
754 async fn spawn_inner_inner(
755 &self,
756 cx: &impl context::Actor,
757 proc_mesh_name: Name,
758 per_host: Extent,
759 ) -> v1::Result<ProcMesh> {
760 let per_host_labels = per_host.labels().iter().collect::<HashSet<_>>();
761 let host_labels = self.region.labels().iter().collect::<HashSet<_>>();
762 if !per_host_labels
763 .intersection(&host_labels)
764 .collect::<Vec<_>>()
765 .is_empty()
766 {
767 return Err(v1::Error::ConfigurationError(anyhow::anyhow!(
768 "per_host dims overlap with existing dims when spawning proc mesh"
769 )));
770 }
771
772 let extent = self
773 .region
774 .extent()
775 .concat(&per_host)
776 .map_err(|err| v1::Error::ConfigurationError(err.into()))?;
777
778 let region: Region = extent.clone().into();
779
780 tracing::info!(
781 name = "ProcMeshStatus",
782 status = "Spawn::Attempt",
783 %region,
784 "spawning proc mesh"
785 );
786
787 let mut procs = Vec::new();
788 let num_ranks = region.num_ranks();
789 let (port, rx) = cx.mailbox().open_accum_port_opts(
792 crate::v1::StatusMesh::from_single(region.clone(), Status::NotExist),
793 Some(ReducerOpts {
794 max_update_interval: Some(Duration::from_millis(50)),
795 }),
796 );
797
798 let mut proc_names = Vec::new();
805 let client_config_override = config::global::attrs();
806 for (host_rank, host) in self.ranks.iter().enumerate() {
807 for per_host_rank in 0..per_host.num_ranks() {
808 let create_rank = per_host.num_ranks() * host_rank + per_host_rank;
809 let proc_name = Name::new(format!("{}_{}", proc_mesh_name.name(), per_host_rank));
810 proc_names.push(proc_name.clone());
811 host.mesh_agent()
812 .create_or_update(
813 cx,
814 proc_name.clone(),
815 resource::Rank::new(create_rank),
816 ProcSpec::new(client_config_override.clone()),
817 )
818 .await
819 .map_err(|e| {
820 v1::Error::HostMeshAgentConfigurationError(
821 host.mesh_agent().actor_id().clone(),
822 format!("failed while creating proc: {}", e),
823 )
824 })?;
825 let mut reply_port = port.bind();
826 reply_port.return_undeliverable(false);
829 host.mesh_agent()
830 .get_rank_status(cx, proc_name.clone(), reply_port)
831 .await
832 .map_err(|e| {
833 v1::Error::HostMeshAgentConfigurationError(
834 host.mesh_agent().actor_id().clone(),
835 format!("failed while querying proc status: {}", e),
836 )
837 })?;
838 let proc_id = host.named_proc(&proc_name);
839 tracing::info!(
840 name = "ProcMeshStatus",
841 status = "Spawn::CreatingProc",
842 %proc_id,
843 rank = create_rank,
844 );
845 procs.push(ProcRef::new(
846 proc_id,
847 create_rank,
848 ActorRef::attest(host.named_proc(&proc_name).actor_id("agent", 0)),
850 ));
851 }
852 }
853
854 let start_time = RealClock.now();
855
856 match GetRankStatus::wait(
859 rx,
860 num_ranks,
861 config::global::get(PROC_SPAWN_MAX_IDLE),
862 region.clone(), )
864 .await
865 {
866 Ok(statuses) => {
867 if let Some((rank, status)) = statuses
870 .values()
871 .enumerate()
872 .find(|(_, s)| s.is_terminating())
873 {
874 let proc_name = &proc_names[rank];
875 let host_rank = rank / per_host.num_ranks();
876 let mesh_agent = self.ranks[host_rank].mesh_agent();
877 let (reply_tx, mut reply_rx) = cx.mailbox().open_port();
878 let mut reply_tx = reply_tx.bind();
879 reply_tx.return_undeliverable(false);
882 mesh_agent
883 .send(
884 cx,
885 resource::GetState {
886 name: proc_name.clone(),
887 reply: reply_tx,
888 },
889 )
890 .map_err(|e| {
891 v1::Error::SendingError(mesh_agent.actor_id().clone(), e.into())
892 })?;
893 let state = match RealClock
894 .timeout(config::global::get(PROC_SPAWN_MAX_IDLE), reply_rx.recv())
895 .await
896 {
897 Ok(Ok(state)) => state,
898 _ => resource::State {
899 name: proc_name.clone(),
900 status,
901 state: None,
902 },
903 };
904
905 tracing::error!(
906 name = "ProcMeshStatus",
907 status = "Spawn::GetRankStatus",
908 rank = host_rank,
909 "rank {} is terminating with state: {}",
910 host_rank,
911 state
912 );
913
914 return Err(v1::Error::ProcCreationError {
915 state,
916 host_rank,
917 mesh_agent,
918 });
919 }
920 }
921 Err(complete) => {
922 tracing::error!(
923 name = "ProcMeshStatus",
924 status = "Spawn::GetRankStatus",
925 "timeout after {:?} when waiting for procs being created",
926 config::global::get(PROC_SPAWN_MAX_IDLE),
927 );
928 let legacy = mesh_to_rankedvalues_with_default(
931 &complete,
932 Status::Timeout(start_time.elapsed()),
933 Status::is_not_exist,
934 num_ranks,
935 );
936 return Err(v1::Error::ProcSpawnError { statuses: legacy });
937 }
938 }
939
940 let mesh =
941 ProcMesh::create_owned_unchecked(cx, proc_mesh_name, extent, self.clone(), procs).await;
942 if let Ok(ref mesh) = mesh {
943 let _controller: ActorHandle<ProcMeshController> =
946 ProcMeshController::spawn(cx, mesh.deref().clone())
947 .await
948 .map_err(|e| v1::Error::ControllerActorSpawnError(mesh.name().clone(), e))?;
949 }
950 mesh
951 }
952
953 pub fn name(&self) -> &Name {
955 &self.name
956 }
957
958 #[hyperactor::instrument(fields(host_mesh=self.name.to_string(), proc_mesh=proc_mesh_name.to_string()))]
959 pub(crate) async fn stop_proc_mesh(
960 &self,
961 cx: &impl hyperactor::context::Actor,
962 proc_mesh_name: &Name,
963 procs: impl IntoIterator<Item = ProcId>,
964 region: Region,
965 ) -> anyhow::Result<()> {
966 let mut proc_names = Vec::new();
969 let num_ranks = region.num_ranks();
970 let (port, rx) = cx.mailbox().open_accum_port_opts(
973 crate::v1::StatusMesh::from_single(region.clone(), Status::NotExist),
974 Some(ReducerOpts {
975 max_update_interval: Some(Duration::from_millis(50)),
976 }),
977 );
978 for proc_id in procs.into_iter() {
979 let Some((addr, proc_name)) = proc_id.as_direct() else {
980 return Err(anyhow::anyhow!(
981 "host mesh proc {} must be direct addressed",
982 proc_id,
983 ));
984 };
985 let proc_name = proc_name.parse::<Name>()?;
989 proc_names.push(proc_name.clone());
990
991 let host = HostRef(addr.clone());
994 host.mesh_agent().send(
995 cx,
996 resource::Stop {
997 name: proc_name.clone(),
998 },
999 )?;
1000 host.mesh_agent()
1001 .get_rank_status(cx, proc_name, port.bind())
1002 .await?;
1003
1004 tracing::info!(
1005 name = "ProcMeshStatus",
1006 %proc_id,
1007 status = "Stop::Sent",
1008 );
1009 }
1010 tracing::info!(
1011 name = "HostMeshStatus",
1012 status = "ProcMesh::Stop::Sent",
1013 "sending Stop to proc mesh for {} procs: {}",
1014 proc_names.len(),
1015 proc_names
1016 .iter()
1017 .map(|n| n.to_string())
1018 .collect::<Vec<_>>()
1019 .join(", ")
1020 );
1021
1022 let start_time = RealClock.now();
1023
1024 match GetRankStatus::wait(
1025 rx,
1026 num_ranks,
1027 config::global::get(PROC_STOP_MAX_IDLE),
1028 region.clone(), )
1030 .await
1031 {
1032 Ok(statuses) => {
1033 let all_stopped = statuses.values().all(|s| s.is_terminating());
1034 if !all_stopped {
1035 tracing::error!(
1036 name = "ProcMeshStatus",
1037 status = "FailedToStop",
1038 "failed to terminate proc mesh: {:?}",
1039 statuses,
1040 );
1041 return Err(anyhow::anyhow!(
1042 "failed to terminate proc mesh: {:?}",
1043 statuses,
1044 ));
1045 }
1046 tracing::info!(name = "ProcMeshStatus", status = "Stopped");
1047 }
1048 Err(complete) => {
1049 let legacy = mesh_to_rankedvalues_with_default(
1052 &complete,
1053 Status::Timeout(start_time.elapsed()),
1054 Status::is_not_exist,
1055 num_ranks,
1056 );
1057 tracing::error!(
1058 name = "ProcMeshStatus",
1059 status = "StoppingTimeout",
1060 "failed to terminate proc mesh before timeout: {:?}",
1061 legacy,
1062 );
1063 return Err(anyhow::anyhow!(
1064 "failed to terminate proc mesh {} before timeout: {:?}",
1065 proc_mesh_name,
1066 legacy
1067 ));
1068 }
1069 }
1070 Ok(())
1071 }
1072
1073 #[allow(clippy::result_large_err)]
1076 pub(crate) async fn proc_states(
1077 &self,
1078 cx: &impl context::Actor,
1079 procs: impl IntoIterator<Item = ProcId>,
1080 region: Region,
1081 ) -> v1::Result<ValueMesh<resource::State<ProcState>>> {
1082 let (tx, mut rx) = cx.mailbox().open_port();
1083
1084 let mut num_ranks = 0;
1085 let procs: Vec<ProcId> = procs.into_iter().collect();
1086 let mut proc_names = Vec::new();
1087 for proc_id in procs.iter() {
1088 num_ranks += 1;
1089 let Some((addr, proc_name)) = proc_id.as_direct() else {
1090 return Err(v1::Error::ConfigurationError(anyhow::anyhow!(
1091 "host mesh proc {} must be direct addressed",
1092 proc_id,
1093 )));
1094 };
1095
1096 let host = HostRef(addr.clone());
1099 let proc_name = proc_name.parse::<Name>()?;
1100 proc_names.push(proc_name.clone());
1101 let mut reply = tx.bind();
1102 reply.return_undeliverable(false);
1105 host.mesh_agent()
1106 .send(
1107 cx,
1108 resource::GetState {
1109 name: proc_name,
1110 reply,
1111 },
1112 )
1113 .map_err(|e| {
1114 v1::Error::CallError(host.mesh_agent().actor_id().clone(), e.into())
1115 })?;
1116 }
1117
1118 let mut states = Vec::with_capacity(num_ranks);
1119 let timeout = config::global::get(GET_PROC_STATE_MAX_IDLE);
1120 for _ in 0..num_ranks {
1121 let state = RealClock.timeout(timeout, rx.recv()).await;
1127 if let Ok(state) = state {
1128 let state = state?;
1130 match state.state {
1131 Some(ref inner) => {
1132 states.push((inner.create_rank, state));
1133 }
1134 None => {
1135 return Err(v1::Error::NotExist(state.name));
1136 }
1137 }
1138 } else {
1139 tracing::warn!(
1142 "Timeout waiting for response from host mesh agent for proc_states after {:?}",
1143 timeout
1144 );
1145 let all_ranks = (0..num_ranks).collect::<HashSet<_>>();
1146 let completed_ranks = states.iter().map(|(rank, _)| *rank).collect::<HashSet<_>>();
1147 let mut leftover_ranks = all_ranks.difference(&completed_ranks).collect::<Vec<_>>();
1148 assert_eq!(leftover_ranks.len(), num_ranks - states.len());
1149 while states.len() < num_ranks {
1150 let rank = *leftover_ranks
1151 .pop()
1152 .expect("leftover ranks should not be empty");
1153 states.push((
1154 rank,
1156 resource::State {
1157 name: proc_names[rank].clone(),
1158 status: resource::Status::Timeout(timeout),
1159 state: None,
1160 },
1161 ));
1162 }
1163 break;
1164 }
1165 }
1166 states.sort_by_key(|(rank, _)| *rank);
1170 let vm = states
1171 .into_iter()
1172 .map(|(_, state)| state)
1173 .collect_mesh::<ValueMesh<_>>(region)?;
1174 Ok(vm)
1175 }
1176}
1177
1178impl view::Ranked for HostMeshRef {
1179 type Item = HostRef;
1180
1181 fn region(&self) -> &Region {
1182 &self.region
1183 }
1184
1185 fn get(&self, rank: usize) -> Option<&Self::Item> {
1186 self.ranks.get(rank)
1187 }
1188}
1189
1190impl view::RankedSliceable for HostMeshRef {
1191 fn sliced(&self, region: Region) -> Self {
1192 let ranks = self
1193 .region()
1194 .remap(®ion)
1195 .unwrap()
1196 .map(|index| self.get(index).unwrap().clone());
1197 Self::new(self.name.clone(), region, ranks.collect()).unwrap()
1198 }
1199}
1200
1201impl std::fmt::Display for HostMeshRef {
1202 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1203 write!(f, "{}:", self.name)?;
1204 for (rank, host) in self.ranks.iter().enumerate() {
1205 if rank > 0 {
1206 write!(f, ",")?;
1207 }
1208 write!(f, "{}", host)?;
1209 }
1210 write!(f, "@{}", self.region)
1211 }
1212}
1213
1214#[derive(thiserror::Error, Debug)]
1216pub enum HostMeshRefParseError {
1217 #[error(transparent)]
1218 RegionParseError(#[from] RegionParseError),
1219
1220 #[error("invalid host mesh ref: missing region")]
1221 MissingRegion,
1222
1223 #[error("invalid host mesh ref: missing name")]
1224 MissingName,
1225
1226 #[error(transparent)]
1227 InvalidName(#[from] v1::NameParseError),
1228
1229 #[error(transparent)]
1230 InvalidHostMeshRef(#[from] Box<v1::Error>),
1231
1232 #[error(transparent)]
1233 Other(#[from] anyhow::Error),
1234}
1235
1236impl From<v1::Error> for HostMeshRefParseError {
1237 fn from(err: v1::Error) -> Self {
1238 Self::InvalidHostMeshRef(Box::new(err))
1239 }
1240}
1241
1242impl FromStr for HostMeshRef {
1243 type Err = HostMeshRefParseError;
1244
1245 fn from_str(s: &str) -> Result<Self, Self::Err> {
1246 let (name, rest) = s
1247 .split_once(':')
1248 .ok_or(HostMeshRefParseError::MissingName)?;
1249
1250 let name = Name::from_str(name)?;
1251
1252 let (hosts, region) = rest
1253 .split_once('@')
1254 .ok_or(HostMeshRefParseError::MissingRegion)?;
1255 let hosts = hosts
1256 .split(',')
1257 .map(|host| host.trim())
1258 .map(|host| host.parse::<HostRef>())
1259 .collect::<Result<Vec<_>, _>>()?;
1260 let region = region.parse()?;
1261 Ok(HostMeshRef::new(name, region, hosts)?)
1262 }
1263}
1264
1265#[cfg(test)]
1266mod tests {
1267 use std::assert_matches::assert_matches;
1268 use std::collections::HashSet;
1269 use std::collections::VecDeque;
1270
1271 use hyperactor::attrs::Attrs;
1272 use hyperactor::context::Mailbox as _;
1273 use itertools::Itertools;
1274 use ndslice::ViewExt;
1275 use ndslice::extent;
1276 use tokio::process::Command;
1277
1278 use super::*;
1279 use crate::Bootstrap;
1280 use crate::bootstrap::MESH_TAIL_LOG_LINES;
1281 use crate::resource::Status;
1282 use crate::v1::ActorMesh;
1283 use crate::v1::testactor;
1284 use crate::v1::testactor::GetConfigAttrs;
1285 use crate::v1::testactor::SetConfigAttrs;
1286 use crate::v1::testing;
1287
1288 #[test]
1289 fn test_host_mesh_subset() {
1290 let hosts: HostMeshRef = "test:local:1,local:2,local:3,local:4@replica=2/2,host=2/1"
1291 .parse()
1292 .unwrap();
1293 assert_eq!(
1294 hosts.range("replica", 1).unwrap().to_string(),
1295 "test:local:3,local:4@2+replica=1/2,host=2/1"
1296 );
1297 }
1298
1299 #[test]
1300 fn test_host_mesh_ref_parse_roundtrip() {
1301 let host_mesh_ref = HostMeshRef::new(
1302 Name::new("test"),
1303 extent!(replica = 2, host = 2).into(),
1304 vec![
1305 "tcp:127.0.0.1:123".parse().unwrap(),
1306 "tcp:127.0.0.1:123".parse().unwrap(),
1307 "tcp:127.0.0.1:123".parse().unwrap(),
1308 "tcp:127.0.0.1:123".parse().unwrap(),
1309 ],
1310 )
1311 .unwrap();
1312
1313 assert_eq!(
1314 host_mesh_ref.to_string().parse::<HostMeshRef>().unwrap(),
1315 host_mesh_ref
1316 );
1317 }
1318
1319 #[tokio::test]
1320 #[cfg(fbcode_build)]
1321 async fn test_allocate() {
1322 let config = hyperactor::config::global::lock();
1323 let _guard = config.override_key(crate::bootstrap::MESH_BOOTSTRAP_ENABLE_PDEATHSIG, false);
1324
1325 let instance = testing::instance().await;
1326
1327 for alloc in testing::allocs(extent!(replicas = 4)).await {
1328 let host_mesh = HostMesh::allocate(instance, alloc, "test", None)
1329 .await
1330 .unwrap();
1331
1332 let proc_mesh1 = host_mesh
1333 .spawn(instance, "test_1", Extent::unity())
1334 .await
1335 .unwrap();
1336
1337 let actor_mesh1: ActorMesh<testactor::TestActor> =
1338 proc_mesh1.spawn(instance, "test", &()).await.unwrap();
1339
1340 let proc_mesh2 = host_mesh
1341 .spawn(instance, "test_2", extent!(gpus = 3, extra = 2))
1342 .await
1343 .unwrap();
1344 assert_eq!(
1345 proc_mesh2.extent(),
1346 extent!(replicas = 4, gpus = 3, extra = 2)
1347 );
1348 assert_eq!(proc_mesh2.values().count(), 24);
1349
1350 let actor_mesh2: ActorMesh<testactor::TestActor> =
1351 proc_mesh2.spawn(instance, "test", &()).await.unwrap();
1352 assert_eq!(
1353 actor_mesh2.extent(),
1354 extent!(replicas = 4, gpus = 3, extra = 2)
1355 );
1356 assert_eq!(actor_mesh2.values().count(), 24);
1357
1358 let host_mesh_ref: HostMeshRef = host_mesh.clone();
1360 assert_eq!(
1362 host_mesh_ref.iter().collect::<Vec<_>>(),
1363 host_mesh.iter().collect::<Vec<_>>(),
1364 );
1365
1366 for actor_mesh in [&actor_mesh1, &actor_mesh2] {
1368 let (port, mut rx) = instance.mailbox().open_port();
1369 actor_mesh
1370 .cast(instance, testactor::GetActorId(port.bind()))
1371 .unwrap();
1372
1373 let mut expected_actor_ids: HashSet<_> = actor_mesh
1374 .values()
1375 .map(|actor_ref| actor_ref.actor_id().clone())
1376 .collect();
1377
1378 while !expected_actor_ids.is_empty() {
1379 let actor_id = rx.recv().await.unwrap();
1380 assert!(
1381 expected_actor_ids.remove(&actor_id),
1382 "got {actor_id}, expect {expected_actor_ids:?}"
1383 );
1384 }
1385 }
1386
1387 let mut to_visit: VecDeque<_> = actor_mesh1
1391 .values()
1392 .chain(actor_mesh2.values())
1393 .map(|actor_ref| actor_ref.port())
1394 .permutations(2)
1396 .flatten()
1398 .collect();
1399
1400 let expect_visited: Vec<_> = to_visit.clone().into();
1401
1402 let (last, mut last_rx) = instance.mailbox().open_port();
1404 to_visit.push_back(last.bind());
1405
1406 let forward = testactor::Forward {
1407 to_visit,
1408 visited: Vec::new(),
1409 };
1410 let first = forward.to_visit.front().unwrap().clone();
1411 first.send(instance, forward).unwrap();
1412
1413 let forward = last_rx.recv().await.unwrap();
1414 assert_eq!(forward.visited, expect_visited);
1415
1416 let _ = host_mesh.shutdown(&instance).await;
1417 }
1418 }
1419
1420 fn free_localhost_addr() -> ChannelAddr {
1426 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
1427 ChannelAddr::Tcp(listener.local_addr().unwrap())
1428 }
1429
1430 #[tokio::test]
1431 #[cfg(fbcode_build)]
1432 async fn test_extrinsic_allocation() {
1433 let config = hyperactor::config::global::lock();
1434 let _guard = config.override_key(crate::bootstrap::MESH_BOOTSTRAP_ENABLE_PDEATHSIG, false);
1435
1436 let program = crate::testresource::get("monarch/hyperactor_mesh/bootstrap");
1437
1438 let hosts = vec![free_localhost_addr(), free_localhost_addr()];
1439
1440 let mut children = Vec::new();
1441 for host in hosts.iter() {
1442 let mut cmd = Command::new(program.clone());
1443 let boot = Bootstrap::Host {
1444 addr: host.clone(),
1445 command: None, config: None,
1447 };
1448 boot.to_env(&mut cmd);
1449 cmd.kill_on_drop(true);
1450 children.push(cmd.spawn().unwrap());
1451 }
1452
1453 let instance = testing::instance().await;
1454 let host_mesh = HostMeshRef::from_hosts(Name::new("test"), hosts);
1455
1456 let proc_mesh = host_mesh
1457 .spawn(&testing::instance().await, "test", Extent::unity())
1458 .await
1459 .unwrap();
1460
1461 let actor_mesh: ActorMesh<testactor::TestActor> = proc_mesh
1462 .spawn(&testing::instance().await, "test", &())
1463 .await
1464 .unwrap();
1465
1466 testactor::assert_mesh_shape(actor_mesh).await;
1467
1468 HostMesh::take(host_mesh)
1469 .shutdown(&instance)
1470 .await
1471 .expect("hosts shutdown");
1472 }
1473
1474 #[tokio::test]
1475 #[cfg(fbcode_build)]
1476 async fn test_failing_proc_allocation() {
1477 let lock = hyperactor::config::global::lock();
1478 let _guard = lock.override_key(MESH_TAIL_LOG_LINES, 100);
1479
1480 let program = crate::testresource::get("monarch/hyperactor_mesh/bootstrap");
1481
1482 let hosts = vec![free_localhost_addr(), free_localhost_addr()];
1483
1484 let mut children = Vec::new();
1485 for host in hosts.iter() {
1486 let mut cmd = Command::new(program.clone());
1487 let boot = Bootstrap::Host {
1488 addr: host.clone(),
1489 config: None,
1490 command: Some(BootstrapCommand::from("false")),
1492 };
1493 boot.to_env(&mut cmd);
1494 cmd.kill_on_drop(true);
1495 children.push(cmd.spawn().unwrap());
1496 }
1497 let host_mesh = HostMeshRef::from_hosts(Name::new("test"), hosts);
1498
1499 let instance = testing::instance().await;
1500
1501 let err = host_mesh
1502 .spawn(&instance, "test", Extent::unity())
1503 .await
1504 .unwrap_err();
1505 assert_matches!(
1506 err, v1::Error::ProcCreationError { state: resource::State { status: resource::Status::Failed(msg), ..}, .. }
1507 if msg.contains("failed to configure process: Terminal(Stopped { exit_code: 1")
1508 );
1509 }
1510
1511 #[tokio::test]
1512 #[cfg(fbcode_build)]
1513 async fn test_halting_proc_allocation() {
1514 let config = config::global::lock();
1515 let _guard1 = config.override_key(PROC_SPAWN_MAX_IDLE, Duration::from_secs(5));
1516
1517 let program = crate::testresource::get("monarch/hyperactor_mesh/bootstrap");
1518
1519 let hosts = vec![free_localhost_addr(), free_localhost_addr()];
1520
1521 let mut children = Vec::new();
1522
1523 for (index, host) in hosts.iter().enumerate() {
1524 let mut cmd = Command::new(program.clone());
1525 let command = if index == 0 {
1526 let mut command = BootstrapCommand::from("sleep");
1527 command.args.push("60".to_string());
1528 Some(command)
1529 } else {
1530 None
1531 };
1532 let boot = Bootstrap::Host {
1533 addr: host.clone(),
1534 config: None,
1535 command,
1536 };
1537 boot.to_env(&mut cmd);
1538 cmd.kill_on_drop(true);
1539 children.push(cmd.spawn().unwrap());
1540 }
1541 let host_mesh = HostMeshRef::from_hosts(Name::new("test"), hosts);
1542
1543 let instance = testing::instance().await;
1544
1545 let err = host_mesh
1546 .spawn(&instance, "test", Extent::unity())
1547 .await
1548 .unwrap_err();
1549 let statuses = err.into_proc_spawn_error().unwrap();
1550 assert_matches!(
1551 &statuses.materialized_iter(2).cloned().collect::<Vec<_>>()[..],
1552 &[Status::Timeout(_), Status::Running]
1553 );
1554 }
1555
1556 #[tokio::test]
1557 #[cfg(fbcode_build)]
1558 async fn test_client_config_override() {
1559 let config = hyperactor::config::global::lock();
1560 let _guard1 = config.override_key(crate::bootstrap::MESH_BOOTSTRAP_ENABLE_PDEATHSIG, false);
1561 let _guard2 =
1562 config.override_key(config::HOST_SPAWN_READY_TIMEOUT, Duration::from_secs(120));
1563 let _guard3 =
1564 config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(60));
1565
1566 let instance = testing::instance().await;
1567
1568 let proc_meshes = testing::proc_meshes(instance, extent!(replicas = 2)).await;
1569 let proc_mesh = proc_meshes.get(1).unwrap();
1570
1571 let actor_mesh: ActorMesh<testactor::TestActor> =
1572 proc_mesh.spawn(instance, "test", &()).await.unwrap();
1573
1574 let mut attrs_override = Attrs::new();
1575 attrs_override.set(config::HOST_SPAWN_READY_TIMEOUT, Duration::from_secs(180));
1576 actor_mesh
1577 .cast(
1578 instance,
1579 SetConfigAttrs(bincode::serialize(&attrs_override).unwrap()),
1580 )
1581 .unwrap();
1582
1583 let (tx, mut rx) = instance.open_port();
1584 actor_mesh
1585 .cast(instance, GetConfigAttrs(tx.bind()))
1586 .unwrap();
1587 let actual_attrs = rx.recv().await.unwrap();
1588 let actual_attrs = bincode::deserialize::<Attrs>(&actual_attrs).unwrap();
1589
1590 assert_eq!(
1591 *actual_attrs.get(config::HOST_SPAWN_READY_TIMEOUT).unwrap(),
1592 Duration::from_secs(180)
1593 );
1594 assert_eq!(
1595 *actual_attrs.get(config::MESSAGE_DELIVERY_TIMEOUT).unwrap(),
1596 Duration::from_secs(60)
1597 );
1598 }
1599}