1use std::any::type_name;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::fmt;
13use std::hash::Hash;
14use std::ops::Deref;
15use std::panic::Location;
16use std::sync::Arc;
17use std::sync::OnceLock;
18use std::sync::atomic::AtomicUsize;
19use std::sync::atomic::Ordering;
20use std::time::Duration;
21
22use hyperactor::Actor;
23use hyperactor::Handler;
24use hyperactor::RemoteMessage;
25use hyperactor::RemoteSpawn;
26use hyperactor::accum::StreamingReducerOpts;
27use hyperactor::actor::ActorStatus;
28use hyperactor::actor::Referable;
29use hyperactor::actor::remote::Remote;
30use hyperactor::channel;
31use hyperactor::channel::ChannelAddr;
32use hyperactor::context;
33use hyperactor::mailbox::DialMailboxRouter;
34use hyperactor::mailbox::MailboxServer;
35use hyperactor::reference as hyperactor_reference;
36use hyperactor::supervision::ActorSupervisionEvent;
37use hyperactor_config::CONFIG;
38use hyperactor_config::ConfigAttr;
39use hyperactor_config::attrs::declare_attrs;
40use ndslice::Extent;
41use ndslice::ViewExt as _;
42use ndslice::view;
43use ndslice::view::CollectMeshExt;
44use ndslice::view::MapIntoExt;
45use ndslice::view::Ranked;
46use ndslice::view::Region;
47use serde::Deserialize;
48use serde::Serialize;
49use tokio::sync::Notify;
50use tracing::Instrument;
51use typeuri::Named;
52
53use crate::ActorMesh;
54use crate::ActorMeshRef;
55use crate::CommActor;
56use crate::Error;
57use crate::HostMeshRef;
58use crate::Name;
59use crate::ValueMesh;
60use crate::alloc::Alloc;
61use crate::alloc::AllocExt;
62use crate::alloc::AllocatedProc;
63use crate::assign::Ranks;
64use crate::comm::CommMeshConfig;
65use crate::host_mesh::host_agent::ProcState;
66use crate::host_mesh::mesh_to_rankedvalues_with_default;
67use crate::mesh_controller::ActorMeshController;
68use crate::proc_agent;
69use crate::proc_agent::ActorState;
70use crate::proc_agent::MeshAgentMessageClient;
71use crate::proc_agent::ProcAgent;
72use crate::proc_agent::ReconfigurableMailboxSender;
73use crate::resource;
74use crate::resource::GetRankStatus;
75use crate::resource::Status;
76use crate::supervision::MeshFailure;
77
78declare_attrs! {
79 @meta(CONFIG = ConfigAttr::new(
82 Some("HYPERACTOR_MESH_ACTOR_SPAWN_MAX_IDLE".to_string()),
83 Some("actor_spawn_max_idle".to_string()),
84 ))
85 pub attr ACTOR_SPAWN_MAX_IDLE: Duration = Duration::from_secs(30);
86
87 @meta(CONFIG = ConfigAttr::new(
90 Some("HYPERACTOR_MESH_GET_ACTOR_STATE_MAX_IDLE".to_string()),
91 Some("get_actor_state_max_idle".to_string()),
92 ))
93 pub attr GET_ACTOR_STATE_MAX_IDLE: Duration = Duration::from_secs(30);
94}
95
96pub const COMM_ACTOR_NAME: &str = "comm";
101
102#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
104pub struct ProcRef {
105 proc_id: hyperactor_reference::ProcId,
106 create_rank: usize,
108 agent: hyperactor_reference::ActorRef<ProcAgent>,
110}
111
112impl ProcRef {
113 pub fn new(
115 proc_id: hyperactor_reference::ProcId,
116 create_rank: usize,
117 agent: hyperactor_reference::ActorRef<ProcAgent>,
118 ) -> Self {
119 Self {
120 proc_id,
121 create_rank,
122 agent,
123 }
124 }
125
126 pub(crate) async fn status(&self, cx: &impl context::Actor) -> crate::Result<bool> {
129 let (port, mut rx) = cx.mailbox().open_port();
130 self.agent
131 .status(cx, port.bind())
132 .await
133 .map_err(|e| Error::CallError(self.agent.actor_id().clone(), e))?;
134 loop {
135 let (rank, status) = rx
136 .recv()
137 .await
138 .map_err(|e| Error::CallError(self.agent.actor_id().clone(), e.into()))?;
139 if rank == self.create_rank {
140 break Ok(status);
141 }
142 }
143 }
144
145 pub fn proc_id(&self) -> &hyperactor_reference::ProcId {
146 &self.proc_id
147 }
148
149 pub(crate) fn actor_id(&self, name: &Name) -> hyperactor_reference::ActorId {
150 self.proc_id.actor_id(name.to_string(), 0)
151 }
152
153 pub(crate) fn attest<A: Referable>(&self, name: &Name) -> hyperactor_reference::ActorRef<A> {
156 hyperactor_reference::ActorRef::attest(self.actor_id(name))
157 }
158}
159
160#[derive(Debug)]
162pub struct ProcMesh {
163 #[allow(dead_code)]
164 name: Name,
165 allocation: ProcMeshAllocation,
166 #[allow(dead_code)]
167 comm_actor_name: Option<Name>,
168 current_ref: ProcMeshRef,
169}
170
171impl ProcMesh {
172 async fn create<C: context::Actor>(
173 cx: &C,
174 name: Name,
175 allocation: ProcMeshAllocation,
176 spawn_comm_actor: bool,
177 ) -> crate::Result<Self>
178 where
179 C::A: Handler<MeshFailure>,
180 {
181 let comm_actor_name = if spawn_comm_actor {
182 Some(Name::new(COMM_ACTOR_NAME).unwrap())
183 } else {
184 None
185 };
186
187 let region = allocation.extent().clone().into();
188 let ranks = allocation.ranks();
189
190 if let Some(first) = ranks.first() {
194 crate::global_context::set_global_supervision_sink(
195 first.agent.port::<ActorSupervisionEvent>(),
196 );
197 }
198
199 let root_comm_actor = comm_actor_name.as_ref().map(|name| {
200 hyperactor_reference::ActorRef::attest(
201 ranks
202 .first()
203 .expect("root mesh cannot be empty")
204 .actor_id(name),
205 )
206 });
207 let host_mesh = allocation.hosts();
208 let current_ref = ProcMeshRef::new(
209 name.clone(),
210 region,
211 ranks,
212 host_mesh.cloned(),
213 None, None, )
216 .unwrap();
217
218 {
220 let name_str = name.to_string();
221 let mesh_id_hash = hyperactor_telemetry::hash_to_u64(&name_str);
222
223 let (parent_mesh_id, parent_view_json) = match host_mesh {
224 Some(hm) => (
225 Some(hyperactor_telemetry::hash_to_u64(&hm.name().to_string())),
226 serde_json::to_string(hm.region()).ok(),
227 ),
228 None => (None, None),
229 };
230
231 hyperactor_telemetry::notify_mesh_created(hyperactor_telemetry::MeshEvent {
232 id: mesh_id_hash,
233 timestamp: std::time::SystemTime::now(),
234 class: "Proc".to_string(),
235 given_name: name.name().to_string(),
236 full_name: name_str,
237 shape_json: serde_json::to_string(¤t_ref.region.extent()).unwrap_or_default(),
238 parent_mesh_id,
239 parent_view_json,
240 });
241
242 let now = std::time::SystemTime::now();
245 for rank in current_ref.ranks.iter() {
246 let actor_id = rank.agent.actor_id();
247
248 hyperactor_telemetry::notify_actor_created(hyperactor_telemetry::ActorEvent {
249 id: hyperactor_telemetry::hash_to_u64(actor_id),
250 timestamp: now,
251 mesh_id: mesh_id_hash,
252 rank: rank.create_rank as u64,
253 full_name: actor_id.to_string(),
254 display_name: None,
255 });
256 }
257 }
258
259 let mut proc_mesh = Self {
260 name,
261 allocation,
262 comm_actor_name: comm_actor_name.clone(),
263 current_ref,
264 };
265
266 if let Some(comm_actor_name) = comm_actor_name {
267 let comm_actor_mesh: ActorMesh<CommActor> = proc_mesh
271 .spawn_with_name(cx, comm_actor_name, &Default::default(), None, true)
272 .await?;
273 let address_book: HashMap<_, _> = comm_actor_mesh
274 .iter()
275 .map(|(point, actor_ref)| (point.rank(), actor_ref))
276 .collect();
277 for (rank, comm_actor) in &address_book {
280 comm_actor
281 .send(cx, CommMeshConfig::new(*rank, address_book.clone()))
282 .map_err(|e| Error::SendingError(comm_actor.actor_id().clone(), Box::new(e)))?
283 }
284
285 proc_mesh.current_ref.root_comm_actor = root_comm_actor;
287 }
288
289 Ok(proc_mesh)
290 }
291
292 pub(crate) async fn create_owned_unchecked<C: context::Actor>(
293 cx: &C,
294 name: Name,
295 extent: Extent,
296 hosts: HostMeshRef,
297 ranks: Vec<ProcRef>,
298 ) -> crate::Result<Self>
299 where
300 C::A: Handler<MeshFailure>,
301 {
302 Self::create(
303 cx,
304 name,
305 ProcMeshAllocation::Owned {
306 hosts,
307 extent,
308 ranks: Arc::new(ranks),
309 },
310 true,
311 )
312 .await
313 }
314
315 fn alloc_counter() -> &'static AtomicUsize {
316 static C: OnceLock<AtomicUsize> = OnceLock::new();
317 C.get_or_init(|| AtomicUsize::new(0))
318 }
319
320 #[track_caller]
323 pub async fn allocate<C: context::Actor>(
324 cx: &C,
325 alloc: Box<dyn Alloc + Send + Sync + 'static>,
326 name: &str,
327 ) -> crate::Result<Self>
328 where
329 C::A: Handler<MeshFailure>,
330 {
331 let caller = Location::caller();
332 Self::allocate_inner(cx, alloc, Name::new(name)?, caller).await
333 }
334
335 #[hyperactor::instrument(fields(proc_mesh=name.to_string()))]
337 async fn allocate_inner<C: context::Actor>(
338 cx: &C,
339 mut alloc: Box<dyn Alloc + Send + Sync + 'static>,
340 name: Name,
341 caller: &'static Location<'static>,
342 ) -> crate::Result<Self>
343 where
344 C::A: Handler<MeshFailure>,
345 {
346 let alloc_id = Self::alloc_counter().fetch_add(1, Ordering::Relaxed) + 1;
347 tracing::info!(
348 name = "ProcMeshStatus",
349 status = "Allocate::Attempt",
350 %caller,
351 alloc_id,
352 shape = ?alloc.shape(),
353 "allocating proc mesh"
354 );
355
356 let running = alloc
357 .initialize()
358 .instrument(tracing::info_span!(
359 "ProcMeshStatus::Allocate::Initialize",
360 alloc_id,
361 proc_mesh = %name
362 ))
363 .await?;
364
365 let proc = cx.instance().proc();
371
372 let proc_channel_addr = {
374 let _guard =
375 tracing::info_span!("allocate_serve_proc", proc_id = %proc.proc_id()).entered();
376 let (addr, rx) = channel::serve(ChannelAddr::any(alloc.transport()))?;
377 proc.clone().serve(rx);
378 tracing::info!(
379 name = "ProcMeshStatus",
380 status = "Allocate::ChannelServe",
381 proc_mesh = %name,
382 %addr,
383 "proc started listening on addr: {addr}"
384 );
385 addr
386 };
387
388 let bind_allocated_procs = |router: &DialMailboxRouter| {
389 for AllocatedProc { proc_id, addr, .. } in running.iter() {
396 if proc_id.addr() != addr {
397 router.bind(proc_id.clone().into(), addr.clone());
398 }
399 }
400 };
401
402 if let Some(router) = proc.forwarder().downcast_ref() {
407 bind_allocated_procs(router);
408 } else if let Some(router) = proc
409 .forwarder()
410 .downcast_ref::<ReconfigurableMailboxSender>()
411 {
412 bind_allocated_procs(
413 router
414 .as_inner()
415 .map_err(|_| Error::UnroutableMesh())?
416 .as_configured()
417 .ok_or(Error::UnroutableMesh())?
418 .downcast_ref()
419 .ok_or(Error::UnroutableMesh())?,
420 );
421 } else {
422 return Err(Error::UnroutableMesh());
423 }
424
425 let address_book: HashMap<_, _> = running
428 .iter()
429 .map(
430 |AllocatedProc {
431 addr, mesh_agent, ..
432 }| { (mesh_agent.actor_id().proc_id().clone(), addr.clone()) },
433 )
434 .collect();
435
436 let (config_handle, mut config_receiver) = cx.mailbox().open_port();
437 for (rank, AllocatedProc { mesh_agent, .. }) in running.iter().enumerate() {
438 mesh_agent
439 .configure(
440 cx,
441 rank,
442 proc_channel_addr.clone(),
443 None, address_book.clone(),
445 config_handle.bind(),
446 true,
447 )
448 .await
449 .map_err(Error::ConfigurationError)?;
450 }
451 let mut completed = Ranks::new(running.len());
452 while !completed.is_full() {
453 let rank = config_receiver
454 .recv()
455 .await
456 .map_err(|err| Error::ConfigurationError(err.into()))?;
457 if completed.insert(rank, rank).is_some() {
458 tracing::warn!("multiple completions received for rank {}", rank);
459 }
460 }
461
462 let ranks: Vec<_> = running
463 .into_iter()
464 .enumerate()
465 .map(|(create_rank, allocated)| ProcRef {
466 proc_id: allocated.proc_id,
467 create_rank,
468 agent: allocated.mesh_agent,
469 })
470 .collect();
471
472 let stop = Arc::new(Notify::new());
473 let extent = alloc.extent().clone();
474 let alloc_name = alloc.alloc_name().to_string();
475
476 let alloc_task = {
477 let stop = Arc::clone(&stop);
478
479 tokio::spawn(
480 async move {
481 loop {
482 tokio::select! {
483 _ = stop.notified() => {
484 if let Err(error) = alloc.stop_and_wait().await {
486 tracing::error!(
487 name = "ProcMeshStatus",
488 alloc_name = %alloc.alloc_name(),
489 status = "FailedToStopAlloc",
490 %error,
491 );
492 }
493 break;
494 }
495 proc_state = alloc.next() => {
497 match proc_state {
498 None => break,
500 Some(proc_state) => {
501 tracing::debug!(
502 alloc_name = %alloc.alloc_name(),
503 "unmonitored allocation event: {}", proc_state);
504 }
505 }
506
507 }
508 }
509 }
510 }
511 .instrument(tracing::info_span!("alloc_monitor")),
512 )
513 };
514
515 let mesh = Self::create(
516 cx,
517 name,
518 ProcMeshAllocation::Allocated {
519 alloc_name,
520 stop,
521 extent,
522 ranks: Arc::new(ranks),
523 alloc_task: Some(alloc_task),
524 },
525 true, )
527 .await;
528 match &mesh {
529 Ok(_) => tracing::info!(name = "ProcMeshStatus", status = "Allocate::Created"),
530 Err(error) => {
531 tracing::info!(name = "ProcMeshStatus", status = "Allocate::Failed", %error)
532 }
533 }
534 mesh
535 }
536
537 pub async fn stop(&mut self, cx: &impl context::Actor, reason: String) -> anyhow::Result<()> {
539 let region = self.region.clone();
540 match &mut self.allocation {
541 ProcMeshAllocation::Allocated {
542 stop,
543 alloc_task,
544 alloc_name,
545 ..
546 } => {
547 stop.notify_one();
548 if let Some(handle) = alloc_task.take() {
551 if let Err(e) = handle.await {
552 tracing::warn!(
553 name = "ProcMeshStatus",
554 proc_mesh = %self.name,
555 alloc_name,
556 %e,
557 "alloc monitor task failed"
558 );
559 }
560 }
561 tracing::info!(
562 name = "ProcMeshStatus",
563 proc_mesh = %self.name,
564 alloc_name,
565 status = "StoppingAlloc",
566 "alloc {alloc_name} has stopped",
567 );
568
569 Ok(())
570 }
571 ProcMeshAllocation::Owned { hosts, .. } => {
572 let procs = self
573 .current_ref
574 .proc_ids()
575 .collect::<Vec<hyperactor_reference::ProcId>>();
576 hosts
579 .stop_proc_mesh(cx, &self.name, procs, region, reason)
580 .await
581 }
582 }
583 }
584
585 #[cfg(test)]
586 pub(crate) fn ranks(&self) -> Arc<Vec<ProcRef>> {
587 self.allocation.ranks()
588 }
589}
590
591impl fmt::Display for ProcMesh {
592 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
593 write!(f, "{}", self.current_ref)
594 }
595}
596
597impl Deref for ProcMesh {
598 type Target = ProcMeshRef;
599
600 fn deref(&self) -> &Self::Target {
601 &self.current_ref
602 }
603}
604
605impl Drop for ProcMesh {
606 fn drop(&mut self) {
607 tracing::info!(
608 name = "ProcMeshStatus",
609 proc_mesh = %self.name,
610 status = "Dropped",
611 );
612 }
613}
614
615enum ProcMeshAllocation {
617 Allocated {
619 alloc_name: String,
621
622 stop: Arc<Notify>,
624
625 extent: Extent,
626
627 ranks: Arc<Vec<ProcRef>>,
629
630 alloc_task: Option<tokio::task::JoinHandle<()>>,
632 },
633
634 Owned {
636 hosts: HostMeshRef,
638 extent: Extent,
641 ranks: Arc<Vec<ProcRef>>,
643 },
644}
645
646impl ProcMeshAllocation {
647 fn extent(&self) -> &Extent {
648 match self {
649 ProcMeshAllocation::Allocated { extent, .. } => extent,
650 ProcMeshAllocation::Owned { extent, .. } => extent,
651 }
652 }
653
654 fn ranks(&self) -> Arc<Vec<ProcRef>> {
655 Arc::clone(match self {
656 ProcMeshAllocation::Allocated { ranks, .. } => ranks,
657 ProcMeshAllocation::Owned { ranks, .. } => ranks,
658 })
659 }
660
661 fn hosts(&self) -> Option<&HostMeshRef> {
662 match self {
663 ProcMeshAllocation::Allocated { .. } => None,
664 ProcMeshAllocation::Owned { hosts, .. } => Some(hosts),
665 }
666 }
667}
668
669impl fmt::Debug for ProcMeshAllocation {
670 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
671 match self {
672 ProcMeshAllocation::Allocated { ranks, .. } => f
673 .debug_struct("ProcMeshAllocation::Allocated")
674 .field("alloc", &"<dyn Alloc>")
675 .field("ranks", ranks)
676 .finish(),
677 ProcMeshAllocation::Owned {
678 hosts,
679 ranks,
680 extent: _,
681 } => f
682 .debug_struct("ProcMeshAllocation::Owned")
683 .field("hosts", hosts)
684 .field("ranks", ranks)
685 .finish(),
686 }
687 }
688}
689
690#[derive(Debug, Clone, PartialEq, Eq, Hash, Named, Serialize, Deserialize)]
697pub struct ProcMeshRef {
698 name: Name,
699 region: Region,
700 ranks: Arc<Vec<ProcRef>>,
701 host_mesh: Option<HostMeshRef>,
703 pub(crate) root_region: Option<Region>,
707 pub(crate) root_comm_actor: Option<hyperactor_reference::ActorRef<CommActor>>,
712}
713wirevalue::register_type!(ProcMeshRef);
714
715impl ProcMeshRef {
716 #[allow(clippy::result_large_err)]
718 fn new(
719 name: Name,
720 region: Region,
721 ranks: Arc<Vec<ProcRef>>,
722 host_mesh: Option<HostMeshRef>,
723 root_region: Option<Region>,
724 root_comm_actor: Option<hyperactor_reference::ActorRef<CommActor>>,
725 ) -> crate::Result<Self> {
726 if region.num_ranks() != ranks.len() {
727 return Err(crate::Error::InvalidRankCardinality {
728 expected: region.num_ranks(),
729 actual: ranks.len(),
730 });
731 }
732 Ok(Self {
733 name,
734 region,
735 ranks,
736 host_mesh,
737 root_region,
738 root_comm_actor,
739 })
740 }
741
742 pub fn new_singleton(name: Name, proc_ref: ProcRef) -> Self {
746 Self {
747 name,
748 region: Extent::unity().into(),
749 ranks: Arc::new(vec![proc_ref]),
750 host_mesh: None,
751 root_region: None,
752 root_comm_actor: None,
753 }
754 }
755
756 pub(crate) fn root_comm_actor(&self) -> Option<&hyperactor_reference::ActorRef<CommActor>> {
757 self.root_comm_actor.as_ref()
758 }
759
760 pub fn name(&self) -> &Name {
761 &self.name
762 }
763
764 pub fn host_mesh_name(&self) -> Option<&Name> {
765 self.host_mesh.as_ref().map(|h| h.name())
766 }
767
768 pub fn hosts(&self) -> Option<&HostMeshRef> {
771 self.host_mesh.as_ref()
772 }
773
774 pub async fn status(&self, cx: &impl context::Actor) -> crate::Result<ValueMesh<bool>> {
776 let vm: ValueMesh<_> = self.map_into(|proc_ref| {
777 let proc_ref = proc_ref.clone();
778 async move { proc_ref.status(cx).await }
779 });
780 vm.join().await.transpose()
781 }
782
783 pub(crate) fn agent_mesh(&self) -> ActorMeshRef<ProcAgent> {
784 let agent_name = self.ranks.first().unwrap().agent.actor_id().name();
785 ActorMeshRef::new(Name::new_reserved(agent_name).unwrap(), self.clone(), None)
789 }
790
791 pub async fn actor_states(
793 &self,
794 cx: &impl context::Actor,
795 name: Name,
796 ) -> crate::Result<ValueMesh<resource::State<ActorState>>> {
797 self.actor_states_with_keepalive(cx, name, None).await
798 }
799
800 pub(crate) async fn actor_states_with_keepalive(
806 &self,
807 cx: &impl context::Actor,
808 name: Name,
809 keepalive: Option<std::time::SystemTime>,
810 ) -> crate::Result<ValueMesh<resource::State<ActorState>>> {
811 let agent_mesh = self.agent_mesh();
812 let (port, mut rx) = cx.mailbox().open_port::<resource::State<ActorState>>();
813 let get_state = resource::GetState::<ActorState> {
816 name: name.clone(),
817 reply: port.bind(),
818 };
819 if let Some(expires_after) = keepalive {
820 agent_mesh.cast(
821 cx,
822 resource::KeepaliveGetState {
823 expires_after,
824 get_state,
825 },
826 )?;
827 } else {
828 agent_mesh.cast(cx, get_state)?;
829 }
830 let expected = self.ranks.len();
831 let mut states = Vec::with_capacity(expected);
832 let timeout = hyperactor_config::global::get(GET_ACTOR_STATE_MAX_IDLE);
833 for _ in 0..expected {
834 let state = tokio::time::timeout(timeout, rx.recv()).await;
840 if let Ok(state) = state {
841 let state = state?;
843 match state.state {
844 Some(ref inner) => {
845 states.push((inner.create_rank, state));
846 }
847 None => {
848 return Err(Error::NotExist(state.name));
849 }
850 }
851 } else {
852 tracing::error!(
853 "timeout waiting for a message after {:?} from proc mesh agent in mesh {}",
854 timeout,
855 agent_mesh
856 );
857 let all_ranks = (0..self.ranks.len()).collect::<HashSet<_>>();
860 let completed_ranks = states.iter().map(|(rank, _)| *rank).collect::<HashSet<_>>();
861 let mut leftover_ranks = all_ranks.difference(&completed_ranks).collect::<Vec<_>>();
862 assert_eq!(leftover_ranks.len(), expected - states.len());
863 while states.len() < expected {
864 let rank = *leftover_ranks
865 .pop()
866 .expect("leftover ranks should not be empty");
867 let agent = agent_mesh.get(rank).expect("agent should exist");
868 let agent_id = agent.actor_id().clone();
869 states.push((
870 rank,
872 resource::State {
873 name: name.clone(),
874 status: resource::Status::Timeout(timeout),
875 generation: u64::MAX,
880 timestamp: std::time::SystemTime::now(),
881 state: Some(ActorState {
882 actor_id: agent_id.clone(),
883 create_rank: rank,
884 supervision_events: vec![ActorSupervisionEvent::new(
885 agent_id,
886 None,
887 ActorStatus::generic_failure(format!(
888 "timeout waiting for message from proc mesh agent while querying for \"{}\". The process likely crashed",
889 name,
890 )),
891 None,
892 )],
893 }),
894 },
895 ));
896 }
897 break;
898 }
899 }
900 states.sort_by_key(|(rank, _)| *rank);
904 let vm = states
905 .into_iter()
906 .map(|(_, state)| state)
907 .collect_mesh::<ValueMesh<_>>(self.region.clone())?;
908 Ok(vm)
909 }
910
911 pub async fn proc_states(
912 &self,
913 cx: &impl context::Actor,
914 ) -> crate::Result<Option<ValueMesh<resource::State<ProcState>>>> {
915 let names = self
916 .proc_ids()
917 .collect::<Vec<hyperactor_reference::ProcId>>();
918 if let Some(host_mesh) = &self.host_mesh {
919 Ok(Some(
920 host_mesh
921 .proc_states(cx, names, self.region.clone())
922 .await?,
923 ))
924 } else {
925 Ok(None)
926 }
927 }
928
929 pub(crate) fn proc_ids(&self) -> impl Iterator<Item = hyperactor_reference::ProcId> {
931 self.ranks.iter().map(|proc_ref| proc_ref.proc_id.clone())
932 }
933
934 pub async fn spawn<A: RemoteSpawn, C: context::Actor>(
944 &self,
945 cx: &C,
946 name: &str,
947 params: &A::Params,
948 ) -> crate::Result<ActorMesh<A>>
949 where
950 A::Params: RemoteMessage,
951 C::A: Handler<MeshFailure>,
952 {
953 self.spawn_with_name(cx, Name::new(name)?, params, None, false)
955 .await
956 }
957
958 pub async fn spawn_service<A: RemoteSpawn, C: context::Actor>(
966 &self,
967 cx: &C,
968 name: &str,
969 params: &A::Params,
970 ) -> crate::Result<ActorMesh<A>>
971 where
972 A::Params: RemoteMessage,
973 C::A: Handler<MeshFailure>,
974 {
975 self.spawn_with_name(cx, Name::new_reserved(name)?, params, None, false)
976 .await
977 }
978
979 #[hyperactor::instrument(fields(
997 host_mesh=self.host_mesh_name().map(|n| n.to_string()),
998 proc_mesh=self.name.to_string(),
999 actor_name=name.to_string(),
1000 ))]
1001 pub async fn spawn_with_name<A: RemoteSpawn, C: context::Actor>(
1002 &self,
1003 cx: &C,
1004 name: Name,
1005 params: &A::Params,
1006 supervision_display_name: Option<String>,
1007 is_system_actor: bool,
1008 ) -> crate::Result<ActorMesh<A>>
1009 where
1010 A::Params: RemoteMessage,
1011 C::A: Handler<MeshFailure>,
1012 {
1013 tracing::info!(
1014 name = "ProcMeshStatus",
1015 status = "ActorMesh::Spawn::Attempt",
1016 );
1017 tracing::info!(name = "ActorMeshStatus", status = "Spawn::Attempt");
1018 let result = self
1019 .spawn_with_name_inner(cx, name, params, supervision_display_name, is_system_actor)
1020 .await;
1021 match &result {
1022 Ok(_) => {
1023 tracing::info!(
1024 name = "ProcMeshStatus",
1025 status = "ActorMesh::Spawn::Success",
1026 );
1027 tracing::info!(name = "ActorMeshStatus", status = "Spawn::Success");
1028 }
1029 Err(error) => {
1030 tracing::error!(name = "ProcMeshStatus", status = "ActorMesh::Spawn::Failed", %error);
1031 tracing::error!(name = "ActorMeshStatus", status = "Spawn::Failed", %error);
1032 }
1033 }
1034 result
1035 }
1036
1037 async fn spawn_with_name_inner<A: RemoteSpawn, C: context::Actor>(
1038 &self,
1039 cx: &C,
1040 name: Name,
1041 params: &A::Params,
1042 supervision_display_name: Option<String>,
1043 is_system_actor: bool,
1044 ) -> crate::Result<ActorMesh<A>>
1045 where
1046 C::A: Handler<MeshFailure>,
1047 {
1048 let remote = Remote::collect();
1049 let actor_type = remote
1053 .name_of::<A>()
1054 .ok_or(Error::ActorTypeNotRegistered(type_name::<A>().to_string()))?
1055 .to_string();
1056
1057 let serialized_params = bincode::serialize(params)?;
1058 let agent_mesh = self.agent_mesh();
1059
1060 agent_mesh.cast(
1061 cx,
1062 resource::CreateOrUpdate::<proc_agent::ActorSpec> {
1063 name: name.clone(),
1064 rank: Default::default(),
1065 spec: proc_agent::ActorSpec {
1066 actor_type: actor_type.clone(),
1067 params_data: serialized_params.clone(),
1068 },
1069 },
1070 )?;
1071
1072 let region = self.region().clone();
1073 let (port, rx) = cx.mailbox().open_accum_port_opts(
1083 crate::StatusMesh::from_single(region.clone(), Status::NotExist),
1086 StreamingReducerOpts {
1087 max_update_interval: Some(Duration::from_millis(50)),
1088 initial_update_interval: None,
1089 },
1090 );
1091
1092 let mut reply = port.bind();
1093 reply.return_undeliverable(false);
1096 agent_mesh.cast(
1099 cx,
1100 resource::GetRankStatus {
1101 name: name.clone(),
1102 reply,
1103 },
1104 )?;
1105
1106 let start_time = tokio::time::Instant::now();
1107
1108 let (statuses, mut mesh) = match GetRankStatus::wait(
1117 rx,
1118 self.ranks.len(),
1119 hyperactor_config::global::get(ACTOR_SPAWN_MAX_IDLE),
1120 region.clone(), )
1122 .await
1123 {
1124 Ok(statuses) => {
1125 let has_terminating = statuses.values().any(|s| s.is_terminating());
1129 if !has_terminating {
1130 Ok((statuses, ActorMesh::new(self.clone(), name.clone(), None)))
1131 } else {
1132 let legacy = mesh_to_rankedvalues_with_default(
1133 &statuses,
1134 Status::NotExist,
1135 Status::is_not_exist,
1136 self.ranks.len(),
1137 );
1138 Err(Error::ActorSpawnError { statuses: legacy })
1139 }
1140 }
1141 Err(complete) => {
1142 let elapsed = start_time.elapsed();
1145 let legacy = mesh_to_rankedvalues_with_default(
1146 &complete,
1147 Status::Timeout(elapsed),
1148 Status::is_not_exist,
1149 self.ranks.len(),
1150 );
1151 Err(Error::ActorSpawnError { statuses: legacy })
1152 }
1153 }?;
1154 if !is_system_actor {
1156 let controller: ActorMeshController<A> = ActorMeshController::new(
1159 mesh.deref().clone(),
1160 supervision_display_name.clone(),
1161 Some(cx.instance().port().bind()),
1162 statuses,
1163 );
1164 let controller_name = format!(
1169 "{}_{}",
1170 crate::mesh_controller::ACTOR_MESH_CONTROLLER_NAME,
1171 mesh.name()
1172 );
1173 let controller = controller
1174 .spawn_with_name(cx, &controller_name)
1175 .map_err(|e| Error::ControllerActorSpawnError(mesh.name().clone(), e))?;
1176 mesh.set_controller(Some(controller.bind()));
1179 }
1180 {
1182 let name_str = mesh.name().to_string();
1183
1184 let mesh_id_hash = hyperactor_telemetry::hash_to_u64(&name_str);
1187
1188 let parent_mesh_id_hash = hyperactor_telemetry::hash_to_u64(&self.name().to_string());
1190
1191 hyperactor_telemetry::notify_mesh_created(hyperactor_telemetry::MeshEvent {
1192 id: mesh_id_hash,
1193 timestamp: std::time::SystemTime::now(),
1194 class: supervision_display_name
1195 .as_deref()
1196 .and_then(python_class_from_supervision_name)
1197 .unwrap_or(actor_type),
1198 given_name: mesh.name().name().to_string(),
1199 full_name: name_str,
1200 shape_json: serde_json::to_string(&self.region().extent()).unwrap_or_default(),
1201 parent_mesh_id: Some(parent_mesh_id_hash),
1202 parent_view_json: serde_json::to_string(self.region()).ok(),
1203 });
1204
1205 let now = std::time::SystemTime::now();
1209 for (rank, proc_ref) in self.ranks.iter().enumerate() {
1210 let display_name = supervision_display_name.as_ref().map(|sdn| {
1211 let point = self.region().extent().point_of_rank(rank).unwrap();
1212 crate::actor_display_name(sdn, &point)
1213 });
1214 let actor_id = proc_ref.actor_id(&name);
1215 hyperactor_telemetry::notify_actor_created(hyperactor_telemetry::ActorEvent {
1216 id: hyperactor_telemetry::hash_to_u64(&actor_id),
1217 timestamp: now,
1218 mesh_id: mesh_id_hash,
1219 rank: rank as u64,
1220 full_name: actor_id.to_string(),
1221 display_name,
1222 });
1223 }
1224 }
1225
1226 Ok(mesh)
1227 }
1228
1229 #[hyperactor::instrument(fields(
1231 host_mesh = self.host_mesh_name().map(|n| n.to_string()),
1232 proc_mesh = self.name.to_string(),
1233 actor_mesh = mesh_name.to_string(),
1234 ))]
1235 pub(crate) async fn stop_actor_by_name(
1236 &self,
1237 cx: &impl context::Actor,
1238 mesh_name: Name,
1239 reason: String,
1240 ) -> crate::Result<ValueMesh<Status>> {
1241 tracing::info!(name = "ProcMeshStatus", status = "ActorMesh::Stop::Attempt");
1242 tracing::info!(name = "ActorMeshStatus", status = "Stop::Attempt");
1243 let result = self.stop_actor_by_name_inner(cx, mesh_name, reason).await;
1244 match &result {
1245 Ok(_) => {
1246 tracing::info!(name = "ProcMeshStatus", status = "ActorMesh::Stop::Success");
1247 tracing::info!(name = "ActorMeshStatus", status = "Stop::Success");
1248 }
1249 Err(error) => {
1250 tracing::error!(name = "ProcMeshStatus", status = "ActorMesh::Stop::Failed", %error);
1251 tracing::error!(name = "ActorMeshStatus", status = "Stop::Failed", %error);
1252 }
1253 }
1254 result
1255 }
1256
1257 async fn stop_actor_by_name_inner(
1258 &self,
1259 cx: &impl context::Actor,
1260 mesh_name: Name,
1261 reason: String,
1262 ) -> crate::Result<ValueMesh<Status>> {
1263 let region = self.region().clone();
1264 let agent_mesh = self.agent_mesh();
1265 agent_mesh.cast(
1266 cx,
1267 resource::Stop {
1268 name: mesh_name.clone(),
1269 reason,
1270 },
1271 )?;
1272
1273 let (port, rx) = cx.mailbox().open_accum_port_opts(
1283 crate::StatusMesh::from_single(region.clone(), Status::NotExist),
1286 StreamingReducerOpts {
1287 max_update_interval: Some(Duration::from_millis(50)),
1288 initial_update_interval: None,
1289 },
1290 );
1291 agent_mesh.cast(
1295 cx,
1296 resource::WaitRankStatus {
1297 name: mesh_name,
1298 min_status: Status::Stopped,
1299 reply: port.bind(),
1300 },
1301 )?;
1302 let start_time = tokio::time::Instant::now();
1303
1304 let max_idle_time = hyperactor_config::global::get(ACTOR_SPAWN_MAX_IDLE);
1306 match GetRankStatus::wait(
1307 rx,
1308 self.ranks.len(),
1309 max_idle_time,
1310 region.clone(), )
1312 .await
1313 {
1314 Ok(statuses) => {
1315 let all_stopped = statuses.values().all(|s| s.is_terminating());
1319 if all_stopped {
1320 Ok(statuses)
1321 } else {
1322 let legacy = mesh_to_rankedvalues_with_default(
1323 &statuses,
1324 Status::NotExist,
1325 Status::is_not_exist,
1326 self.ranks.len(),
1327 );
1328 Err(Error::ActorStopError { statuses: legacy })
1329 }
1330 }
1331 Err(complete) => {
1332 let legacy = mesh_to_rankedvalues_with_default(
1335 &complete,
1336 Status::Timeout(start_time.elapsed()),
1337 Status::is_not_exist,
1338 self.ranks.len(),
1339 );
1340 Err(Error::ActorStopError { statuses: legacy })
1341 }
1342 }
1343 }
1344}
1345
1346impl fmt::Display for ProcMeshRef {
1347 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1348 write!(f, "{}{{{}}}", self.name, self.region)
1349 }
1350}
1351
1352impl view::Ranked for ProcMeshRef {
1353 type Item = ProcRef;
1354
1355 fn region(&self) -> &Region {
1356 &self.region
1357 }
1358
1359 fn get(&self, rank: usize) -> Option<&Self::Item> {
1360 self.ranks.get(rank)
1361 }
1362}
1363
1364impl view::RankedSliceable for ProcMeshRef {
1365 fn sliced(&self, region: Region) -> Self {
1366 debug_assert!(region.is_subset(view::Ranked::region(self)));
1367 let ranks = self
1368 .region()
1369 .remap(®ion)
1370 .unwrap()
1371 .map(|index| self.get(index).unwrap().clone())
1372 .collect();
1373 Self::new(
1374 self.name.clone(),
1375 region,
1376 Arc::new(ranks),
1377 self.host_mesh.clone(),
1378 Some(self.root_region.as_ref().unwrap_or(&self.region).clone()),
1379 self.root_comm_actor.clone(),
1380 )
1381 .unwrap()
1382 }
1383}
1384
1385fn python_class_from_supervision_name(sdn: &str) -> Option<String> {
1390 let inner = sdn.rsplit_once('<')?.1.strip_suffix('>')?;
1391 let qualified = inner.split_whitespace().next()?;
1392 let class_name = qualified.rsplit_once('.')?.1;
1393 Some(format!("Python<{class_name}>"))
1394}
1395
1396#[cfg(test)]
1397mod tests {
1398 use hyperactor::Instance;
1399 use ndslice::ViewExt;
1400 use ndslice::extent;
1401 use timed_test::async_timed_test;
1402
1403 use crate::resource::RankedValues;
1404 use crate::resource::Status;
1405 use crate::testactor;
1406 use crate::testing;
1407
1408 #[tokio::test]
1409 async fn test_proc_mesh_allocate() {
1410 let (mesh, actor, _router) = testing::local_proc_mesh(extent!(replica = 4)).await;
1411 assert_eq!(mesh.extent(), extent!(replica = 4));
1412 assert_eq!(mesh.ranks.len(), 4);
1413
1414 for proc_ref in mesh.values() {
1416 assert!(proc_ref.status(&actor).await.unwrap());
1417 }
1418
1419 assert!(
1421 mesh.status(&actor)
1422 .await
1423 .unwrap()
1424 .values()
1425 .all(|status| status)
1426 );
1427 }
1428
1429 #[async_timed_test(timeout_secs = 30)]
1430 #[cfg(fbcode_build)]
1431 async fn test_spawn_actor() {
1432 hyperactor_telemetry::initialize_logging(hyperactor_telemetry::DefaultTelemetryClock {});
1433
1434 let instance = testing::instance();
1435
1436 let mut hm = testing::host_mesh(4).await;
1437 let proc_mesh = hm
1438 .spawn(&instance, "test", extent!(gpus = 2), None)
1439 .await
1440 .unwrap();
1441 let actor_mesh = proc_mesh.spawn(instance, "test", &()).await.unwrap();
1442 testactor::assert_mesh_shape(actor_mesh).await;
1443
1444 let _ = hm.shutdown(instance).await;
1445 }
1446
1447 #[tokio::test]
1448 #[cfg(fbcode_build)]
1449 async fn test_failing_spawn_actor() {
1450 hyperactor_telemetry::initialize_logging(hyperactor_telemetry::DefaultTelemetryClock {});
1451
1452 let instance = testing::instance();
1453
1454 let mut hm = testing::host_mesh(4).await;
1455 let proc_mesh = hm
1456 .spawn(&instance, "test", extent!(gpus = 2), None)
1457 .await
1458 .unwrap();
1459 let err = proc_mesh
1460 .spawn::<testactor::FailingCreateTestActor, Instance<testing::TestRootClient>>(
1461 instance,
1462 "testfail",
1463 &(),
1464 )
1465 .await
1466 .unwrap_err();
1467 let statuses = err.into_actor_spawn_error().unwrap();
1468 assert_eq!(
1469 statuses,
1470 RankedValues::from((0..8, Status::Failed("test failure".to_string()))),
1471 );
1472
1473 let _ = hm.shutdown(instance).await;
1474 }
1475
1476 #[test]
1477 fn test_python_class_from_supervision_name() {
1478 use super::python_class_from_supervision_name;
1479
1480 assert_eq!(
1481 python_class_from_supervision_name("instance0.<my_module.MyWorker test_mesh>"),
1482 Some("Python<MyWorker>".to_string()),
1483 );
1484 assert_eq!(
1485 python_class_from_supervision_name(
1486 "instance0.<package.submodule.TrainingActor mesh_0>"
1487 ),
1488 Some("Python<TrainingActor>".to_string()),
1489 );
1490 assert_eq!(python_class_from_supervision_name("plain_name"), None,);
1492 assert_eq!(
1494 python_class_from_supervision_name("instance0.<NoModule mesh>"),
1495 None,
1496 );
1497 }
1498}