1pub mod mesh;
13
14use core::slice::GetDisjointMutIndex as _;
15use std::collections::HashMap;
16use std::fmt;
17use std::fmt::Debug;
18use std::hash::Hash;
19use std::mem::replace;
20use std::mem::take;
21use std::ops::Deref;
22use std::ops::DerefMut;
23use std::ops::Range;
24use std::time::Duration;
25
26use enum_as_inner::EnumAsInner;
27use hyperactor::Bind;
28use hyperactor::HandleClient;
29use hyperactor::Handler;
30use hyperactor::RefClient;
31use hyperactor::RemoteMessage;
32use hyperactor::Unbind;
33use hyperactor::mailbox::PortReceiver;
34use hyperactor::message::Bind;
35use hyperactor::message::Bindings;
36use hyperactor::message::Unbind;
37use hyperactor::reference as hyperactor_reference;
38use hyperactor_config::attrs::Attrs;
39use ndslice::Region;
40use ndslice::ViewExt;
41use serde::Deserialize;
42use serde::Serialize;
43use typeuri::Named;
44
45use crate::Name;
46use crate::StatusOverlay;
47use crate::bootstrap;
48use crate::bootstrap::BootstrapCommand;
49use crate::bootstrap::ProcBind;
50use crate::host_mesh::host_agent::ProcState;
51use crate::proc_agent::ActorSpec;
52use crate::proc_agent::ActorState;
53
54#[derive(
56 Clone,
57 Debug,
58 Serialize,
59 Deserialize,
60 Named,
61 PartialOrd,
62 Ord,
63 PartialEq,
64 Eq,
65 Hash,
66 EnumAsInner,
67 strum::Display,
68 Bind,
69 Unbind
70)]
71pub enum Status {
72 NotExist,
74 Initializing,
76 Running,
78 Stopping,
80 Stopped,
82 #[strum(to_string = "Failed({0})")]
84 Failed(String),
85 #[strum(to_string = "Timeout({0:?})")]
87 Timeout(Duration),
88 Unknown,
90}
91
92impl Status {
93 pub fn is_terminating(&self) -> bool {
95 matches!(
96 self,
97 Status::Stopping | Status::Stopped | Status::Failed(_) | Status::Timeout(_)
98 )
99 }
100
101 pub fn is_failure(&self) -> bool {
105 matches!(self, Self::Failed(_) | Self::Timeout(_))
106 }
107
108 pub fn is_terminated(&self) -> bool {
111 matches!(
112 self,
113 Status::Stopped | Status::Failed(_) | Status::Timeout(_)
114 )
115 }
116
117 pub fn is_healthy(&self) -> bool {
118 matches!(self, Status::Initializing | Status::Running)
119 }
120
121 pub fn clamp_min(self, floor: Status) -> Status {
128 if floor.is_terminating() && !self.is_terminating() {
129 floor
130 } else {
131 self
132 }
133 }
134}
135
136impl From<bootstrap::ProcStatus> for Status {
137 fn from(status: bootstrap::ProcStatus) -> Self {
138 use bootstrap::ProcStatus;
139 match status {
140 ProcStatus::Starting => Status::Initializing,
141 ProcStatus::Running { .. } | ProcStatus::Ready { .. } => Status::Running,
142 ProcStatus::Stopping { .. } => Status::Stopping,
143 ProcStatus::Stopped { .. } => Status::Stopped,
144 ProcStatus::Failed { reason } => Status::Failed(reason),
145 ProcStatus::Killed { .. } => Status::Failed(format!("{}", status)),
146 }
147 }
148}
149
150impl From<hyperactor::host::LocalProcStatus> for Status {
151 fn from(status: hyperactor::host::LocalProcStatus) -> Self {
152 match status {
153 hyperactor::host::LocalProcStatus::Stopping => Status::Stopping,
154 hyperactor::host::LocalProcStatus::Stopped => Status::Stopped,
155 }
156 }
157}
158
159#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq, Default)]
163pub struct Rank(pub Option<usize>);
164wirevalue::register_type!(Rank);
165
166impl Rank {
167 pub fn new(rank: usize) -> Self {
169 Self(Some(rank))
170 }
171
172 pub fn unwrap(&self) -> usize {
174 self.0.unwrap()
175 }
176}
177
178impl Unbind for Rank {
179 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
180 bindings.push_back(self)
181 }
182}
183
184impl Bind for Rank {
185 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
186 let bound = bindings.try_pop_front::<Rank>()?;
187 self.0 = bound.0;
188 Ok(())
189 }
190}
191
192#[derive(
199 Clone,
200 Debug,
201 Serialize,
202 Deserialize,
203 Named,
204 Handler,
205 HandleClient,
206 RefClient,
207 Bind,
208 Unbind
209)]
210pub struct GetRankStatus {
211 pub name: Name,
213 #[binding(include)]
215 pub reply: hyperactor_reference::PortRef<StatusOverlay>,
216}
217
218#[derive(
222 Clone,
223 Debug,
224 Serialize,
225 Deserialize,
226 Named,
227 Handler,
228 HandleClient,
229 RefClient,
230 Bind,
231 Unbind
232)]
233pub struct WaitRankStatus {
234 pub name: Name,
236 pub min_status: Status,
240 #[binding(include)]
242 pub reply: hyperactor_reference::PortRef<StatusOverlay>,
243}
244
245impl GetRankStatus {
246 pub async fn wait(
247 mut rx: PortReceiver<crate::StatusMesh>,
248 num_ranks: usize,
249 max_idle_time: Duration,
250 region: Region, ) -> Result<crate::StatusMesh, crate::StatusMesh> {
252 debug_assert_eq!(region.num_ranks(), num_ranks, "region/num_ranks mismatch");
253
254 let mut alarm = hyperactor::time::Alarm::new();
255 alarm.arm(max_idle_time);
256
257 let mut snapshot =
259 crate::StatusMesh::from_single(region, crate::resource::Status::NotExist);
260
261 loop {
262 let mut sleeper = alarm.sleeper();
263 tokio::select! {
264 _ = sleeper.sleep() => return Err(snapshot),
265 next = rx.recv() => {
266 match next {
267 Ok(mesh) => { snapshot = mesh; } Err(_) => return Err(snapshot),
269 }
270 }
271 }
272
273 alarm.arm(max_idle_time);
274
275 if snapshot
279 .values()
280 .take(num_ranks)
281 .all(|s| !matches!(s, crate::resource::Status::NotExist))
282 {
283 break Ok(snapshot);
284 }
285 }
286 }
287}
288
289#[derive(
291 Clone,
292 Debug,
293 Serialize,
294 Deserialize,
295 Named,
296 PartialEq,
297 Eq,
298 Handler,
299 Bind,
300 Unbind
301)]
302pub struct State<S> {
303 pub name: Name,
305 pub status: Status,
307 pub state: Option<S>,
309 pub generation: u64,
311 pub timestamp: std::time::SystemTime,
313}
314wirevalue::register_type!(State<ActorState>);
315
316impl<S: Serialize> fmt::Display for State<S> {
317 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
318 match serde_json::to_string(self) {
320 Ok(json) => write!(f, "{}", json),
321 Err(e) => write!(f, "<state: serde_json error: {}>", e),
322 }
323 }
324}
325
326#[derive(
328 Debug,
329 Clone,
330 Serialize,
331 Deserialize,
332 Named,
333 Handler,
334 HandleClient,
335 RefClient,
336 Bind,
337 Unbind
338)]
339pub struct CreateOrUpdate<S> {
340 pub name: Name,
342 #[binding(include)]
344 pub rank: Rank,
345 pub spec: S,
347}
348wirevalue::register_type!(CreateOrUpdate<ProcSpec>);
349wirevalue::register_type!(CreateOrUpdate<ActorSpec>);
350
351#[derive(
353 Debug,
354 Clone,
355 Serialize,
356 Deserialize,
357 Named,
358 Handler,
359 HandleClient,
360 RefClient,
361 Bind,
362 Unbind
363)]
364pub struct Stop {
365 pub name: Name,
367 pub reason: String,
369}
370wirevalue::register_type!(Stop);
371
372#[derive(
376 Debug,
377 Clone,
378 Serialize,
379 Deserialize,
380 Named,
381 Handler,
382 HandleClient,
383 RefClient,
384 Bind,
385 Unbind
386)]
387pub struct StopAll {
388 pub reason: String,
390}
391wirevalue::register_type!(StopAll);
392
393#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
395pub struct GetState<S> {
396 pub name: Name,
398 #[reply]
400 pub reply: hyperactor_reference::PortRef<State<S>>,
401}
402wirevalue::register_type!(GetState<ProcState>);
403wirevalue::register_type!(GetState<ActorState>);
404
405impl<S> Unbind for GetState<S>
407where
408 S: RemoteMessage,
409 S: Unbind,
410{
411 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
412 self.reply.unbind(bindings)
413 }
414}
415
416impl<S> Bind for GetState<S>
417where
418 S: RemoteMessage,
419 S: Bind,
420{
421 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
422 self.reply.bind(bindings)
423 }
424}
425
426impl<S> Clone for GetState<S>
427where
428 S: RemoteMessage,
429{
430 fn clone(&self) -> Self {
431 Self {
432 name: self.name.clone(),
433 reply: self.reply.clone(),
434 }
435 }
436}
437
438#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
441pub struct KeepaliveGetState<S> {
442 pub expires_after: std::time::SystemTime,
445 pub get_state: GetState<S>,
446}
447wirevalue::register_type!(KeepaliveGetState<ProcState>);
448wirevalue::register_type!(KeepaliveGetState<ActorState>);
449
450impl<S> Unbind for KeepaliveGetState<S>
452where
453 S: RemoteMessage,
454 S: Unbind,
455{
456 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
457 self.get_state.unbind(bindings)
458 }
459}
460
461impl<S> Bind for KeepaliveGetState<S>
462where
463 S: RemoteMessage,
464 S: Bind,
465{
466 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
467 self.get_state.bind(bindings)
468 }
469}
470
471impl<S> Clone for KeepaliveGetState<S>
472where
473 S: RemoteMessage,
474{
475 fn clone(&self) -> Self {
476 Self {
477 expires_after: self.expires_after.clone(),
478 get_state: self.get_state.clone(),
479 }
480 }
481}
482
483#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
487pub struct StreamState<S> {
488 pub name: Name,
490 pub subscriber: hyperactor_reference::PortRef<State<S>>,
492}
493wirevalue::register_type!(StreamState<ActorState>);
494
495impl<S> Unbind for StreamState<S>
497where
498 S: RemoteMessage,
499 S: Unbind,
500{
501 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
502 self.subscriber.unbind(bindings)
503 }
504}
505
506impl<S> Bind for StreamState<S>
507where
508 S: RemoteMessage,
509 S: Bind,
510{
511 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
512 self.subscriber.bind(bindings)
513 }
514}
515
516impl<S> Clone for StreamState<S>
517where
518 S: RemoteMessage,
519{
520 fn clone(&self) -> Self {
521 Self {
522 name: self.name.clone(),
523 subscriber: self.subscriber.clone(),
524 }
525 }
526}
527
528#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
530pub struct List {
531 #[reply]
533 pub reply: hyperactor_reference::PortRef<Vec<Name>>,
534}
535wirevalue::register_type!(List);
536
537pub trait Resource {
539 type Spec: typeuri::Named
541 + Serialize
542 + for<'de> Deserialize<'de>
543 + Send
544 + Sync
545 + std::fmt::Debug;
546
547 type State: typeuri::Named
549 + Serialize
550 + for<'de> Deserialize<'de>
551 + Send
552 + Sync
553 + std::fmt::Debug;
554}
555
556hyperactor::behavior!(
558 Controller<R: Resource>,
559 CreateOrUpdate<R::Spec>,
560 GetState<R::State>,
561 Stop,
562);
563
564#[derive(Debug, Clone, Named, Serialize, Deserialize)]
569pub struct RankedValues<T> {
570 intervals: Vec<(Range<usize>, T)>,
571}
572
573impl<T: PartialEq> PartialEq for RankedValues<T> {
574 fn eq(&self, other: &Self) -> bool {
575 self.intervals == other.intervals
576 }
577}
578
579impl<T: Eq> Eq for RankedValues<T> {}
580
581impl<T> Default for RankedValues<T> {
582 fn default() -> Self {
583 Self {
584 intervals: Vec::new(),
585 }
586 }
587}
588
589impl<T> RankedValues<T> {
590 pub fn iter(&self) -> impl Iterator<Item = &(Range<usize>, T)> + '_ {
592 self.intervals.iter()
593 }
594
595 pub fn rank(&self, value: usize) -> usize {
598 self.iter()
599 .take_while(|(ranks, _)| ranks.start <= value)
600 .map(|(ranks, _)| ranks.end.min(value) - ranks.start)
601 .sum()
602 }
603}
604
605impl<T: Clone> RankedValues<T> {
606 pub fn materialized_iter(&self, until: usize) -> impl Iterator<Item = &T> + '_ {
607 assert_eq!(self.rank(until), until, "insufficient rank");
608 self.iter()
609 .flat_map(|(range, value)| std::iter::repeat_n(value, range.end - range.start))
610 }
611}
612
613impl<T: Hash + Eq + Clone> RankedValues<T> {
614 pub fn invert(&self) -> ValuesByRank<T> {
616 let mut inverted: HashMap<T, Vec<Range<usize>>> = HashMap::new();
617 for (range, value) in self.iter() {
618 inverted
619 .entry(value.clone())
620 .or_default()
621 .push(range.clone());
622 }
623 ValuesByRank { values: inverted }
624 }
625}
626
627impl<T: Eq + Clone> RankedValues<T> {
628 pub fn merge_from(&mut self, other: Self) {
636 let mut left_iter = take(&mut self.intervals).into_iter();
637 let mut right_iter = other.intervals.into_iter();
638
639 let mut left = left_iter.next();
640 let mut right = right_iter.next();
641
642 while left.is_some() && right.is_some() {
643 let (left_ranks, left_value) = left.as_mut().unwrap();
644 let (right_ranks, right_value) = right.as_mut().unwrap();
645
646 if left_ranks.is_overlapping(right_ranks) {
647 if left_value == right_value {
648 let ranks = left_ranks.start.min(right_ranks.start)..right_ranks.end;
649 let (_, value) = replace(&mut right, right_iter.next()).unwrap();
650 left_ranks.start = ranks.end;
651 if left_ranks.is_empty() {
652 left = left_iter.next();
653 }
654 self.append(ranks, value);
655 } else if left_ranks.start < right_ranks.start {
656 let ranks = left_ranks.start..right_ranks.start;
657 left_ranks.start = ranks.end;
658 self.append(ranks, left_value.clone());
660 } else {
661 let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
662 left_ranks.start = ranks.end;
663 if left_ranks.is_empty() {
664 left = left_iter.next();
665 }
666 self.append(ranks, value);
667 }
668 } else if left_ranks.start < right_ranks.start {
669 let (ranks, value) = replace(&mut left, left_iter.next()).unwrap();
670 self.append(ranks, value);
671 } else {
672 let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
673 self.append(ranks, value);
674 }
675 }
676
677 while let Some((left_ranks, left_value)) = left {
678 self.append(left_ranks, left_value);
679 left = left_iter.next();
680 }
681 while let Some((right_ranks, right_value)) = right {
682 self.append(right_ranks, right_value);
683 right = right_iter.next();
684 }
685 }
686
687 pub fn merge_into(self, other: &mut Self) {
689 other.merge_from(self);
690 }
691
692 fn append(&mut self, range: Range<usize>, value: T) {
693 if let Some(last) = self.intervals.last_mut()
694 && last.0.end == range.start
695 && last.1 == value
696 {
697 last.0.end = range.end;
698 } else {
699 self.intervals.push((range, value));
700 }
701 }
702}
703
704impl RankedValues<Status> {
705 pub fn first_terminating(&self) -> Option<(usize, Status)> {
706 self.intervals
707 .iter()
708 .find(|(_, status)| status.is_terminating())
709 .map(|(range, status)| (range.start, status.clone()))
710 }
711
712 pub fn first_failed(&self) -> Option<(usize, Status)> {
713 self.intervals
714 .iter()
715 .find(|(_, status)| matches!(status, Status::Failed(_) | Status::Timeout(_)))
716 .map(|(range, status)| (range.start, status.clone()))
717 }
718}
719
720impl<T> From<(usize, T)> for RankedValues<T> {
721 fn from((rank, value): (usize, T)) -> Self {
722 Self {
723 intervals: vec![(rank..rank + 1, value)],
724 }
725 }
726}
727
728impl<T> From<(Range<usize>, T)> for RankedValues<T> {
729 fn from((range, value): (Range<usize>, T)) -> Self {
730 Self {
731 intervals: vec![(range, value)],
732 }
733 }
734}
735
736#[derive(Clone, Debug)]
739pub struct ValuesByRank<T> {
740 values: HashMap<T, Vec<Range<usize>>>,
741}
742
743impl<T: Eq + Hash> PartialEq for ValuesByRank<T> {
744 fn eq(&self, other: &Self) -> bool {
745 self.values == other.values
746 }
747}
748
749impl<T: Eq + Hash> Eq for ValuesByRank<T> {}
750
751impl<T: fmt::Display> fmt::Display for ValuesByRank<T> {
752 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
753 let mut first_value = true;
754 for (value, ranges) in self.iter() {
755 if first_value {
756 first_value = false;
757 } else {
758 write!(f, ";")?;
759 }
760 write!(f, "{}=", value)?;
761 let mut first_range = true;
762 for range in ranges.iter() {
763 if first_range {
764 first_range = false;
765 } else {
766 write!(f, ",")?;
767 }
768 write!(f, "{}..{}", range.start, range.end)?;
769 }
770 }
771 Ok(())
772 }
773}
774
775impl<T> Deref for ValuesByRank<T> {
776 type Target = HashMap<T, Vec<Range<usize>>>;
777
778 fn deref(&self) -> &Self::Target {
779 &self.values
780 }
781}
782
783impl<T> DerefMut for ValuesByRank<T> {
784 fn deref_mut(&mut self) -> &mut Self::Target {
785 &mut self.values
786 }
787}
788
789#[cfg(test)]
792impl<T> FromIterator<(Range<usize>, T)> for RankedValues<T> {
793 fn from_iter<I: IntoIterator<Item = (Range<usize>, T)>>(iter: I) -> Self {
794 Self {
795 intervals: iter.into_iter().collect(),
796 }
797 }
798}
799
800#[derive(Clone, Debug, Serialize, Deserialize, Named, Default)]
802pub(crate) struct ProcSpec {
803 pub(crate) client_config_override: Attrs,
806 pub(crate) proc_bind: Option<ProcBind>,
808 pub(crate) bootstrap_command: Option<BootstrapCommand>,
811 pub(crate) host_mesh_name: Option<crate::Name>,
815}
816wirevalue::register_type!(ProcSpec);
817
818#[cfg(test)]
819mod tests {
820 use super::*;
821
822 #[test]
823 fn test_ranked_values_merge() {
824 #[derive(PartialEq, Debug, Eq, Clone)]
825 enum Side {
826 Left,
827 Right,
828 Both,
829 }
830 use Side::Both;
831 use Side::Left;
832 use Side::Right;
833
834 let mut left: RankedValues<Side> = [
835 (0..10, Left),
836 (15..20, Left),
837 (30..50, Both),
838 (60..70, Both),
839 ]
840 .into_iter()
841 .collect();
842
843 let right: RankedValues<Side> = [
844 (9..12, Right),
845 (25..30, Right),
846 (30..40, Both),
847 (40..50, Right),
848 (50..60, Both),
849 ]
850 .into_iter()
851 .collect();
852
853 left.merge_from(right);
854 assert_eq!(
855 left.iter().cloned().collect::<Vec<_>>(),
856 vec![
857 (0..9, Left),
858 (9..12, Right),
859 (15..20, Left),
860 (25..30, Right),
861 (30..40, Both),
862 (40..50, Right),
863 (50..70, Both)
865 ]
866 );
867
868 assert_eq!(left.rank(5), 5);
869 assert_eq!(left.rank(10), 10);
870 assert_eq!(left.rank(16), 13);
871 assert_eq!(left.rank(70), 62);
872 assert_eq!(left.rank(100), 62);
873 }
874
875 #[test]
876 fn test_equality() {
877 assert_eq!(
878 RankedValues::from((0..10, 123)),
879 RankedValues::from((0..10, 123))
880 );
881 assert_eq!(
882 RankedValues::from((0..10, Status::Failed("foo".to_string()))),
883 RankedValues::from((0..10, Status::Failed("foo".to_string()))),
884 );
885 }
886
887 #[test]
888 fn test_default_through_merging() {
889 let values: RankedValues<usize> =
890 [(0..10, 1), (15..20, 1), (30..50, 1)].into_iter().collect();
891
892 let mut default = RankedValues::from((0..50, 0));
893 default.merge_from(values);
894
895 assert_eq!(
896 default.iter().cloned().collect::<Vec<_>>(),
897 vec![
898 (0..10, 1),
899 (10..15, 0),
900 (15..20, 1),
901 (20..30, 0),
902 (30..50, 1)
903 ]
904 );
905 }
906}