1use std::collections::BTreeSet;
12use std::collections::HashMap;
13use std::collections::HashSet;
14use std::collections::hash_map::Entry;
15use std::fmt::Display;
16use std::fmt::Formatter;
17use std::hash::Hash;
18use std::sync::Arc;
19use std::sync::LazyLock;
20use std::time::SystemTime;
21
22use async_trait::async_trait;
23use dashmap::DashMap;
24use enum_as_inner::EnumAsInner;
25use hyperactor::Actor;
26use hyperactor::ActorHandle;
27use hyperactor::ActorId;
28use hyperactor::ActorRef;
29use hyperactor::Context;
30use hyperactor::HandleClient;
31use hyperactor::Instance;
32use hyperactor::Named;
33use hyperactor::OncePortRef;
34use hyperactor::PortHandle;
35use hyperactor::PortRef;
36use hyperactor::ProcId;
37use hyperactor::RefClient;
38use hyperactor::WorldId;
39use hyperactor::actor::Handler;
40use hyperactor::channel::ChannelAddr;
41use hyperactor::channel::sim::SimAddr;
42use hyperactor::clock::Clock;
43use hyperactor::clock::ClockKind;
44use hyperactor::id;
45use hyperactor::mailbox::BoxedMailboxSender;
46use hyperactor::mailbox::DialMailboxRouter;
47use hyperactor::mailbox::MailboxSender;
48use hyperactor::mailbox::MailboxSenderError;
49use hyperactor::mailbox::MessageEnvelope;
50use hyperactor::mailbox::PortSender;
51use hyperactor::mailbox::Undeliverable;
52use hyperactor::mailbox::mailbox_admin_message::MailboxAdminMessage;
53use hyperactor::mailbox::monitored_return_handle;
54use hyperactor::proc::Proc;
55use hyperactor::reference::Index;
56use serde::Deserialize;
57use serde::Serialize;
58use tokio::time::Duration;
59use tokio::time::Instant;
60
61use super::proc_actor::ProcMessage;
62use crate::proc_actor::Environment;
63use crate::proc_actor::ProcActor;
64use crate::proc_actor::ProcStopResult;
65use crate::supervision::ProcStatus;
66use crate::supervision::ProcSupervisionMessage;
67use crate::supervision::ProcSupervisionState;
68use crate::supervision::WorldSupervisionMessage;
69use crate::supervision::WorldSupervisionState;
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
73pub struct WorldSnapshotProcInfo {
74 pub labels: HashMap<String, String>,
76}
77
78impl From<&ProcInfo> for WorldSnapshotProcInfo {
79 fn from(proc_info: &ProcInfo) -> Self {
80 Self {
81 labels: proc_info.labels.clone(),
82 }
83 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
88pub struct WorldSnapshot {
89 pub host_procs: HashSet<ProcId>,
96
97 pub procs: HashMap<ProcId, WorldSnapshotProcInfo>,
99
100 pub status: WorldStatus,
102
103 pub labels: HashMap<String, String>,
106}
107
108impl WorldSnapshot {
109 fn from_world_filtered(world: &World, filter: &SystemSnapshotFilter) -> Self {
110 WorldSnapshot {
111 host_procs: world.state.host_map.keys().map(|h| &h.0).cloned().collect(),
112 procs: world
113 .state
114 .procs
115 .iter()
116 .map_while(|(k, v)| filter.proc_matches(v).then_some((k.clone(), v.into())))
117 .collect(),
118 status: world.state.status.clone(),
119 labels: world.labels.clone(),
120 }
121 }
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
126pub struct SystemSnapshot {
127 pub worlds: HashMap<WorldId, WorldSnapshot>,
129 pub execution_id: String,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named, Default)]
135pub struct SystemSnapshotFilter {
136 pub worlds: Vec<WorldId>,
138 pub world_labels: HashMap<String, String>,
140 pub proc_labels: HashMap<String, String>,
142}
143
144impl SystemSnapshotFilter {
145 pub fn all() -> Self {
147 Self {
148 worlds: Vec::new(),
149 world_labels: HashMap::new(),
150 proc_labels: HashMap::new(),
151 }
152 }
153
154 fn world_matches(&self, world: &World) -> bool {
156 if !self.worlds.is_empty() && !self.worlds.contains(&world.world_id) {
157 return false;
158 }
159 Self::labels_match(&self.world_labels, &world.labels)
160 }
161
162 fn proc_matches(&self, proc_info: &ProcInfo) -> bool {
163 Self::labels_match(&self.proc_labels, &proc_info.labels)
164 }
165
166 fn labels_match(
168 filter_labels: &HashMap<String, String>,
169 labels: &HashMap<String, String>,
170 ) -> bool {
171 filter_labels.is_empty()
172 || filter_labels
173 .iter()
174 .all(|(k, v)| labels.contains_key(k) && labels.get(k).unwrap() == v)
175 }
176}
177
178#[derive(Debug, Clone, PartialEq)]
181struct MaintainWorldHealth;
182
183#[derive(Named, Debug, Clone, Serialize, Deserialize, PartialEq)]
185pub enum ProcLifecycleMode {
186 Detached,
188 ManagedBySystem,
190 ManagingSystem,
193}
194
195impl ProcLifecycleMode {
196 pub fn is_managed(&self) -> bool {
198 matches!(
199 self,
200 ProcLifecycleMode::ManagedBySystem | ProcLifecycleMode::ManagingSystem
201 )
202 }
203}
204
205#[derive(
207 hyperactor::Handler,
208 HandleClient,
209 RefClient,
210 Named,
211 Debug,
212 Clone,
213 Serialize,
214 Deserialize,
215 PartialEq
216)]
217pub enum SystemMessage {
218 Join {
220 world_id: WorldId,
222 proc_id: ProcId,
224 proc_message_port: PortRef<ProcMessage>,
226 proc_addr: ChannelAddr,
228 labels: HashMap<String, String>,
230 lifecycle_mode: ProcLifecycleMode,
232 },
233
234 UpsertWorld {
236 world_id: WorldId,
238 shape: Shape,
240 num_procs_per_host: usize,
242 env: Environment,
244 labels: HashMap<String, String>,
246 },
247
248 #[log_level(debug)]
250 Snapshot {
251 filter: SystemSnapshotFilter,
253 #[reply]
255 ret: OncePortRef<SystemSnapshot>,
256 },
257
258 Stop {
265 worlds: Option<Vec<WorldId>>,
269 proc_timeout: Duration,
271 reply_port: OncePortRef<()>,
273 },
274}
275
276#[derive(thiserror::Error, Debug)]
278pub enum SystemActorError {
279 #[error("procs cannot join uncreated world {0}")]
281 UnknownWorldId(WorldId),
282
283 #[error("failed to spawn procs")]
285 SpawnProcsFailed(#[from] MailboxSenderError),
286
287 #[error("invalid host {0}: does not start with prefix '{SHADOW_PREFIX}'")]
289 InvalidHostPrefix(HostId),
290
291 #[error("host ID {0} already exists in world")]
293 DuplicatedHostId(HostId),
294
295 #[error("host {0} does not exist in world")]
297 HostNotExist(HostId),
298}
299
300#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
302pub enum Shape {
303 Definite(Vec<usize>),
307 Unknown,
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct World {
324 world_id: WorldId,
326 scheduler_params: SchedulerParams,
328 labels: HashMap<String, String>,
330 state: WorldState,
332}
333
334#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
335struct Host {
336 num_procs_assigned: usize,
337 proc_message_port: PortRef<ProcMessage>,
338 host_rank: usize,
339}
340
341impl Host {
342 fn new(proc_message_port: PortRef<ProcMessage>, host_rank: usize) -> Self {
343 Self {
344 num_procs_assigned: 0,
345 proc_message_port,
346 host_rank,
347 }
348 }
349
350 fn get_assigned_procs(
351 &mut self,
352 world_id: &WorldId,
353 scheduler_params: &mut SchedulerParams,
354 ) -> Vec<ProcId> {
355 let mut proc_ids = Vec::new();
357
358 let num_saturated_hosts =
362 scheduler_params.num_procs() / scheduler_params.num_procs_per_host;
363 let num_scheduled = if self.host_rank == num_saturated_hosts {
368 scheduler_params.num_procs() % scheduler_params.num_procs_per_host
369 } else {
370 scheduler_params.num_procs_per_host
371 };
372
373 scheduler_params.num_procs_scheduled += num_scheduled;
374
375 for _ in 0..num_scheduled {
376 let rank =
387 self.host_rank * scheduler_params.num_procs_per_host + self.num_procs_assigned;
388 let proc_id = ProcId::Ranked(world_id.clone(), rank);
389 proc_ids.push(proc_id);
390 self.num_procs_assigned += 1;
391 }
392
393 proc_ids
394 }
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
398struct SchedulerParams {
399 shape: Shape,
400 num_procs_scheduled: usize,
401 num_procs_per_host: usize,
402 next_rank: Index,
403 env: Environment,
404}
405
406impl SchedulerParams {
407 fn num_procs(&self) -> usize {
408 match &self.shape {
409 Shape::Definite(v) => v.iter().product(),
410 Shape::Unknown => unimplemented!(),
411 }
412 }
413}
414
415pub type HostWorldId = WorldId;
417static SHADOW_PREFIX: &str = "host";
418
419#[derive(
421 Debug,
422 Serialize,
423 Deserialize,
424 Clone,
425 PartialEq,
426 Eq,
427 PartialOrd,
428 Hash,
429 Ord
430)]
431pub struct HostId(ProcId);
432impl HostId {
433 pub fn new(proc_id: ProcId) -> Result<Self, anyhow::Error> {
435 if !proc_id
436 .world_name()
437 .expect("proc must be ranked for world_name check")
438 .starts_with(SHADOW_PREFIX)
439 {
440 anyhow::bail!(
441 "proc_id {} is not a valid HostId because it does not start with {}",
442 proc_id,
443 SHADOW_PREFIX
444 )
445 }
446 Ok(Self(proc_id))
447 }
448}
449
450impl TryFrom<ProcId> for HostId {
451 type Error = anyhow::Error;
452
453 fn try_from(proc_id: ProcId) -> Result<Self, anyhow::Error> {
454 if !proc_id
455 .world_name()
456 .expect("proc must be ranked for world_name check")
457 .starts_with(SHADOW_PREFIX)
458 {
459 anyhow::bail!(
460 "proc_id {} is not a valid HostId because it does not start with {}",
461 proc_id,
462 SHADOW_PREFIX
463 )
464 }
465 Ok(Self(proc_id))
466 }
467}
468
469impl std::fmt::Display for HostId {
470 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
471 self.0.fmt(f)
472 }
473}
474
475type HostMap = HashMap<HostId, Host>;
476
477#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
478struct ProcInfo {
479 port_ref: PortRef<ProcMessage>,
480 labels: HashMap<String, String>,
481}
482
483#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
484struct WorldState {
485 host_map: HostMap,
486 procs: HashMap<ProcId, ProcInfo>,
487 status: WorldStatus,
488}
489
490#[derive(Debug, Clone, Serialize, Deserialize, EnumAsInner, PartialEq)]
492pub enum WorldStatus {
493 AwaitingCreation,
495
496 Live,
499
500 Unhealthy(SystemTime),
504}
505
506impl Display for WorldStatus {
507 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
508 match self {
509 WorldStatus::AwaitingCreation => write!(f, "Awaiting Creation"),
510 WorldStatus::Live => write!(f, "Live"),
511 WorldStatus::Unhealthy(_) => write!(f, "Unhealthy"),
512 }
513 }
514}
515
516impl WorldState {
517 fn get_host_map_mut(&mut self) -> &mut HostMap {
519 &mut self.host_map
520 }
521
522 fn get_host_map(&self) -> &HostMap {
524 &self.host_map
525 }
526}
527
528impl World {
529 fn new(
530 world_id: WorldId,
531 shape: Shape,
532 state: WorldState,
533 num_procs_per_host: usize,
534 env: Environment,
535 labels: HashMap<String, String>,
536 ) -> Result<Self, anyhow::Error> {
537 if world_id.name().starts_with(SHADOW_PREFIX) {
538 anyhow::bail!(
539 "world name {} cannot start with {}!",
540 world_id,
541 SHADOW_PREFIX
542 )
543 }
544 tracing::info!("creating world {}", world_id,);
545 Ok(Self {
546 world_id,
547 scheduler_params: SchedulerParams {
548 shape,
549 num_procs_per_host,
550 num_procs_scheduled: 0,
551 next_rank: 0,
552 env,
553 },
554 state,
555 labels,
556 })
557 }
558
559 fn get_real_world_id(proc_world_id: &WorldId) -> WorldId {
560 WorldId(
561 proc_world_id
562 .name()
563 .strip_prefix(SHADOW_PREFIX)
564 .unwrap_or(proc_world_id.name())
565 .to_string(),
566 )
567 }
568
569 fn is_host_world(world_id: &WorldId) -> bool {
570 world_id.name().starts_with(SHADOW_PREFIX)
571 }
572
573 #[allow(clippy::result_large_err)] fn get_port_ref_from_host(
575 &self,
576 host_id: &HostId,
577 ) -> Result<PortRef<ProcMessage>, SystemActorError> {
578 let host_map = self.state.get_host_map();
579 match host_map.get(host_id) {
581 Some(h) => Ok(h.proc_message_port.clone()),
582 None => Err(SystemActorError::HostNotExist(host_id.clone())),
583 }
584 }
585
586 #[allow(clippy::result_large_err)] fn add_proc(
589 &mut self,
590 proc_id: ProcId,
591 proc_message_port: PortRef<ProcMessage>,
592 labels: HashMap<String, String>,
593 ) -> Result<(), SystemActorError> {
594 self.state.procs.insert(
595 proc_id,
596 ProcInfo {
597 port_ref: proc_message_port,
598 labels,
599 },
600 );
601 if self.state.status.is_unhealthy()
602 && self.state.procs.len() >= self.scheduler_params.num_procs()
603 {
604 self.state.status = WorldStatus::Live;
605 tracing::info!(
606 "world {}: ready to serve with {} procs",
607 self.world_id,
608 self.state.procs.len()
609 );
610 }
611 Ok(())
612 }
613
614 async fn on_host_join(
618 &mut self,
619 host_id: HostId,
620 proc_message_port: PortRef<ProcMessage>,
621 router: &DialMailboxRouter,
622 ) -> Result<(), SystemActorError> {
623 let mut host_entry = match self.state.host_map.entry(host_id.clone()) {
624 Entry::Occupied(_) => {
625 return Err(SystemActorError::DuplicatedHostId(host_id));
626 }
627 Entry::Vacant(entry) => entry.insert_entry(Host::new(
628 proc_message_port.clone(),
629 host_id
630 .0
631 .rank()
632 .expect("host proc must be ranked for rank access"),
633 )),
634 };
635
636 if self.state.status == WorldStatus::AwaitingCreation {
637 return Ok(());
638 }
639
640 let proc_ids = host_entry
641 .get_mut()
642 .get_assigned_procs(&self.world_id, &mut self.scheduler_params);
643
644 router.serialize_and_send(
645 &proc_message_port,
646 ProcMessage::SpawnProc {
647 env: self.scheduler_params.env.clone(),
648 world_id: self.world_id.clone(),
649 proc_ids,
650 world_size: self.scheduler_params.num_procs(),
651 },
652 monitored_return_handle(),
653 )?;
654 Ok(())
655 }
656
657 #[allow(clippy::result_large_err)] fn get_hosts_to_procs(&mut self) -> Result<HashMap<HostId, Vec<ProcId>>, SystemActorError> {
659 let mut host_proc_map: HashMap<HostId, Vec<ProcId>> = HashMap::new();
661 let host_map = self.state.get_host_map_mut();
662 for (host_id, host) in host_map {
664 if host.num_procs_assigned == self.scheduler_params.num_procs_per_host {
666 continue;
667 }
668 let host_procs = host.get_assigned_procs(&self.world_id, &mut self.scheduler_params);
669 if host_procs.is_empty() {
670 continue;
671 }
672 host_proc_map.insert(host_id.clone(), host_procs);
673 }
674 Ok(host_proc_map)
675 }
676
677 async fn on_create(&mut self, router: &DialMailboxRouter) -> Result<(), anyhow::Error> {
678 let host_procs_map = self.get_hosts_to_procs()?;
679 for (host_id, procs_ids) in host_procs_map {
680 if procs_ids.is_empty() {
681 continue;
682 }
683
684 let world_id = procs_ids
686 .first()
687 .unwrap()
688 .clone()
689 .into_ranked()
690 .expect("proc must be ranked for world_id access")
691 .0
692 .clone();
693 tracing::info!("spawning procs for host {:?}", host_id);
695 router.serialize_and_send(
696 &self.get_port_ref_from_host(&host_id)?,
698 ProcMessage::SpawnProc {
699 env: self.scheduler_params.env.clone(),
700 world_id,
701 proc_ids: procs_ids,
703 world_size: self.scheduler_params.num_procs(),
704 },
705 monitored_return_handle(),
706 )?;
707 }
708 Ok(())
709 }
710}
711
712#[derive(Debug, Clone)]
716pub struct ReportingRouter {
717 router: DialMailboxRouter,
718 address_cache: Arc<DashMap<ProcId, HashSet<ProcId>>>,
722}
723
724impl MailboxSender for ReportingRouter {
725 fn post(
726 &self,
727 envelope: MessageEnvelope,
728 return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
729 ) {
730 let ReportingRouter { router, .. } = self;
731 self.post_update_address(&envelope);
732 router.post(envelope, return_handle);
733 }
734}
735
736impl ReportingRouter {
737 fn new() -> Self {
738 Self {
739 router: DialMailboxRouter::new(),
740 address_cache: Arc::new(DashMap::new()),
741 }
742 }
743 fn post_update_address(&self, envelope: &MessageEnvelope) {
744 let system_proc_id = id!(system[0]);
745 if envelope.sender().proc_id() == &id!(unknown[0])
758 || envelope.sender().proc_id().world_id() == Some(&id!(user))
759 || envelope.sender().proc_id() == &system_proc_id
760 || envelope.dest().actor_id().proc_id() == &system_proc_id
761 || envelope.sender().proc_id() == envelope.dest().actor_id().proc_id()
762 {
763 return;
764 }
765 let (dst_proc_id, dst_proc_addr) = self.dest_proc_id_and_address(envelope);
766 let Some(dst_proc_addr) = dst_proc_addr else {
767 tracing::warn!("unknown address for {}", &dst_proc_id);
768 return;
769 };
770
771 let sender_proc_id = envelope.sender().proc_id();
772 self.upsert_address_cache(sender_proc_id, &dst_proc_id);
773 let sender_address = self.router.lookup_addr(envelope.sender());
776 let dst_proc_addr =
777 if let (Some(ChannelAddr::Sim(sender_sim_addr)), ChannelAddr::Sim(dest_sim_addr)) =
778 (sender_address, &dst_proc_addr)
779 {
780 ChannelAddr::Sim(
781 SimAddr::new_with_src(
782 sender_sim_addr.addr().clone(),
784 dest_sim_addr.addr().clone(),
786 )
787 .unwrap(),
788 )
789 } else {
790 dst_proc_addr
791 };
792 self.serialize_and_send(
793 &self.proc_port_ref(sender_proc_id),
794 MailboxAdminMessage::UpdateAddress {
795 proc_id: dst_proc_id,
796 addr: dst_proc_addr,
797 },
798 monitored_return_handle(),
799 )
800 .expect("unexpected serialization failure")
801 }
802
803 fn broadcast_addr(&self, dst_proc_id: &ProcId, dst_proc_addr: ChannelAddr) {
806 if let Some(r) = self.address_cache.get(dst_proc_id) {
807 for sender_proc_id in r.value() {
808 tracing::info!(
809 "broadcasting address change to {} for {}: {}",
810 sender_proc_id,
811 dst_proc_id,
812 dst_proc_addr
813 );
814 self.serialize_and_send(
815 &self.proc_port_ref(sender_proc_id),
816 MailboxAdminMessage::UpdateAddress {
817 proc_id: dst_proc_id.clone(),
818 addr: dst_proc_addr.clone(),
819 },
820 monitored_return_handle(),
821 )
822 .expect("unexpected serialization failure")
823 }
824 }
825 }
826
827 fn upsert_address_cache(&self, src_proc_id: &ProcId, dst_proc_id: &ProcId) {
828 self.address_cache
829 .entry(dst_proc_id.clone())
830 .and_modify(|src_proc_ids| {
831 src_proc_ids.insert(src_proc_id.clone());
832 })
833 .or_insert({
834 let mut set = HashSet::new();
835 set.insert(src_proc_id.clone());
836 set
837 });
838 }
839
840 fn dest_proc_id_and_address(
841 &self,
842 envelope: &MessageEnvelope,
843 ) -> (ProcId, Option<ChannelAddr>) {
844 let dest_proc_port_id = envelope.dest();
845 let dest_proc_actor_id = dest_proc_port_id.actor_id();
846 let dest_proc_id = dest_proc_actor_id.proc_id();
847 let dest_proc_addr = self.router.lookup_addr(dest_proc_actor_id);
848 (dest_proc_id.clone(), dest_proc_addr)
849 }
850
851 fn proc_port_ref(&self, proc_id: &ProcId) -> PortRef<MailboxAdminMessage> {
852 let proc_actor_id = ActorId(proc_id.clone(), "proc".to_string(), 0);
853 let proc_actor_ref = ActorRef::<ProcActor>::attest(proc_actor_id);
854 proc_actor_ref.port::<MailboxAdminMessage>()
855 }
856}
857
858#[derive(Debug, Clone)]
860pub struct SystemActorParams {
861 mailbox_router: ReportingRouter,
862
863 supervision_update_timeout: Duration,
865
866 world_eviction_timeout: Duration,
868}
869
870impl SystemActorParams {
871 pub fn new(supervision_update_timeout: Duration, world_eviction_timeout: Duration) -> Self {
873 Self {
874 mailbox_router: ReportingRouter::new(),
875 supervision_update_timeout,
876 world_eviction_timeout,
877 }
878 }
879}
880
881#[derive(Debug, Clone, Serialize, Deserialize)]
883struct SystemSupervisionState {
884 supervision_map: HashMap<WorldId, WorldSupervisionInfo>,
886 supervision_update_timeout: Duration,
888}
889
890#[derive(Debug, Clone, Default)]
892struct HeartbeatRecord {
893 btree_index: BTreeSet<(Instant, ProcId)>,
896
897 proc_last_update_time: HashMap<ProcId, Instant>,
899}
900
901impl HeartbeatRecord {
902 fn update(&mut self, proc_id: &ProcId, clock: &impl Clock) {
904 if let Some(update_time) = self.proc_last_update_time.get(proc_id) {
906 self.btree_index
907 .remove(&(update_time.clone(), proc_id.clone()));
908 }
909
910 let now = clock.now();
912 self.proc_last_update_time
913 .insert(proc_id.clone(), now.clone());
914 self.btree_index.insert((now.clone(), proc_id.clone()));
915 }
916
917 fn mark_expired_procs(
920 &self,
921 state: &mut WorldSupervisionState,
922 clock: &impl Clock,
923 supervision_update_timeout: Duration,
924 ) {
925 let now = clock.now();
927 self.btree_index
928 .iter()
929 .take_while(|(last_update_time, _)| {
930 now > *last_update_time + supervision_update_timeout
931 })
932 .for_each(|(_, proc_id)| {
933 if let Some(proc_state) = state
934 .procs
935 .get_mut(&proc_id.rank().expect("proc must be ranked for rank access"))
936 {
937 match proc_state.proc_health {
938 ProcStatus::Alive => proc_state.proc_health = ProcStatus::Expired,
939 _ => (),
941 }
942 }
943 });
944 }
945}
946
947#[derive(Debug, Clone, Serialize, Deserialize)]
948struct WorldSupervisionInfo {
949 state: WorldSupervisionState,
950
951 lifecycle_mode: HashMap<ProcId, ProcLifecycleMode>,
953
954 #[serde(skip)]
955 heartbeat_record: HeartbeatRecord,
956}
957
958impl WorldSupervisionInfo {
959 fn new() -> Self {
960 Self {
961 state: WorldSupervisionState {
962 procs: HashMap::new(),
963 },
964 lifecycle_mode: HashMap::new(),
965 heartbeat_record: HeartbeatRecord::default(),
966 }
967 }
968}
969
970impl SystemSupervisionState {
971 fn new(supervision_update_timeout: Duration) -> Self {
972 Self {
973 supervision_map: HashMap::new(),
974 supervision_update_timeout,
975 }
976 }
977
978 fn create(
980 &mut self,
981 proc_state: ProcSupervisionState,
982 lifecycle_mode: ProcLifecycleMode,
983 clock: &impl Clock,
984 ) {
985 if World::is_host_world(&proc_state.world_id) {
986 return;
987 }
988
989 let world = self
990 .supervision_map
991 .entry(proc_state.world_id.clone())
992 .or_insert_with(WorldSupervisionInfo::new);
993 world
994 .lifecycle_mode
995 .insert(proc_state.proc_id.clone(), lifecycle_mode);
996
997 self.update(proc_state, clock);
998 }
999
1000 fn update(&mut self, proc_state: ProcSupervisionState, clock: &impl Clock) {
1002 if World::is_host_world(&proc_state.world_id) {
1003 return;
1004 }
1005
1006 let world = self
1007 .supervision_map
1008 .entry(proc_state.world_id.clone())
1009 .or_insert_with(WorldSupervisionInfo::new);
1010
1011 world.heartbeat_record.update(&proc_state.proc_id, clock);
1012
1013 if let Some(info) = world.state.procs.get_mut(
1015 &proc_state
1016 .proc_id
1017 .rank()
1018 .expect("proc must be ranked for proc state update"),
1019 ) {
1020 match info.proc_health {
1021 ProcStatus::Alive => info.proc_health = proc_state.proc_health,
1022 _ => (),
1024 }
1025 info.failed_actors.extend(proc_state.failed_actors);
1026 } else {
1027 world.state.procs.insert(
1028 proc_state
1029 .proc_id
1030 .rank()
1031 .expect("proc must be ranked for rank access"),
1032 proc_state,
1033 );
1034 }
1035 }
1036
1037 fn report(&mut self, proc_state: ProcSupervisionState, clock: &impl Clock) {
1039 if World::is_host_world(&proc_state.world_id) {
1040 return;
1041 }
1042
1043 let proc_id = proc_state.proc_id.clone();
1044 match self.supervision_map.entry(proc_state.world_id.clone()) {
1045 Entry::Occupied(mut world_supervision_info) => {
1046 match world_supervision_info
1047 .get_mut()
1048 .state
1049 .procs
1050 .entry(proc_id.rank().expect("proc must be ranked for rank access"))
1051 {
1052 Entry::Occupied(_) => {
1053 self.update(proc_state, clock);
1054 }
1055 Entry::Vacant(_) => {
1056 tracing::error!("supervision not enabled for proc {}", &proc_id);
1057 }
1058 }
1059 }
1060 Entry::Vacant(_) => {
1061 tracing::error!("supervision not enabled for proc {}", &proc_id);
1062 }
1063 }
1064 }
1065
1066 fn get_world_with_failures(
1069 &mut self,
1070 world_id: &WorldId,
1071 clock: &impl Clock,
1072 ) -> Option<WorldSupervisionState> {
1073 if let Some(world) = self.supervision_map.get_mut(world_id) {
1074 world.heartbeat_record.mark_expired_procs(
1075 &mut world.state,
1076 clock,
1077 self.supervision_update_timeout,
1078 );
1079 let mut world_state_copy = world.state.clone();
1081 world_state_copy
1083 .procs
1084 .retain(|_, proc_state| !proc_state.is_healthy());
1085 return Some(world_state_copy);
1086 }
1087 None
1088 }
1089
1090 fn is_world_healthy(&mut self, world_id: &WorldId, clock: &impl Clock) -> bool {
1091 self.get_world_with_failures(world_id, clock)
1092 .is_none_or(|state| WorldSupervisionState::is_healthy(&state))
1093 }
1094}
1095
1096#[derive(Debug, Clone, Serialize, Deserialize)]
1097struct WorldStoppingState {
1098 stopping_procs: HashSet<ProcId>,
1099 stopped_procs: HashSet<ProcId>,
1100}
1101
1102#[derive(Debug, Clone, PartialEq, EnumAsInner)]
1104enum SystemStopMessage {
1105 StopSystemActor,
1106 EvictWorlds(Vec<WorldId>),
1107}
1108
1109#[derive(Debug, Clone)]
1114#[hyperactor::export(
1115 handlers = [
1116 SystemMessage,
1117 ProcSupervisionMessage,
1118 WorldSupervisionMessage,
1119 ],
1120)]
1121pub struct SystemActor {
1122 params: SystemActorParams,
1123 supervision_state: SystemSupervisionState,
1124 worlds: HashMap<WorldId, World>,
1125 worlds_to_stop: HashMap<WorldId, WorldStoppingState>,
1127 shutting_down: bool,
1128}
1129
1130pub static SYSTEM_WORLD: LazyLock<WorldId> = LazyLock::new(|| id!(system));
1132
1133static SYSTEM_ACTOR_ID: LazyLock<ActorId> = LazyLock::new(|| id!(system[0].root));
1135
1136pub static SYSTEM_ACTOR_REF: LazyLock<ActorRef<SystemActor>> =
1138 LazyLock::new(|| ActorRef::attest(id!(system[0].root)));
1139
1140impl SystemActor {
1141 fn add_new_world(&mut self, world_id: WorldId) -> Result<(), anyhow::Error> {
1143 let world_state = WorldState {
1144 host_map: HashMap::new(),
1145 procs: HashMap::new(),
1146 status: WorldStatus::AwaitingCreation,
1147 };
1148 let world = World::new(
1149 world_id.clone(),
1150 Shape::Unknown,
1151 world_state,
1152 0,
1153 Environment::Local,
1154 HashMap::new(),
1155 )?;
1156 self.worlds.insert(world_id.clone(), world);
1157 Ok(())
1158 }
1159
1160 fn router(&self) -> &ReportingRouter {
1161 &self.params.mailbox_router
1162 }
1163
1164 pub async fn bootstrap(
1168 params: SystemActorParams,
1169 ) -> Result<(ActorHandle<SystemActor>, Proc), anyhow::Error> {
1170 Self::bootstrap_with_clock(params, ClockKind::default()).await
1171 }
1172
1173 pub async fn bootstrap_with_clock(
1177 params: SystemActorParams,
1178 clock: ClockKind,
1179 ) -> Result<(ActorHandle<SystemActor>, Proc), anyhow::Error> {
1180 let system_proc = Proc::new_with_clock(
1181 SYSTEM_ACTOR_ID.proc_id().clone(),
1182 BoxedMailboxSender::new(params.mailbox_router.clone()),
1183 clock,
1184 );
1185 let actor_handle = system_proc
1186 .spawn::<SystemActor>(SYSTEM_ACTOR_ID.name(), params)
1187 .await?;
1188
1189 Ok((actor_handle, system_proc))
1190 }
1191
1192 fn evict_world(&mut self, world_id: &WorldId) {
1194 self.worlds.remove(world_id);
1195 self.supervision_state.supervision_map.remove(world_id);
1196 self.params
1198 .mailbox_router
1199 .router
1200 .unbind(&world_id.clone().into());
1201 }
1202}
1203
1204#[async_trait]
1205impl Actor for SystemActor {
1206 type Params = SystemActorParams;
1207
1208 async fn new(params: SystemActorParams) -> Result<Self, anyhow::Error> {
1209 let supervision_update_timeout = params.supervision_update_timeout.clone();
1210 Ok(Self {
1211 params,
1212 supervision_state: SystemSupervisionState::new(supervision_update_timeout),
1213 worlds: HashMap::new(),
1214 worlds_to_stop: HashMap::new(),
1215 shutting_down: false,
1216 })
1217 }
1218
1219 async fn init(&mut self, cx: &Instance<Self>) -> Result<(), anyhow::Error> {
1220 cx.self_message_with_delay(MaintainWorldHealth {}, Duration::from_secs(0))?;
1222 Ok(())
1223 }
1224
1225 async fn handle_undeliverable_message(
1226 &mut self,
1227 _cx: &Instance<Self>,
1228 Undeliverable(envelope): Undeliverable<MessageEnvelope>,
1229 ) -> Result<(), anyhow::Error> {
1230 let to = envelope.dest().clone();
1231 let from = envelope.sender().clone();
1232 tracing::info!(
1233 "a message from {} to {} was undeliverable and returned to the system actor",
1234 from,
1235 to,
1236 );
1237
1238 let proc_id = to.actor_id().proc_id();
1242 let world_id = proc_id
1243 .world_id()
1244 .expect("proc must be ranked for world_id access");
1245 if let Some(world) = &mut self.supervision_state.supervision_map.get_mut(world_id) {
1246 if let Some(proc) = world
1247 .state
1248 .procs
1249 .get_mut(&proc_id.rank().expect("proc must be ranked for rank access"))
1250 {
1251 match proc.proc_health {
1252 ProcStatus::Alive => proc.proc_health = ProcStatus::ConnectionFailure,
1253 _ => (),
1256 }
1257 } else {
1258 tracing::error!(
1259 "can't update proc {} status because there isn't one",
1260 proc_id
1261 );
1262 }
1263 } else {
1264 tracing::error!(
1265 "can't update world {} status because there isn't one",
1266 world_id
1267 );
1268 }
1269 Ok(())
1270 }
1271}
1272
1273#[async_trait]
1285#[hyperactor::forward(SystemMessage)]
1286impl SystemMessageHandler for SystemActor {
1287 async fn join(
1288 &mut self,
1289 cx: &Context<Self>,
1290 world_id: WorldId,
1291 proc_id: ProcId,
1292 proc_message_port: PortRef<ProcMessage>,
1293 channel_addr: ChannelAddr,
1294 labels: HashMap<String, String>,
1295 lifecycle_mode: ProcLifecycleMode,
1296 ) -> Result<(), anyhow::Error> {
1297 tracing::info!("received join for proc {} in world {}", proc_id, world_id);
1298 self.router()
1300 .router
1301 .bind(proc_id.clone().into(), channel_addr.clone());
1302
1303 self.router().broadcast_addr(&proc_id, channel_addr.clone());
1304
1305 self.router().serialize_and_send(
1307 &proc_message_port,
1308 ProcMessage::Joined(),
1309 monitored_return_handle(),
1310 )?;
1311
1312 if lifecycle_mode.is_managed() {
1313 self.supervision_state.create(
1314 ProcSupervisionState {
1315 world_id: world_id.clone(),
1316 proc_id: proc_id.clone(),
1317 proc_addr: channel_addr.clone(),
1318 proc_health: ProcStatus::Alive,
1319 failed_actors: Vec::new(),
1320 },
1321 lifecycle_mode.clone(),
1322 cx.clock(),
1323 );
1324 }
1325
1326 if lifecycle_mode != ProcLifecycleMode::ManagedBySystem {
1329 tracing::info!("ignoring join for proc {} in world {}", proc_id, world_id);
1330 return Ok(());
1331 }
1332
1333 let world_id = World::get_real_world_id(&world_id);
1334 if !self.worlds.contains_key(&world_id) {
1335 self.add_new_world(world_id.clone())?;
1336 }
1337 let world = self
1338 .worlds
1339 .get_mut(&world_id)
1340 .ok_or(anyhow::anyhow!("failed to get world from map"))?;
1341
1342 match HostId::try_from(proc_id.clone()) {
1343 Ok(host_id) => {
1344 tracing::info!("{}: adding host {}", world_id, host_id);
1345 return world
1346 .on_host_join(
1347 host_id,
1348 proc_message_port,
1349 &self.params.mailbox_router.router,
1350 )
1351 .await
1352 .map_err(anyhow::Error::from);
1353 }
1354 Err(_) => {
1357 tracing::info!("proc {} joined to world {}", &proc_id, &world_id,);
1358 if let Err(e) = world.add_proc(proc_id.clone(), proc_message_port, labels) {
1362 tracing::warn!(
1363 "failed to add proc {} to world {}: {}",
1364 &proc_id,
1365 &world_id,
1366 e
1367 );
1368 }
1369 }
1370 };
1371 Ok(())
1372 }
1373
1374 async fn upsert_world(
1375 &mut self,
1376 cx: &Context<Self>,
1377 world_id: WorldId,
1378 shape: Shape,
1379 num_procs_per_host: usize,
1380 env: Environment,
1381 labels: HashMap<String, String>,
1382 ) -> Result<(), anyhow::Error> {
1383 tracing::info!("received upsert_world for world {}!", world_id);
1384 match self.worlds.get_mut(&world_id) {
1385 Some(world) => {
1386 tracing::info!("found existing world {}!", world_id);
1387 match &world.state.status {
1388 WorldStatus::AwaitingCreation => {
1389 world.scheduler_params.shape = shape;
1390 world.scheduler_params.num_procs_per_host = num_procs_per_host;
1391 world.scheduler_params.env = env;
1392 world.state = WorldState {
1393 host_map: world.state.host_map.clone(),
1394 procs: world.state.procs.clone(),
1395 status: if world.state.procs.len() < world.scheduler_params.num_procs()
1396 || !self
1397 .supervision_state
1398 .is_world_healthy(&world_id, cx.clock())
1399 {
1400 WorldStatus::Unhealthy(cx.clock().system_time_now())
1401 } else {
1402 WorldStatus::Live
1403 },
1404 };
1405 for (k, v) in labels {
1406 if world.labels.contains_key(&k) {
1407 anyhow::bail!("cannot overwrite world label: {}", k);
1408 }
1409 world.labels.insert(k.clone(), v.clone());
1410 }
1411 }
1412 _ => {
1413 anyhow::bail!("cannot modify world {}: already exists", world.world_id)
1414 }
1415 }
1416
1417 world.on_create(&self.params.mailbox_router.router).await?;
1418 tracing::info!(
1419 "modified parameters to world {} with shape: {:?} and labels {:?}",
1420 &world.world_id,
1421 world.scheduler_params.shape,
1422 world.labels
1423 );
1424 }
1425 None => {
1426 let world = World::new(
1427 world_id.clone(),
1428 shape.clone(),
1429 WorldState {
1430 host_map: HashMap::new(),
1431 procs: HashMap::new(),
1432 status: WorldStatus::Unhealthy(cx.clock().system_time_now()),
1433 },
1434 num_procs_per_host,
1435 env,
1436 labels,
1437 )?;
1438 tracing::info!("new world {} added with shape: {:?}", world_id, &shape);
1439 self.worlds.insert(world_id, world);
1440 }
1441 };
1442 Ok(())
1443 }
1444
1445 async fn snapshot(
1446 &mut self,
1447 _cx: &Context<Self>,
1448 filter: SystemSnapshotFilter,
1449 ) -> Result<SystemSnapshot, anyhow::Error> {
1450 let world_snapshots = self
1451 .worlds
1452 .iter()
1453 .filter(|(_, world)| filter.world_matches(world))
1454 .map(|(world_id, world)| {
1455 (
1456 world_id.clone(),
1457 WorldSnapshot::from_world_filtered(world, &filter),
1458 )
1459 })
1460 .collect();
1461 Ok(SystemSnapshot {
1462 worlds: world_snapshots,
1463 execution_id: hyperactor_telemetry::env::execution_id(),
1464 })
1465 }
1466
1467 async fn stop(
1468 &mut self,
1469 cx: &Context<Self>,
1470 worlds: Option<Vec<WorldId>>,
1471 proc_timeout: Duration,
1472 reply_port: OncePortRef<()>,
1473 ) -> Result<(), anyhow::Error> {
1474 match &worlds {
1477 Some(world_ids) => {
1478 tracing::info!("stopping worlds: {:?}", world_ids);
1479 }
1480 None => {
1481 tracing::info!("stopping system actor and all worlds");
1482 self.shutting_down = true;
1483 }
1484 }
1485
1486 if self.worlds.is_empty() && self.shutting_down {
1488 cx.stop()?;
1489 reply_port.send(cx, ())?;
1490 return Ok(());
1491 }
1492
1493 let mut world_ids = vec![];
1494 match &worlds {
1495 Some(worlds) => {
1496 world_ids.extend(worlds.clone().into_iter().collect::<Vec<_>>());
1498 }
1499 None => {
1500 world_ids.extend(
1502 self.worlds
1503 .keys()
1504 .filter(|x| x.name() != "user")
1505 .cloned()
1506 .collect::<Vec<_>>(),
1507 );
1508 }
1509 }
1510
1511 for world_id in &world_ids {
1512 if self.worlds_to_stop.contains_key(world_id) || !self.worlds.contains_key(world_id) {
1513 continue;
1515 }
1516 self.worlds_to_stop.insert(
1517 world_id.clone(),
1518 WorldStoppingState {
1519 stopping_procs: HashSet::new(),
1520 stopped_procs: HashSet::new(),
1521 },
1522 );
1523 }
1524
1525 let all_procs = self
1526 .worlds
1527 .iter()
1528 .filter(|(world_id, _)| match &worlds {
1529 Some(worlds_ids) => worlds_ids.contains(world_id),
1530 None => true,
1531 })
1532 .flat_map(|(_, world)| {
1533 world
1534 .state
1535 .host_map
1536 .iter()
1537 .map(|(host_id, host)| (host_id.0.clone(), host.proc_message_port.clone()))
1538 .chain(
1539 world
1540 .state
1541 .procs
1542 .iter()
1543 .map(|(proc_id, info)| (proc_id.clone(), info.port_ref.clone())),
1544 )
1545 .collect::<Vec<_>>()
1546 })
1547 .collect::<HashMap<_, _>>();
1548
1549 for (proc_id, port) in all_procs.into_iter() {
1553 let stopping_state = self
1554 .worlds_to_stop
1555 .get_mut(&World::get_real_world_id(
1556 proc_id
1557 .world_id()
1558 .expect("proc must be ranked for world_id access"),
1559 ))
1560 .unwrap();
1561 if !stopping_state.stopping_procs.insert(proc_id) {
1562 continue;
1563 }
1564
1565 let reply_to = cx.port::<ProcStopResult>().bind().into_once();
1571 port.send(
1572 cx,
1573 ProcMessage::Stop {
1574 timeout: proc_timeout,
1575 reply_to,
1576 },
1577 )?;
1578 }
1579
1580 let stop_msg = match &worlds {
1581 Some(_) => SystemStopMessage::EvictWorlds(world_ids.clone()),
1582 None => SystemStopMessage::StopSystemActor {},
1583 };
1584
1585 cx.self_message_with_delay(stop_msg, Duration::from_secs(8))?;
1587
1588 reply_port.send(cx, ())?;
1589 Ok(())
1590 }
1591}
1592
1593#[async_trait]
1594impl Handler<MaintainWorldHealth> for SystemActor {
1595 async fn handle(&mut self, cx: &Context<Self>, _: MaintainWorldHealth) -> anyhow::Result<()> {
1596 let mut next_check_delay = self.params.world_eviction_timeout;
1600 tracing::debug!("Checking world state. Got {} worlds", self.worlds.len());
1601
1602 for world in self.worlds.values_mut() {
1603 if world.state.status == WorldStatus::AwaitingCreation {
1604 continue;
1605 }
1606
1607 let Some(state) = self
1608 .supervision_state
1609 .get_world_with_failures(&world.world_id, cx.clock())
1610 else {
1611 tracing::debug!("world {} does not have failures, skipping.", world.world_id);
1612 continue;
1613 };
1614
1615 if state.is_healthy() {
1616 tracing::debug!(
1617 "world {} with procs {:?} is healthy, skipping.",
1618 world.world_id,
1619 state
1620 .procs
1621 .values()
1622 .map(|p| p.proc_id.clone())
1623 .collect::<Vec<_>>()
1624 );
1625 continue;
1626 }
1627 for (_, proc_state) in state.procs.iter() {
1629 if proc_state.proc_health == ProcStatus::Alive {
1630 tracing::debug!("proc {} is still alive.", proc_state.proc_id);
1631 continue;
1632 }
1633 if self
1634 .supervision_state
1635 .supervision_map
1636 .get(&world.world_id)
1637 .and_then(|world| world.lifecycle_mode.get(&proc_state.proc_id))
1638 .map_or(true, |mode| *mode != ProcLifecycleMode::ManagingSystem)
1639 {
1640 tracing::debug!(
1641 "proc {} with state {} does not manage system.",
1642 proc_state.proc_id,
1643 proc_state.proc_health
1644 );
1645 continue;
1646 }
1647
1648 tracing::error!(
1649 "proc {} is unhealthy, stop the system as the proc manages the system",
1650 proc_state.proc_id
1651 );
1652
1653 let (tx, _) = cx.open_once_port::<()>();
1655 cx.port().send(SystemMessage::Stop {
1656 worlds: None,
1657 proc_timeout: Duration::from_secs(5),
1658 reply_port: tx.bind(),
1659 })?;
1660 }
1661
1662 if world.state.status == WorldStatus::Live {
1663 world.state.status = WorldStatus::Unhealthy(cx.clock().system_time_now());
1664 }
1665
1666 match &world.state.status {
1667 WorldStatus::Unhealthy(last_unhealthy_time) => {
1668 let elapsed = last_unhealthy_time
1669 .elapsed()
1670 .inspect_err(|err| {
1671 tracing::error!(
1672 "failed to get elapsed time for unhealthy world {}: {}",
1673 world.world_id,
1674 err
1675 )
1676 })
1677 .unwrap_or_else(|_| Duration::from_secs(0));
1678
1679 if elapsed < self.params.world_eviction_timeout {
1680 next_check_delay = std::cmp::min(
1682 next_check_delay,
1683 self.params.world_eviction_timeout - elapsed,
1684 );
1685 } else {
1686 next_check_delay = Duration::from_secs(0);
1687 }
1688 }
1689 _ => {
1690 tracing::error!(
1691 "find a failed world {} with healthy state {}",
1692 world.world_id,
1693 world.state.status
1694 );
1695 continue;
1696 }
1697 }
1698 }
1699 cx.self_message_with_delay(MaintainWorldHealth {}, next_check_delay)?;
1700
1701 Ok(())
1702 }
1703}
1704
1705#[async_trait]
1706impl Handler<ProcSupervisionMessage> for SystemActor {
1707 async fn handle(
1708 &mut self,
1709 cx: &Context<Self>,
1710 msg: ProcSupervisionMessage,
1711 ) -> anyhow::Result<()> {
1712 match msg {
1713 ProcSupervisionMessage::Update(state, reply_port) => {
1714 self.supervision_state.report(state, cx.clock());
1715 let _ = reply_port.send(cx, ());
1716 }
1717 }
1718 Ok(())
1719 }
1720}
1721
1722#[async_trait]
1723impl Handler<WorldSupervisionMessage> for SystemActor {
1724 async fn handle(
1725 &mut self,
1726 cx: &Context<Self>,
1727 msg: WorldSupervisionMessage,
1728 ) -> anyhow::Result<()> {
1729 match msg {
1730 WorldSupervisionMessage::State(world_id, reply_port) => {
1731 let world_state = self
1732 .supervision_state
1733 .get_world_with_failures(&world_id, cx.clock());
1734 let _ = reply_port.send(cx, world_state);
1736 }
1737 }
1738 Ok(())
1739 }
1740}
1741
1742#[async_trait]
1745impl Handler<ProcStopResult> for SystemActor {
1746 async fn handle(&mut self, cx: &Context<Self>, msg: ProcStopResult) -> anyhow::Result<()> {
1747 fn stopping_proc_msg<'a>(sprocs: impl Iterator<Item = &'a ProcId>) -> String {
1748 let sprocs = sprocs.collect::<Vec<_>>();
1749 if sprocs.is_empty() {
1750 return "no procs left".to_string();
1751 }
1752 let msg = sprocs
1753 .iter()
1754 .take(3)
1755 .map(|proc_id| proc_id.to_string())
1756 .collect::<Vec<_>>()
1757 .join(", ");
1758 if sprocs.len() > 3 {
1759 format!("remaining procs: {} and {} more", msg, sprocs.len() - 3)
1760 } else {
1761 format!("remaining procs: {}", msg)
1762 }
1763 }
1764 let mut world_stopped = false;
1765 let world_id = &msg
1766 .proc_id
1767 .clone()
1768 .into_ranked()
1769 .expect("proc must be ranked for world_id access")
1770 .0;
1771 if let Some(stopping_state) = self.worlds_to_stop.get_mut(world_id) {
1772 stopping_state.stopped_procs.insert(msg.proc_id.clone());
1773 tracing::debug!(
1774 "received stop response from {}: {} stopped actors, {} aborted actors: {}",
1775 msg.proc_id,
1776 msg.actors_stopped,
1777 msg.actors_aborted,
1778 stopping_proc_msg(
1779 stopping_state
1780 .stopping_procs
1781 .difference(&stopping_state.stopped_procs)
1782 ),
1783 );
1784 world_stopped =
1785 stopping_state.stopping_procs.len() == stopping_state.stopped_procs.len();
1786 } else {
1787 tracing::warn!(
1788 "received stop response from {} but no inflight stopping request is found, possibly late response",
1789 msg.proc_id
1790 );
1791 }
1792
1793 if world_stopped {
1794 self.evict_world(world_id);
1795 self.worlds_to_stop.remove(world_id);
1796 }
1797
1798 if self.shutting_down && self.worlds.is_empty() {
1799 cx.stop()?;
1800 }
1801
1802 Ok(())
1803 }
1804}
1805
1806#[async_trait]
1807impl Handler<SystemStopMessage> for SystemActor {
1808 async fn handle(
1809 &mut self,
1810 cx: &Context<Self>,
1811 message: SystemStopMessage,
1812 ) -> anyhow::Result<()> {
1813 match message {
1814 SystemStopMessage::EvictWorlds(world_ids) => {
1815 for world_id in &world_ids {
1816 if self.worlds_to_stop.contains_key(world_id) {
1817 tracing::warn!(
1818 "Waiting for world to stop timed out, evicting world anyways: {:?}",
1819 world_id
1820 );
1821 self.evict_world(world_id);
1822 }
1823 }
1824 }
1825 SystemStopMessage::StopSystemActor => {
1826 if self.worlds_to_stop.is_empty() {
1827 tracing::warn!(
1828 "waiting for all worlds to stop timed out, stopping the system actor and evicting the these worlds anyways: {:?}",
1829 self.worlds_to_stop.keys()
1830 );
1831 } else {
1832 tracing::warn!(
1833 "waiting for all worlds to stop timed out, stopping the system actor"
1834 );
1835 }
1836
1837 cx.stop()?;
1838 }
1839 }
1840 Ok(())
1841 }
1842}
1843
1844#[cfg(test)]
1845mod tests {
1846 use std::assert_matches::assert_matches;
1847
1848 use anyhow::Result;
1849 use hyperactor::PortId;
1850 use hyperactor::actor::ActorStatus;
1851 use hyperactor::attrs::Attrs;
1852 use hyperactor::channel;
1853 use hyperactor::channel::ChannelTransport;
1854 use hyperactor::channel::Rx;
1855 use hyperactor::clock::Clock;
1856 use hyperactor::clock::RealClock;
1857 use hyperactor::data::Serialized;
1858 use hyperactor::mailbox::Mailbox;
1859 use hyperactor::mailbox::MailboxSender;
1860 use hyperactor::mailbox::MailboxServer;
1861 use hyperactor::mailbox::MessageEnvelope;
1862 use hyperactor::mailbox::PortHandle;
1863 use hyperactor::mailbox::PortReceiver;
1864 use hyperactor::mailbox::monitored_return_handle;
1865 use hyperactor::simnet;
1866 use hyperactor::test_utils::pingpong::PingPongActorParams;
1867
1868 use super::*;
1869 use crate::System;
1870 use crate::supervision::WorldSupervisionMessageClient;
1871
1872 struct MockHostActor {
1873 local_proc_id: ProcId,
1874 local_proc_addr: ChannelAddr,
1875 local_proc_message_port: PortHandle<ProcMessage>,
1876 local_proc_message_receiver: PortReceiver<ProcMessage>,
1877 }
1878
1879 async fn spawn_mock_host_actor(proc_world_id: WorldId, host_id: usize) -> MockHostActor {
1880 let local_proc_id = ProcId::Ranked(
1882 WorldId(format!("{}{}", SHADOW_PREFIX, proc_world_id.name())),
1883 host_id,
1884 );
1885 let (local_proc_addr, local_proc_rx) =
1886 channel::serve::<MessageEnvelope>(ChannelAddr::any(ChannelTransport::Local))
1887 .await
1888 .unwrap();
1889 let local_proc_mbox = Mailbox::new_detached(local_proc_id.actor_id("test".to_string(), 0));
1890 let (local_proc_message_port, local_proc_message_receiver) = local_proc_mbox.open_port();
1891 let _local_proc_serve_handle = local_proc_mbox.clone().serve(local_proc_rx);
1892 MockHostActor {
1893 local_proc_id,
1894 local_proc_addr,
1895 local_proc_message_port,
1896 local_proc_message_receiver,
1897 }
1898 }
1899
1900 async fn mock_proc_actor(
1902 name: &str,
1903 idx: usize,
1904 ) -> (
1905 WorldId,
1906 ProcId,
1907 ChannelAddr,
1908 Mailbox,
1909 PortHandle<ProcMessage>,
1910 PortReceiver<ProcMessage>,
1911 ) {
1912 let world_id = WorldId(name.to_string());
1913 let local_proc_id = ProcId::Ranked(world_id.clone(), idx);
1915 let (local_proc_addr, local_proc_rx) =
1916 channel::serve(ChannelAddr::any(ChannelTransport::Local))
1917 .await
1918 .unwrap();
1919 let local_proc_actor_id = local_proc_id.actor_id("proc", 0);
1921 let local_proc_mbox = Mailbox::new_detached(local_proc_actor_id);
1922 let (local_proc_message_port, local_proc_message_receiver) = local_proc_mbox.open_port();
1923 let _local_proc_serve_handle = local_proc_mbox.clone().serve(local_proc_rx);
1924 (
1925 world_id,
1926 local_proc_id,
1927 local_proc_addr,
1928 local_proc_mbox,
1929 local_proc_message_port,
1930 local_proc_message_receiver,
1931 )
1932 }
1933
1934 #[tracing_test::traced_test]
1935 #[tokio::test]
1936 async fn test_join() {
1937 let params = SystemActorParams::new(Duration::from_secs(10), Duration::from_secs(10));
1938 let (system_actor_handle, system_proc) = SystemActor::bootstrap(params).await.unwrap();
1939
1940 let (
1942 world_id,
1943 local_proc_id,
1944 local_proc_addr,
1945 local_proc_mbox,
1946 local_proc_message_port,
1947 mut local_proc_message_receiver,
1948 ) = mock_proc_actor("test", 0).await; system_actor_handle
1952 .send(SystemMessage::Join {
1953 proc_id: local_proc_id, world_id, proc_message_port: local_proc_message_port.bind(), proc_addr: local_proc_addr, labels: HashMap::new(),
1958 lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
1959 })
1960 .unwrap();
1961
1962 assert_matches!(
1963 local_proc_message_receiver.recv().await.unwrap(),
1964 ProcMessage::Joined()
1965 );
1966
1967 let (local_u64_port, mut local_u64_receiver) = local_proc_mbox.open_port::<u64>();
1969 system_proc.post(
1970 MessageEnvelope::new(
1971 local_proc_mbox.actor_id().clone(),
1972 local_u64_port.bind().port_id().clone(),
1973 Serialized::serialize(&123u64).unwrap(),
1974 Attrs::new(),
1975 ),
1976 monitored_return_handle(),
1977 );
1978 assert_eq!(local_u64_receiver.recv().await.unwrap(), 123);
1979 }
1980
1981 #[tokio::test]
1982 async fn test_supervision_state() {
1983 let mut sv = SystemSupervisionState::new(Duration::from_secs(1));
1984 let world_id = id!(world);
1985 let proc_id_0 = world_id.proc_id(0);
1986 let clock = ClockKind::Real(RealClock);
1987 sv.create(
1988 ProcSupervisionState {
1989 world_id: world_id.clone(),
1990 proc_addr: ChannelAddr::any(ChannelTransport::Local),
1991 proc_id: proc_id_0.clone(),
1992 proc_health: ProcStatus::Alive,
1993 failed_actors: Vec::new(),
1994 },
1995 ProcLifecycleMode::ManagedBySystem,
1996 &clock,
1997 );
1998 let actor_id = id!(world[1].actor);
1999 let proc_id_1 = actor_id.proc_id();
2000 sv.create(
2001 ProcSupervisionState {
2002 world_id: world_id.clone(),
2003 proc_addr: ChannelAddr::any(ChannelTransport::Local),
2004 proc_id: proc_id_1.clone(),
2005 proc_health: ProcStatus::Alive,
2006 failed_actors: Vec::new(),
2007 },
2008 ProcLifecycleMode::ManagedBySystem,
2009 &clock,
2010 );
2011 let world_id = id!(world);
2012
2013 let unknown_world_id = id!(unknow_world);
2014 let failures = sv.get_world_with_failures(&unknown_world_id, &clock);
2015 assert!(failures.is_none());
2016
2017 let failures = sv.get_world_with_failures(&world_id, &clock);
2019 assert!(failures.is_some());
2020 assert_eq!(failures.unwrap().procs.len(), 0);
2021
2022 RealClock.sleep(Duration::from_secs(2)).await;
2024 sv.report(
2025 ProcSupervisionState {
2026 world_id: world_id.clone(),
2027 proc_addr: ChannelAddr::any(ChannelTransport::Local),
2028 proc_id: proc_id_1.clone(),
2029 proc_health: ProcStatus::Alive,
2030 failed_actors: Vec::new(),
2031 },
2032 &clock,
2033 );
2034 let failures = sv.get_world_with_failures(&world_id, &clock);
2035 let procs = failures.unwrap().procs;
2036 assert_eq!(procs.len(), 1);
2037 assert!(
2038 procs.contains_key(
2039 &proc_id_0
2040 .rank()
2041 .expect("proc must be ranked for rank access")
2042 )
2043 );
2044
2045 sv.report(
2047 ProcSupervisionState {
2048 world_id: world_id.clone(),
2049 proc_addr: ChannelAddr::any(ChannelTransport::Local),
2050 proc_id: proc_id_1.clone(),
2051 proc_health: ProcStatus::Alive,
2052 failed_actors: [(actor_id.clone(), ActorStatus::Failed("Actor failed".into()))]
2053 .to_vec(),
2054 },
2055 &clock,
2056 );
2057
2058 let failures = sv.get_world_with_failures(&world_id, &clock);
2059 let procs = failures.unwrap().procs;
2060 assert_eq!(procs.len(), 2);
2061 assert!(
2062 procs.contains_key(
2063 &proc_id_0
2064 .rank()
2065 .expect("proc must be ranked for rank access")
2066 )
2067 );
2068 assert!(
2069 procs.contains_key(
2070 &proc_id_1
2071 .rank()
2072 .expect("proc must be ranked for rank access")
2073 )
2074 );
2075 }
2076
2077 #[tokio::test]
2078 async fn test_supervision_timeout() {
2079 let timeout: Duration = Duration::from_secs(1);
2081
2082 let params = SystemActorParams::new(timeout, timeout);
2083 let (system_actor_handle, _system_proc) = SystemActor::bootstrap(params).await.unwrap();
2084
2085 let (
2087 world_id,
2088 client_proc_id,
2089 client_proc_addr,
2090 client_mailbox,
2091 client_proc_message_port,
2092 mut client_proc_message_receiver,
2093 ) = mock_proc_actor("client", 0).await;
2094
2095 system_actor_handle
2097 .send(SystemMessage::Join {
2098 world_id: world_id.clone(),
2099 proc_id: client_proc_id.clone(),
2100 proc_message_port: client_proc_message_port.bind(),
2101 proc_addr: client_proc_addr,
2102 labels: HashMap::new(),
2103 lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
2104 })
2105 .unwrap();
2106 assert_matches!(
2107 client_proc_message_receiver.recv().await.unwrap(),
2108 ProcMessage::Joined()
2109 );
2110
2111 let (client_tx, client_rx) = client_mailbox.open_once_port::<SystemSnapshot>();
2113 system_actor_handle
2114 .send(SystemMessage::Snapshot {
2115 filter: SystemSnapshotFilter::all(),
2116 ret: client_tx.bind(),
2117 })
2118 .unwrap();
2119 let ret = client_rx.recv().await.unwrap();
2120 assert_eq!(ret.worlds.len(), 1);
2121 assert_eq!(
2122 ret.worlds
2123 .get(
2124 client_proc_id
2125 .world_id()
2126 .expect("proc must be ranked for world_id access")
2127 )
2128 .unwrap()
2129 .status,
2130 WorldStatus::AwaitingCreation
2131 );
2132
2133 let (
2135 world_id,
2136 local_proc_id,
2137 local_proc_addr,
2138 _,
2139 local_proc_message_port,
2140 mut local_proc_message_receiver,
2141 ) = mock_proc_actor("unreacheable_proc", 1).await;
2142 system_actor_handle
2144 .send(SystemMessage::Join {
2145 proc_id: local_proc_id.clone(),
2146 world_id: world_id.clone(),
2147 proc_message_port: local_proc_message_port.bind(),
2148 proc_addr: local_proc_addr.clone(),
2149 labels: HashMap::new(),
2150 lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
2151 })
2152 .unwrap();
2153
2154 assert_matches!(
2155 local_proc_message_receiver.recv().await.unwrap(),
2156 ProcMessage::Joined()
2157 );
2158
2159 let unknown_world_id = id!(unknow_world);
2161 let (client_supervision_tx, client_supervision_receiver) =
2162 client_mailbox.open_once_port::<Option<WorldSupervisionState>>();
2163 let client_supervision_port_ref = client_supervision_tx.bind();
2164 system_actor_handle
2165 .send(WorldSupervisionMessage::State(
2166 unknown_world_id,
2167 client_supervision_port_ref,
2168 ))
2169 .unwrap();
2170 let msg = client_supervision_receiver.recv().await;
2171 assert_eq!(msg.unwrap(), None);
2172
2173 RealClock.sleep(2 * timeout).await;
2175
2176 let (client_supervision_tx, client_supervision_receiver) =
2178 client_mailbox.open_once_port::<Option<WorldSupervisionState>>();
2179 let client_supervision_port_ref = client_supervision_tx.bind();
2180
2181 system_actor_handle
2182 .send(WorldSupervisionMessage::State(
2183 World::get_real_world_id(&world_id),
2184 client_supervision_port_ref,
2185 ))
2186 .unwrap();
2187
2188 let msg = client_supervision_receiver.recv().await;
2190
2191 assert_eq!(
2192 msg.unwrap(),
2193 Some(WorldSupervisionState {
2194 procs: HashMap::from([(
2195 local_proc_id
2196 .rank()
2197 .expect("proc must be ranked for rank access"),
2198 ProcSupervisionState {
2199 world_id: world_id.clone(),
2200 proc_addr: local_proc_addr.clone(),
2201 proc_id: local_proc_id.clone(),
2202 proc_health: ProcStatus::Expired,
2203 failed_actors: Vec::new(),
2204 }
2205 )])
2206 })
2207 );
2208 }
2209
2210 #[tokio::test]
2211 async fn test_world_eviction() {
2212 let timeout: Duration = Duration::from_secs(2);
2214
2215 let params = SystemActorParams::new(timeout, timeout);
2216 let (system_actor_handle, _system_proc) = SystemActor::bootstrap(params).await.unwrap();
2217
2218 let (
2220 world_id,
2221 client_proc_id,
2222 client_proc_addr,
2223 client_mailbox,
2224 client_proc_message_port,
2225 mut client_proc_message_receiver,
2226 ) = mock_proc_actor("client", 0).await;
2227
2228 system_actor_handle
2230 .send(SystemMessage::Join {
2231 proc_id: client_proc_id.clone(),
2232 world_id: world_id.clone(),
2233 proc_message_port: client_proc_message_port.bind(),
2234 proc_addr: client_proc_addr,
2235 labels: HashMap::new(),
2236 lifecycle_mode: ProcLifecycleMode::Detached,
2237 })
2238 .unwrap();
2239 assert_matches!(
2240 client_proc_message_receiver.recv().await.unwrap(),
2241 ProcMessage::Joined()
2242 );
2243
2244 let (
2246 world_id,
2247 local_proc_id,
2248 local_proc_addr,
2249 _,
2250 local_proc_message_port,
2251 mut local_proc_message_receiver,
2252 ) = mock_proc_actor("unreacheable_proc", 1).await;
2253
2254 system_actor_handle
2256 .send(SystemMessage::UpsertWorld {
2257 world_id: local_proc_id
2258 .world_id()
2259 .expect("proc must be ranked for world_id access")
2260 .clone(),
2261 shape: Shape::Definite(vec![1]),
2262 num_procs_per_host: 1,
2263 env: Environment::Local,
2264 labels: HashMap::new(),
2265 })
2266 .unwrap();
2267
2268 system_actor_handle
2270 .send(SystemMessage::Join {
2271 proc_id: local_proc_id.clone(),
2272 world_id: world_id.clone(),
2273 proc_message_port: local_proc_message_port.bind(),
2274 proc_addr: local_proc_addr,
2275 labels: HashMap::new(),
2276 lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
2277 })
2278 .unwrap();
2279
2280 assert_matches!(
2281 local_proc_message_receiver.recv().await.unwrap(),
2282 ProcMessage::Joined()
2283 );
2284
2285 let snapshot = system_actor_handle
2287 .snapshot(&client_mailbox, SystemSnapshotFilter::all())
2288 .await
2289 .unwrap();
2290 assert_eq!(snapshot.worlds.len(), 1);
2291 assert_eq!(
2292 snapshot
2293 .worlds
2294 .get(
2295 local_proc_id
2296 .world_id()
2297 .expect("proc must be ranked for world_id access")
2298 )
2299 .unwrap()
2300 .status,
2301 WorldStatus::Live
2302 );
2303
2304 RealClock.sleep(2 * timeout).await;
2306
2307 let mut iter = 0;
2308 let mut state = system_actor_handle
2310 .state(
2311 &client_mailbox,
2312 local_proc_id
2313 .world_id()
2314 .expect("proc must be ranked for world_id access")
2315 .clone(),
2316 )
2317 .await
2318 .unwrap()
2319 .unwrap();
2320 while iter < 100 {
2321 if state.procs.values().any(|p| !p.is_healthy()) {
2322 break;
2323 }
2324 iter += 1;
2325 RealClock.sleep(Duration::from_millis(100)).await;
2327 state = system_actor_handle
2328 .state(
2329 &client_mailbox,
2330 local_proc_id
2331 .world_id()
2332 .expect("proc must be ranked for world_id access")
2333 .clone(),
2334 )
2335 .await
2336 .unwrap()
2337 .unwrap();
2338 }
2339 assert!(state.procs.values().any(|p| !p.is_healthy()));
2340 let mut snapshot = system_actor_handle
2342 .snapshot(&client_mailbox, SystemSnapshotFilter::all())
2343 .await
2344 .unwrap();
2345 assert!(snapshot.worlds.len() == 1);
2346
2347 let (client_tx, client_rx) = client_mailbox.open_once_port::<()>();
2349 let _ = system_actor_handle
2350 .stop(
2351 &client_mailbox,
2352 Some(vec![
2353 local_proc_id
2354 .world_id()
2355 .expect("proc must be ranked for world_id access")
2356 .clone(),
2357 ]),
2358 Duration::from_secs(2),
2359 client_tx.bind(),
2360 )
2361 .await;
2362 client_rx.recv().await.unwrap();
2363 RealClock.sleep(10 * timeout).await;
2365
2366 snapshot = system_actor_handle
2368 .snapshot(&client_mailbox, SystemSnapshotFilter::all())
2369 .await
2370 .unwrap();
2371
2372 assert_eq!(snapshot.worlds.len(), 0, "{:?}", snapshot);
2373 }
2374
2375 #[tokio::test]
2376 async fn test_proc_managing_system() {
2377 let timeout: Duration = Duration::from_secs(2);
2379
2380 let params = SystemActorParams::new(timeout, timeout);
2381 let (system_actor_handle, _system_proc) = SystemActor::bootstrap(params).await.unwrap();
2382 let mut sys_status_rx = system_actor_handle.status();
2383
2384 let (
2386 world_id,
2387 client_proc_id,
2388 client_proc_addr,
2389 client_mailbox,
2390 client_proc_message_port,
2391 mut client_proc_message_receiver,
2392 ) = mock_proc_actor("client", 0).await;
2393
2394 system_actor_handle
2396 .send(SystemMessage::Join {
2397 proc_id: client_proc_id.clone(),
2398 world_id: world_id.clone(),
2399 proc_message_port: client_proc_message_port.bind(),
2400 proc_addr: client_proc_addr,
2401 labels: HashMap::new(),
2402 lifecycle_mode: ProcLifecycleMode::Detached,
2403 })
2404 .unwrap();
2405 assert_matches!(
2406 client_proc_message_receiver.recv().await.unwrap(),
2407 ProcMessage::Joined()
2408 );
2409
2410 {
2411 let received = sys_status_rx.borrow_and_update();
2412 assert_eq!(*received, ActorStatus::Idle);
2413 }
2414
2415 let (
2417 world_id,
2418 local_proc_id,
2419 local_proc_addr,
2420 _,
2421 local_proc_message_port,
2422 mut local_proc_message_receiver,
2423 ) = mock_proc_actor("unreacheable_proc", 1).await;
2424
2425 system_actor_handle
2427 .send(SystemMessage::UpsertWorld {
2428 world_id: local_proc_id
2429 .world_id()
2430 .expect("proc must be ranked for world_id access")
2431 .clone(),
2432 shape: Shape::Definite(vec![1]),
2433 num_procs_per_host: 1,
2434 env: Environment::Local,
2435 labels: HashMap::new(),
2436 })
2437 .unwrap();
2438
2439 system_actor_handle
2441 .send(SystemMessage::Join {
2442 proc_id: local_proc_id.clone(),
2443 world_id: world_id.clone(),
2444 proc_message_port: local_proc_message_port.bind(),
2445 proc_addr: local_proc_addr,
2446 labels: HashMap::new(),
2447 lifecycle_mode: ProcLifecycleMode::ManagingSystem,
2448 })
2449 .unwrap();
2450
2451 assert_matches!(
2452 local_proc_message_receiver.recv().await.unwrap(),
2453 ProcMessage::Joined()
2454 );
2455
2456 RealClock.sleep(2 * timeout).await;
2460
2461 let mut iter = 0;
2462 let mut state = system_actor_handle
2464 .state(
2465 &client_mailbox,
2466 local_proc_id
2467 .world_id()
2468 .expect("proc must be ranked for world_id access")
2469 .clone(),
2470 )
2471 .await
2472 .unwrap()
2473 .unwrap();
2474 while iter < 100 {
2475 if state.procs.values().any(|p| !p.is_healthy()) {
2476 break;
2477 }
2478 iter += 1;
2479 RealClock.sleep(Duration::from_millis(100)).await;
2481 state = system_actor_handle
2482 .state(
2483 &client_mailbox,
2484 local_proc_id
2485 .world_id()
2486 .expect("proc must be ranked for world_id access")
2487 .clone(),
2488 )
2489 .await
2490 .unwrap()
2491 .unwrap();
2492 }
2493 assert!(state.procs.values().any(|p| !p.is_healthy()));
2494
2495 let (client_tx, client_rx) = client_mailbox.open_once_port::<()>();
2497 let _ = system_actor_handle
2498 .stop(
2499 &client_mailbox,
2500 Some(vec![
2501 local_proc_id
2502 .world_id()
2503 .expect("proc must be ranked for world_id access")
2504 .clone(),
2505 ]),
2506 Duration::from_secs(2),
2507 client_tx.bind(),
2508 )
2509 .await;
2510 client_rx.recv().await.unwrap();
2511 RealClock.sleep(10 * timeout).await;
2513 {
2514 assert!(sys_status_rx.borrow().has_changed());
2515 let received = sys_status_rx.borrow_and_update();
2516 assert_eq!(*received, ActorStatus::Stopped);
2517 }
2518 }
2519
2520 #[tracing_test::traced_test]
2521 #[tokio::test]
2522 async fn test_host_join_before_world() {
2523 let params = SystemActorParams::new(Duration::from_secs(10), Duration::from_secs(10));
2525 let (system_actor_handle, _system_proc) = SystemActor::bootstrap(params).await.unwrap();
2526
2527 let mut host_actors: Vec<MockHostActor> = Vec::new();
2529
2530 let world_name = "test".to_string();
2531 let world_id = WorldId(world_name.clone());
2532 host_actors.push(spawn_mock_host_actor(world_id.clone(), 0).await);
2533 host_actors.push(spawn_mock_host_actor(world_id.clone(), 1).await);
2534
2535 for host_actor in host_actors.iter_mut() {
2536 system_actor_handle
2538 .send(SystemMessage::Join {
2539 proc_id: host_actor.local_proc_id.clone(),
2540 world_id: world_id.clone(),
2541 proc_message_port: host_actor.local_proc_message_port.bind(),
2542 proc_addr: host_actor.local_proc_addr.clone(),
2543 labels: HashMap::new(),
2544 lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
2545 })
2546 .unwrap();
2547
2548 assert_matches!(
2551 host_actor.local_proc_message_receiver.recv().await.unwrap(),
2552 ProcMessage::Joined()
2553 );
2554 }
2555
2556 let num_procs = 6;
2558 let shape = Shape::Definite(vec![2, 3]);
2559 system_actor_handle
2560 .send(SystemMessage::UpsertWorld {
2561 world_id: world_id.clone(),
2562 shape,
2563 num_procs_per_host: 3,
2564 env: Environment::Local,
2565 labels: HashMap::new(),
2566 })
2567 .unwrap();
2568
2569 let mut all_procs: Vec<ProcId> = Vec::new();
2570 for host_actor in host_actors.iter_mut() {
2571 let m = host_actor.local_proc_message_receiver.recv().await.unwrap();
2572 match m {
2573 ProcMessage::SpawnProc {
2574 env,
2575 world_id,
2576 mut proc_ids,
2577 world_size,
2578 } => {
2579 assert_eq!(world_id, WorldId(world_name.clone()));
2580 assert_eq!(env, Environment::Local);
2581 assert_eq!(world_size, num_procs);
2582 all_procs.append(&mut proc_ids);
2583 }
2584 _ => std::panic!("Unexpected message type!"),
2585 }
2586 }
2587 assert_eq!(all_procs.len(), num_procs);
2589 all_procs.sort();
2590 for (i, proc) in all_procs.iter().enumerate() {
2591 assert_eq!(*proc, ProcId::Ranked(WorldId(world_name.clone()), i));
2592 }
2593 }
2594
2595 #[tokio::test]
2596 async fn test_host_join_after_world() {
2597 let params = SystemActorParams::new(Duration::from_secs(10), Duration::from_secs(10));
2599 let (system_actor_handle, _system_proc) = SystemActor::bootstrap(params).await.unwrap();
2600
2601 let world_name = "test".to_string();
2603 let world_id = WorldId(world_name.clone());
2604 let num_procs = 6;
2605 let shape = Shape::Definite(vec![2, 3]);
2606 system_actor_handle
2607 .send(SystemMessage::UpsertWorld {
2608 world_id: world_id.clone(),
2609 shape,
2610 num_procs_per_host: 3,
2611 env: Environment::Local,
2612 labels: HashMap::new(),
2613 })
2614 .unwrap();
2615
2616 let mut host_actors: Vec<MockHostActor> = Vec::new();
2618
2619 host_actors.push(spawn_mock_host_actor(world_id.clone(), 0).await);
2620 host_actors.push(spawn_mock_host_actor(world_id.clone(), 1).await);
2621
2622 for host_actor in host_actors.iter_mut() {
2623 system_actor_handle
2625 .send(SystemMessage::Join {
2626 proc_id: host_actor.local_proc_id.clone(),
2627 world_id: world_id.clone(),
2628 proc_message_port: host_actor.local_proc_message_port.bind(),
2629 proc_addr: host_actor.local_proc_addr.clone(),
2630 labels: HashMap::new(),
2631 lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
2632 })
2633 .unwrap();
2634
2635 assert_matches!(
2638 host_actor.local_proc_message_receiver.recv().await.unwrap(),
2639 ProcMessage::Joined()
2640 );
2641 }
2642
2643 let mut all_procs: Vec<ProcId> = Vec::new();
2644 for host_actor in host_actors.iter_mut() {
2645 let m = host_actor.local_proc_message_receiver.recv().await.unwrap();
2646 match m {
2647 ProcMessage::SpawnProc {
2648 env,
2649 world_id,
2650 mut proc_ids,
2651 world_size,
2652 } => {
2653 assert_eq!(world_id, WorldId(world_name.clone()));
2654 assert_eq!(env, Environment::Local);
2655 assert_eq!(world_size, num_procs);
2656 all_procs.append(&mut proc_ids);
2657 }
2658 _ => std::panic!("Unexpected message type!"),
2659 }
2660 }
2661 assert_eq!(all_procs.len(), num_procs);
2663 all_procs.sort();
2664 for (i, proc) in all_procs.iter().enumerate() {
2665 assert_eq!(*proc, ProcId::Ranked(WorldId(world_name.clone()), i));
2666 }
2667 }
2668
2669 #[test]
2670 fn test_snapshot_filter() {
2671 let test_world = World::new(
2672 WorldId("test_world".to_string()),
2673 Shape::Definite(vec![1]),
2674 WorldState {
2675 host_map: HashMap::new(),
2676 procs: HashMap::new(),
2677 status: WorldStatus::Live,
2678 },
2679 1,
2680 Environment::Local,
2681 HashMap::from([("foo".to_string(), "bar".to_string())]),
2682 )
2683 .unwrap();
2684 let filter = SystemSnapshotFilter::all();
2686 assert!(filter.world_matches(&test_world));
2687 assert!(SystemSnapshotFilter::labels_match(
2688 &HashMap::new(),
2689 &HashMap::from([("foo".to_string(), "bar".to_string())])
2690 ));
2691 let mut filter = SystemSnapshotFilter::all();
2693 filter.worlds = vec![WorldId("test_world".to_string())];
2694 assert!(filter.world_matches(&test_world));
2695 filter.worlds = vec![WorldId("unknow_world".to_string())];
2696 assert!(!filter.world_matches(&test_world));
2697 assert!(SystemSnapshotFilter::labels_match(
2698 &HashMap::from([("foo".to_string(), "baz".to_string())]),
2699 &HashMap::from([("foo".to_string(), "baz".to_string())]),
2700 ));
2701 assert!(!SystemSnapshotFilter::labels_match(
2702 &HashMap::from([("foo".to_string(), "bar".to_string())]),
2703 &HashMap::from([("foo".to_string(), "baz".to_string())]),
2704 ));
2705 }
2706
2707 #[tokio::test]
2708 async fn test_undeliverable_message_return() {
2709 use hyperactor::mailbox::MailboxClient;
2712 use hyperactor::test_utils::pingpong::PingPongActor;
2713 use hyperactor::test_utils::pingpong::PingPongMessage;
2714
2715 use crate::System;
2716 use crate::proc_actor::ProcActor;
2717 use crate::supervision::ProcSupervisor;
2718
2719 let config = hyperactor::config::global::lock();
2721 let _guard = config.override_key(
2722 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2723 Duration::from_secs(1),
2724 );
2725
2726 let server_handle = System::serve(
2729 ChannelAddr::any(ChannelTransport::Tcp),
2730 Duration::from_secs(2), Duration::from_secs(2), )
2733 .await
2734 .unwrap();
2735 let system_actor_handle = server_handle.system_actor_handle();
2736 let mut system = System::new(server_handle.local_addr().clone());
2737 let client = system.attach().await.unwrap();
2738
2739 let snapshot = system_actor_handle
2741 .snapshot(&client, SystemSnapshotFilter::all())
2742 .await
2743 .unwrap();
2744 assert_eq!(snapshot.worlds.len(), 0);
2745
2746 let world_id = id!(world);
2748 system_actor_handle
2749 .send(SystemMessage::UpsertWorld {
2750 world_id: world_id.clone(),
2751 shape: Shape::Definite(vec![1]),
2752 num_procs_per_host: 1,
2753 env: Environment::Local,
2754 labels: HashMap::new(),
2755 })
2756 .unwrap();
2757
2758 let snapshot = system_actor_handle
2760 .snapshot(&client, SystemSnapshotFilter::all())
2761 .await
2762 .unwrap();
2763 assert_eq!(snapshot.worlds.len(), 1);
2764 assert!(snapshot.worlds.contains_key(&world_id));
2766 assert!(matches!(
2768 snapshot.worlds.get(&world_id).unwrap().status,
2769 WorldStatus::Unhealthy(_)
2770 ));
2771
2772 let sup_mail = system.attach().await.unwrap();
2774 let (sup_tx, _sup_rx) = sup_mail.open_port::<ProcSupervisionMessage>();
2775 sup_tx.bind_to(ProcSupervisionMessage::port());
2776 let sup_ref = ActorRef::<ProcSupervisor>::attest(sup_mail.actor_id().clone());
2777
2778 let system_sender = BoxedMailboxSender::new(MailboxClient::new(
2780 channel::dial(server_handle.local_addr().clone()).unwrap(),
2781 ));
2782 let proc_forwarder =
2784 BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
2785
2786 let proc_0 = Proc::new(world_id.proc_id(0), proc_forwarder.clone());
2788 let _proc_actor_0 = ProcActor::bootstrap_for_proc(
2789 proc_0.clone(),
2790 world_id.clone(),
2791 ChannelAddr::any(ChannelTransport::Tcp),
2792 server_handle.local_addr().clone(),
2793 sup_ref.clone(),
2794 Duration::from_millis(300), HashMap::new(),
2796 ProcLifecycleMode::ManagedBySystem,
2797 )
2798 .await
2799 .unwrap();
2800 let proc_0_client = proc_0.attach("client").unwrap();
2801 let (proc_0_undeliverable_tx, _proc_0_undeliverable_rx) = proc_0_client.open_port();
2802
2803 let proc_1 = Proc::new(world_id.proc_id(1), proc_forwarder.clone());
2805 let proc_actor_1 = ProcActor::bootstrap_for_proc(
2806 proc_1.clone(),
2807 world_id.clone(),
2808 ChannelAddr::any(ChannelTransport::Tcp),
2809 server_handle.local_addr().clone(),
2810 sup_ref.clone(),
2811 Duration::from_millis(300), HashMap::new(),
2813 ProcLifecycleMode::ManagedBySystem,
2814 )
2815 .await
2816 .unwrap();
2817 let proc_1_client = proc_1.attach("client").unwrap();
2818 let (proc_1_undeliverable_tx, mut _proc_1_undeliverable_rx) = proc_1_client.open_port();
2819
2820 let ping_params = PingPongActorParams::new(Some(proc_0_undeliverable_tx.bind()), None);
2824 let ping_handle = proc_0
2825 .spawn::<PingPongActor>("ping", ping_params)
2826 .await
2827 .unwrap();
2828 let pong_params = PingPongActorParams::new(Some(proc_1_undeliverable_tx.bind()), None);
2829 let pong_handle = proc_1
2830 .spawn::<PingPongActor>("pong", pong_params)
2831 .await
2832 .unwrap();
2833
2834 proc_actor_1.mailbox.stop("from testing");
2837 proc_actor_1.mailbox.await.unwrap().unwrap();
2838
2839 let snapshot = system_actor_handle
2842 .snapshot(&client, SystemSnapshotFilter::all())
2843 .await
2844 .unwrap();
2845 assert_eq!(snapshot.worlds.len(), 1);
2846 assert!(snapshot.worlds.contains_key(&world_id));
2847 assert_eq!(
2848 snapshot.worlds.get(&world_id).unwrap().status,
2849 WorldStatus::Live
2850 );
2851
2852 let ttl = 1_u64;
2854 let (game_over, on_game_over) = proc_0_client.open_once_port::<bool>();
2855 ping_handle
2856 .send(PingPongMessage(ttl, pong_handle.bind(), game_over.bind()))
2857 .unwrap();
2858
2859 assert!(
2862 RealClock
2863 .timeout(tokio::time::Duration::from_secs(4), on_game_over.recv())
2864 .await
2865 .is_err()
2866 );
2867
2868 let snapshot = system_actor_handle
2871 .snapshot(&client, SystemSnapshotFilter::all())
2872 .await
2873 .unwrap();
2874 assert_eq!(snapshot.worlds.len(), 1);
2875 assert!(matches!(
2876 snapshot.worlds.get(&world_id).unwrap().status,
2877 WorldStatus::Unhealthy(_)
2878 ));
2879 }
2880
2881 #[tokio::test]
2882 async fn test_stop_fast() -> Result<()> {
2883 let server_handle = System::serve(
2884 ChannelAddr::any(ChannelTransport::Tcp),
2885 Duration::from_secs(2), Duration::from_secs(2), )
2888 .await?;
2889 let system_actor_handle = server_handle.system_actor_handle();
2890 let mut system = System::new(server_handle.local_addr().clone());
2891 let client = system.attach().await?;
2892
2893 let (client_tx, client_rx) = client.open_once_port::<()>();
2895 system_actor_handle.send(SystemMessage::Stop {
2896 worlds: None,
2897 proc_timeout: Duration::from_secs(5),
2898 reply_port: client_tx.bind(),
2899 })?;
2900 client_rx.recv().await?;
2901
2902 let mut sys_status_rx = system_actor_handle.status();
2904 {
2905 let received = sys_status_rx.borrow_and_update();
2906 assert_eq!(*received, ActorStatus::Stopped);
2907 }
2908
2909 Ok(())
2910 }
2911
2912 #[tokio::test]
2913 async fn test_update_sim_address() {
2914 simnet::start();
2915
2916 let src_id = id!(proc[0].actor);
2917 let src_addr = ChannelAddr::Sim(SimAddr::new("unix!@src".parse().unwrap()).unwrap());
2918 let dst_addr = ChannelAddr::Sim(SimAddr::new("unix!@dst".parse().unwrap()).unwrap());
2919 let (_, mut rx) = channel::serve::<MessageEnvelope>(src_addr.clone())
2920 .await
2921 .unwrap();
2922
2923 let router = ReportingRouter::new();
2924
2925 router
2926 .router
2927 .bind(src_id.proc_id().clone().into(), src_addr);
2928 router.router.bind(id!(proc[1]).into(), dst_addr);
2929
2930 router.post_update_address(&MessageEnvelope::new(
2931 src_id,
2932 PortId(id!(proc[1].actor), 9999u64),
2933 Serialized::serialize(&1u64).unwrap(),
2934 Attrs::new(),
2935 ));
2936
2937 let envelope = rx.recv().await.unwrap();
2938 let admin_msg = envelope
2939 .data()
2940 .deserialized::<MailboxAdminMessage>()
2941 .unwrap();
2942 let MailboxAdminMessage::UpdateAddress {
2943 addr: ChannelAddr::Sim(addr),
2944 ..
2945 } = admin_msg
2946 else {
2947 panic!("Expected sim address");
2948 };
2949
2950 assert_eq!(addr.src().clone().unwrap().to_string(), "unix:@src");
2951 assert_eq!(addr.addr().to_string(), "unix:@dst");
2952 }
2953}