hyperactor_mesh/
resource.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This modules defines a set of common message types used for managing resources
10//! in hyperactor meshes.
11
12pub mod mesh;
13
14use core::slice::GetDisjointMutIndex as _;
15use std::collections::HashMap;
16use std::fmt;
17use std::fmt::Debug;
18use std::hash::Hash;
19use std::mem::replace;
20use std::mem::take;
21use std::ops::Deref;
22use std::ops::DerefMut;
23use std::ops::Range;
24use std::time::Duration;
25
26use enum_as_inner::EnumAsInner;
27use hyperactor::Bind;
28use hyperactor::HandleClient;
29use hyperactor::Handler;
30use hyperactor::PortRef;
31use hyperactor::RefClient;
32use hyperactor::RemoteMessage;
33use hyperactor::Unbind;
34use hyperactor::mailbox::PortReceiver;
35use hyperactor::message::Bind;
36use hyperactor::message::Bindings;
37use hyperactor::message::Unbind;
38use hyperactor_config::attrs::Attrs;
39use ndslice::Region;
40use ndslice::ViewExt;
41use serde::Deserialize;
42use serde::Serialize;
43use typeuri::Named;
44
45use crate::bootstrap;
46use crate::v1::Name;
47use crate::v1::StatusOverlay;
48
49/// The current lifecycle status of a resource.
50#[derive(
51    Clone,
52    Debug,
53    Serialize,
54    Deserialize,
55    Named,
56    PartialOrd,
57    Ord,
58    PartialEq,
59    Eq,
60    Hash,
61    EnumAsInner,
62    strum::Display
63)]
64pub enum Status {
65    /// The resource does not exist.
66    NotExist,
67    /// The resource is being created.
68    Initializing,
69    /// The resource is running.
70    Running,
71    /// The resource is being stopped.
72    Stopping,
73    /// The resource is stopped.
74    Stopped,
75    /// The resource has failed, with an error message.
76    #[strum(to_string = "Failed({0})")]
77    Failed(String),
78    /// The resource has been declared failed after a timeout.
79    #[strum(to_string = "Timeout({0:?})")]
80    Timeout(Duration),
81}
82
83impl Status {
84    /// Returns whether the status is a terminating status.
85    pub fn is_terminating(&self) -> bool {
86        matches!(
87            self,
88            Status::Stopping | Status::Stopped | Status::Failed(_) | Status::Timeout(_)
89        )
90    }
91
92    /// Tells whether the status represents a failure. A failure is both terminating
93    /// (the resource is not running), but also means abnormal exit (the resource
94    /// did not stop cleanly).
95    pub fn is_failure(&self) -> bool {
96        matches!(self, Self::Failed(_) | Self::Timeout(_))
97    }
98
99    pub fn is_healthy(&self) -> bool {
100        matches!(self, Status::Initializing | Status::Running)
101    }
102}
103
104impl From<bootstrap::ProcStatus> for Status {
105    fn from(status: bootstrap::ProcStatus) -> Self {
106        use bootstrap::ProcStatus;
107        match status {
108            ProcStatus::Starting => Status::Initializing,
109            ProcStatus::Running { .. } | ProcStatus::Ready { .. } => Status::Running,
110            ProcStatus::Stopping { .. } => Status::Stopping,
111            ProcStatus::Stopped { .. } => Status::Stopped,
112            ProcStatus::Failed { reason } => Status::Failed(reason),
113            ProcStatus::Killed { .. } => Status::Failed(format!("{}", status)),
114        }
115    }
116}
117
118/// Data type used to communicate ranks.
119/// Implements [`Bind`] and [`Unbind`]; the comm actor replaces
120/// instances with the delivered rank.
121#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq, Default)]
122pub struct Rank(pub Option<usize>);
123wirevalue::register_type!(Rank);
124
125impl Rank {
126    /// Create a new rank with the provided value.
127    pub fn new(rank: usize) -> Self {
128        Self(Some(rank))
129    }
130
131    /// Unwrap the rank; panics if not set.
132    pub fn unwrap(&self) -> usize {
133        self.0.unwrap()
134    }
135}
136
137impl Unbind for Rank {
138    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
139        bindings.push_back(self)
140    }
141}
142
143impl Bind for Rank {
144    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
145        let bound = bindings.try_pop_front::<Rank>()?;
146        self.0 = bound.0;
147        Ok(())
148    }
149}
150
151/// Get the status of a resource across the mesh.
152///
153/// This message is cast to all ranks; each rank replies with a sparse
154/// status **overlay**. The comm reducer merges overlays (right-wins)
155/// and the accumulator applies them to produce **full StatusMesh
156/// snapshots** on the receiver side.
157#[derive(
158    Clone,
159    Debug,
160    Serialize,
161    Deserialize,
162    Named,
163    Handler,
164    HandleClient,
165    RefClient,
166    Bind,
167    Unbind
168)]
169pub struct GetRankStatus {
170    /// The name of the resource.
171    pub name: Name,
172    /// Sparse status updates (overlays) from a rank.
173    #[binding(include)]
174    pub reply: PortRef<StatusOverlay>,
175}
176
177impl GetRankStatus {
178    pub async fn wait(
179        mut rx: PortReceiver<crate::v1::StatusMesh>,
180        num_ranks: usize,
181        max_idle_time: Duration,
182        region: Region, // used only for fallback
183    ) -> Result<crate::v1::StatusMesh, crate::v1::StatusMesh> {
184        debug_assert_eq!(region.num_ranks(), num_ranks, "region/num_ranks mismatch");
185
186        let mut alarm = hyperactor::time::Alarm::new();
187        alarm.arm(max_idle_time);
188
189        // Fallback snapshot if we time out before receiving anything.
190        let mut snapshot =
191            crate::v1::StatusMesh::from_single(region, crate::resource::Status::NotExist);
192
193        loop {
194            let mut sleeper = alarm.sleeper();
195            tokio::select! {
196                _ = sleeper.sleep() => return Err(snapshot),
197                next = rx.recv() => {
198                    match next {
199                        Ok(mesh) => { snapshot = mesh; }   // latest-wins snapshot
200                        Err(_)   => return Err(snapshot),
201                    }
202                }
203            }
204
205            alarm.arm(max_idle_time);
206
207            // Completion: once every rank (among the first
208            // `num_ranks`) has reported at least something (i.e.
209            // moved off NotExist).
210            if snapshot
211                .values()
212                .take(num_ranks)
213                .all(|s| !matches!(s, crate::resource::Status::NotExist))
214            {
215                break Ok(snapshot);
216            }
217        }
218    }
219}
220
221/// The state of a resource.
222#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq)]
223pub struct State<S> {
224    /// The name of the resource.
225    pub name: Name,
226    /// Its status.
227    pub status: Status,
228    /// Optionally, a resource-defined state.
229    pub state: Option<S>,
230}
231
232impl<S: Serialize> fmt::Display for State<S> {
233    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234        // Use serde_json to serialize the struct to a compact JSON string
235        match serde_json::to_string(self) {
236            Ok(json) => write!(f, "{}", json),
237            Err(e) => write!(f, "<state: serde_json error: {}>", e),
238        }
239    }
240}
241
242/// Create or update a resource according to a spec.
243#[derive(
244    Debug,
245    Clone,
246    Serialize,
247    Deserialize,
248    Named,
249    Handler,
250    HandleClient,
251    RefClient,
252    Bind,
253    Unbind
254)]
255pub struct CreateOrUpdate<S> {
256    /// The name of the resource to create or update.
257    pub name: Name,
258    /// The rank of the resource, when available.
259    #[binding(include)]
260    pub rank: Rank,
261    /// The specification of the resource.
262    pub spec: S,
263}
264
265/// Stop a resource according to a spec.
266#[derive(
267    Debug,
268    Clone,
269    Serialize,
270    Deserialize,
271    Named,
272    Handler,
273    HandleClient,
274    RefClient,
275    Bind,
276    Unbind
277)]
278pub struct Stop {
279    /// The name of the resource to stop.
280    pub name: Name,
281}
282
283/// Stop all resources owned by the receiver of this message.
284/// No reply, this just issues the stop command.
285/// Use GetRankStatus to determine if it has successfully stopped.
286#[derive(
287    Debug,
288    Clone,
289    Serialize,
290    Deserialize,
291    Named,
292    Handler,
293    HandleClient,
294    RefClient,
295    Bind,
296    Unbind
297)]
298pub struct StopAll {}
299
300/// Retrieve the current state of the resource.
301#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
302pub struct GetState<S> {
303    /// The name of the resource.
304    pub name: Name,
305    /// A reply containing the state.
306    #[reply]
307    pub reply: PortRef<State<S>>,
308}
309
310// Cannot derive Bind and Unbind for this generic, implement manually.
311impl<S> Unbind for GetState<S>
312where
313    S: RemoteMessage,
314    S: Unbind,
315{
316    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
317        self.reply.unbind(bindings)
318    }
319}
320
321impl<S> Bind for GetState<S>
322where
323    S: RemoteMessage,
324    S: Bind,
325{
326    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
327        self.reply.bind(bindings)
328    }
329}
330
331impl<S> Clone for GetState<S>
332where
333    S: RemoteMessage,
334{
335    fn clone(&self) -> Self {
336        Self {
337            name: self.name.clone(),
338            reply: self.reply.clone(),
339        }
340    }
341}
342
343/// List the set of resources managed by the controller.
344#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
345pub struct List {
346    /// List of resource names managed by this controller.
347    #[reply]
348    pub reply: PortRef<Vec<Name>>,
349}
350wirevalue::register_type!(List);
351
352/// A trait that bundles a set of types that together define a resource.
353pub trait Resource {
354    /// The spec specification for this resource.
355    type Spec: typeuri::Named
356        + Serialize
357        + for<'de> Deserialize<'de>
358        + Send
359        + Sync
360        + std::fmt::Debug;
361
362    /// The state for this resource.
363    type State: typeuri::Named
364        + Serialize
365        + for<'de> Deserialize<'de>
366        + Send
367        + Sync
368        + std::fmt::Debug;
369}
370
371// A behavior defining the interface for a mesh controller.
372hyperactor::behavior!(
373    Controller<R: Resource>,
374    CreateOrUpdate<R::Spec>,
375    GetState<R::State>,
376    Stop,
377);
378
379/// RankedValues compactly represents rank-indexed values of type T.
380/// It stores contiguous values in a set of intervals; thus it is
381/// efficient and compact when the cardinality of T-typed values is
382/// low.
383#[derive(Debug, Clone, Named, Serialize, Deserialize)]
384pub struct RankedValues<T> {
385    intervals: Vec<(Range<usize>, T)>,
386}
387
388impl<T: PartialEq> PartialEq for RankedValues<T> {
389    fn eq(&self, other: &Self) -> bool {
390        self.intervals == other.intervals
391    }
392}
393
394impl<T: Eq> Eq for RankedValues<T> {}
395
396impl<T> Default for RankedValues<T> {
397    fn default() -> Self {
398        Self {
399            intervals: Vec::new(),
400        }
401    }
402}
403
404impl<T> RankedValues<T> {
405    /// Iterate over contiguous rank intervals of values.
406    pub fn iter(&self) -> impl Iterator<Item = &(Range<usize>, T)> + '_ {
407        self.intervals.iter()
408    }
409
410    /// The (set) rank of the RankedValues is the number of values stored with
411    /// rank less than `value`.
412    pub fn rank(&self, value: usize) -> usize {
413        self.iter()
414            .take_while(|(ranks, _)| ranks.start <= value)
415            .map(|(ranks, _)| ranks.end.min(value) - ranks.start)
416            .sum()
417    }
418}
419
420impl<T: Clone> RankedValues<T> {
421    pub fn materialized_iter(&self, until: usize) -> impl Iterator<Item = &T> + '_ {
422        assert_eq!(self.rank(until), until, "insufficient rank");
423        self.iter()
424            .flat_map(|(range, value)| std::iter::repeat_n(value, range.end - range.start))
425    }
426}
427
428impl<T: Hash + Eq + Clone> RankedValues<T> {
429    /// Invert this ranked values into a [`ValuesByRank<T>`].
430    pub fn invert(&self) -> ValuesByRank<T> {
431        let mut inverted: HashMap<T, Vec<Range<usize>>> = HashMap::new();
432        for (range, value) in self.iter() {
433            inverted
434                .entry(value.clone())
435                .or_default()
436                .push(range.clone());
437        }
438        ValuesByRank { values: inverted }
439    }
440}
441
442impl<T: Eq + Clone> RankedValues<T> {
443    /// Merge `other` into this set of ranked values. Values in `other` that overlap
444    /// with `self` take prededence.
445    ///
446    /// This currently uses a simple algorithm that merges the full set of RankedValues.
447    /// This remains efficient when the cardinality of T-typed values is low. However,
448    /// it does not efficiently merge high cardinality value sets. Consider using interval
449    /// trees or bitmap techniques like Roaring Bitmaps in these cases.
450    pub fn merge_from(&mut self, other: Self) {
451        let mut left_iter = take(&mut self.intervals).into_iter();
452        let mut right_iter = other.intervals.into_iter();
453
454        let mut left = left_iter.next();
455        let mut right = right_iter.next();
456
457        while left.is_some() && right.is_some() {
458            let (left_ranks, left_value) = left.as_mut().unwrap();
459            let (right_ranks, right_value) = right.as_mut().unwrap();
460
461            if left_ranks.is_overlapping(right_ranks) {
462                if left_value == right_value {
463                    let ranks = left_ranks.start.min(right_ranks.start)..right_ranks.end;
464                    let (_, value) = replace(&mut right, right_iter.next()).unwrap();
465                    left_ranks.start = ranks.end;
466                    if left_ranks.is_empty() {
467                        left = left_iter.next();
468                    }
469                    self.append(ranks, value);
470                } else if left_ranks.start < right_ranks.start {
471                    let ranks = left_ranks.start..right_ranks.start;
472                    left_ranks.start = ranks.end;
473                    // TODO: get rid of clone
474                    self.append(ranks, left_value.clone());
475                } else {
476                    let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
477                    left_ranks.start = ranks.end;
478                    if left_ranks.is_empty() {
479                        left = left_iter.next();
480                    }
481                    self.append(ranks, value);
482                }
483            } else if left_ranks.start < right_ranks.start {
484                let (ranks, value) = replace(&mut left, left_iter.next()).unwrap();
485                self.append(ranks, value);
486            } else {
487                let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
488                self.append(ranks, value);
489            }
490        }
491
492        while let Some((left_ranks, left_value)) = left {
493            self.append(left_ranks, left_value);
494            left = left_iter.next();
495        }
496        while let Some((right_ranks, right_value)) = right {
497            self.append(right_ranks, right_value);
498            right = right_iter.next();
499        }
500    }
501
502    /// Merge the contents of this RankedValues into another RankedValues.
503    pub fn merge_into(self, other: &mut Self) {
504        other.merge_from(self);
505    }
506
507    fn append(&mut self, range: Range<usize>, value: T) {
508        if let Some(last) = self.intervals.last_mut()
509            && last.0.end == range.start
510            && last.1 == value
511        {
512            last.0.end = range.end;
513        } else {
514            self.intervals.push((range, value));
515        }
516    }
517}
518
519impl RankedValues<Status> {
520    pub fn first_terminating(&self) -> Option<(usize, Status)> {
521        self.intervals
522            .iter()
523            .find(|(_, status)| status.is_terminating())
524            .map(|(range, status)| (range.start, status.clone()))
525    }
526
527    pub fn first_failed(&self) -> Option<(usize, Status)> {
528        self.intervals
529            .iter()
530            .find(|(_, status)| matches!(status, Status::Failed(_) | Status::Timeout(_)))
531            .map(|(range, status)| (range.start, status.clone()))
532    }
533}
534
535impl<T> From<(usize, T)> for RankedValues<T> {
536    fn from((rank, value): (usize, T)) -> Self {
537        Self {
538            intervals: vec![(rank..rank + 1, value)],
539        }
540    }
541}
542
543impl<T> From<(Range<usize>, T)> for RankedValues<T> {
544    fn from((range, value): (Range<usize>, T)) -> Self {
545        Self {
546            intervals: vec![(range, value)],
547        }
548    }
549}
550
551/// An inverted index of RankedValues, providing all ranks for
552/// which each unique T-typed value appears.
553#[derive(Clone, Debug)]
554pub struct ValuesByRank<T> {
555    values: HashMap<T, Vec<Range<usize>>>,
556}
557
558impl<T: Eq + Hash> PartialEq for ValuesByRank<T> {
559    fn eq(&self, other: &Self) -> bool {
560        self.values == other.values
561    }
562}
563
564impl<T: Eq + Hash> Eq for ValuesByRank<T> {}
565
566impl<T: fmt::Display> fmt::Display for ValuesByRank<T> {
567    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
568        let mut first_value = true;
569        for (value, ranges) in self.iter() {
570            if first_value {
571                first_value = false;
572            } else {
573                write!(f, ";")?;
574            }
575            write!(f, "{}=", value)?;
576            let mut first_range = true;
577            for range in ranges.iter() {
578                if first_range {
579                    first_range = false;
580                } else {
581                    write!(f, ",")?;
582                }
583                write!(f, "{}..{}", range.start, range.end)?;
584            }
585        }
586        Ok(())
587    }
588}
589
590impl<T> Deref for ValuesByRank<T> {
591    type Target = HashMap<T, Vec<Range<usize>>>;
592
593    fn deref(&self) -> &Self::Target {
594        &self.values
595    }
596}
597
598impl<T> DerefMut for ValuesByRank<T> {
599    fn deref_mut(&mut self) -> &mut Self::Target {
600        &mut self.values
601    }
602}
603
604/// Enabled for test only because we have to guarantee that the input
605/// iterator is well-formed.
606#[cfg(test)]
607impl<T> FromIterator<(Range<usize>, T)> for RankedValues<T> {
608    fn from_iter<I: IntoIterator<Item = (Range<usize>, T)>>(iter: I) -> Self {
609        Self {
610            intervals: iter.into_iter().collect(),
611        }
612    }
613}
614
615/// Spec for a host mesh agent to use when spawning a new proc.
616#[derive(Clone, Debug, Serialize, Deserialize, Named, Default)]
617pub(crate) struct ProcSpec {
618    /// Config values to set on the spawned proc's global config,
619    /// at the `ClientOverride` layer.
620    pub(crate) client_config_override: Attrs,
621}
622wirevalue::register_type!(ProcSpec);
623
624impl ProcSpec {
625    pub(crate) fn new(client_config_override: Attrs) -> Self {
626        Self {
627            client_config_override,
628        }
629    }
630}
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635
636    #[test]
637    fn test_ranked_values_merge() {
638        #[derive(PartialEq, Debug, Eq, Clone)]
639        enum Side {
640            Left,
641            Right,
642            Both,
643        }
644        use Side::Both;
645        use Side::Left;
646        use Side::Right;
647
648        let mut left: RankedValues<Side> = [
649            (0..10, Left),
650            (15..20, Left),
651            (30..50, Both),
652            (60..70, Both),
653        ]
654        .into_iter()
655        .collect();
656
657        let right: RankedValues<Side> = [
658            (9..12, Right),
659            (25..30, Right),
660            (30..40, Both),
661            (40..50, Right),
662            (50..60, Both),
663        ]
664        .into_iter()
665        .collect();
666
667        left.merge_from(right);
668        assert_eq!(
669            left.iter().cloned().collect::<Vec<_>>(),
670            vec![
671                (0..9, Left),
672                (9..12, Right),
673                (15..20, Left),
674                (25..30, Right),
675                (30..40, Both),
676                (40..50, Right),
677                // Merge consecutive:
678                (50..70, Both)
679            ]
680        );
681
682        assert_eq!(left.rank(5), 5);
683        assert_eq!(left.rank(10), 10);
684        assert_eq!(left.rank(16), 13);
685        assert_eq!(left.rank(70), 62);
686        assert_eq!(left.rank(100), 62);
687    }
688
689    #[test]
690    fn test_equality() {
691        assert_eq!(
692            RankedValues::from((0..10, 123)),
693            RankedValues::from((0..10, 123))
694        );
695        assert_eq!(
696            RankedValues::from((0..10, Status::Failed("foo".to_string()))),
697            RankedValues::from((0..10, Status::Failed("foo".to_string()))),
698        );
699    }
700
701    #[test]
702    fn test_default_through_merging() {
703        let values: RankedValues<usize> =
704            [(0..10, 1), (15..20, 1), (30..50, 1)].into_iter().collect();
705
706        let mut default = RankedValues::from((0..50, 0));
707        default.merge_from(values);
708
709        assert_eq!(
710            default.iter().cloned().collect::<Vec<_>>(),
711            vec![
712                (0..10, 1),
713                (10..15, 0),
714                (15..20, 1),
715                (20..30, 0),
716                (30..50, 1)
717            ]
718        );
719    }
720}