1pub mod mesh;
13
14use core::slice::GetDisjointMutIndex as _;
15use std::collections::HashMap;
16use std::fmt;
17use std::fmt::Debug;
18use std::hash::Hash;
19use std::mem::replace;
20use std::mem::take;
21use std::ops::Deref;
22use std::ops::DerefMut;
23use std::ops::Range;
24use std::time::Duration;
25
26use enum_as_inner::EnumAsInner;
27use hyperactor::Bind;
28use hyperactor::HandleClient;
29use hyperactor::Handler;
30use hyperactor::Named;
31use hyperactor::PortRef;
32use hyperactor::RefClient;
33use hyperactor::RemoteMessage;
34use hyperactor::Unbind;
35use hyperactor::attrs::Attrs;
36use hyperactor::mailbox::PortReceiver;
37use hyperactor::message::Bind;
38use hyperactor::message::Bindings;
39use hyperactor::message::Unbind;
40use ndslice::Region;
41use ndslice::ViewExt;
42use serde::Deserialize;
43use serde::Serialize;
44
45use crate::bootstrap;
46use crate::v1::Name;
47use crate::v1::StatusOverlay;
48
49#[derive(
51 Clone,
52 Debug,
53 Serialize,
54 Deserialize,
55 Named,
56 PartialOrd,
57 Ord,
58 PartialEq,
59 Eq,
60 Hash,
61 EnumAsInner,
62 strum::Display
63)]
64pub enum Status {
65 NotExist,
67 Initializing,
69 Running,
71 Stopping,
73 Stopped,
75 #[strum(to_string = "Failed({0})")]
77 Failed(String),
78 #[strum(to_string = "Timeout({0:?})")]
80 Timeout(Duration),
81}
82
83impl Status {
84 pub fn is_terminating(&self) -> bool {
86 matches!(
87 self,
88 Status::Stopping | Status::Stopped | Status::Failed(_) | Status::Timeout(_)
89 )
90 }
91
92 pub fn is_failure(&self) -> bool {
96 matches!(self, Self::Failed(_) | Self::Timeout(_))
97 }
98
99 pub fn is_healthy(&self) -> bool {
100 matches!(self, Status::Initializing | Status::Running)
101 }
102}
103
104impl From<bootstrap::ProcStatus> for Status {
105 fn from(status: bootstrap::ProcStatus) -> Self {
106 use bootstrap::ProcStatus;
107 match status {
108 ProcStatus::Starting => Status::Initializing,
109 ProcStatus::Running { .. } | ProcStatus::Ready { .. } => Status::Running,
110 ProcStatus::Stopping { .. } => Status::Stopping,
111 ProcStatus::Stopped { .. } => Status::Stopped,
112 ProcStatus::Failed { reason } => Status::Failed(reason),
113 ProcStatus::Killed { .. } => Status::Failed(format!("{}", status)),
114 }
115 }
116}
117
118#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq, Default)]
122pub struct Rank(pub Option<usize>);
123
124impl Rank {
125 pub fn new(rank: usize) -> Self {
127 Self(Some(rank))
128 }
129
130 pub fn unwrap(&self) -> usize {
132 self.0.unwrap()
133 }
134}
135
136impl Unbind for Rank {
137 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
138 bindings.push_back(self)
139 }
140}
141
142impl Bind for Rank {
143 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
144 let bound = bindings.try_pop_front::<Rank>()?;
145 self.0 = bound.0;
146 Ok(())
147 }
148}
149
150#[derive(
157 Clone,
158 Debug,
159 Serialize,
160 Deserialize,
161 Named,
162 Handler,
163 HandleClient,
164 RefClient,
165 Bind,
166 Unbind
167)]
168pub struct GetRankStatus {
169 pub name: Name,
171 #[binding(include)]
173 pub reply: PortRef<StatusOverlay>,
174}
175
176impl GetRankStatus {
177 pub async fn wait(
178 mut rx: PortReceiver<crate::v1::StatusMesh>,
179 num_ranks: usize,
180 max_idle_time: Duration,
181 region: Region, ) -> Result<crate::v1::StatusMesh, crate::v1::StatusMesh> {
183 debug_assert_eq!(region.num_ranks(), num_ranks, "region/num_ranks mismatch");
184
185 let mut alarm = hyperactor::time::Alarm::new();
186 alarm.arm(max_idle_time);
187
188 let mut snapshot =
190 crate::v1::StatusMesh::from_single(region, crate::resource::Status::NotExist);
191
192 loop {
193 let mut sleeper = alarm.sleeper();
194 tokio::select! {
195 _ = sleeper.sleep() => return Err(snapshot),
196 next = rx.recv() => {
197 match next {
198 Ok(mesh) => { snapshot = mesh; } Err(_) => return Err(snapshot),
200 }
201 }
202 }
203
204 alarm.arm(max_idle_time);
205
206 if snapshot
210 .values()
211 .take(num_ranks)
212 .all(|s| !matches!(s, crate::resource::Status::NotExist))
213 {
214 break Ok(snapshot);
215 }
216 }
217 }
218}
219
220#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq)]
222pub struct State<S> {
223 pub name: Name,
225 pub status: Status,
227 pub state: Option<S>,
229}
230
231impl<S: Serialize> fmt::Display for State<S> {
232 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233 match serde_json::to_string(self) {
235 Ok(json) => write!(f, "{}", json),
236 Err(e) => write!(f, "<state: serde_json error: {}>", e),
237 }
238 }
239}
240
241#[derive(
243 Debug,
244 Clone,
245 Serialize,
246 Deserialize,
247 Named,
248 Handler,
249 HandleClient,
250 RefClient,
251 Bind,
252 Unbind
253)]
254pub struct CreateOrUpdate<S> {
255 pub name: Name,
257 #[binding(include)]
259 pub rank: Rank,
260 pub spec: S,
262}
263
264#[derive(
266 Debug,
267 Clone,
268 Serialize,
269 Deserialize,
270 Named,
271 Handler,
272 HandleClient,
273 RefClient,
274 Bind,
275 Unbind
276)]
277pub struct Stop {
278 pub name: Name,
280}
281
282#[derive(
286 Debug,
287 Clone,
288 Serialize,
289 Deserialize,
290 Named,
291 Handler,
292 HandleClient,
293 RefClient,
294 Bind,
295 Unbind
296)]
297pub struct StopAll {}
298
299#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
301pub struct GetState<S> {
302 pub name: Name,
304 #[reply]
306 pub reply: PortRef<State<S>>,
307}
308
309impl<S> Unbind for GetState<S>
311where
312 S: RemoteMessage,
313 S: Unbind,
314{
315 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
316 self.reply.unbind(bindings)
317 }
318}
319
320impl<S> Bind for GetState<S>
321where
322 S: RemoteMessage,
323 S: Bind,
324{
325 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
326 self.reply.bind(bindings)
327 }
328}
329
330impl<S> Clone for GetState<S>
331where
332 S: RemoteMessage,
333{
334 fn clone(&self) -> Self {
335 Self {
336 name: self.name.clone(),
337 reply: self.reply.clone(),
338 }
339 }
340}
341
342pub trait Resource {
344 type Spec: Named + Serialize + for<'de> Deserialize<'de> + Send + Sync + std::fmt::Debug;
346
347 type State: Named + Serialize + for<'de> Deserialize<'de> + Send + Sync + std::fmt::Debug;
349}
350
351hyperactor::behavior!(
353 Controller<R: Resource>,
354 CreateOrUpdate<R::Spec>,
355 GetState<R::State>,
356 Stop,
357);
358
359#[derive(Debug, Clone, Named, Serialize, Deserialize)]
364pub struct RankedValues<T> {
365 intervals: Vec<(Range<usize>, T)>,
366}
367
368impl<T: PartialEq> PartialEq for RankedValues<T> {
369 fn eq(&self, other: &Self) -> bool {
370 self.intervals == other.intervals
371 }
372}
373
374impl<T: Eq> Eq for RankedValues<T> {}
375
376impl<T> Default for RankedValues<T> {
377 fn default() -> Self {
378 Self {
379 intervals: Vec::new(),
380 }
381 }
382}
383
384impl<T> RankedValues<T> {
385 pub fn iter(&self) -> impl Iterator<Item = &(Range<usize>, T)> + '_ {
387 self.intervals.iter()
388 }
389
390 pub fn rank(&self, value: usize) -> usize {
393 self.iter()
394 .take_while(|(ranks, _)| ranks.start <= value)
395 .map(|(ranks, _)| ranks.end.min(value) - ranks.start)
396 .sum()
397 }
398}
399
400impl<T: Clone> RankedValues<T> {
401 pub fn materialized_iter(&self, until: usize) -> impl Iterator<Item = &T> + '_ {
402 assert_eq!(self.rank(until), until, "insufficient rank");
403 self.iter()
404 .flat_map(|(range, value)| std::iter::repeat_n(value, range.end - range.start))
405 }
406}
407
408impl<T: Hash + Eq + Clone> RankedValues<T> {
409 pub fn invert(&self) -> ValuesByRank<T> {
411 let mut inverted: HashMap<T, Vec<Range<usize>>> = HashMap::new();
412 for (range, value) in self.iter() {
413 inverted
414 .entry(value.clone())
415 .or_default()
416 .push(range.clone());
417 }
418 ValuesByRank { values: inverted }
419 }
420}
421
422impl<T: Eq + Clone> RankedValues<T> {
423 pub fn merge_from(&mut self, other: Self) {
431 let mut left_iter = take(&mut self.intervals).into_iter();
432 let mut right_iter = other.intervals.into_iter();
433
434 let mut left = left_iter.next();
435 let mut right = right_iter.next();
436
437 while left.is_some() && right.is_some() {
438 let (left_ranks, left_value) = left.as_mut().unwrap();
439 let (right_ranks, right_value) = right.as_mut().unwrap();
440
441 if left_ranks.is_overlapping(right_ranks) {
442 if left_value == right_value {
443 let ranks = left_ranks.start.min(right_ranks.start)..right_ranks.end;
444 let (_, value) = replace(&mut right, right_iter.next()).unwrap();
445 left_ranks.start = ranks.end;
446 if left_ranks.is_empty() {
447 left = left_iter.next();
448 }
449 self.append(ranks, value);
450 } else if left_ranks.start < right_ranks.start {
451 let ranks = left_ranks.start..right_ranks.start;
452 left_ranks.start = ranks.end;
453 self.append(ranks, left_value.clone());
455 } else {
456 let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
457 left_ranks.start = ranks.end;
458 if left_ranks.is_empty() {
459 left = left_iter.next();
460 }
461 self.append(ranks, value);
462 }
463 } else if left_ranks.start < right_ranks.start {
464 let (ranks, value) = replace(&mut left, left_iter.next()).unwrap();
465 self.append(ranks, value);
466 } else {
467 let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
468 self.append(ranks, value);
469 }
470 }
471
472 while let Some((left_ranks, left_value)) = left {
473 self.append(left_ranks, left_value);
474 left = left_iter.next();
475 }
476 while let Some((right_ranks, right_value)) = right {
477 self.append(right_ranks, right_value);
478 right = right_iter.next();
479 }
480 }
481
482 pub fn merge_into(self, other: &mut Self) {
484 other.merge_from(self);
485 }
486
487 fn append(&mut self, range: Range<usize>, value: T) {
488 if let Some(last) = self.intervals.last_mut()
489 && last.0.end == range.start
490 && last.1 == value
491 {
492 last.0.end = range.end;
493 } else {
494 self.intervals.push((range, value));
495 }
496 }
497}
498
499impl RankedValues<Status> {
500 pub fn first_terminating(&self) -> Option<(usize, Status)> {
501 self.intervals
502 .iter()
503 .find(|(_, status)| status.is_terminating())
504 .map(|(range, status)| (range.start, status.clone()))
505 }
506
507 pub fn first_failed(&self) -> Option<(usize, Status)> {
508 self.intervals
509 .iter()
510 .find(|(_, status)| matches!(status, Status::Failed(_) | Status::Timeout(_)))
511 .map(|(range, status)| (range.start, status.clone()))
512 }
513}
514
515impl<T> From<(usize, T)> for RankedValues<T> {
516 fn from((rank, value): (usize, T)) -> Self {
517 Self {
518 intervals: vec![(rank..rank + 1, value)],
519 }
520 }
521}
522
523impl<T> From<(Range<usize>, T)> for RankedValues<T> {
524 fn from((range, value): (Range<usize>, T)) -> Self {
525 Self {
526 intervals: vec![(range, value)],
527 }
528 }
529}
530
531#[derive(Clone, Debug)]
534pub struct ValuesByRank<T> {
535 values: HashMap<T, Vec<Range<usize>>>,
536}
537
538impl<T: Eq + Hash> PartialEq for ValuesByRank<T> {
539 fn eq(&self, other: &Self) -> bool {
540 self.values == other.values
541 }
542}
543
544impl<T: Eq + Hash> Eq for ValuesByRank<T> {}
545
546impl<T: fmt::Display> fmt::Display for ValuesByRank<T> {
547 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
548 let mut first_value = true;
549 for (value, ranges) in self.iter() {
550 if first_value {
551 first_value = false;
552 } else {
553 write!(f, ";")?;
554 }
555 write!(f, "{}=", value)?;
556 let mut first_range = true;
557 for range in ranges.iter() {
558 if first_range {
559 first_range = false;
560 } else {
561 write!(f, ",")?;
562 }
563 write!(f, "{}..{}", range.start, range.end)?;
564 }
565 }
566 Ok(())
567 }
568}
569
570impl<T> Deref for ValuesByRank<T> {
571 type Target = HashMap<T, Vec<Range<usize>>>;
572
573 fn deref(&self) -> &Self::Target {
574 &self.values
575 }
576}
577
578impl<T> DerefMut for ValuesByRank<T> {
579 fn deref_mut(&mut self) -> &mut Self::Target {
580 &mut self.values
581 }
582}
583
584#[cfg(test)]
587impl<T> FromIterator<(Range<usize>, T)> for RankedValues<T> {
588 fn from_iter<I: IntoIterator<Item = (Range<usize>, T)>>(iter: I) -> Self {
589 Self {
590 intervals: iter.into_iter().collect(),
591 }
592 }
593}
594
595#[derive(Clone, Debug, Serialize, Deserialize, Named, Default)]
597pub(crate) struct ProcSpec {
598 pub(crate) client_config_override: Attrs,
601}
602
603impl ProcSpec {
604 pub(crate) fn new(client_config_override: Attrs) -> Self {
605 Self {
606 client_config_override,
607 }
608 }
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614
615 #[test]
616 fn test_ranked_values_merge() {
617 #[derive(PartialEq, Debug, Eq, Clone)]
618 enum Side {
619 Left,
620 Right,
621 Both,
622 }
623 use Side::Both;
624 use Side::Left;
625 use Side::Right;
626
627 let mut left: RankedValues<Side> = [
628 (0..10, Left),
629 (15..20, Left),
630 (30..50, Both),
631 (60..70, Both),
632 ]
633 .into_iter()
634 .collect();
635
636 let right: RankedValues<Side> = [
637 (9..12, Right),
638 (25..30, Right),
639 (30..40, Both),
640 (40..50, Right),
641 (50..60, Both),
642 ]
643 .into_iter()
644 .collect();
645
646 left.merge_from(right);
647 assert_eq!(
648 left.iter().cloned().collect::<Vec<_>>(),
649 vec![
650 (0..9, Left),
651 (9..12, Right),
652 (15..20, Left),
653 (25..30, Right),
654 (30..40, Both),
655 (40..50, Right),
656 (50..70, Both)
658 ]
659 );
660
661 assert_eq!(left.rank(5), 5);
662 assert_eq!(left.rank(10), 10);
663 assert_eq!(left.rank(16), 13);
664 assert_eq!(left.rank(70), 62);
665 assert_eq!(left.rank(100), 62);
666 }
667
668 #[test]
669 fn test_equality() {
670 assert_eq!(
671 RankedValues::from((0..10, 123)),
672 RankedValues::from((0..10, 123))
673 );
674 assert_eq!(
675 RankedValues::from((0..10, Status::Failed("foo".to_string()))),
676 RankedValues::from((0..10, Status::Failed("foo".to_string()))),
677 );
678 }
679
680 #[test]
681 fn test_default_through_merging() {
682 let values: RankedValues<usize> =
683 [(0..10, 1), (15..20, 1), (30..50, 1)].into_iter().collect();
684
685 let mut default = RankedValues::from((0..50, 0));
686 default.merge_from(values);
687
688 assert_eq!(
689 default.iter().cloned().collect::<Vec<_>>(),
690 vec![
691 (0..10, 1),
692 (10..15, 0),
693 (15..20, 1),
694 (20..30, 0),
695 (30..50, 1)
696 ]
697 );
698 }
699}