1pub mod mesh;
13
14use core::slice::GetDisjointMutIndex as _;
15use std::collections::HashMap;
16use std::fmt;
17use std::fmt::Debug;
18use std::hash::Hash;
19use std::mem::replace;
20use std::mem::take;
21use std::ops::Deref;
22use std::ops::DerefMut;
23use std::ops::Range;
24use std::time::Duration;
25
26use enum_as_inner::EnumAsInner;
27use hyperactor::Bind;
28use hyperactor::HandleClient;
29use hyperactor::Handler;
30use hyperactor::PortRef;
31use hyperactor::RefClient;
32use hyperactor::RemoteMessage;
33use hyperactor::Unbind;
34use hyperactor::mailbox::PortReceiver;
35use hyperactor::message::Bind;
36use hyperactor::message::Bindings;
37use hyperactor::message::Unbind;
38use hyperactor_config::attrs::Attrs;
39use ndslice::Region;
40use ndslice::ViewExt;
41use serde::Deserialize;
42use serde::Serialize;
43use typeuri::Named;
44
45use crate::StatusOverlay;
46use crate::bootstrap;
47use crate::bootstrap::BootstrapCommand;
48use crate::bootstrap::ProcBind;
49use crate::host_mesh::host_agent::ProcState;
50use crate::mesh_id::ResourceId;
51use crate::proc_agent::ActorSpec;
52use crate::proc_agent::ActorState;
53
54#[derive(
56 Clone,
57 Debug,
58 Serialize,
59 Deserialize,
60 Named,
61 PartialOrd,
62 Ord,
63 PartialEq,
64 Eq,
65 Hash,
66 EnumAsInner,
67 strum::Display,
68 Bind,
69 Unbind
70)]
71pub enum Status {
72 NotExist,
74 Initializing,
76 Running,
78 Stopping,
80 Stopped,
82 #[strum(to_string = "Failed({0})")]
84 Failed(String),
85 #[strum(to_string = "Timeout({0:?})")]
87 Timeout(Duration),
88 Unknown,
90}
91
92impl Status {
93 pub fn is_terminating(&self) -> bool {
95 matches!(
96 self,
97 Status::Stopping | Status::Stopped | Status::Failed(_) | Status::Timeout(_)
98 )
99 }
100
101 pub fn is_failure(&self) -> bool {
105 matches!(self, Self::Failed(_) | Self::Timeout(_))
106 }
107
108 pub fn is_terminated(&self) -> bool {
111 matches!(
112 self,
113 Status::Stopped | Status::Failed(_) | Status::Timeout(_)
114 )
115 }
116
117 pub fn is_healthy(&self) -> bool {
118 matches!(self, Status::Initializing | Status::Running)
119 }
120
121 pub fn clamp_min(self, floor: Status) -> Status {
128 if floor.is_terminating() && !self.is_terminating() {
129 floor
130 } else {
131 self
132 }
133 }
134}
135
136impl From<bootstrap::ProcStatus> for Status {
137 fn from(status: bootstrap::ProcStatus) -> Self {
138 use bootstrap::ProcStatus;
139 match status {
140 ProcStatus::Starting => Status::Initializing,
141 ProcStatus::Running { .. } | ProcStatus::Ready { .. } => Status::Running,
142 ProcStatus::Stopping { .. } => Status::Stopping,
143 ProcStatus::Stopped { .. } => Status::Stopped,
144 ProcStatus::Failed { reason } => Status::Failed(reason),
145 ProcStatus::Killed { .. } => Status::Failed(format!("{}", status)),
146 }
147 }
148}
149
150impl From<crate::host::LocalProcStatus> for Status {
151 fn from(status: crate::host::LocalProcStatus) -> Self {
152 match status {
153 crate::host::LocalProcStatus::Stopping => Status::Stopping,
154 crate::host::LocalProcStatus::Stopped => Status::Stopped,
155 }
156 }
157}
158
159#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq, Default)]
163pub struct Rank(pub Option<usize>);
164wirevalue::register_type!(Rank);
165
166impl Rank {
167 pub fn new(rank: usize) -> Self {
169 Self(Some(rank))
170 }
171
172 pub fn unwrap(&self) -> usize {
174 self.0.unwrap()
175 }
176}
177
178impl Unbind for Rank {
179 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
180 bindings.push_back(self)
181 }
182}
183
184impl Bind for Rank {
185 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
186 let bound = bindings.try_pop_front::<Rank>()?;
187 self.0 = bound.0;
188 Ok(())
189 }
190}
191
192#[derive(
199 Clone,
200 Debug,
201 Serialize,
202 Deserialize,
203 Named,
204 Handler,
205 HandleClient,
206 RefClient,
207 Bind,
208 Unbind
209)]
210pub struct GetRankStatus {
211 pub id: ResourceId,
213 #[binding(include)]
215 pub reply: PortRef<StatusOverlay>,
216}
217
218#[derive(
222 Clone,
223 Debug,
224 Serialize,
225 Deserialize,
226 Named,
227 Handler,
228 HandleClient,
229 RefClient,
230 Bind,
231 Unbind
232)]
233pub struct WaitRankStatus {
234 pub id: ResourceId,
236 pub min_status: Status,
240 #[binding(include)]
242 pub reply: PortRef<StatusOverlay>,
243}
244
245impl GetRankStatus {
246 pub async fn wait(
247 mut rx: PortReceiver<crate::StatusMesh>,
248 num_ranks: usize,
249 max_idle_time: Duration,
250 region: Region, ) -> Result<crate::StatusMesh, crate::StatusMesh> {
252 debug_assert_eq!(region.num_ranks(), num_ranks, "region/num_ranks mismatch");
253
254 let mut alarm = hyperactor::time::Alarm::new();
255 alarm.arm(max_idle_time);
256
257 let mut snapshot =
259 crate::StatusMesh::from_single(region, crate::resource::Status::NotExist);
260
261 loop {
262 let mut sleeper = alarm.sleeper();
263 tokio::select! {
264 _ = sleeper.sleep() => return Err(snapshot),
265 next = rx.recv() => {
266 match next {
267 Ok(mesh) => { snapshot = mesh; } Err(_) => return Err(snapshot),
269 }
270 }
271 }
272
273 alarm.arm(max_idle_time);
274
275 if snapshot
279 .values()
280 .take(num_ranks)
281 .all(|s| !matches!(s, crate::resource::Status::NotExist))
282 {
283 break Ok(snapshot);
284 }
285 }
286 }
287}
288
289#[derive(
291 Clone,
292 Debug,
293 Serialize,
294 Deserialize,
295 Named,
296 PartialEq,
297 Eq,
298 Handler,
299 Bind,
300 Unbind
301)]
302pub struct State<S> {
303 pub id: ResourceId,
305 pub status: Status,
307 pub state: Option<S>,
309 pub generation: u64,
311 pub timestamp: std::time::SystemTime,
313}
314wirevalue::register_type!(State<ActorState>);
315wirevalue::register_type!(State<ProcState>);
316
317impl<S: Serialize> fmt::Display for State<S> {
318 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
319 match serde_json::to_string(self) {
321 Ok(json) => write!(f, "{}", json),
322 Err(e) => write!(f, "<state: serde_json error: {}>", e),
323 }
324 }
325}
326
327#[derive(
329 Debug,
330 Clone,
331 Serialize,
332 Deserialize,
333 Named,
334 Handler,
335 HandleClient,
336 RefClient,
337 Bind,
338 Unbind
339)]
340pub struct CreateOrUpdate<S> {
341 pub id: ResourceId,
343 #[binding(include)]
345 pub rank: Rank,
346 pub spec: S,
348}
349wirevalue::register_type!(CreateOrUpdate<ProcSpec>);
350wirevalue::register_type!(CreateOrUpdate<ActorSpec>);
351
352#[derive(
354 Debug,
355 Clone,
356 Serialize,
357 Deserialize,
358 Named,
359 Handler,
360 HandleClient,
361 RefClient,
362 Bind,
363 Unbind
364)]
365pub struct Stop {
366 pub id: ResourceId,
368 pub reason: String,
370}
371wirevalue::register_type!(Stop);
372
373#[derive(
377 Debug,
378 Clone,
379 Serialize,
380 Deserialize,
381 Named,
382 Handler,
383 HandleClient,
384 RefClient,
385 Bind,
386 Unbind
387)]
388pub struct StopAll {
389 pub reason: String,
391}
392wirevalue::register_type!(StopAll);
393
394#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
396pub struct GetState<S> {
397 pub id: ResourceId,
399 #[reply]
401 pub reply: PortRef<State<S>>,
402}
403wirevalue::register_type!(GetState<ProcState>);
404wirevalue::register_type!(GetState<ActorState>);
405
406impl<S> Unbind for GetState<S>
408where
409 S: RemoteMessage,
410 S: Unbind,
411{
412 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
413 self.reply.unbind(bindings)
414 }
415}
416
417impl<S> Bind for GetState<S>
418where
419 S: RemoteMessage,
420 S: Bind,
421{
422 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
423 self.reply.bind(bindings)
424 }
425}
426
427impl<S> Clone for GetState<S>
428where
429 S: RemoteMessage,
430{
431 fn clone(&self) -> Self {
432 Self {
433 id: self.id.clone(),
434 reply: self.reply.clone(),
435 }
436 }
437}
438
439#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
442pub struct KeepaliveGetState<S> {
443 pub expires_after: std::time::SystemTime,
446 pub get_state: GetState<S>,
447}
448wirevalue::register_type!(KeepaliveGetState<ProcState>);
449wirevalue::register_type!(KeepaliveGetState<ActorState>);
450
451impl<S> Unbind for KeepaliveGetState<S>
453where
454 S: RemoteMessage,
455 S: Unbind,
456{
457 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
458 self.get_state.unbind(bindings)
459 }
460}
461
462impl<S> Bind for KeepaliveGetState<S>
463where
464 S: RemoteMessage,
465 S: Bind,
466{
467 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
468 self.get_state.bind(bindings)
469 }
470}
471
472impl<S> Clone for KeepaliveGetState<S>
473where
474 S: RemoteMessage,
475{
476 fn clone(&self) -> Self {
477 Self {
478 expires_after: self.expires_after,
479 get_state: self.get_state.clone(),
480 }
481 }
482}
483
484#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
488pub struct StreamState<S> {
489 pub id: ResourceId,
491 pub subscriber: PortRef<State<S>>,
493}
494wirevalue::register_type!(StreamState<ActorState>);
495wirevalue::register_type!(StreamState<ProcState>);
496
497impl<S> Unbind for StreamState<S>
499where
500 S: RemoteMessage,
501 S: Unbind,
502{
503 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
504 self.subscriber.unbind(bindings)
505 }
506}
507
508impl<S> Bind for StreamState<S>
509where
510 S: RemoteMessage,
511 S: Bind,
512{
513 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
514 self.subscriber.bind(bindings)
515 }
516}
517
518impl<S> Clone for StreamState<S>
519where
520 S: RemoteMessage,
521{
522 fn clone(&self) -> Self {
523 Self {
524 id: self.id.clone(),
525 subscriber: self.subscriber.clone(),
526 }
527 }
528}
529
530#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
532pub struct List {
533 #[reply]
535 pub reply: PortRef<Vec<ResourceId>>,
536}
537wirevalue::register_type!(List);
538
539pub trait Resource {
541 type Spec: typeuri::Named
543 + Serialize
544 + for<'de> Deserialize<'de>
545 + Send
546 + Sync
547 + std::fmt::Debug;
548
549 type State: typeuri::Named
551 + Serialize
552 + for<'de> Deserialize<'de>
553 + Send
554 + Sync
555 + std::fmt::Debug;
556}
557
558hyperactor::behavior!(
560 Controller<R: Resource>,
561 CreateOrUpdate<R::Spec>,
562 GetState<R::State>,
563 Stop,
564);
565
566#[derive(Debug, Clone, Named, Serialize, Deserialize)]
571pub struct RankedValues<T> {
572 intervals: Vec<(Range<usize>, T)>,
573}
574
575impl<T: PartialEq> PartialEq for RankedValues<T> {
576 fn eq(&self, other: &Self) -> bool {
577 self.intervals == other.intervals
578 }
579}
580
581impl<T: Eq> Eq for RankedValues<T> {}
582
583impl<T> Default for RankedValues<T> {
584 fn default() -> Self {
585 Self {
586 intervals: Vec::new(),
587 }
588 }
589}
590
591impl<T> RankedValues<T> {
592 pub fn iter(&self) -> impl Iterator<Item = &(Range<usize>, T)> + '_ {
594 self.intervals.iter()
595 }
596
597 pub fn rank(&self, value: usize) -> usize {
600 self.iter()
601 .take_while(|(ranks, _)| ranks.start <= value)
602 .map(|(ranks, _)| ranks.end.min(value) - ranks.start)
603 .sum()
604 }
605}
606
607impl<T: Clone> RankedValues<T> {
608 pub fn materialized_iter(&self, until: usize) -> impl Iterator<Item = &T> + '_ {
609 assert_eq!(self.rank(until), until, "insufficient rank");
610 self.iter()
611 .flat_map(|(range, value)| std::iter::repeat_n(value, range.end - range.start))
612 }
613}
614
615impl<T: Hash + Eq + Clone> RankedValues<T> {
616 pub fn invert(&self) -> ValuesByRank<T> {
618 let mut inverted: HashMap<T, Vec<Range<usize>>> = HashMap::new();
619 for (range, value) in self.iter() {
620 inverted
621 .entry(value.clone())
622 .or_default()
623 .push(range.clone());
624 }
625 ValuesByRank { values: inverted }
626 }
627}
628
629impl<T: Eq + Clone> RankedValues<T> {
630 pub fn merge_from(&mut self, other: Self) {
638 let mut left_iter = take(&mut self.intervals).into_iter();
639 let mut right_iter = other.intervals.into_iter();
640
641 let mut left = left_iter.next();
642 let mut right = right_iter.next();
643
644 while left.is_some() && right.is_some() {
645 let (left_ranks, left_value) = left.as_mut().unwrap();
646 let (right_ranks, right_value) = right.as_mut().unwrap();
647
648 if left_ranks.is_overlapping(right_ranks) {
649 if left_value == right_value {
650 let ranks = left_ranks.start.min(right_ranks.start)..right_ranks.end;
651 let (_, value) = replace(&mut right, right_iter.next()).unwrap();
652 left_ranks.start = ranks.end;
653 if left_ranks.is_empty() {
654 left = left_iter.next();
655 }
656 self.append(ranks, value);
657 } else if left_ranks.start < right_ranks.start {
658 let ranks = left_ranks.start..right_ranks.start;
659 left_ranks.start = ranks.end;
660 self.append(ranks, left_value.clone());
662 } else {
663 let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
664 left_ranks.start = ranks.end;
665 if left_ranks.is_empty() {
666 left = left_iter.next();
667 }
668 self.append(ranks, value);
669 }
670 } else if left_ranks.start < right_ranks.start {
671 let (ranks, value) = replace(&mut left, left_iter.next()).unwrap();
672 self.append(ranks, value);
673 } else {
674 let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
675 self.append(ranks, value);
676 }
677 }
678
679 while let Some((left_ranks, left_value)) = left {
680 self.append(left_ranks, left_value);
681 left = left_iter.next();
682 }
683 while let Some((right_ranks, right_value)) = right {
684 self.append(right_ranks, right_value);
685 right = right_iter.next();
686 }
687 }
688
689 pub fn merge_into(self, other: &mut Self) {
691 other.merge_from(self);
692 }
693
694 fn append(&mut self, range: Range<usize>, value: T) {
695 if let Some(last) = self.intervals.last_mut()
696 && last.0.end == range.start
697 && last.1 == value
698 {
699 last.0.end = range.end;
700 } else {
701 self.intervals.push((range, value));
702 }
703 }
704}
705
706impl RankedValues<Status> {
707 pub fn first_terminating(&self) -> Option<(usize, Status)> {
708 self.intervals
709 .iter()
710 .find(|(_, status)| status.is_terminating())
711 .map(|(range, status)| (range.start, status.clone()))
712 }
713
714 pub fn first_failed(&self) -> Option<(usize, Status)> {
715 self.intervals
716 .iter()
717 .find(|(_, status)| matches!(status, Status::Failed(_) | Status::Timeout(_)))
718 .map(|(range, status)| (range.start, status.clone()))
719 }
720}
721
722impl<T> From<(usize, T)> for RankedValues<T> {
723 fn from((rank, value): (usize, T)) -> Self {
724 Self {
725 intervals: vec![(rank..rank + 1, value)],
726 }
727 }
728}
729
730impl<T> From<(Range<usize>, T)> for RankedValues<T> {
731 fn from((range, value): (Range<usize>, T)) -> Self {
732 Self {
733 intervals: vec![(range, value)],
734 }
735 }
736}
737
738#[derive(Clone, Debug)]
741pub struct ValuesByRank<T> {
742 values: HashMap<T, Vec<Range<usize>>>,
743}
744
745impl<T: Eq + Hash> PartialEq for ValuesByRank<T> {
746 fn eq(&self, other: &Self) -> bool {
747 self.values == other.values
748 }
749}
750
751impl<T: Eq + Hash> Eq for ValuesByRank<T> {}
752
753impl<T: fmt::Display> fmt::Display for ValuesByRank<T> {
754 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
755 let mut first_value = true;
756 for (value, ranges) in self.iter() {
757 if first_value {
758 first_value = false;
759 } else {
760 write!(f, ";")?;
761 }
762 write!(f, "{}=", value)?;
763 let mut first_range = true;
764 for range in ranges.iter() {
765 if first_range {
766 first_range = false;
767 } else {
768 write!(f, ",")?;
769 }
770 write!(f, "{}..{}", range.start, range.end)?;
771 }
772 }
773 Ok(())
774 }
775}
776
777impl<T> Deref for ValuesByRank<T> {
778 type Target = HashMap<T, Vec<Range<usize>>>;
779
780 fn deref(&self) -> &Self::Target {
781 &self.values
782 }
783}
784
785impl<T> DerefMut for ValuesByRank<T> {
786 fn deref_mut(&mut self) -> &mut Self::Target {
787 &mut self.values
788 }
789}
790
791#[cfg(test)]
794impl<T> FromIterator<(Range<usize>, T)> for RankedValues<T> {
795 fn from_iter<I: IntoIterator<Item = (Range<usize>, T)>>(iter: I) -> Self {
796 Self {
797 intervals: iter.into_iter().collect(),
798 }
799 }
800}
801
802#[derive(Clone, Debug, Serialize, Deserialize, Named, Default)]
804pub(crate) struct ProcSpec {
805 pub(crate) client_config_override: Attrs,
808 pub(crate) proc_bind: Option<ProcBind>,
810 pub(crate) bootstrap_command: Option<BootstrapCommand>,
813 pub(crate) host_mesh_id: Option<crate::mesh_id::HostMeshId>,
817}
818wirevalue::register_type!(ProcSpec);
819
820#[cfg(test)]
821mod tests {
822 use super::*;
823
824 #[test]
825 fn test_ranked_values_merge() {
826 #[derive(PartialEq, Debug, Eq, Clone)]
827 enum Side {
828 Left,
829 Right,
830 Both,
831 }
832 use Side::Both;
833 use Side::Left;
834 use Side::Right;
835
836 let mut left: RankedValues<Side> = [
837 (0..10, Left),
838 (15..20, Left),
839 (30..50, Both),
840 (60..70, Both),
841 ]
842 .into_iter()
843 .collect();
844
845 let right: RankedValues<Side> = [
846 (9..12, Right),
847 (25..30, Right),
848 (30..40, Both),
849 (40..50, Right),
850 (50..60, Both),
851 ]
852 .into_iter()
853 .collect();
854
855 left.merge_from(right);
856 assert_eq!(
857 left.iter().cloned().collect::<Vec<_>>(),
858 vec![
859 (0..9, Left),
860 (9..12, Right),
861 (15..20, Left),
862 (25..30, Right),
863 (30..40, Both),
864 (40..50, Right),
865 (50..70, Both)
867 ]
868 );
869
870 assert_eq!(left.rank(5), 5);
871 assert_eq!(left.rank(10), 10);
872 assert_eq!(left.rank(16), 13);
873 assert_eq!(left.rank(70), 62);
874 assert_eq!(left.rank(100), 62);
875 }
876
877 #[test]
878 fn test_equality() {
879 assert_eq!(
880 RankedValues::from((0..10, 123)),
881 RankedValues::from((0..10, 123))
882 );
883 assert_eq!(
884 RankedValues::from((0..10, Status::Failed("foo".to_string()))),
885 RankedValues::from((0..10, Status::Failed("foo".to_string()))),
886 );
887 }
888
889 #[test]
890 fn test_default_through_merging() {
891 let values: RankedValues<usize> =
892 [(0..10, 1), (15..20, 1), (30..50, 1)].into_iter().collect();
893
894 let mut default = RankedValues::from((0..50, 0));
895 default.merge_from(values);
896
897 assert_eq!(
898 default.iter().cloned().collect::<Vec<_>>(),
899 vec![
900 (0..10, 1),
901 (10..15, 0),
902 (15..20, 1),
903 (20..30, 0),
904 (30..50, 1)
905 ]
906 );
907 }
908}