1pub mod mesh;
13
14use core::slice::GetDisjointMutIndex as _;
15use std::collections::HashMap;
16use std::fmt;
17use std::fmt::Debug;
18use std::hash::Hash;
19use std::mem::replace;
20use std::mem::take;
21use std::ops::Deref;
22use std::ops::DerefMut;
23use std::ops::Range;
24use std::time::Duration;
25
26use enum_as_inner::EnumAsInner;
27use hyperactor::Bind;
28use hyperactor::HandleClient;
29use hyperactor::Handler;
30use hyperactor::RefClient;
31use hyperactor::RemoteMessage;
32use hyperactor::Unbind;
33use hyperactor::mailbox::PortReceiver;
34use hyperactor::message::Bind;
35use hyperactor::message::Bindings;
36use hyperactor::message::Unbind;
37use hyperactor::reference as hyperactor_reference;
38use hyperactor_config::attrs::Attrs;
39use ndslice::Region;
40use ndslice::ViewExt;
41use serde::Deserialize;
42use serde::Serialize;
43use typeuri::Named;
44
45use crate::Name;
46use crate::StatusOverlay;
47use crate::bootstrap;
48use crate::host_mesh::host_agent::ProcState;
49use crate::proc_agent::ActorSpec;
50use crate::proc_agent::ActorState;
51
52#[derive(
54 Clone,
55 Debug,
56 Serialize,
57 Deserialize,
58 Named,
59 PartialOrd,
60 Ord,
61 PartialEq,
62 Eq,
63 Hash,
64 EnumAsInner,
65 strum::Display
66)]
67pub enum Status {
68 NotExist,
70 Initializing,
72 Running,
74 Stopping,
76 Stopped,
78 #[strum(to_string = "Failed({0})")]
80 Failed(String),
81 #[strum(to_string = "Timeout({0:?})")]
83 Timeout(Duration),
84 Unknown,
86}
87
88impl Status {
89 pub fn is_terminating(&self) -> bool {
91 matches!(
92 self,
93 Status::Stopping | Status::Stopped | Status::Failed(_) | Status::Timeout(_)
94 )
95 }
96
97 pub fn is_failure(&self) -> bool {
101 matches!(self, Self::Failed(_) | Self::Timeout(_))
102 }
103
104 pub fn is_healthy(&self) -> bool {
105 matches!(self, Status::Initializing | Status::Running)
106 }
107}
108
109impl From<bootstrap::ProcStatus> for Status {
110 fn from(status: bootstrap::ProcStatus) -> Self {
111 use bootstrap::ProcStatus;
112 match status {
113 ProcStatus::Starting => Status::Initializing,
114 ProcStatus::Running { .. } | ProcStatus::Ready { .. } => Status::Running,
115 ProcStatus::Stopping { .. } => Status::Stopping,
116 ProcStatus::Stopped { .. } => Status::Stopped,
117 ProcStatus::Failed { reason } => Status::Failed(reason),
118 ProcStatus::Killed { .. } => Status::Failed(format!("{}", status)),
119 }
120 }
121}
122
123#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq, Default)]
127pub struct Rank(pub Option<usize>);
128wirevalue::register_type!(Rank);
129
130impl Rank {
131 pub fn new(rank: usize) -> Self {
133 Self(Some(rank))
134 }
135
136 pub fn unwrap(&self) -> usize {
138 self.0.unwrap()
139 }
140}
141
142impl Unbind for Rank {
143 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
144 bindings.push_back(self)
145 }
146}
147
148impl Bind for Rank {
149 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
150 let bound = bindings.try_pop_front::<Rank>()?;
151 self.0 = bound.0;
152 Ok(())
153 }
154}
155
156#[derive(
163 Clone,
164 Debug,
165 Serialize,
166 Deserialize,
167 Named,
168 Handler,
169 HandleClient,
170 RefClient,
171 Bind,
172 Unbind
173)]
174pub struct GetRankStatus {
175 pub name: Name,
177 #[binding(include)]
179 pub reply: hyperactor_reference::PortRef<StatusOverlay>,
180}
181
182impl GetRankStatus {
183 pub async fn wait(
184 mut rx: PortReceiver<crate::StatusMesh>,
185 num_ranks: usize,
186 max_idle_time: Duration,
187 region: Region, ) -> Result<crate::StatusMesh, crate::StatusMesh> {
189 debug_assert_eq!(region.num_ranks(), num_ranks, "region/num_ranks mismatch");
190
191 let mut alarm = hyperactor::time::Alarm::new();
192 alarm.arm(max_idle_time);
193
194 let mut snapshot =
196 crate::StatusMesh::from_single(region, crate::resource::Status::NotExist);
197
198 loop {
199 let mut sleeper = alarm.sleeper();
200 tokio::select! {
201 _ = sleeper.sleep() => return Err(snapshot),
202 next = rx.recv() => {
203 match next {
204 Ok(mesh) => { snapshot = mesh; } Err(_) => return Err(snapshot),
206 }
207 }
208 }
209
210 alarm.arm(max_idle_time);
211
212 if snapshot
216 .values()
217 .take(num_ranks)
218 .all(|s| !matches!(s, crate::resource::Status::NotExist))
219 {
220 break Ok(snapshot);
221 }
222 }
223 }
224}
225
226#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq)]
228pub struct State<S> {
229 pub name: Name,
231 pub status: Status,
233 pub state: Option<S>,
235}
236
237impl<S: Serialize> fmt::Display for State<S> {
238 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239 match serde_json::to_string(self) {
241 Ok(json) => write!(f, "{}", json),
242 Err(e) => write!(f, "<state: serde_json error: {}>", e),
243 }
244 }
245}
246
247#[derive(
249 Debug,
250 Clone,
251 Serialize,
252 Deserialize,
253 Named,
254 Handler,
255 HandleClient,
256 RefClient,
257 Bind,
258 Unbind
259)]
260pub struct CreateOrUpdate<S> {
261 pub name: Name,
263 #[binding(include)]
265 pub rank: Rank,
266 pub spec: S,
268}
269wirevalue::register_type!(CreateOrUpdate<ProcSpec>);
270wirevalue::register_type!(CreateOrUpdate<ActorSpec>);
271
272#[derive(
274 Debug,
275 Clone,
276 Serialize,
277 Deserialize,
278 Named,
279 Handler,
280 HandleClient,
281 RefClient,
282 Bind,
283 Unbind
284)]
285pub struct Stop {
286 pub name: Name,
288 pub reason: String,
290}
291wirevalue::register_type!(Stop);
292
293#[derive(
297 Debug,
298 Clone,
299 Serialize,
300 Deserialize,
301 Named,
302 Handler,
303 HandleClient,
304 RefClient,
305 Bind,
306 Unbind
307)]
308pub struct StopAll {
309 pub reason: String,
311}
312wirevalue::register_type!(StopAll);
313
314#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
316pub struct GetState<S> {
317 pub name: Name,
319 #[reply]
321 pub reply: hyperactor_reference::PortRef<State<S>>,
322}
323wirevalue::register_type!(GetState<ProcState>);
324wirevalue::register_type!(GetState<ActorState>);
325
326impl<S> Unbind for GetState<S>
328where
329 S: RemoteMessage,
330 S: Unbind,
331{
332 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
333 self.reply.unbind(bindings)
334 }
335}
336
337impl<S> Bind for GetState<S>
338where
339 S: RemoteMessage,
340 S: Bind,
341{
342 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
343 self.reply.bind(bindings)
344 }
345}
346
347impl<S> Clone for GetState<S>
348where
349 S: RemoteMessage,
350{
351 fn clone(&self) -> Self {
352 Self {
353 name: self.name.clone(),
354 reply: self.reply.clone(),
355 }
356 }
357}
358
359#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
362pub struct KeepaliveGetState<S> {
363 pub expires_after: std::time::SystemTime,
366 pub get_state: GetState<S>,
367}
368wirevalue::register_type!(KeepaliveGetState<ProcState>);
369wirevalue::register_type!(KeepaliveGetState<ActorState>);
370
371impl<S> Unbind for KeepaliveGetState<S>
373where
374 S: RemoteMessage,
375 S: Unbind,
376{
377 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
378 self.get_state.unbind(bindings)
379 }
380}
381
382impl<S> Bind for KeepaliveGetState<S>
383where
384 S: RemoteMessage,
385 S: Bind,
386{
387 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
388 self.get_state.bind(bindings)
389 }
390}
391
392impl<S> Clone for KeepaliveGetState<S>
393where
394 S: RemoteMessage,
395{
396 fn clone(&self) -> Self {
397 Self {
398 expires_after: self.expires_after.clone(),
399 get_state: self.get_state.clone(),
400 }
401 }
402}
403
404#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
406pub struct List {
407 #[reply]
409 pub reply: hyperactor_reference::PortRef<Vec<Name>>,
410}
411wirevalue::register_type!(List);
412
413pub trait Resource {
415 type Spec: typeuri::Named
417 + Serialize
418 + for<'de> Deserialize<'de>
419 + Send
420 + Sync
421 + std::fmt::Debug;
422
423 type State: typeuri::Named
425 + Serialize
426 + for<'de> Deserialize<'de>
427 + Send
428 + Sync
429 + std::fmt::Debug;
430}
431
432hyperactor::behavior!(
434 Controller<R: Resource>,
435 CreateOrUpdate<R::Spec>,
436 GetState<R::State>,
437 Stop,
438);
439
440#[derive(Debug, Clone, Named, Serialize, Deserialize)]
445pub struct RankedValues<T> {
446 intervals: Vec<(Range<usize>, T)>,
447}
448
449impl<T: PartialEq> PartialEq for RankedValues<T> {
450 fn eq(&self, other: &Self) -> bool {
451 self.intervals == other.intervals
452 }
453}
454
455impl<T: Eq> Eq for RankedValues<T> {}
456
457impl<T> Default for RankedValues<T> {
458 fn default() -> Self {
459 Self {
460 intervals: Vec::new(),
461 }
462 }
463}
464
465impl<T> RankedValues<T> {
466 pub fn iter(&self) -> impl Iterator<Item = &(Range<usize>, T)> + '_ {
468 self.intervals.iter()
469 }
470
471 pub fn rank(&self, value: usize) -> usize {
474 self.iter()
475 .take_while(|(ranks, _)| ranks.start <= value)
476 .map(|(ranks, _)| ranks.end.min(value) - ranks.start)
477 .sum()
478 }
479}
480
481impl<T: Clone> RankedValues<T> {
482 pub fn materialized_iter(&self, until: usize) -> impl Iterator<Item = &T> + '_ {
483 assert_eq!(self.rank(until), until, "insufficient rank");
484 self.iter()
485 .flat_map(|(range, value)| std::iter::repeat_n(value, range.end - range.start))
486 }
487}
488
489impl<T: Hash + Eq + Clone> RankedValues<T> {
490 pub fn invert(&self) -> ValuesByRank<T> {
492 let mut inverted: HashMap<T, Vec<Range<usize>>> = HashMap::new();
493 for (range, value) in self.iter() {
494 inverted
495 .entry(value.clone())
496 .or_default()
497 .push(range.clone());
498 }
499 ValuesByRank { values: inverted }
500 }
501}
502
503impl<T: Eq + Clone> RankedValues<T> {
504 pub fn merge_from(&mut self, other: Self) {
512 let mut left_iter = take(&mut self.intervals).into_iter();
513 let mut right_iter = other.intervals.into_iter();
514
515 let mut left = left_iter.next();
516 let mut right = right_iter.next();
517
518 while left.is_some() && right.is_some() {
519 let (left_ranks, left_value) = left.as_mut().unwrap();
520 let (right_ranks, right_value) = right.as_mut().unwrap();
521
522 if left_ranks.is_overlapping(right_ranks) {
523 if left_value == right_value {
524 let ranks = left_ranks.start.min(right_ranks.start)..right_ranks.end;
525 let (_, value) = replace(&mut right, right_iter.next()).unwrap();
526 left_ranks.start = ranks.end;
527 if left_ranks.is_empty() {
528 left = left_iter.next();
529 }
530 self.append(ranks, value);
531 } else if left_ranks.start < right_ranks.start {
532 let ranks = left_ranks.start..right_ranks.start;
533 left_ranks.start = ranks.end;
534 self.append(ranks, left_value.clone());
536 } else {
537 let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
538 left_ranks.start = ranks.end;
539 if left_ranks.is_empty() {
540 left = left_iter.next();
541 }
542 self.append(ranks, value);
543 }
544 } else if left_ranks.start < right_ranks.start {
545 let (ranks, value) = replace(&mut left, left_iter.next()).unwrap();
546 self.append(ranks, value);
547 } else {
548 let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
549 self.append(ranks, value);
550 }
551 }
552
553 while let Some((left_ranks, left_value)) = left {
554 self.append(left_ranks, left_value);
555 left = left_iter.next();
556 }
557 while let Some((right_ranks, right_value)) = right {
558 self.append(right_ranks, right_value);
559 right = right_iter.next();
560 }
561 }
562
563 pub fn merge_into(self, other: &mut Self) {
565 other.merge_from(self);
566 }
567
568 fn append(&mut self, range: Range<usize>, value: T) {
569 if let Some(last) = self.intervals.last_mut()
570 && last.0.end == range.start
571 && last.1 == value
572 {
573 last.0.end = range.end;
574 } else {
575 self.intervals.push((range, value));
576 }
577 }
578}
579
580impl RankedValues<Status> {
581 pub fn first_terminating(&self) -> Option<(usize, Status)> {
582 self.intervals
583 .iter()
584 .find(|(_, status)| status.is_terminating())
585 .map(|(range, status)| (range.start, status.clone()))
586 }
587
588 pub fn first_failed(&self) -> Option<(usize, Status)> {
589 self.intervals
590 .iter()
591 .find(|(_, status)| matches!(status, Status::Failed(_) | Status::Timeout(_)))
592 .map(|(range, status)| (range.start, status.clone()))
593 }
594}
595
596impl<T> From<(usize, T)> for RankedValues<T> {
597 fn from((rank, value): (usize, T)) -> Self {
598 Self {
599 intervals: vec![(rank..rank + 1, value)],
600 }
601 }
602}
603
604impl<T> From<(Range<usize>, T)> for RankedValues<T> {
605 fn from((range, value): (Range<usize>, T)) -> Self {
606 Self {
607 intervals: vec![(range, value)],
608 }
609 }
610}
611
612#[derive(Clone, Debug)]
615pub struct ValuesByRank<T> {
616 values: HashMap<T, Vec<Range<usize>>>,
617}
618
619impl<T: Eq + Hash> PartialEq for ValuesByRank<T> {
620 fn eq(&self, other: &Self) -> bool {
621 self.values == other.values
622 }
623}
624
625impl<T: Eq + Hash> Eq for ValuesByRank<T> {}
626
627impl<T: fmt::Display> fmt::Display for ValuesByRank<T> {
628 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
629 let mut first_value = true;
630 for (value, ranges) in self.iter() {
631 if first_value {
632 first_value = false;
633 } else {
634 write!(f, ";")?;
635 }
636 write!(f, "{}=", value)?;
637 let mut first_range = true;
638 for range in ranges.iter() {
639 if first_range {
640 first_range = false;
641 } else {
642 write!(f, ",")?;
643 }
644 write!(f, "{}..{}", range.start, range.end)?;
645 }
646 }
647 Ok(())
648 }
649}
650
651impl<T> Deref for ValuesByRank<T> {
652 type Target = HashMap<T, Vec<Range<usize>>>;
653
654 fn deref(&self) -> &Self::Target {
655 &self.values
656 }
657}
658
659impl<T> DerefMut for ValuesByRank<T> {
660 fn deref_mut(&mut self) -> &mut Self::Target {
661 &mut self.values
662 }
663}
664
665#[cfg(test)]
668impl<T> FromIterator<(Range<usize>, T)> for RankedValues<T> {
669 fn from_iter<I: IntoIterator<Item = (Range<usize>, T)>>(iter: I) -> Self {
670 Self {
671 intervals: iter.into_iter().collect(),
672 }
673 }
674}
675
676#[derive(Clone, Debug, Serialize, Deserialize, Named, Default)]
678pub(crate) struct ProcSpec {
679 pub(crate) client_config_override: Attrs,
682}
683wirevalue::register_type!(ProcSpec);
684
685impl ProcSpec {
686 pub(crate) fn new(client_config_override: Attrs) -> Self {
687 Self {
688 client_config_override,
689 }
690 }
691}
692
693#[cfg(test)]
694mod tests {
695 use super::*;
696
697 #[test]
698 fn test_ranked_values_merge() {
699 #[derive(PartialEq, Debug, Eq, Clone)]
700 enum Side {
701 Left,
702 Right,
703 Both,
704 }
705 use Side::Both;
706 use Side::Left;
707 use Side::Right;
708
709 let mut left: RankedValues<Side> = [
710 (0..10, Left),
711 (15..20, Left),
712 (30..50, Both),
713 (60..70, Both),
714 ]
715 .into_iter()
716 .collect();
717
718 let right: RankedValues<Side> = [
719 (9..12, Right),
720 (25..30, Right),
721 (30..40, Both),
722 (40..50, Right),
723 (50..60, Both),
724 ]
725 .into_iter()
726 .collect();
727
728 left.merge_from(right);
729 assert_eq!(
730 left.iter().cloned().collect::<Vec<_>>(),
731 vec![
732 (0..9, Left),
733 (9..12, Right),
734 (15..20, Left),
735 (25..30, Right),
736 (30..40, Both),
737 (40..50, Right),
738 (50..70, Both)
740 ]
741 );
742
743 assert_eq!(left.rank(5), 5);
744 assert_eq!(left.rank(10), 10);
745 assert_eq!(left.rank(16), 13);
746 assert_eq!(left.rank(70), 62);
747 assert_eq!(left.rank(100), 62);
748 }
749
750 #[test]
751 fn test_equality() {
752 assert_eq!(
753 RankedValues::from((0..10, 123)),
754 RankedValues::from((0..10, 123))
755 );
756 assert_eq!(
757 RankedValues::from((0..10, Status::Failed("foo".to_string()))),
758 RankedValues::from((0..10, Status::Failed("foo".to_string()))),
759 );
760 }
761
762 #[test]
763 fn test_default_through_merging() {
764 let values: RankedValues<usize> =
765 [(0..10, 1), (15..20, 1), (30..50, 1)].into_iter().collect();
766
767 let mut default = RankedValues::from((0..50, 0));
768 default.merge_from(values);
769
770 assert_eq!(
771 default.iter().cloned().collect::<Vec<_>>(),
772 vec![
773 (0..10, 1),
774 (10..15, 0),
775 (15..20, 1),
776 (20..30, 0),
777 (30..50, 1)
778 ]
779 );
780 }
781}