1pub mod mesh;
13
14use core::slice::GetDisjointMutIndex as _;
15use std::collections::HashMap;
16use std::fmt;
17use std::fmt::Debug;
18use std::hash::Hash;
19use std::mem::replace;
20use std::mem::take;
21use std::ops::Deref;
22use std::ops::DerefMut;
23use std::ops::Range;
24use std::time::Duration;
25
26use enum_as_inner::EnumAsInner;
27use hyperactor::Bind;
28use hyperactor::HandleClient;
29use hyperactor::Handler;
30use hyperactor::PortRef;
31use hyperactor::RefClient;
32use hyperactor::RemoteMessage;
33use hyperactor::Unbind;
34use hyperactor::mailbox::PortReceiver;
35use hyperactor::message::Bind;
36use hyperactor::message::Bindings;
37use hyperactor::message::Unbind;
38use hyperactor_config::attrs::Attrs;
39use ndslice::Region;
40use ndslice::ViewExt;
41use serde::Deserialize;
42use serde::Serialize;
43use typeuri::Named;
44
45use crate::bootstrap;
46use crate::v1::Name;
47use crate::v1::StatusOverlay;
48
49#[derive(
51 Clone,
52 Debug,
53 Serialize,
54 Deserialize,
55 Named,
56 PartialOrd,
57 Ord,
58 PartialEq,
59 Eq,
60 Hash,
61 EnumAsInner,
62 strum::Display
63)]
64pub enum Status {
65 NotExist,
67 Initializing,
69 Running,
71 Stopping,
73 Stopped,
75 #[strum(to_string = "Failed({0})")]
77 Failed(String),
78 #[strum(to_string = "Timeout({0:?})")]
80 Timeout(Duration),
81}
82
83impl Status {
84 pub fn is_terminating(&self) -> bool {
86 matches!(
87 self,
88 Status::Stopping | Status::Stopped | Status::Failed(_) | Status::Timeout(_)
89 )
90 }
91
92 pub fn is_failure(&self) -> bool {
96 matches!(self, Self::Failed(_) | Self::Timeout(_))
97 }
98
99 pub fn is_healthy(&self) -> bool {
100 matches!(self, Status::Initializing | Status::Running)
101 }
102}
103
104impl From<bootstrap::ProcStatus> for Status {
105 fn from(status: bootstrap::ProcStatus) -> Self {
106 use bootstrap::ProcStatus;
107 match status {
108 ProcStatus::Starting => Status::Initializing,
109 ProcStatus::Running { .. } | ProcStatus::Ready { .. } => Status::Running,
110 ProcStatus::Stopping { .. } => Status::Stopping,
111 ProcStatus::Stopped { .. } => Status::Stopped,
112 ProcStatus::Failed { reason } => Status::Failed(reason),
113 ProcStatus::Killed { .. } => Status::Failed(format!("{}", status)),
114 }
115 }
116}
117
118#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq, Default)]
122pub struct Rank(pub Option<usize>);
123wirevalue::register_type!(Rank);
124
125impl Rank {
126 pub fn new(rank: usize) -> Self {
128 Self(Some(rank))
129 }
130
131 pub fn unwrap(&self) -> usize {
133 self.0.unwrap()
134 }
135}
136
137impl Unbind for Rank {
138 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
139 bindings.push_back(self)
140 }
141}
142
143impl Bind for Rank {
144 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
145 let bound = bindings.try_pop_front::<Rank>()?;
146 self.0 = bound.0;
147 Ok(())
148 }
149}
150
151#[derive(
158 Clone,
159 Debug,
160 Serialize,
161 Deserialize,
162 Named,
163 Handler,
164 HandleClient,
165 RefClient,
166 Bind,
167 Unbind
168)]
169pub struct GetRankStatus {
170 pub name: Name,
172 #[binding(include)]
174 pub reply: PortRef<StatusOverlay>,
175}
176
177impl GetRankStatus {
178 pub async fn wait(
179 mut rx: PortReceiver<crate::v1::StatusMesh>,
180 num_ranks: usize,
181 max_idle_time: Duration,
182 region: Region, ) -> Result<crate::v1::StatusMesh, crate::v1::StatusMesh> {
184 debug_assert_eq!(region.num_ranks(), num_ranks, "region/num_ranks mismatch");
185
186 let mut alarm = hyperactor::time::Alarm::new();
187 alarm.arm(max_idle_time);
188
189 let mut snapshot =
191 crate::v1::StatusMesh::from_single(region, crate::resource::Status::NotExist);
192
193 loop {
194 let mut sleeper = alarm.sleeper();
195 tokio::select! {
196 _ = sleeper.sleep() => return Err(snapshot),
197 next = rx.recv() => {
198 match next {
199 Ok(mesh) => { snapshot = mesh; } Err(_) => return Err(snapshot),
201 }
202 }
203 }
204
205 alarm.arm(max_idle_time);
206
207 if snapshot
211 .values()
212 .take(num_ranks)
213 .all(|s| !matches!(s, crate::resource::Status::NotExist))
214 {
215 break Ok(snapshot);
216 }
217 }
218 }
219}
220
221#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq)]
223pub struct State<S> {
224 pub name: Name,
226 pub status: Status,
228 pub state: Option<S>,
230}
231
232impl<S: Serialize> fmt::Display for State<S> {
233 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234 match serde_json::to_string(self) {
236 Ok(json) => write!(f, "{}", json),
237 Err(e) => write!(f, "<state: serde_json error: {}>", e),
238 }
239 }
240}
241
242#[derive(
244 Debug,
245 Clone,
246 Serialize,
247 Deserialize,
248 Named,
249 Handler,
250 HandleClient,
251 RefClient,
252 Bind,
253 Unbind
254)]
255pub struct CreateOrUpdate<S> {
256 pub name: Name,
258 #[binding(include)]
260 pub rank: Rank,
261 pub spec: S,
263}
264
265#[derive(
267 Debug,
268 Clone,
269 Serialize,
270 Deserialize,
271 Named,
272 Handler,
273 HandleClient,
274 RefClient,
275 Bind,
276 Unbind
277)]
278pub struct Stop {
279 pub name: Name,
281}
282
283#[derive(
287 Debug,
288 Clone,
289 Serialize,
290 Deserialize,
291 Named,
292 Handler,
293 HandleClient,
294 RefClient,
295 Bind,
296 Unbind
297)]
298pub struct StopAll {}
299
300#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
302pub struct GetState<S> {
303 pub name: Name,
305 #[reply]
307 pub reply: PortRef<State<S>>,
308}
309
310impl<S> Unbind for GetState<S>
312where
313 S: RemoteMessage,
314 S: Unbind,
315{
316 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
317 self.reply.unbind(bindings)
318 }
319}
320
321impl<S> Bind for GetState<S>
322where
323 S: RemoteMessage,
324 S: Bind,
325{
326 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
327 self.reply.bind(bindings)
328 }
329}
330
331impl<S> Clone for GetState<S>
332where
333 S: RemoteMessage,
334{
335 fn clone(&self) -> Self {
336 Self {
337 name: self.name.clone(),
338 reply: self.reply.clone(),
339 }
340 }
341}
342
343#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
345pub struct List {
346 #[reply]
348 pub reply: PortRef<Vec<Name>>,
349}
350wirevalue::register_type!(List);
351
352pub trait Resource {
354 type Spec: typeuri::Named
356 + Serialize
357 + for<'de> Deserialize<'de>
358 + Send
359 + Sync
360 + std::fmt::Debug;
361
362 type State: typeuri::Named
364 + Serialize
365 + for<'de> Deserialize<'de>
366 + Send
367 + Sync
368 + std::fmt::Debug;
369}
370
371hyperactor::behavior!(
373 Controller<R: Resource>,
374 CreateOrUpdate<R::Spec>,
375 GetState<R::State>,
376 Stop,
377);
378
379#[derive(Debug, Clone, Named, Serialize, Deserialize)]
384pub struct RankedValues<T> {
385 intervals: Vec<(Range<usize>, T)>,
386}
387
388impl<T: PartialEq> PartialEq for RankedValues<T> {
389 fn eq(&self, other: &Self) -> bool {
390 self.intervals == other.intervals
391 }
392}
393
394impl<T: Eq> Eq for RankedValues<T> {}
395
396impl<T> Default for RankedValues<T> {
397 fn default() -> Self {
398 Self {
399 intervals: Vec::new(),
400 }
401 }
402}
403
404impl<T> RankedValues<T> {
405 pub fn iter(&self) -> impl Iterator<Item = &(Range<usize>, T)> + '_ {
407 self.intervals.iter()
408 }
409
410 pub fn rank(&self, value: usize) -> usize {
413 self.iter()
414 .take_while(|(ranks, _)| ranks.start <= value)
415 .map(|(ranks, _)| ranks.end.min(value) - ranks.start)
416 .sum()
417 }
418}
419
420impl<T: Clone> RankedValues<T> {
421 pub fn materialized_iter(&self, until: usize) -> impl Iterator<Item = &T> + '_ {
422 assert_eq!(self.rank(until), until, "insufficient rank");
423 self.iter()
424 .flat_map(|(range, value)| std::iter::repeat_n(value, range.end - range.start))
425 }
426}
427
428impl<T: Hash + Eq + Clone> RankedValues<T> {
429 pub fn invert(&self) -> ValuesByRank<T> {
431 let mut inverted: HashMap<T, Vec<Range<usize>>> = HashMap::new();
432 for (range, value) in self.iter() {
433 inverted
434 .entry(value.clone())
435 .or_default()
436 .push(range.clone());
437 }
438 ValuesByRank { values: inverted }
439 }
440}
441
442impl<T: Eq + Clone> RankedValues<T> {
443 pub fn merge_from(&mut self, other: Self) {
451 let mut left_iter = take(&mut self.intervals).into_iter();
452 let mut right_iter = other.intervals.into_iter();
453
454 let mut left = left_iter.next();
455 let mut right = right_iter.next();
456
457 while left.is_some() && right.is_some() {
458 let (left_ranks, left_value) = left.as_mut().unwrap();
459 let (right_ranks, right_value) = right.as_mut().unwrap();
460
461 if left_ranks.is_overlapping(right_ranks) {
462 if left_value == right_value {
463 let ranks = left_ranks.start.min(right_ranks.start)..right_ranks.end;
464 let (_, value) = replace(&mut right, right_iter.next()).unwrap();
465 left_ranks.start = ranks.end;
466 if left_ranks.is_empty() {
467 left = left_iter.next();
468 }
469 self.append(ranks, value);
470 } else if left_ranks.start < right_ranks.start {
471 let ranks = left_ranks.start..right_ranks.start;
472 left_ranks.start = ranks.end;
473 self.append(ranks, left_value.clone());
475 } else {
476 let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
477 left_ranks.start = ranks.end;
478 if left_ranks.is_empty() {
479 left = left_iter.next();
480 }
481 self.append(ranks, value);
482 }
483 } else if left_ranks.start < right_ranks.start {
484 let (ranks, value) = replace(&mut left, left_iter.next()).unwrap();
485 self.append(ranks, value);
486 } else {
487 let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
488 self.append(ranks, value);
489 }
490 }
491
492 while let Some((left_ranks, left_value)) = left {
493 self.append(left_ranks, left_value);
494 left = left_iter.next();
495 }
496 while let Some((right_ranks, right_value)) = right {
497 self.append(right_ranks, right_value);
498 right = right_iter.next();
499 }
500 }
501
502 pub fn merge_into(self, other: &mut Self) {
504 other.merge_from(self);
505 }
506
507 fn append(&mut self, range: Range<usize>, value: T) {
508 if let Some(last) = self.intervals.last_mut()
509 && last.0.end == range.start
510 && last.1 == value
511 {
512 last.0.end = range.end;
513 } else {
514 self.intervals.push((range, value));
515 }
516 }
517}
518
519impl RankedValues<Status> {
520 pub fn first_terminating(&self) -> Option<(usize, Status)> {
521 self.intervals
522 .iter()
523 .find(|(_, status)| status.is_terminating())
524 .map(|(range, status)| (range.start, status.clone()))
525 }
526
527 pub fn first_failed(&self) -> Option<(usize, Status)> {
528 self.intervals
529 .iter()
530 .find(|(_, status)| matches!(status, Status::Failed(_) | Status::Timeout(_)))
531 .map(|(range, status)| (range.start, status.clone()))
532 }
533}
534
535impl<T> From<(usize, T)> for RankedValues<T> {
536 fn from((rank, value): (usize, T)) -> Self {
537 Self {
538 intervals: vec![(rank..rank + 1, value)],
539 }
540 }
541}
542
543impl<T> From<(Range<usize>, T)> for RankedValues<T> {
544 fn from((range, value): (Range<usize>, T)) -> Self {
545 Self {
546 intervals: vec![(range, value)],
547 }
548 }
549}
550
551#[derive(Clone, Debug)]
554pub struct ValuesByRank<T> {
555 values: HashMap<T, Vec<Range<usize>>>,
556}
557
558impl<T: Eq + Hash> PartialEq for ValuesByRank<T> {
559 fn eq(&self, other: &Self) -> bool {
560 self.values == other.values
561 }
562}
563
564impl<T: Eq + Hash> Eq for ValuesByRank<T> {}
565
566impl<T: fmt::Display> fmt::Display for ValuesByRank<T> {
567 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
568 let mut first_value = true;
569 for (value, ranges) in self.iter() {
570 if first_value {
571 first_value = false;
572 } else {
573 write!(f, ";")?;
574 }
575 write!(f, "{}=", value)?;
576 let mut first_range = true;
577 for range in ranges.iter() {
578 if first_range {
579 first_range = false;
580 } else {
581 write!(f, ",")?;
582 }
583 write!(f, "{}..{}", range.start, range.end)?;
584 }
585 }
586 Ok(())
587 }
588}
589
590impl<T> Deref for ValuesByRank<T> {
591 type Target = HashMap<T, Vec<Range<usize>>>;
592
593 fn deref(&self) -> &Self::Target {
594 &self.values
595 }
596}
597
598impl<T> DerefMut for ValuesByRank<T> {
599 fn deref_mut(&mut self) -> &mut Self::Target {
600 &mut self.values
601 }
602}
603
604#[cfg(test)]
607impl<T> FromIterator<(Range<usize>, T)> for RankedValues<T> {
608 fn from_iter<I: IntoIterator<Item = (Range<usize>, T)>>(iter: I) -> Self {
609 Self {
610 intervals: iter.into_iter().collect(),
611 }
612 }
613}
614
615#[derive(Clone, Debug, Serialize, Deserialize, Named, Default)]
617pub(crate) struct ProcSpec {
618 pub(crate) client_config_override: Attrs,
621}
622wirevalue::register_type!(ProcSpec);
623
624impl ProcSpec {
625 pub(crate) fn new(client_config_override: Attrs) -> Self {
626 Self {
627 client_config_override,
628 }
629 }
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635
636 #[test]
637 fn test_ranked_values_merge() {
638 #[derive(PartialEq, Debug, Eq, Clone)]
639 enum Side {
640 Left,
641 Right,
642 Both,
643 }
644 use Side::Both;
645 use Side::Left;
646 use Side::Right;
647
648 let mut left: RankedValues<Side> = [
649 (0..10, Left),
650 (15..20, Left),
651 (30..50, Both),
652 (60..70, Both),
653 ]
654 .into_iter()
655 .collect();
656
657 let right: RankedValues<Side> = [
658 (9..12, Right),
659 (25..30, Right),
660 (30..40, Both),
661 (40..50, Right),
662 (50..60, Both),
663 ]
664 .into_iter()
665 .collect();
666
667 left.merge_from(right);
668 assert_eq!(
669 left.iter().cloned().collect::<Vec<_>>(),
670 vec![
671 (0..9, Left),
672 (9..12, Right),
673 (15..20, Left),
674 (25..30, Right),
675 (30..40, Both),
676 (40..50, Right),
677 (50..70, Both)
679 ]
680 );
681
682 assert_eq!(left.rank(5), 5);
683 assert_eq!(left.rank(10), 10);
684 assert_eq!(left.rank(16), 13);
685 assert_eq!(left.rank(70), 62);
686 assert_eq!(left.rank(100), 62);
687 }
688
689 #[test]
690 fn test_equality() {
691 assert_eq!(
692 RankedValues::from((0..10, 123)),
693 RankedValues::from((0..10, 123))
694 );
695 assert_eq!(
696 RankedValues::from((0..10, Status::Failed("foo".to_string()))),
697 RankedValues::from((0..10, Status::Failed("foo".to_string()))),
698 );
699 }
700
701 #[test]
702 fn test_default_through_merging() {
703 let values: RankedValues<usize> =
704 [(0..10, 1), (15..20, 1), (30..50, 1)].into_iter().collect();
705
706 let mut default = RankedValues::from((0..50, 0));
707 default.merge_from(values);
708
709 assert_eq!(
710 default.iter().cloned().collect::<Vec<_>>(),
711 vec![
712 (0..10, 1),
713 (10..15, 0),
714 (15..20, 1),
715 (20..30, 0),
716 (30..50, 1)
717 ]
718 );
719 }
720}