1use std::collections::BTreeSet;
12use std::collections::HashMap;
13use std::collections::HashSet;
14use std::collections::hash_map::Entry;
15use std::fmt::Display;
16use std::fmt::Formatter;
17use std::hash::Hash;
18use std::sync::Arc;
19use std::sync::LazyLock;
20use std::time::SystemTime;
21
22use async_trait::async_trait;
23use dashmap::DashMap;
24use enum_as_inner::EnumAsInner;
25use hyperactor::Actor;
26use hyperactor::ActorHandle;
27use hyperactor::ActorId;
28use hyperactor::ActorRef;
29use hyperactor::Context;
30use hyperactor::HandleClient;
31use hyperactor::Instance;
32use hyperactor::Named;
33use hyperactor::OncePortRef;
34use hyperactor::PortHandle;
35use hyperactor::PortRef;
36use hyperactor::ProcId;
37use hyperactor::RefClient;
38use hyperactor::WorldId;
39use hyperactor::actor::Handler;
40use hyperactor::channel::ChannelAddr;
41use hyperactor::channel::sim::SimAddr;
42use hyperactor::clock::Clock;
43use hyperactor::clock::ClockKind;
44use hyperactor::id;
45use hyperactor::mailbox::BoxedMailboxSender;
46use hyperactor::mailbox::DialMailboxRouter;
47use hyperactor::mailbox::MailboxSender;
48use hyperactor::mailbox::MailboxSenderError;
49use hyperactor::mailbox::MessageEnvelope;
50use hyperactor::mailbox::PortSender;
51use hyperactor::mailbox::Undeliverable;
52use hyperactor::mailbox::mailbox_admin_message::MailboxAdminMessage;
53use hyperactor::mailbox::monitored_return_handle;
54use hyperactor::proc::Proc;
55use hyperactor::reference::Index;
56use serde::Deserialize;
57use serde::Serialize;
58use tokio::time::Duration;
59use tokio::time::Instant;
60
61use super::proc_actor::ProcMessage;
62use crate::proc_actor::Environment;
63use crate::proc_actor::ProcActor;
64use crate::proc_actor::ProcStopResult;
65use crate::supervision::ProcStatus;
66use crate::supervision::ProcSupervisionMessage;
67use crate::supervision::ProcSupervisionState;
68use crate::supervision::WorldSupervisionMessage;
69use crate::supervision::WorldSupervisionState;
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
73pub struct WorldSnapshotProcInfo {
74 pub labels: HashMap<String, String>,
76}
77
78impl From<&ProcInfo> for WorldSnapshotProcInfo {
79 fn from(proc_info: &ProcInfo) -> Self {
80 Self {
81 labels: proc_info.labels.clone(),
82 }
83 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
88pub struct WorldSnapshot {
89 pub host_procs: HashSet<ProcId>,
96
97 pub procs: HashMap<ProcId, WorldSnapshotProcInfo>,
99
100 pub status: WorldStatus,
102
103 pub labels: HashMap<String, String>,
106}
107
108impl WorldSnapshot {
109 fn from_world_filtered(world: &World, filter: &SystemSnapshotFilter) -> Self {
110 WorldSnapshot {
111 host_procs: world.state.host_map.keys().map(|h| &h.0).cloned().collect(),
112 procs: world
113 .state
114 .procs
115 .iter()
116 .map_while(|(k, v)| filter.proc_matches(v).then_some((k.clone(), v.into())))
117 .collect(),
118 status: world.state.status.clone(),
119 labels: world.labels.clone(),
120 }
121 }
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
126pub struct SystemSnapshot {
127 pub worlds: HashMap<WorldId, WorldSnapshot>,
129 pub execution_id: String,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named, Default)]
135pub struct SystemSnapshotFilter {
136 pub worlds: Vec<WorldId>,
138 pub world_labels: HashMap<String, String>,
140 pub proc_labels: HashMap<String, String>,
142}
143
144impl SystemSnapshotFilter {
145 pub fn all() -> Self {
147 Self {
148 worlds: Vec::new(),
149 world_labels: HashMap::new(),
150 proc_labels: HashMap::new(),
151 }
152 }
153
154 fn world_matches(&self, world: &World) -> bool {
156 if !self.worlds.is_empty() && !self.worlds.contains(&world.world_id) {
157 return false;
158 }
159 Self::labels_match(&self.world_labels, &world.labels)
160 }
161
162 fn proc_matches(&self, proc_info: &ProcInfo) -> bool {
163 Self::labels_match(&self.proc_labels, &proc_info.labels)
164 }
165
166 fn labels_match(
168 filter_labels: &HashMap<String, String>,
169 labels: &HashMap<String, String>,
170 ) -> bool {
171 filter_labels.is_empty()
172 || filter_labels
173 .iter()
174 .all(|(k, v)| labels.contains_key(k) && labels.get(k).unwrap() == v)
175 }
176}
177
178#[derive(Debug, Clone, PartialEq)]
181struct MaintainWorldHealth;
182
183#[derive(Named, Debug, Clone, Serialize, Deserialize, PartialEq)]
185pub enum ProcLifecycleMode {
186 Detached,
188 ManagedBySystem,
190 ManagingSystem,
193}
194
195impl ProcLifecycleMode {
196 pub fn is_managed(&self) -> bool {
198 matches!(
199 self,
200 ProcLifecycleMode::ManagedBySystem | ProcLifecycleMode::ManagingSystem
201 )
202 }
203}
204
205#[derive(
207 hyperactor::Handler,
208 HandleClient,
209 RefClient,
210 Named,
211 Debug,
212 Clone,
213 Serialize,
214 Deserialize,
215 PartialEq
216)]
217pub enum SystemMessage {
218 Join {
220 world_id: WorldId,
222 proc_id: ProcId,
224 proc_message_port: PortRef<ProcMessage>,
226 proc_addr: ChannelAddr,
228 labels: HashMap<String, String>,
230 lifecycle_mode: ProcLifecycleMode,
232 },
233
234 UpsertWorld {
236 world_id: WorldId,
238 shape: Shape,
240 num_procs_per_host: usize,
242 env: Environment,
244 labels: HashMap<String, String>,
246 },
247
248 #[log_level(debug)]
250 Snapshot {
251 filter: SystemSnapshotFilter,
253 #[reply]
255 ret: OncePortRef<SystemSnapshot>,
256 },
257
258 Stop {
265 worlds: Option<Vec<WorldId>>,
269 proc_timeout: Duration,
271 reply_port: OncePortRef<()>,
273 },
274}
275
276#[derive(thiserror::Error, Debug)]
278pub enum SystemActorError {
279 #[error("procs cannot join uncreated world {0}")]
281 UnknownWorldId(WorldId),
282
283 #[error("failed to spawn procs")]
285 SpawnProcsFailed(#[from] MailboxSenderError),
286
287 #[error("invalid host {0}: does not start with prefix '{SHADOW_PREFIX}'")]
289 InvalidHostPrefix(HostId),
290
291 #[error("host ID {0} already exists in world")]
293 DuplicatedHostId(HostId),
294
295 #[error("host {0} does not exist in world")]
297 HostNotExist(HostId),
298}
299
300#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
302pub enum Shape {
303 Definite(Vec<usize>),
307 Unknown,
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct World {
324 world_id: WorldId,
326 scheduler_params: SchedulerParams,
328 labels: HashMap<String, String>,
330 state: WorldState,
332}
333
334#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
335struct Host {
336 num_procs_assigned: usize,
337 proc_message_port: PortRef<ProcMessage>,
338 host_rank: usize,
339}
340
341impl Host {
342 fn new(proc_message_port: PortRef<ProcMessage>, host_rank: usize) -> Self {
343 Self {
344 num_procs_assigned: 0,
345 proc_message_port,
346 host_rank,
347 }
348 }
349
350 fn get_assigned_procs(
351 &mut self,
352 world_id: &WorldId,
353 scheduler_params: &mut SchedulerParams,
354 ) -> Vec<ProcId> {
355 let mut proc_ids = Vec::new();
357
358 let num_saturated_hosts =
362 scheduler_params.num_procs() / scheduler_params.num_procs_per_host;
363 let num_scheduled = if self.host_rank == num_saturated_hosts {
368 scheduler_params.num_procs() % scheduler_params.num_procs_per_host
369 } else {
370 scheduler_params.num_procs_per_host
371 };
372
373 scheduler_params.num_procs_scheduled += num_scheduled;
374
375 for _ in 0..num_scheduled {
376 let rank =
387 self.host_rank * scheduler_params.num_procs_per_host + self.num_procs_assigned;
388 let proc_id = ProcId::Ranked(world_id.clone(), rank);
389 proc_ids.push(proc_id);
390 self.num_procs_assigned += 1;
391 }
392
393 proc_ids
394 }
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
398struct SchedulerParams {
399 shape: Shape,
400 num_procs_scheduled: usize,
401 num_procs_per_host: usize,
402 next_rank: Index,
403 env: Environment,
404}
405
406impl SchedulerParams {
407 fn num_procs(&self) -> usize {
408 match &self.shape {
409 Shape::Definite(v) => v.iter().product(),
410 Shape::Unknown => unimplemented!(),
411 }
412 }
413}
414
415pub type HostWorldId = WorldId;
417static SHADOW_PREFIX: &str = "host";
418
419#[derive(
421 Debug,
422 Serialize,
423 Deserialize,
424 Clone,
425 PartialEq,
426 Eq,
427 PartialOrd,
428 Hash,
429 Ord
430)]
431pub struct HostId(ProcId);
432impl HostId {
433 pub fn new(proc_id: ProcId) -> Result<Self, anyhow::Error> {
435 if !proc_id
436 .world_name()
437 .expect("proc must be ranked for world_name check")
438 .starts_with(SHADOW_PREFIX)
439 {
440 anyhow::bail!(
441 "proc_id {} is not a valid HostId because it does not start with {}",
442 proc_id,
443 SHADOW_PREFIX
444 )
445 }
446 Ok(Self(proc_id))
447 }
448}
449
450impl TryFrom<ProcId> for HostId {
451 type Error = anyhow::Error;
452
453 fn try_from(proc_id: ProcId) -> Result<Self, anyhow::Error> {
454 if !proc_id
455 .world_name()
456 .expect("proc must be ranked for world_name check")
457 .starts_with(SHADOW_PREFIX)
458 {
459 anyhow::bail!(
460 "proc_id {} is not a valid HostId because it does not start with {}",
461 proc_id,
462 SHADOW_PREFIX
463 )
464 }
465 Ok(Self(proc_id))
466 }
467}
468
469impl std::fmt::Display for HostId {
470 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
471 self.0.fmt(f)
472 }
473}
474
475type HostMap = HashMap<HostId, Host>;
476
477#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
478struct ProcInfo {
479 port_ref: PortRef<ProcMessage>,
480 labels: HashMap<String, String>,
481}
482
483#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
484struct WorldState {
485 host_map: HostMap,
486 procs: HashMap<ProcId, ProcInfo>,
487 status: WorldStatus,
488}
489
490#[derive(Debug, Clone, Serialize, Deserialize, EnumAsInner, PartialEq)]
492pub enum WorldStatus {
493 AwaitingCreation,
495
496 Live,
499
500 Unhealthy(SystemTime),
504}
505
506impl Display for WorldStatus {
507 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
508 match self {
509 WorldStatus::AwaitingCreation => write!(f, "Awaiting Creation"),
510 WorldStatus::Live => write!(f, "Live"),
511 WorldStatus::Unhealthy(_) => write!(f, "Unhealthy"),
512 }
513 }
514}
515
516impl WorldState {
517 fn get_host_map_mut(&mut self) -> &mut HostMap {
519 &mut self.host_map
520 }
521
522 fn get_host_map(&self) -> &HostMap {
524 &self.host_map
525 }
526}
527
528impl World {
529 fn new(
530 world_id: WorldId,
531 shape: Shape,
532 state: WorldState,
533 num_procs_per_host: usize,
534 env: Environment,
535 labels: HashMap<String, String>,
536 ) -> Result<Self, anyhow::Error> {
537 if world_id.name().starts_with(SHADOW_PREFIX) {
538 anyhow::bail!(
539 "world name {} cannot start with {}!",
540 world_id,
541 SHADOW_PREFIX
542 )
543 }
544 tracing::info!("creating world {}", world_id,);
545 Ok(Self {
546 world_id,
547 scheduler_params: SchedulerParams {
548 shape,
549 num_procs_per_host,
550 num_procs_scheduled: 0,
551 next_rank: 0,
552 env,
553 },
554 state,
555 labels,
556 })
557 }
558
559 fn get_real_world_id(proc_world_id: &WorldId) -> WorldId {
560 WorldId(
561 proc_world_id
562 .name()
563 .strip_prefix(SHADOW_PREFIX)
564 .unwrap_or(proc_world_id.name())
565 .to_string(),
566 )
567 }
568
569 fn is_host_world(world_id: &WorldId) -> bool {
570 world_id.name().starts_with(SHADOW_PREFIX)
571 }
572
573 fn get_port_ref_from_host(
574 &self,
575 host_id: &HostId,
576 ) -> Result<PortRef<ProcMessage>, SystemActorError> {
577 let host_map = self.state.get_host_map();
578 match host_map.get(host_id) {
580 Some(h) => Ok(h.proc_message_port.clone()),
581 None => Err(SystemActorError::HostNotExist(host_id.clone())),
582 }
583 }
584
585 fn add_proc(
587 &mut self,
588 proc_id: ProcId,
589 proc_message_port: PortRef<ProcMessage>,
590 labels: HashMap<String, String>,
591 ) -> Result<(), SystemActorError> {
592 self.state.procs.insert(
593 proc_id,
594 ProcInfo {
595 port_ref: proc_message_port,
596 labels,
597 },
598 );
599 if self.state.status.is_unhealthy()
600 && self.state.procs.len() >= self.scheduler_params.num_procs()
601 {
602 self.state.status = WorldStatus::Live;
603 tracing::info!(
604 "world {}: ready to serve with {} procs",
605 self.world_id,
606 self.state.procs.len()
607 );
608 }
609 Ok(())
610 }
611
612 async fn on_host_join(
616 &mut self,
617 host_id: HostId,
618 proc_message_port: PortRef<ProcMessage>,
619 router: &DialMailboxRouter,
620 ) -> Result<(), SystemActorError> {
621 let mut host_entry = match self.state.host_map.entry(host_id.clone()) {
622 Entry::Occupied(_) => {
623 return Err(SystemActorError::DuplicatedHostId(host_id));
624 }
625 Entry::Vacant(entry) => entry.insert_entry(Host::new(
626 proc_message_port.clone(),
627 host_id
628 .0
629 .rank()
630 .expect("host proc must be ranked for rank access"),
631 )),
632 };
633
634 if self.state.status == WorldStatus::AwaitingCreation {
635 return Ok(());
636 }
637
638 let proc_ids = host_entry
639 .get_mut()
640 .get_assigned_procs(&self.world_id, &mut self.scheduler_params);
641
642 router.serialize_and_send(
643 &proc_message_port,
644 ProcMessage::SpawnProc {
645 env: self.scheduler_params.env.clone(),
646 world_id: self.world_id.clone(),
647 proc_ids,
648 world_size: self.scheduler_params.num_procs(),
649 },
650 monitored_return_handle(),
651 )?;
652 Ok(())
653 }
654
655 fn get_hosts_to_procs(&mut self) -> Result<HashMap<HostId, Vec<ProcId>>, SystemActorError> {
656 let mut host_proc_map: HashMap<HostId, Vec<ProcId>> = HashMap::new();
658 let host_map = self.state.get_host_map_mut();
659 for (host_id, host) in host_map {
661 if host.num_procs_assigned == self.scheduler_params.num_procs_per_host {
663 continue;
664 }
665 let host_procs = host.get_assigned_procs(&self.world_id, &mut self.scheduler_params);
666 if host_procs.is_empty() {
667 continue;
668 }
669 host_proc_map.insert(host_id.clone(), host_procs);
670 }
671 Ok(host_proc_map)
672 }
673
674 async fn on_create(&mut self, router: &DialMailboxRouter) -> Result<(), anyhow::Error> {
675 let host_procs_map = self.get_hosts_to_procs()?;
676 for (host_id, procs_ids) in host_procs_map {
677 if procs_ids.is_empty() {
678 continue;
679 }
680
681 let world_id = procs_ids
683 .first()
684 .unwrap()
685 .clone()
686 .into_ranked()
687 .expect("proc must be ranked for world_id access")
688 .0
689 .clone();
690 tracing::info!("spawning procs for host {:?}", host_id);
692 router.serialize_and_send(
693 &self.get_port_ref_from_host(&host_id)?,
695 ProcMessage::SpawnProc {
696 env: self.scheduler_params.env.clone(),
697 world_id,
698 proc_ids: procs_ids,
700 world_size: self.scheduler_params.num_procs(),
701 },
702 monitored_return_handle(),
703 )?;
704 }
705 Ok(())
706 }
707}
708
709#[derive(Debug, Clone)]
713pub struct ReportingRouter {
714 router: DialMailboxRouter,
715 address_cache: Arc<DashMap<ProcId, HashSet<ProcId>>>,
719}
720
721impl MailboxSender for ReportingRouter {
722 fn post_unchecked(
723 &self,
724 envelope: MessageEnvelope,
725 return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
726 ) {
727 let ReportingRouter { router, .. } = self;
728 self.post_update_address(&envelope);
729 router.post_unchecked(envelope, return_handle);
730 }
731}
732
733impl ReportingRouter {
734 fn new() -> Self {
735 Self {
736 router: DialMailboxRouter::new(),
737 address_cache: Arc::new(DashMap::new()),
738 }
739 }
740 fn post_update_address(&self, envelope: &MessageEnvelope) {
741 let system_proc_id = id!(system[0]);
742 if envelope.sender().proc_id() == &id!(unknown[0])
755 || envelope.sender().proc_id().world_id() == Some(&id!(user))
756 || envelope.sender().proc_id() == &system_proc_id
757 || envelope.dest().actor_id().proc_id() == &system_proc_id
758 || envelope.sender().proc_id() == envelope.dest().actor_id().proc_id()
759 {
760 return;
761 }
762 let (dst_proc_id, dst_proc_addr) = self.dest_proc_id_and_address(envelope);
763 let Some(dst_proc_addr) = dst_proc_addr else {
764 tracing::warn!("unknown address for {}", &dst_proc_id);
765 return;
766 };
767
768 let sender_proc_id = envelope.sender().proc_id();
769 self.upsert_address_cache(sender_proc_id, &dst_proc_id);
770 let sender_address = self.router.lookup_addr(envelope.sender());
773 let dst_proc_addr =
774 if let (Some(ChannelAddr::Sim(sender_sim_addr)), ChannelAddr::Sim(dest_sim_addr)) =
775 (sender_address, &dst_proc_addr)
776 {
777 ChannelAddr::Sim(
778 SimAddr::new_with_src(
779 sender_sim_addr.addr().clone(),
781 dest_sim_addr.addr().clone(),
783 )
784 .unwrap(),
785 )
786 } else {
787 dst_proc_addr
788 };
789 self.serialize_and_send(
790 &self.proc_port_ref(sender_proc_id),
791 MailboxAdminMessage::UpdateAddress {
792 proc_id: dst_proc_id,
793 addr: dst_proc_addr,
794 },
795 monitored_return_handle(),
796 )
797 .expect("unexpected serialization failure")
798 }
799
800 fn broadcast_addr(&self, dst_proc_id: &ProcId, dst_proc_addr: ChannelAddr) {
803 if let Some(r) = self.address_cache.get(dst_proc_id) {
804 for sender_proc_id in r.value() {
805 tracing::info!(
806 "broadcasting address change to {} for {}: {}",
807 sender_proc_id,
808 dst_proc_id,
809 dst_proc_addr
810 );
811 self.serialize_and_send(
812 &self.proc_port_ref(sender_proc_id),
813 MailboxAdminMessage::UpdateAddress {
814 proc_id: dst_proc_id.clone(),
815 addr: dst_proc_addr.clone(),
816 },
817 monitored_return_handle(),
818 )
819 .expect("unexpected serialization failure")
820 }
821 }
822 }
823
824 fn upsert_address_cache(&self, src_proc_id: &ProcId, dst_proc_id: &ProcId) {
825 self.address_cache
826 .entry(dst_proc_id.clone())
827 .and_modify(|src_proc_ids| {
828 src_proc_ids.insert(src_proc_id.clone());
829 })
830 .or_insert({
831 let mut set = HashSet::new();
832 set.insert(src_proc_id.clone());
833 set
834 });
835 }
836
837 fn dest_proc_id_and_address(
838 &self,
839 envelope: &MessageEnvelope,
840 ) -> (ProcId, Option<ChannelAddr>) {
841 let dest_proc_port_id = envelope.dest();
842 let dest_proc_actor_id = dest_proc_port_id.actor_id();
843 let dest_proc_id = dest_proc_actor_id.proc_id();
844 let dest_proc_addr = self.router.lookup_addr(dest_proc_actor_id);
845 (dest_proc_id.clone(), dest_proc_addr)
846 }
847
848 fn proc_port_ref(&self, proc_id: &ProcId) -> PortRef<MailboxAdminMessage> {
849 let proc_actor_id = ActorId(proc_id.clone(), "proc".to_string(), 0);
850 let proc_actor_ref = ActorRef::<ProcActor>::attest(proc_actor_id);
851 proc_actor_ref.port::<MailboxAdminMessage>()
852 }
853}
854
855#[derive(Debug, Clone)]
857pub struct SystemActorParams {
858 mailbox_router: ReportingRouter,
859
860 supervision_update_timeout: Duration,
862
863 world_eviction_timeout: Duration,
865}
866
867impl SystemActorParams {
868 pub fn new(supervision_update_timeout: Duration, world_eviction_timeout: Duration) -> Self {
870 Self {
871 mailbox_router: ReportingRouter::new(),
872 supervision_update_timeout,
873 world_eviction_timeout,
874 }
875 }
876}
877
878#[derive(Debug, Clone, Serialize, Deserialize)]
880struct SystemSupervisionState {
881 supervision_map: HashMap<WorldId, WorldSupervisionInfo>,
883 supervision_update_timeout: Duration,
885}
886
887#[derive(Debug, Clone, Default)]
889struct HeartbeatRecord {
890 btree_index: BTreeSet<(Instant, ProcId)>,
893
894 proc_last_update_time: HashMap<ProcId, Instant>,
896}
897
898impl HeartbeatRecord {
899 fn update(&mut self, proc_id: &ProcId, clock: &impl Clock) {
901 if let Some(update_time) = self.proc_last_update_time.get(proc_id) {
903 self.btree_index
904 .remove(&(update_time.clone(), proc_id.clone()));
905 }
906
907 let now = clock.now();
909 self.proc_last_update_time
910 .insert(proc_id.clone(), now.clone());
911 self.btree_index.insert((now.clone(), proc_id.clone()));
912 }
913
914 fn mark_expired_procs(
917 &self,
918 state: &mut WorldSupervisionState,
919 clock: &impl Clock,
920 supervision_update_timeout: Duration,
921 ) {
922 let now = clock.now();
924 self.btree_index
925 .iter()
926 .take_while(|(last_update_time, _)| {
927 now > *last_update_time + supervision_update_timeout
928 })
929 .for_each(|(_, proc_id)| {
930 if let Some(proc_state) = state
931 .procs
932 .get_mut(&proc_id.rank().expect("proc must be ranked for rank access"))
933 {
934 match proc_state.proc_health {
935 ProcStatus::Alive => proc_state.proc_health = ProcStatus::Expired,
936 _ => (),
938 }
939 }
940 });
941 }
942}
943
944#[derive(Debug, Clone, Serialize, Deserialize)]
945struct WorldSupervisionInfo {
946 state: WorldSupervisionState,
947
948 lifecycle_mode: HashMap<ProcId, ProcLifecycleMode>,
950
951 #[serde(skip)]
952 heartbeat_record: HeartbeatRecord,
953}
954
955impl WorldSupervisionInfo {
956 fn new() -> Self {
957 Self {
958 state: WorldSupervisionState {
959 procs: HashMap::new(),
960 },
961 lifecycle_mode: HashMap::new(),
962 heartbeat_record: HeartbeatRecord::default(),
963 }
964 }
965}
966
967impl SystemSupervisionState {
968 fn new(supervision_update_timeout: Duration) -> Self {
969 Self {
970 supervision_map: HashMap::new(),
971 supervision_update_timeout,
972 }
973 }
974
975 fn create(
977 &mut self,
978 proc_state: ProcSupervisionState,
979 lifecycle_mode: ProcLifecycleMode,
980 clock: &impl Clock,
981 ) {
982 if World::is_host_world(&proc_state.world_id) {
983 return;
984 }
985
986 let world = self
987 .supervision_map
988 .entry(proc_state.world_id.clone())
989 .or_insert_with(WorldSupervisionInfo::new);
990 world
991 .lifecycle_mode
992 .insert(proc_state.proc_id.clone(), lifecycle_mode);
993
994 self.update(proc_state, clock);
995 }
996
997 fn update(&mut self, proc_state: ProcSupervisionState, clock: &impl Clock) {
999 if World::is_host_world(&proc_state.world_id) {
1000 return;
1001 }
1002
1003 let world = self
1004 .supervision_map
1005 .entry(proc_state.world_id.clone())
1006 .or_insert_with(WorldSupervisionInfo::new);
1007
1008 world.heartbeat_record.update(&proc_state.proc_id, clock);
1009
1010 if let Some(info) = world.state.procs.get_mut(
1012 &proc_state
1013 .proc_id
1014 .rank()
1015 .expect("proc must be ranked for proc state update"),
1016 ) {
1017 match info.proc_health {
1018 ProcStatus::Alive => info.proc_health = proc_state.proc_health,
1019 _ => (),
1021 }
1022 info.failed_actors.extend(proc_state.failed_actors);
1023 } else {
1024 world.state.procs.insert(
1025 proc_state
1026 .proc_id
1027 .rank()
1028 .expect("proc must be ranked for rank access"),
1029 proc_state,
1030 );
1031 }
1032 }
1033
1034 fn report(&mut self, proc_state: ProcSupervisionState, clock: &impl Clock) {
1036 if World::is_host_world(&proc_state.world_id) {
1037 return;
1038 }
1039
1040 let proc_id = proc_state.proc_id.clone();
1041 match self.supervision_map.entry(proc_state.world_id.clone()) {
1042 Entry::Occupied(mut world_supervision_info) => {
1043 match world_supervision_info
1044 .get_mut()
1045 .state
1046 .procs
1047 .entry(proc_id.rank().expect("proc must be ranked for rank access"))
1048 {
1049 Entry::Occupied(_) => {
1050 self.update(proc_state, clock);
1051 }
1052 Entry::Vacant(_) => {
1053 tracing::error!("supervision not enabled for proc {}", &proc_id);
1054 }
1055 }
1056 }
1057 Entry::Vacant(_) => {
1058 tracing::error!("supervision not enabled for proc {}", &proc_id);
1059 }
1060 }
1061 }
1062
1063 fn get_world_with_failures(
1066 &mut self,
1067 world_id: &WorldId,
1068 clock: &impl Clock,
1069 ) -> Option<WorldSupervisionState> {
1070 if let Some(world) = self.supervision_map.get_mut(world_id) {
1071 world.heartbeat_record.mark_expired_procs(
1072 &mut world.state,
1073 clock,
1074 self.supervision_update_timeout,
1075 );
1076 let mut world_state_copy = world.state.clone();
1078 world_state_copy
1080 .procs
1081 .retain(|_, proc_state| !proc_state.is_healthy());
1082 return Some(world_state_copy);
1083 }
1084 None
1085 }
1086
1087 fn is_world_healthy(&mut self, world_id: &WorldId, clock: &impl Clock) -> bool {
1088 self.get_world_with_failures(world_id, clock)
1089 .is_none_or(|state| WorldSupervisionState::is_healthy(&state))
1090 }
1091}
1092
1093#[derive(Debug, Clone, Serialize, Deserialize)]
1094struct WorldStoppingState {
1095 stopping_procs: HashSet<ProcId>,
1096 stopped_procs: HashSet<ProcId>,
1097}
1098
1099#[derive(Debug, Clone, PartialEq, EnumAsInner)]
1101enum SystemStopMessage {
1102 StopSystemActor,
1103 EvictWorlds(Vec<WorldId>),
1104}
1105
1106#[derive(Debug, Clone)]
1111#[hyperactor::export(
1112 handlers = [
1113 SystemMessage,
1114 ProcSupervisionMessage,
1115 WorldSupervisionMessage,
1116 ],
1117)]
1118pub struct SystemActor {
1119 params: SystemActorParams,
1120 supervision_state: SystemSupervisionState,
1121 worlds: HashMap<WorldId, World>,
1122 worlds_to_stop: HashMap<WorldId, WorldStoppingState>,
1124 shutting_down: bool,
1125}
1126
1127pub static SYSTEM_WORLD: LazyLock<WorldId> = LazyLock::new(|| id!(system));
1129
1130static SYSTEM_ACTOR_ID: LazyLock<ActorId> = LazyLock::new(|| id!(system[0].root));
1132
1133pub static SYSTEM_ACTOR_REF: LazyLock<ActorRef<SystemActor>> =
1135 LazyLock::new(|| ActorRef::attest(id!(system[0].root)));
1136
1137impl SystemActor {
1138 fn add_new_world(&mut self, world_id: WorldId) -> Result<(), anyhow::Error> {
1140 let world_state = WorldState {
1141 host_map: HashMap::new(),
1142 procs: HashMap::new(),
1143 status: WorldStatus::AwaitingCreation,
1144 };
1145 let world = World::new(
1146 world_id.clone(),
1147 Shape::Unknown,
1148 world_state,
1149 0,
1150 Environment::Local,
1151 HashMap::new(),
1152 )?;
1153 self.worlds.insert(world_id.clone(), world);
1154 Ok(())
1155 }
1156
1157 fn router(&self) -> &ReportingRouter {
1158 &self.params.mailbox_router
1159 }
1160
1161 pub async fn bootstrap(
1165 params: SystemActorParams,
1166 ) -> Result<(ActorHandle<SystemActor>, Proc), anyhow::Error> {
1167 Self::bootstrap_with_clock(params, ClockKind::default()).await
1168 }
1169
1170 pub async fn bootstrap_with_clock(
1174 params: SystemActorParams,
1175 clock: ClockKind,
1176 ) -> Result<(ActorHandle<SystemActor>, Proc), anyhow::Error> {
1177 let system_proc = Proc::new_with_clock(
1178 SYSTEM_ACTOR_ID.proc_id().clone(),
1179 BoxedMailboxSender::new(params.mailbox_router.clone()),
1180 clock,
1181 );
1182 let actor_handle = system_proc
1183 .spawn::<SystemActor>(SYSTEM_ACTOR_ID.name(), params)
1184 .await?;
1185
1186 Ok((actor_handle, system_proc))
1187 }
1188
1189 fn evict_world(&mut self, world_id: &WorldId) {
1191 self.worlds.remove(world_id);
1192 self.supervision_state.supervision_map.remove(world_id);
1193 self.params
1195 .mailbox_router
1196 .router
1197 .unbind(&world_id.clone().into());
1198 }
1199}
1200
1201#[async_trait]
1202impl Actor for SystemActor {
1203 type Params = SystemActorParams;
1204
1205 async fn new(params: SystemActorParams) -> Result<Self, anyhow::Error> {
1206 let supervision_update_timeout = params.supervision_update_timeout.clone();
1207 Ok(Self {
1208 params,
1209 supervision_state: SystemSupervisionState::new(supervision_update_timeout),
1210 worlds: HashMap::new(),
1211 worlds_to_stop: HashMap::new(),
1212 shutting_down: false,
1213 })
1214 }
1215
1216 async fn init(&mut self, cx: &Instance<Self>) -> Result<(), anyhow::Error> {
1217 cx.self_message_with_delay(MaintainWorldHealth {}, Duration::from_secs(0))?;
1219 Ok(())
1220 }
1221
1222 async fn handle_undeliverable_message(
1223 &mut self,
1224 _cx: &Instance<Self>,
1225 Undeliverable(envelope): Undeliverable<MessageEnvelope>,
1226 ) -> Result<(), anyhow::Error> {
1227 let to = envelope.dest().clone();
1228 let from = envelope.sender().clone();
1229 tracing::info!(
1230 "a message from {} to {} was undeliverable and returned to the system actor",
1231 from,
1232 to,
1233 );
1234
1235 let proc_id = to.actor_id().proc_id();
1239 let world_id = proc_id
1240 .world_id()
1241 .expect("proc must be ranked for world_id access");
1242 if let Some(world) = &mut self.supervision_state.supervision_map.get_mut(world_id) {
1243 if let Some(proc) = world
1244 .state
1245 .procs
1246 .get_mut(&proc_id.rank().expect("proc must be ranked for rank access"))
1247 {
1248 match proc.proc_health {
1249 ProcStatus::Alive => proc.proc_health = ProcStatus::ConnectionFailure,
1250 _ => (),
1253 }
1254 } else {
1255 tracing::error!(
1256 "can't update proc {} status because there isn't one",
1257 proc_id
1258 );
1259 }
1260 } else {
1261 tracing::error!(
1262 "can't update world {} status because there isn't one",
1263 world_id
1264 );
1265 }
1266 Ok(())
1267 }
1268}
1269
1270#[async_trait]
1282#[hyperactor::forward(SystemMessage)]
1283impl SystemMessageHandler for SystemActor {
1284 async fn join(
1285 &mut self,
1286 cx: &Context<Self>,
1287 world_id: WorldId,
1288 proc_id: ProcId,
1289 proc_message_port: PortRef<ProcMessage>,
1290 channel_addr: ChannelAddr,
1291 labels: HashMap<String, String>,
1292 lifecycle_mode: ProcLifecycleMode,
1293 ) -> Result<(), anyhow::Error> {
1294 tracing::info!("received join for proc {} in world {}", proc_id, world_id);
1295 self.router()
1297 .router
1298 .bind(proc_id.clone().into(), channel_addr.clone());
1299
1300 self.router().broadcast_addr(&proc_id, channel_addr.clone());
1301
1302 self.router().serialize_and_send(
1304 &proc_message_port,
1305 ProcMessage::Joined(),
1306 monitored_return_handle(),
1307 )?;
1308
1309 if lifecycle_mode.is_managed() {
1310 self.supervision_state.create(
1311 ProcSupervisionState {
1312 world_id: world_id.clone(),
1313 proc_id: proc_id.clone(),
1314 proc_addr: channel_addr.clone(),
1315 proc_health: ProcStatus::Alive,
1316 failed_actors: Vec::new(),
1317 },
1318 lifecycle_mode.clone(),
1319 cx.clock(),
1320 );
1321 }
1322
1323 if lifecycle_mode != ProcLifecycleMode::ManagedBySystem {
1326 tracing::info!("ignoring join for proc {} in world {}", proc_id, world_id);
1327 return Ok(());
1328 }
1329
1330 let world_id = World::get_real_world_id(&world_id);
1331 if !self.worlds.contains_key(&world_id) {
1332 self.add_new_world(world_id.clone())?;
1333 }
1334 let world = self
1335 .worlds
1336 .get_mut(&world_id)
1337 .ok_or(anyhow::anyhow!("failed to get world from map"))?;
1338
1339 match HostId::try_from(proc_id.clone()) {
1340 Ok(host_id) => {
1341 tracing::info!("{}: adding host {}", world_id, host_id);
1342 return world
1343 .on_host_join(
1344 host_id,
1345 proc_message_port,
1346 &self.params.mailbox_router.router,
1347 )
1348 .await
1349 .map_err(anyhow::Error::from);
1350 }
1351 Err(_) => {
1354 tracing::info!("proc {} joined to world {}", &proc_id, &world_id,);
1355 if let Err(e) = world.add_proc(proc_id.clone(), proc_message_port, labels) {
1359 tracing::warn!(
1360 "failed to add proc {} to world {}: {}",
1361 &proc_id,
1362 &world_id,
1363 e
1364 );
1365 }
1366 }
1367 };
1368 Ok(())
1369 }
1370
1371 async fn upsert_world(
1372 &mut self,
1373 cx: &Context<Self>,
1374 world_id: WorldId,
1375 shape: Shape,
1376 num_procs_per_host: usize,
1377 env: Environment,
1378 labels: HashMap<String, String>,
1379 ) -> Result<(), anyhow::Error> {
1380 tracing::info!("received upsert_world for world {}!", world_id);
1381 match self.worlds.get_mut(&world_id) {
1382 Some(world) => {
1383 tracing::info!("found existing world {}!", world_id);
1384 match &world.state.status {
1385 WorldStatus::AwaitingCreation => {
1386 world.scheduler_params.shape = shape;
1387 world.scheduler_params.num_procs_per_host = num_procs_per_host;
1388 world.scheduler_params.env = env;
1389 world.state = WorldState {
1390 host_map: world.state.host_map.clone(),
1391 procs: world.state.procs.clone(),
1392 status: if world.state.procs.len() < world.scheduler_params.num_procs()
1393 || !self
1394 .supervision_state
1395 .is_world_healthy(&world_id, cx.clock())
1396 {
1397 WorldStatus::Unhealthy(cx.clock().system_time_now())
1398 } else {
1399 WorldStatus::Live
1400 },
1401 };
1402 for (k, v) in labels {
1403 if world.labels.contains_key(&k) {
1404 anyhow::bail!("cannot overwrite world label: {}", k);
1405 }
1406 world.labels.insert(k.clone(), v.clone());
1407 }
1408 }
1409 _ => {
1410 anyhow::bail!("cannot modify world {}: already exists", world.world_id)
1411 }
1412 }
1413
1414 world.on_create(&self.params.mailbox_router.router).await?;
1415 tracing::info!(
1416 "modified parameters to world {} with shape: {:?} and labels {:?}",
1417 &world.world_id,
1418 world.scheduler_params.shape,
1419 world.labels
1420 );
1421 }
1422 None => {
1423 let world = World::new(
1424 world_id.clone(),
1425 shape.clone(),
1426 WorldState {
1427 host_map: HashMap::new(),
1428 procs: HashMap::new(),
1429 status: WorldStatus::Unhealthy(cx.clock().system_time_now()),
1430 },
1431 num_procs_per_host,
1432 env,
1433 labels,
1434 )?;
1435 tracing::info!("new world {} added with shape: {:?}", world_id, &shape);
1436 self.worlds.insert(world_id, world);
1437 }
1438 };
1439 Ok(())
1440 }
1441
1442 async fn snapshot(
1443 &mut self,
1444 _cx: &Context<Self>,
1445 filter: SystemSnapshotFilter,
1446 ) -> Result<SystemSnapshot, anyhow::Error> {
1447 let world_snapshots = self
1448 .worlds
1449 .iter()
1450 .filter(|(_, world)| filter.world_matches(world))
1451 .map(|(world_id, world)| {
1452 (
1453 world_id.clone(),
1454 WorldSnapshot::from_world_filtered(world, &filter),
1455 )
1456 })
1457 .collect();
1458 Ok(SystemSnapshot {
1459 worlds: world_snapshots,
1460 execution_id: hyperactor_telemetry::env::execution_id(),
1461 })
1462 }
1463
1464 async fn stop(
1465 &mut self,
1466 cx: &Context<Self>,
1467 worlds: Option<Vec<WorldId>>,
1468 proc_timeout: Duration,
1469 reply_port: OncePortRef<()>,
1470 ) -> Result<(), anyhow::Error> {
1471 match &worlds {
1474 Some(world_ids) => {
1475 tracing::info!("stopping worlds: {:?}", world_ids);
1476 }
1477 None => {
1478 tracing::info!("stopping system actor and all worlds");
1479 self.shutting_down = true;
1480 }
1481 }
1482
1483 if self.worlds.is_empty() && self.shutting_down {
1485 cx.stop()?;
1486 reply_port.send(cx, ())?;
1487 return Ok(());
1488 }
1489
1490 let mut world_ids = vec![];
1491 match &worlds {
1492 Some(worlds) => {
1493 world_ids.extend(worlds.clone().into_iter().collect::<Vec<_>>());
1495 }
1496 None => {
1497 world_ids.extend(
1499 self.worlds
1500 .keys()
1501 .filter(|x| x.name() != "user")
1502 .cloned()
1503 .collect::<Vec<_>>(),
1504 );
1505 }
1506 }
1507
1508 for world_id in &world_ids {
1509 if self.worlds_to_stop.contains_key(world_id) || !self.worlds.contains_key(world_id) {
1510 continue;
1512 }
1513 self.worlds_to_stop.insert(
1514 world_id.clone(),
1515 WorldStoppingState {
1516 stopping_procs: HashSet::new(),
1517 stopped_procs: HashSet::new(),
1518 },
1519 );
1520 }
1521
1522 let all_procs = self
1523 .worlds
1524 .iter()
1525 .filter(|(world_id, _)| match &worlds {
1526 Some(worlds_ids) => worlds_ids.contains(world_id),
1527 None => true,
1528 })
1529 .flat_map(|(_, world)| {
1530 world
1531 .state
1532 .host_map
1533 .iter()
1534 .map(|(host_id, host)| (host_id.0.clone(), host.proc_message_port.clone()))
1535 .chain(
1536 world
1537 .state
1538 .procs
1539 .iter()
1540 .map(|(proc_id, info)| (proc_id.clone(), info.port_ref.clone())),
1541 )
1542 .collect::<Vec<_>>()
1543 })
1544 .collect::<HashMap<_, _>>();
1545
1546 for (proc_id, port) in all_procs.into_iter() {
1550 let stopping_state = self
1551 .worlds_to_stop
1552 .get_mut(&World::get_real_world_id(
1553 proc_id
1554 .world_id()
1555 .expect("proc must be ranked for world_id access"),
1556 ))
1557 .unwrap();
1558 if !stopping_state.stopping_procs.insert(proc_id) {
1559 continue;
1560 }
1561
1562 let reply_to = cx.port::<ProcStopResult>().bind().into_once();
1568 port.send(
1569 cx,
1570 ProcMessage::Stop {
1571 timeout: proc_timeout,
1572 reply_to,
1573 },
1574 )?;
1575 }
1576
1577 let stop_msg = match &worlds {
1578 Some(_) => SystemStopMessage::EvictWorlds(world_ids.clone()),
1579 None => SystemStopMessage::StopSystemActor {},
1580 };
1581
1582 cx.self_message_with_delay(stop_msg, Duration::from_secs(8))?;
1584
1585 reply_port.send(cx, ())?;
1586 Ok(())
1587 }
1588}
1589
1590#[async_trait]
1591impl Handler<MaintainWorldHealth> for SystemActor {
1592 async fn handle(&mut self, cx: &Context<Self>, _: MaintainWorldHealth) -> anyhow::Result<()> {
1593 let mut next_check_delay = self.params.world_eviction_timeout;
1597 tracing::debug!("Checking world state. Got {} worlds", self.worlds.len());
1598
1599 for world in self.worlds.values_mut() {
1600 if world.state.status == WorldStatus::AwaitingCreation {
1601 continue;
1602 }
1603
1604 let Some(state) = self
1605 .supervision_state
1606 .get_world_with_failures(&world.world_id, cx.clock())
1607 else {
1608 tracing::debug!("world {} does not have failures, skipping.", world.world_id);
1609 continue;
1610 };
1611
1612 if state.is_healthy() {
1613 tracing::debug!(
1614 "world {} with procs {:?} is healthy, skipping.",
1615 world.world_id,
1616 state
1617 .procs
1618 .values()
1619 .map(|p| p.proc_id.clone())
1620 .collect::<Vec<_>>()
1621 );
1622 continue;
1623 }
1624 for (_, proc_state) in state.procs.iter() {
1626 if proc_state.proc_health == ProcStatus::Alive {
1627 tracing::debug!("proc {} is still alive.", proc_state.proc_id);
1628 continue;
1629 }
1630 if self
1631 .supervision_state
1632 .supervision_map
1633 .get(&world.world_id)
1634 .and_then(|world| world.lifecycle_mode.get(&proc_state.proc_id))
1635 .map_or(true, |mode| *mode != ProcLifecycleMode::ManagingSystem)
1636 {
1637 tracing::debug!(
1638 "proc {} with state {} does not manage system.",
1639 proc_state.proc_id,
1640 proc_state.proc_health
1641 );
1642 continue;
1643 }
1644
1645 tracing::error!(
1646 "proc {} is unhealthy, stop the system as the proc manages the system",
1647 proc_state.proc_id
1648 );
1649
1650 let (tx, _) = cx.open_once_port::<()>();
1652 cx.port().send(SystemMessage::Stop {
1653 worlds: None,
1654 proc_timeout: Duration::from_secs(5),
1655 reply_port: tx.bind(),
1656 })?;
1657 }
1658
1659 if world.state.status == WorldStatus::Live {
1660 world.state.status = WorldStatus::Unhealthy(cx.clock().system_time_now());
1661 }
1662
1663 match &world.state.status {
1664 WorldStatus::Unhealthy(last_unhealthy_time) => {
1665 let elapsed = last_unhealthy_time
1666 .elapsed()
1667 .inspect_err(|err| {
1668 tracing::error!(
1669 "failed to get elapsed time for unhealthy world {}: {}",
1670 world.world_id,
1671 err
1672 )
1673 })
1674 .unwrap_or_else(|_| Duration::from_secs(0));
1675
1676 if elapsed < self.params.world_eviction_timeout {
1677 next_check_delay = std::cmp::min(
1679 next_check_delay,
1680 self.params.world_eviction_timeout - elapsed,
1681 );
1682 } else {
1683 next_check_delay = Duration::from_secs(0);
1684 }
1685 }
1686 _ => {
1687 tracing::error!(
1688 "find a failed world {} with healthy state {}",
1689 world.world_id,
1690 world.state.status
1691 );
1692 continue;
1693 }
1694 }
1695 }
1696 cx.self_message_with_delay(MaintainWorldHealth {}, next_check_delay)?;
1697
1698 Ok(())
1699 }
1700}
1701
1702#[async_trait]
1703impl Handler<ProcSupervisionMessage> for SystemActor {
1704 async fn handle(
1705 &mut self,
1706 cx: &Context<Self>,
1707 msg: ProcSupervisionMessage,
1708 ) -> anyhow::Result<()> {
1709 match msg {
1710 ProcSupervisionMessage::Update(state, reply_port) => {
1711 self.supervision_state.report(state, cx.clock());
1712 let _ = reply_port.send(cx, ());
1713 }
1714 }
1715 Ok(())
1716 }
1717}
1718
1719#[async_trait]
1720impl Handler<WorldSupervisionMessage> for SystemActor {
1721 async fn handle(
1722 &mut self,
1723 cx: &Context<Self>,
1724 msg: WorldSupervisionMessage,
1725 ) -> anyhow::Result<()> {
1726 match msg {
1727 WorldSupervisionMessage::State(world_id, reply_port) => {
1728 let world_state = self
1729 .supervision_state
1730 .get_world_with_failures(&world_id, cx.clock());
1731 let _ = reply_port.send(cx, world_state);
1733 }
1734 }
1735 Ok(())
1736 }
1737}
1738
1739#[async_trait]
1742impl Handler<ProcStopResult> for SystemActor {
1743 async fn handle(&mut self, cx: &Context<Self>, msg: ProcStopResult) -> anyhow::Result<()> {
1744 fn stopping_proc_msg<'a>(sprocs: impl Iterator<Item = &'a ProcId>) -> String {
1745 let sprocs = sprocs.collect::<Vec<_>>();
1746 if sprocs.is_empty() {
1747 return "no procs left".to_string();
1748 }
1749 let msg = sprocs
1750 .iter()
1751 .take(3)
1752 .map(|proc_id| proc_id.to_string())
1753 .collect::<Vec<_>>()
1754 .join(", ");
1755 if sprocs.len() > 3 {
1756 format!("remaining procs: {} and {} more", msg, sprocs.len() - 3)
1757 } else {
1758 format!("remaining procs: {}", msg)
1759 }
1760 }
1761 let mut world_stopped = false;
1762 let world_id = &msg
1763 .proc_id
1764 .clone()
1765 .into_ranked()
1766 .expect("proc must be ranked for world_id access")
1767 .0;
1768 if let Some(stopping_state) = self.worlds_to_stop.get_mut(world_id) {
1769 stopping_state.stopped_procs.insert(msg.proc_id.clone());
1770 tracing::debug!(
1771 "received stop response from {}: {} stopped actors, {} aborted actors: {}",
1772 msg.proc_id,
1773 msg.actors_stopped,
1774 msg.actors_aborted,
1775 stopping_proc_msg(
1776 stopping_state
1777 .stopping_procs
1778 .difference(&stopping_state.stopped_procs)
1779 ),
1780 );
1781 world_stopped =
1782 stopping_state.stopping_procs.len() == stopping_state.stopped_procs.len();
1783 } else {
1784 tracing::warn!(
1785 "received stop response from {} but no inflight stopping request is found, possibly late response",
1786 msg.proc_id
1787 );
1788 }
1789
1790 if world_stopped {
1791 self.evict_world(world_id);
1792 self.worlds_to_stop.remove(world_id);
1793 }
1794
1795 if self.shutting_down && self.worlds.is_empty() {
1796 cx.stop()?;
1797 }
1798
1799 Ok(())
1800 }
1801}
1802
1803#[async_trait]
1804impl Handler<SystemStopMessage> for SystemActor {
1805 async fn handle(
1806 &mut self,
1807 cx: &Context<Self>,
1808 message: SystemStopMessage,
1809 ) -> anyhow::Result<()> {
1810 match message {
1811 SystemStopMessage::EvictWorlds(world_ids) => {
1812 for world_id in &world_ids {
1813 if self.worlds_to_stop.contains_key(world_id) {
1814 tracing::warn!(
1815 "Waiting for world to stop timed out, evicting world anyways: {:?}",
1816 world_id
1817 );
1818 self.evict_world(world_id);
1819 }
1820 }
1821 }
1822 SystemStopMessage::StopSystemActor => {
1823 if self.worlds_to_stop.is_empty() {
1824 tracing::warn!(
1825 "waiting for all worlds to stop timed out, stopping the system actor and evicting the these worlds anyways: {:?}",
1826 self.worlds_to_stop.keys()
1827 );
1828 } else {
1829 tracing::warn!(
1830 "waiting for all worlds to stop timed out, stopping the system actor"
1831 );
1832 }
1833
1834 cx.stop()?;
1835 }
1836 }
1837 Ok(())
1838 }
1839}
1840
1841#[cfg(test)]
1842mod tests {
1843 use std::assert_matches::assert_matches;
1844
1845 use anyhow::Result;
1846 use hyperactor::PortId;
1847 use hyperactor::actor::ActorStatus;
1848 use hyperactor::attrs::Attrs;
1849 use hyperactor::channel;
1850 use hyperactor::channel::ChannelTransport;
1851 use hyperactor::channel::Rx;
1852 use hyperactor::channel::TcpMode;
1853 use hyperactor::clock::Clock;
1854 use hyperactor::clock::RealClock;
1855 use hyperactor::data::Serialized;
1856 use hyperactor::mailbox::Mailbox;
1857 use hyperactor::mailbox::MailboxServer;
1858 use hyperactor::mailbox::MessageEnvelope;
1859 use hyperactor::mailbox::PortHandle;
1860 use hyperactor::mailbox::PortReceiver;
1861 use hyperactor::simnet;
1862 use hyperactor::test_utils::pingpong::PingPongActorParams;
1863
1864 use super::*;
1865 use crate::System;
1866
1867 struct MockHostActor {
1868 local_proc_id: ProcId,
1869 local_proc_addr: ChannelAddr,
1870 local_proc_message_port: PortHandle<ProcMessage>,
1871 local_proc_message_receiver: PortReceiver<ProcMessage>,
1872 }
1873
1874 async fn spawn_mock_host_actor(proc_world_id: WorldId, host_id: usize) -> MockHostActor {
1875 let local_proc_id = ProcId::Ranked(
1877 WorldId(format!("{}{}", SHADOW_PREFIX, proc_world_id.name())),
1878 host_id,
1879 );
1880 let (local_proc_addr, local_proc_rx) =
1881 channel::serve::<MessageEnvelope>(ChannelAddr::any(ChannelTransport::Local)).unwrap();
1882 let local_proc_mbox = Mailbox::new_detached(local_proc_id.actor_id("test".to_string(), 0));
1883 let (local_proc_message_port, local_proc_message_receiver) = local_proc_mbox.open_port();
1884 let _local_proc_serve_handle = local_proc_mbox.clone().serve(local_proc_rx);
1885 MockHostActor {
1886 local_proc_id,
1887 local_proc_addr,
1888 local_proc_message_port,
1889 local_proc_message_receiver,
1890 }
1891 }
1892
1893 #[tokio::test]
1894 async fn test_supervision_state() {
1895 let mut sv = SystemSupervisionState::new(Duration::from_secs(1));
1896 let world_id = id!(world);
1897 let proc_id_0 = world_id.proc_id(0);
1898 let clock = ClockKind::Real(RealClock);
1899 sv.create(
1900 ProcSupervisionState {
1901 world_id: world_id.clone(),
1902 proc_addr: ChannelAddr::any(ChannelTransport::Local),
1903 proc_id: proc_id_0.clone(),
1904 proc_health: ProcStatus::Alive,
1905 failed_actors: Vec::new(),
1906 },
1907 ProcLifecycleMode::ManagedBySystem,
1908 &clock,
1909 );
1910 let actor_id = id!(world[1].actor);
1911 let proc_id_1 = actor_id.proc_id();
1912 sv.create(
1913 ProcSupervisionState {
1914 world_id: world_id.clone(),
1915 proc_addr: ChannelAddr::any(ChannelTransport::Local),
1916 proc_id: proc_id_1.clone(),
1917 proc_health: ProcStatus::Alive,
1918 failed_actors: Vec::new(),
1919 },
1920 ProcLifecycleMode::ManagedBySystem,
1921 &clock,
1922 );
1923 let world_id = id!(world);
1924
1925 let unknown_world_id = id!(unknow_world);
1926 let failures = sv.get_world_with_failures(&unknown_world_id, &clock);
1927 assert!(failures.is_none());
1928
1929 let failures = sv.get_world_with_failures(&world_id, &clock);
1931 assert!(failures.is_some());
1932 assert_eq!(failures.unwrap().procs.len(), 0);
1933
1934 RealClock.sleep(Duration::from_secs(2)).await;
1936 sv.report(
1937 ProcSupervisionState {
1938 world_id: world_id.clone(),
1939 proc_addr: ChannelAddr::any(ChannelTransport::Local),
1940 proc_id: proc_id_1.clone(),
1941 proc_health: ProcStatus::Alive,
1942 failed_actors: Vec::new(),
1943 },
1944 &clock,
1945 );
1946 let failures = sv.get_world_with_failures(&world_id, &clock);
1947 let procs = failures.unwrap().procs;
1948 assert_eq!(procs.len(), 1);
1949 assert!(
1950 procs.contains_key(
1951 &proc_id_0
1952 .rank()
1953 .expect("proc must be ranked for rank access")
1954 )
1955 );
1956
1957 sv.report(
1959 ProcSupervisionState {
1960 world_id: world_id.clone(),
1961 proc_addr: ChannelAddr::any(ChannelTransport::Local),
1962 proc_id: proc_id_1.clone(),
1963 proc_health: ProcStatus::Alive,
1964 failed_actors: [(
1965 actor_id.clone(),
1966 ActorStatus::generic_failure("Actor failed"),
1967 )]
1968 .to_vec(),
1969 },
1970 &clock,
1971 );
1972
1973 let failures = sv.get_world_with_failures(&world_id, &clock);
1974 let procs = failures.unwrap().procs;
1975 assert_eq!(procs.len(), 2);
1976 assert!(
1977 procs.contains_key(
1978 &proc_id_0
1979 .rank()
1980 .expect("proc must be ranked for rank access")
1981 )
1982 );
1983 assert!(
1984 procs.contains_key(
1985 &proc_id_1
1986 .rank()
1987 .expect("proc must be ranked for rank access")
1988 )
1989 );
1990 }
1991
1992 #[tracing_test::traced_test]
1993 #[tokio::test]
1994 async fn test_host_join_before_world() {
1995 let params = SystemActorParams::new(Duration::from_secs(10), Duration::from_secs(10));
1997 let (system_actor_handle, _system_proc) = SystemActor::bootstrap(params).await.unwrap();
1998
1999 let mut host_actors: Vec<MockHostActor> = Vec::new();
2001
2002 let world_name = "test".to_string();
2003 let world_id = WorldId(world_name.clone());
2004 host_actors.push(spawn_mock_host_actor(world_id.clone(), 0).await);
2005 host_actors.push(spawn_mock_host_actor(world_id.clone(), 1).await);
2006
2007 for host_actor in host_actors.iter_mut() {
2008 system_actor_handle
2010 .send(SystemMessage::Join {
2011 proc_id: host_actor.local_proc_id.clone(),
2012 world_id: world_id.clone(),
2013 proc_message_port: host_actor.local_proc_message_port.bind(),
2014 proc_addr: host_actor.local_proc_addr.clone(),
2015 labels: HashMap::new(),
2016 lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
2017 })
2018 .unwrap();
2019
2020 assert_matches!(
2023 host_actor.local_proc_message_receiver.recv().await.unwrap(),
2024 ProcMessage::Joined()
2025 );
2026 }
2027
2028 let num_procs = 6;
2030 let shape = Shape::Definite(vec![2, 3]);
2031 system_actor_handle
2032 .send(SystemMessage::UpsertWorld {
2033 world_id: world_id.clone(),
2034 shape,
2035 num_procs_per_host: 3,
2036 env: Environment::Local,
2037 labels: HashMap::new(),
2038 })
2039 .unwrap();
2040
2041 let mut all_procs: Vec<ProcId> = Vec::new();
2042 for host_actor in host_actors.iter_mut() {
2043 let m = host_actor.local_proc_message_receiver.recv().await.unwrap();
2044 match m {
2045 ProcMessage::SpawnProc {
2046 env,
2047 world_id,
2048 mut proc_ids,
2049 world_size,
2050 } => {
2051 assert_eq!(world_id, WorldId(world_name.clone()));
2052 assert_eq!(env, Environment::Local);
2053 assert_eq!(world_size, num_procs);
2054 all_procs.append(&mut proc_ids);
2055 }
2056 _ => std::panic!("Unexpected message type!"),
2057 }
2058 }
2059 assert_eq!(all_procs.len(), num_procs);
2061 all_procs.sort();
2062 for (i, proc) in all_procs.iter().enumerate() {
2063 assert_eq!(*proc, ProcId::Ranked(WorldId(world_name.clone()), i));
2064 }
2065 }
2066
2067 #[tokio::test]
2068 async fn test_host_join_after_world() {
2069 let params = SystemActorParams::new(Duration::from_secs(10), Duration::from_secs(10));
2071 let (system_actor_handle, _system_proc) = SystemActor::bootstrap(params).await.unwrap();
2072
2073 let world_name = "test".to_string();
2075 let world_id = WorldId(world_name.clone());
2076 let num_procs = 6;
2077 let shape = Shape::Definite(vec![2, 3]);
2078 system_actor_handle
2079 .send(SystemMessage::UpsertWorld {
2080 world_id: world_id.clone(),
2081 shape,
2082 num_procs_per_host: 3,
2083 env: Environment::Local,
2084 labels: HashMap::new(),
2085 })
2086 .unwrap();
2087
2088 let mut host_actors: Vec<MockHostActor> = Vec::new();
2090
2091 host_actors.push(spawn_mock_host_actor(world_id.clone(), 0).await);
2092 host_actors.push(spawn_mock_host_actor(world_id.clone(), 1).await);
2093
2094 for host_actor in host_actors.iter_mut() {
2095 system_actor_handle
2097 .send(SystemMessage::Join {
2098 proc_id: host_actor.local_proc_id.clone(),
2099 world_id: world_id.clone(),
2100 proc_message_port: host_actor.local_proc_message_port.bind(),
2101 proc_addr: host_actor.local_proc_addr.clone(),
2102 labels: HashMap::new(),
2103 lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
2104 })
2105 .unwrap();
2106
2107 assert_matches!(
2110 host_actor.local_proc_message_receiver.recv().await.unwrap(),
2111 ProcMessage::Joined()
2112 );
2113 }
2114
2115 let mut all_procs: Vec<ProcId> = Vec::new();
2116 for host_actor in host_actors.iter_mut() {
2117 let m = host_actor.local_proc_message_receiver.recv().await.unwrap();
2118 match m {
2119 ProcMessage::SpawnProc {
2120 env,
2121 world_id,
2122 mut proc_ids,
2123 world_size,
2124 } => {
2125 assert_eq!(world_id, WorldId(world_name.clone()));
2126 assert_eq!(env, Environment::Local);
2127 assert_eq!(world_size, num_procs);
2128 all_procs.append(&mut proc_ids);
2129 }
2130 _ => std::panic!("Unexpected message type!"),
2131 }
2132 }
2133 assert_eq!(all_procs.len(), num_procs);
2135 all_procs.sort();
2136 for (i, proc) in all_procs.iter().enumerate() {
2137 assert_eq!(*proc, ProcId::Ranked(WorldId(world_name.clone()), i));
2138 }
2139 }
2140
2141 #[test]
2142 fn test_snapshot_filter() {
2143 let test_world = World::new(
2144 WorldId("test_world".to_string()),
2145 Shape::Definite(vec![1]),
2146 WorldState {
2147 host_map: HashMap::new(),
2148 procs: HashMap::new(),
2149 status: WorldStatus::Live,
2150 },
2151 1,
2152 Environment::Local,
2153 HashMap::from([("foo".to_string(), "bar".to_string())]),
2154 )
2155 .unwrap();
2156 let filter = SystemSnapshotFilter::all();
2158 assert!(filter.world_matches(&test_world));
2159 assert!(SystemSnapshotFilter::labels_match(
2160 &HashMap::new(),
2161 &HashMap::from([("foo".to_string(), "bar".to_string())])
2162 ));
2163 let mut filter = SystemSnapshotFilter::all();
2165 filter.worlds = vec![WorldId("test_world".to_string())];
2166 assert!(filter.world_matches(&test_world));
2167 filter.worlds = vec![WorldId("unknow_world".to_string())];
2168 assert!(!filter.world_matches(&test_world));
2169 assert!(SystemSnapshotFilter::labels_match(
2170 &HashMap::from([("foo".to_string(), "baz".to_string())]),
2171 &HashMap::from([("foo".to_string(), "baz".to_string())]),
2172 ));
2173 assert!(!SystemSnapshotFilter::labels_match(
2174 &HashMap::from([("foo".to_string(), "bar".to_string())]),
2175 &HashMap::from([("foo".to_string(), "baz".to_string())]),
2176 ));
2177 }
2178
2179 #[tokio::test]
2180 async fn test_undeliverable_message_return() {
2181 use hyperactor::mailbox::MailboxClient;
2184 use hyperactor::test_utils::pingpong::PingPongActor;
2185 use hyperactor::test_utils::pingpong::PingPongMessage;
2186
2187 use crate::System;
2188 use crate::proc_actor::ProcActor;
2189 use crate::supervision::ProcSupervisor;
2190
2191 let config = hyperactor::config::global::lock();
2193 let _guard = config.override_key(
2194 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2195 Duration::from_secs(1),
2196 );
2197
2198 let server_handle = System::serve(
2201 ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
2202 Duration::from_secs(2), Duration::from_secs(2), )
2205 .await
2206 .unwrap();
2207 let system_actor_handle = server_handle.system_actor_handle();
2208 let mut system = System::new(server_handle.local_addr().clone());
2209 let client = system.attach().await.unwrap();
2210
2211 let snapshot = system_actor_handle
2213 .snapshot(&client, SystemSnapshotFilter::all())
2214 .await
2215 .unwrap();
2216 assert_eq!(snapshot.worlds.len(), 0);
2217
2218 let world_id = id!(world);
2220 system_actor_handle
2221 .send(SystemMessage::UpsertWorld {
2222 world_id: world_id.clone(),
2223 shape: Shape::Definite(vec![1]),
2224 num_procs_per_host: 1,
2225 env: Environment::Local,
2226 labels: HashMap::new(),
2227 })
2228 .unwrap();
2229
2230 let snapshot = system_actor_handle
2232 .snapshot(&client, SystemSnapshotFilter::all())
2233 .await
2234 .unwrap();
2235 assert_eq!(snapshot.worlds.len(), 1);
2236 assert!(snapshot.worlds.contains_key(&world_id));
2238 assert!(matches!(
2240 snapshot.worlds.get(&world_id).unwrap().status,
2241 WorldStatus::Unhealthy(_)
2242 ));
2243
2244 let supervisor = system.attach().await.unwrap();
2246 let (_sup_tx, _sup_rx) = supervisor.bind_actor_port::<ProcSupervisionMessage>();
2247 let sup_ref = ActorRef::<ProcSupervisor>::attest(supervisor.self_id().clone());
2248
2249 let system_sender = BoxedMailboxSender::new(MailboxClient::new(
2251 channel::dial(server_handle.local_addr().clone()).unwrap(),
2252 ));
2253 let proc_forwarder =
2255 BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
2256
2257 let proc_0 = Proc::new(world_id.proc_id(0), proc_forwarder.clone());
2259 let _proc_actor_0 = ProcActor::bootstrap_for_proc(
2260 proc_0.clone(),
2261 world_id.clone(),
2262 ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
2263 server_handle.local_addr().clone(),
2264 sup_ref.clone(),
2265 Duration::from_millis(300), HashMap::new(),
2267 ProcLifecycleMode::ManagedBySystem,
2268 )
2269 .await
2270 .unwrap();
2271 let proc_0_client = proc_0.attach("client").unwrap();
2272 let (proc_0_undeliverable_tx, _proc_0_undeliverable_rx) = proc_0_client.open_port();
2273
2274 let proc_1 = Proc::new(world_id.proc_id(1), proc_forwarder.clone());
2276 let proc_actor_1 = ProcActor::bootstrap_for_proc(
2277 proc_1.clone(),
2278 world_id.clone(),
2279 ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
2280 server_handle.local_addr().clone(),
2281 sup_ref.clone(),
2282 Duration::from_millis(300), HashMap::new(),
2284 ProcLifecycleMode::ManagedBySystem,
2285 )
2286 .await
2287 .unwrap();
2288 let proc_1_client = proc_1.attach("client").unwrap();
2289 let (proc_1_undeliverable_tx, mut _proc_1_undeliverable_rx) = proc_1_client.open_port();
2290
2291 let ping_params = PingPongActorParams::new(Some(proc_0_undeliverable_tx.bind()), None);
2295 let ping_handle = proc_0
2296 .spawn::<PingPongActor>("ping", ping_params)
2297 .await
2298 .unwrap();
2299 let pong_params = PingPongActorParams::new(Some(proc_1_undeliverable_tx.bind()), None);
2300 let pong_handle = proc_1
2301 .spawn::<PingPongActor>("pong", pong_params)
2302 .await
2303 .unwrap();
2304
2305 proc_actor_1.mailbox.stop("from testing");
2308 proc_actor_1.mailbox.await.unwrap().unwrap();
2309
2310 let snapshot = system_actor_handle
2313 .snapshot(&client, SystemSnapshotFilter::all())
2314 .await
2315 .unwrap();
2316 assert_eq!(snapshot.worlds.len(), 1);
2317 assert!(snapshot.worlds.contains_key(&world_id));
2318 assert_eq!(
2319 snapshot.worlds.get(&world_id).unwrap().status,
2320 WorldStatus::Live
2321 );
2322
2323 let ttl = 1_u64;
2325 let (game_over, on_game_over) = proc_0_client.open_once_port::<bool>();
2326 ping_handle
2327 .send(PingPongMessage(ttl, pong_handle.bind(), game_over.bind()))
2328 .unwrap();
2329
2330 assert!(
2333 RealClock
2334 .timeout(tokio::time::Duration::from_secs(4), on_game_over.recv())
2335 .await
2336 .is_err()
2337 );
2338
2339 let snapshot = system_actor_handle
2342 .snapshot(&client, SystemSnapshotFilter::all())
2343 .await
2344 .unwrap();
2345 assert_eq!(snapshot.worlds.len(), 1);
2346 assert!(matches!(
2347 snapshot.worlds.get(&world_id).unwrap().status,
2348 WorldStatus::Unhealthy(_)
2349 ));
2350 }
2351
2352 #[tokio::test]
2353 async fn test_stop_fast() -> Result<()> {
2354 let server_handle = System::serve(
2355 ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
2356 Duration::from_secs(2), Duration::from_secs(2), )
2359 .await?;
2360 let system_actor_handle = server_handle.system_actor_handle();
2361 let mut system = System::new(server_handle.local_addr().clone());
2362 let client = system.attach().await?;
2363
2364 let (client_tx, client_rx) = client.open_once_port::<()>();
2366 system_actor_handle.send(SystemMessage::Stop {
2367 worlds: None,
2368 proc_timeout: Duration::from_secs(5),
2369 reply_port: client_tx.bind(),
2370 })?;
2371 client_rx.recv().await?;
2372
2373 let mut sys_status_rx = system_actor_handle.status();
2375 {
2376 let received = sys_status_rx.borrow_and_update();
2377 assert_eq!(*received, ActorStatus::Stopped);
2378 }
2379
2380 Ok(())
2381 }
2382
2383 #[tokio::test]
2384 async fn test_update_sim_address() {
2385 simnet::start();
2386
2387 let src_id = id!(proc[0].actor);
2388 let src_addr = ChannelAddr::Sim(SimAddr::new("unix!@src".parse().unwrap()).unwrap());
2389 let dst_addr = ChannelAddr::Sim(SimAddr::new("unix!@dst".parse().unwrap()).unwrap());
2390 let (_, mut rx) = channel::serve::<MessageEnvelope>(src_addr.clone()).unwrap();
2391
2392 let router = ReportingRouter::new();
2393
2394 router
2395 .router
2396 .bind(src_id.proc_id().clone().into(), src_addr);
2397 router.router.bind(id!(proc[1]).into(), dst_addr);
2398
2399 router.post_update_address(&MessageEnvelope::new(
2400 src_id,
2401 PortId(id!(proc[1].actor), 9999u64),
2402 Serialized::serialize(&1u64).unwrap(),
2403 Attrs::new(),
2404 ));
2405
2406 let envelope = rx.recv().await.unwrap();
2407 let admin_msg = envelope
2408 .data()
2409 .deserialized::<MailboxAdminMessage>()
2410 .unwrap();
2411 let MailboxAdminMessage::UpdateAddress {
2412 addr: ChannelAddr::Sim(addr),
2413 ..
2414 } = admin_msg
2415 else {
2416 panic!("Expected sim address");
2417 };
2418
2419 assert_eq!(addr.src().clone().unwrap().to_string(), "unix:@src");
2420 assert_eq!(addr.addr().to_string(), "unix:@dst");
2421 }
2422}