hyperactor_mesh/
resource.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This modules defines a set of common message types used for managing resources
10//! in hyperactor meshes.
11
12pub mod mesh;
13
14use core::slice::GetDisjointMutIndex as _;
15use std::collections::HashMap;
16use std::fmt;
17use std::fmt::Debug;
18use std::hash::Hash;
19use std::mem::replace;
20use std::mem::take;
21use std::ops::Deref;
22use std::ops::DerefMut;
23use std::ops::Range;
24use std::time::Duration;
25
26use enum_as_inner::EnumAsInner;
27use hyperactor::Bind;
28use hyperactor::HandleClient;
29use hyperactor::Handler;
30use hyperactor::Named;
31use hyperactor::PortRef;
32use hyperactor::RefClient;
33use hyperactor::RemoteMessage;
34use hyperactor::Unbind;
35use hyperactor::attrs::Attrs;
36use hyperactor::mailbox::PortReceiver;
37use hyperactor::message::Bind;
38use hyperactor::message::Bindings;
39use hyperactor::message::Unbind;
40use ndslice::Region;
41use ndslice::ViewExt;
42use serde::Deserialize;
43use serde::Serialize;
44
45use crate::bootstrap;
46use crate::v1::Name;
47use crate::v1::StatusOverlay;
48
49/// The current lifecycle status of a resource.
50#[derive(
51    Clone,
52    Debug,
53    Serialize,
54    Deserialize,
55    Named,
56    PartialOrd,
57    Ord,
58    PartialEq,
59    Eq,
60    Hash,
61    EnumAsInner,
62    strum::Display
63)]
64pub enum Status {
65    /// The resource does not exist.
66    NotExist,
67    /// The resource is being created.
68    Initializing,
69    /// The resource is running.
70    Running,
71    /// The resource is being stopped.
72    Stopping,
73    /// The resource is stopped.
74    Stopped,
75    /// The resource has failed, with an error message.
76    #[strum(to_string = "Failed({0})")]
77    Failed(String),
78    /// The resource has been declared failed after a timeout.
79    #[strum(to_string = "Timeout({0:?})")]
80    Timeout(Duration),
81}
82
83impl Status {
84    /// Returns whether the status is a terminating status.
85    pub fn is_terminating(&self) -> bool {
86        matches!(
87            self,
88            Status::Stopping | Status::Stopped | Status::Failed(_) | Status::Timeout(_)
89        )
90    }
91
92    /// Tells whether the status represents a failure. A failure is both terminating
93    /// (the resource is not running), but also means abnormal exit (the resource
94    /// did not stop cleanly).
95    pub fn is_failure(&self) -> bool {
96        matches!(self, Self::Failed(_) | Self::Timeout(_))
97    }
98
99    pub fn is_healthy(&self) -> bool {
100        matches!(self, Status::Initializing | Status::Running)
101    }
102}
103
104impl From<bootstrap::ProcStatus> for Status {
105    fn from(status: bootstrap::ProcStatus) -> Self {
106        use bootstrap::ProcStatus;
107        match status {
108            ProcStatus::Starting => Status::Initializing,
109            ProcStatus::Running { .. } | ProcStatus::Ready { .. } => Status::Running,
110            ProcStatus::Stopping { .. } => Status::Stopping,
111            ProcStatus::Stopped { .. } => Status::Stopped,
112            ProcStatus::Failed { reason } => Status::Failed(reason),
113            ProcStatus::Killed { .. } => Status::Failed(format!("{}", status)),
114        }
115    }
116}
117
118/// Data type used to communicate ranks.
119/// Implements [`Bind`] and [`Unbind`]; the comm actor replaces
120/// instances with the delivered rank.
121#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq, Default)]
122pub struct Rank(pub Option<usize>);
123
124impl Rank {
125    /// Create a new rank with the provided value.
126    pub fn new(rank: usize) -> Self {
127        Self(Some(rank))
128    }
129
130    /// Unwrap the rank; panics if not set.
131    pub fn unwrap(&self) -> usize {
132        self.0.unwrap()
133    }
134}
135
136impl Unbind for Rank {
137    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
138        bindings.push_back(self)
139    }
140}
141
142impl Bind for Rank {
143    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
144        let bound = bindings.try_pop_front::<Rank>()?;
145        self.0 = bound.0;
146        Ok(())
147    }
148}
149
150/// Get the status of a resource across the mesh.
151///
152/// This message is cast to all ranks; each rank replies with a sparse
153/// status **overlay**. The comm reducer merges overlays (right-wins)
154/// and the accumulator applies them to produce **full StatusMesh
155/// snapshots** on the receiver side.
156#[derive(
157    Clone,
158    Debug,
159    Serialize,
160    Deserialize,
161    Named,
162    Handler,
163    HandleClient,
164    RefClient,
165    Bind,
166    Unbind
167)]
168pub struct GetRankStatus {
169    /// The name of the resource.
170    pub name: Name,
171    /// Sparse status updates (overlays) from a rank.
172    #[binding(include)]
173    pub reply: PortRef<StatusOverlay>,
174}
175
176impl GetRankStatus {
177    pub async fn wait(
178        mut rx: PortReceiver<crate::v1::StatusMesh>,
179        num_ranks: usize,
180        max_idle_time: Duration,
181        region: Region, // used only for fallback
182    ) -> Result<crate::v1::StatusMesh, crate::v1::StatusMesh> {
183        debug_assert_eq!(region.num_ranks(), num_ranks, "region/num_ranks mismatch");
184
185        let mut alarm = hyperactor::time::Alarm::new();
186        alarm.arm(max_idle_time);
187
188        // Fallback snapshot if we time out before receiving anything.
189        let mut snapshot =
190            crate::v1::StatusMesh::from_single(region, crate::resource::Status::NotExist);
191
192        loop {
193            let mut sleeper = alarm.sleeper();
194            tokio::select! {
195                _ = sleeper.sleep() => return Err(snapshot),
196                next = rx.recv() => {
197                    match next {
198                        Ok(mesh) => { snapshot = mesh; }   // latest-wins snapshot
199                        Err(_)   => return Err(snapshot),
200                    }
201                }
202            }
203
204            alarm.arm(max_idle_time);
205
206            // Completion: once every rank (among the first
207            // `num_ranks`) has reported at least something (i.e.
208            // moved off NotExist).
209            if snapshot
210                .values()
211                .take(num_ranks)
212                .all(|s| !matches!(s, crate::resource::Status::NotExist))
213            {
214                break Ok(snapshot);
215            }
216        }
217    }
218}
219
220/// The state of a resource.
221#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq)]
222pub struct State<S> {
223    /// The name of the resource.
224    pub name: Name,
225    /// Its status.
226    pub status: Status,
227    /// Optionally, a resource-defined state.
228    pub state: Option<S>,
229}
230
231impl<S: Serialize> fmt::Display for State<S> {
232    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233        // Use serde_json to serialize the struct to a compact JSON string
234        match serde_json::to_string(self) {
235            Ok(json) => write!(f, "{}", json),
236            Err(e) => write!(f, "<state: serde_json error: {}>", e),
237        }
238    }
239}
240
241/// Create or update a resource according to a spec.
242#[derive(
243    Debug,
244    Clone,
245    Serialize,
246    Deserialize,
247    Named,
248    Handler,
249    HandleClient,
250    RefClient,
251    Bind,
252    Unbind
253)]
254pub struct CreateOrUpdate<S> {
255    /// The name of the resource to create or update.
256    pub name: Name,
257    /// The rank of the resource, when available.
258    #[binding(include)]
259    pub rank: Rank,
260    /// The specification of the resource.
261    pub spec: S,
262}
263
264/// Stop a resource according to a spec.
265#[derive(
266    Debug,
267    Clone,
268    Serialize,
269    Deserialize,
270    Named,
271    Handler,
272    HandleClient,
273    RefClient,
274    Bind,
275    Unbind
276)]
277pub struct Stop {
278    /// The name of the resource to stop.
279    pub name: Name,
280}
281
282/// Stop all resources owned by the receiver of this message.
283/// No reply, this just issues the stop command.
284/// Use GetRankStatus to determine if it has successfully stopped.
285#[derive(
286    Debug,
287    Clone,
288    Serialize,
289    Deserialize,
290    Named,
291    Handler,
292    HandleClient,
293    RefClient,
294    Bind,
295    Unbind
296)]
297pub struct StopAll {}
298
299/// Retrieve the current state of the resource.
300#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
301pub struct GetState<S> {
302    /// The name of the resource.
303    pub name: Name,
304    /// A reply containing the state.
305    #[reply]
306    pub reply: PortRef<State<S>>,
307}
308
309// Cannot derive Bind and Unbind for this generic, implement manually.
310impl<S> Unbind for GetState<S>
311where
312    S: RemoteMessage,
313    S: Unbind,
314{
315    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
316        self.reply.unbind(bindings)
317    }
318}
319
320impl<S> Bind for GetState<S>
321where
322    S: RemoteMessage,
323    S: Bind,
324{
325    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
326        self.reply.bind(bindings)
327    }
328}
329
330impl<S> Clone for GetState<S>
331where
332    S: RemoteMessage,
333{
334    fn clone(&self) -> Self {
335        Self {
336            name: self.name.clone(),
337            reply: self.reply.clone(),
338        }
339    }
340}
341
342/// A trait that bundles a set of types that together define a resource.
343pub trait Resource {
344    /// The spec specification for this resource.
345    type Spec: Named + Serialize + for<'de> Deserialize<'de> + Send + Sync + std::fmt::Debug;
346
347    /// The state for this resource.
348    type State: Named + Serialize + for<'de> Deserialize<'de> + Send + Sync + std::fmt::Debug;
349}
350
351// A behavior defining the interface for a mesh controller.
352hyperactor::behavior!(
353    Controller<R: Resource>,
354    CreateOrUpdate<R::Spec>,
355    GetState<R::State>,
356    Stop,
357);
358
359/// RankedValues compactly represents rank-indexed values of type T.
360/// It stores contiguous values in a set of intervals; thus it is
361/// efficient and compact when the cardinality of T-typed values is
362/// low.
363#[derive(Debug, Clone, Named, Serialize, Deserialize)]
364pub struct RankedValues<T> {
365    intervals: Vec<(Range<usize>, T)>,
366}
367
368impl<T: PartialEq> PartialEq for RankedValues<T> {
369    fn eq(&self, other: &Self) -> bool {
370        self.intervals == other.intervals
371    }
372}
373
374impl<T: Eq> Eq for RankedValues<T> {}
375
376impl<T> Default for RankedValues<T> {
377    fn default() -> Self {
378        Self {
379            intervals: Vec::new(),
380        }
381    }
382}
383
384impl<T> RankedValues<T> {
385    /// Iterate over contiguous rank intervals of values.
386    pub fn iter(&self) -> impl Iterator<Item = &(Range<usize>, T)> + '_ {
387        self.intervals.iter()
388    }
389
390    /// The (set) rank of the RankedValues is the number of values stored with
391    /// rank less than `value`.
392    pub fn rank(&self, value: usize) -> usize {
393        self.iter()
394            .take_while(|(ranks, _)| ranks.start <= value)
395            .map(|(ranks, _)| ranks.end.min(value) - ranks.start)
396            .sum()
397    }
398}
399
400impl<T: Clone> RankedValues<T> {
401    pub fn materialized_iter(&self, until: usize) -> impl Iterator<Item = &T> + '_ {
402        assert_eq!(self.rank(until), until, "insufficient rank");
403        self.iter()
404            .flat_map(|(range, value)| std::iter::repeat_n(value, range.end - range.start))
405    }
406}
407
408impl<T: Hash + Eq + Clone> RankedValues<T> {
409    /// Invert this ranked values into a [`ValuesByRank<T>`].
410    pub fn invert(&self) -> ValuesByRank<T> {
411        let mut inverted: HashMap<T, Vec<Range<usize>>> = HashMap::new();
412        for (range, value) in self.iter() {
413            inverted
414                .entry(value.clone())
415                .or_default()
416                .push(range.clone());
417        }
418        ValuesByRank { values: inverted }
419    }
420}
421
422impl<T: Eq + Clone> RankedValues<T> {
423    /// Merge `other` into this set of ranked values. Values in `other` that overlap
424    /// with `self` take prededence.
425    ///
426    /// This currently uses a simple algorithm that merges the full set of RankedValues.
427    /// This remains efficient when the cardinality of T-typed values is low. However,
428    /// it does not efficiently merge high cardinality value sets. Consider using interval
429    /// trees or bitmap techniques like Roaring Bitmaps in these cases.
430    pub fn merge_from(&mut self, other: Self) {
431        let mut left_iter = take(&mut self.intervals).into_iter();
432        let mut right_iter = other.intervals.into_iter();
433
434        let mut left = left_iter.next();
435        let mut right = right_iter.next();
436
437        while left.is_some() && right.is_some() {
438            let (left_ranks, left_value) = left.as_mut().unwrap();
439            let (right_ranks, right_value) = right.as_mut().unwrap();
440
441            if left_ranks.is_overlapping(right_ranks) {
442                if left_value == right_value {
443                    let ranks = left_ranks.start.min(right_ranks.start)..right_ranks.end;
444                    let (_, value) = replace(&mut right, right_iter.next()).unwrap();
445                    left_ranks.start = ranks.end;
446                    if left_ranks.is_empty() {
447                        left = left_iter.next();
448                    }
449                    self.append(ranks, value);
450                } else if left_ranks.start < right_ranks.start {
451                    let ranks = left_ranks.start..right_ranks.start;
452                    left_ranks.start = ranks.end;
453                    // TODO: get rid of clone
454                    self.append(ranks, left_value.clone());
455                } else {
456                    let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
457                    left_ranks.start = ranks.end;
458                    if left_ranks.is_empty() {
459                        left = left_iter.next();
460                    }
461                    self.append(ranks, value);
462                }
463            } else if left_ranks.start < right_ranks.start {
464                let (ranks, value) = replace(&mut left, left_iter.next()).unwrap();
465                self.append(ranks, value);
466            } else {
467                let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
468                self.append(ranks, value);
469            }
470        }
471
472        while let Some((left_ranks, left_value)) = left {
473            self.append(left_ranks, left_value);
474            left = left_iter.next();
475        }
476        while let Some((right_ranks, right_value)) = right {
477            self.append(right_ranks, right_value);
478            right = right_iter.next();
479        }
480    }
481
482    /// Merge the contents of this RankedValues into another RankedValues.
483    pub fn merge_into(self, other: &mut Self) {
484        other.merge_from(self);
485    }
486
487    fn append(&mut self, range: Range<usize>, value: T) {
488        if let Some(last) = self.intervals.last_mut()
489            && last.0.end == range.start
490            && last.1 == value
491        {
492            last.0.end = range.end;
493        } else {
494            self.intervals.push((range, value));
495        }
496    }
497}
498
499impl RankedValues<Status> {
500    pub fn first_terminating(&self) -> Option<(usize, Status)> {
501        self.intervals
502            .iter()
503            .find(|(_, status)| status.is_terminating())
504            .map(|(range, status)| (range.start, status.clone()))
505    }
506
507    pub fn first_failed(&self) -> Option<(usize, Status)> {
508        self.intervals
509            .iter()
510            .find(|(_, status)| matches!(status, Status::Failed(_) | Status::Timeout(_)))
511            .map(|(range, status)| (range.start, status.clone()))
512    }
513}
514
515impl<T> From<(usize, T)> for RankedValues<T> {
516    fn from((rank, value): (usize, T)) -> Self {
517        Self {
518            intervals: vec![(rank..rank + 1, value)],
519        }
520    }
521}
522
523impl<T> From<(Range<usize>, T)> for RankedValues<T> {
524    fn from((range, value): (Range<usize>, T)) -> Self {
525        Self {
526            intervals: vec![(range, value)],
527        }
528    }
529}
530
531/// An inverted index of RankedValues, providing all ranks for
532/// which each unique T-typed value appears.
533#[derive(Clone, Debug)]
534pub struct ValuesByRank<T> {
535    values: HashMap<T, Vec<Range<usize>>>,
536}
537
538impl<T: Eq + Hash> PartialEq for ValuesByRank<T> {
539    fn eq(&self, other: &Self) -> bool {
540        self.values == other.values
541    }
542}
543
544impl<T: Eq + Hash> Eq for ValuesByRank<T> {}
545
546impl<T: fmt::Display> fmt::Display for ValuesByRank<T> {
547    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
548        let mut first_value = true;
549        for (value, ranges) in self.iter() {
550            if first_value {
551                first_value = false;
552            } else {
553                write!(f, ";")?;
554            }
555            write!(f, "{}=", value)?;
556            let mut first_range = true;
557            for range in ranges.iter() {
558                if first_range {
559                    first_range = false;
560                } else {
561                    write!(f, ",")?;
562                }
563                write!(f, "{}..{}", range.start, range.end)?;
564            }
565        }
566        Ok(())
567    }
568}
569
570impl<T> Deref for ValuesByRank<T> {
571    type Target = HashMap<T, Vec<Range<usize>>>;
572
573    fn deref(&self) -> &Self::Target {
574        &self.values
575    }
576}
577
578impl<T> DerefMut for ValuesByRank<T> {
579    fn deref_mut(&mut self) -> &mut Self::Target {
580        &mut self.values
581    }
582}
583
584/// Enabled for test only because we have to guarantee that the input
585/// iterator is well-formed.
586#[cfg(test)]
587impl<T> FromIterator<(Range<usize>, T)> for RankedValues<T> {
588    fn from_iter<I: IntoIterator<Item = (Range<usize>, T)>>(iter: I) -> Self {
589        Self {
590            intervals: iter.into_iter().collect(),
591        }
592    }
593}
594
595/// Spec for a host mesh agent to use when spawning a new proc.
596#[derive(Clone, Debug, Serialize, Deserialize, Named, Default)]
597pub(crate) struct ProcSpec {
598    /// Config values to set on the spawned proc's global config,
599    /// at the `ClientOverride` layer.
600    pub(crate) client_config_override: Attrs,
601}
602
603impl ProcSpec {
604    pub(crate) fn new(client_config_override: Attrs) -> Self {
605        Self {
606            client_config_override,
607        }
608    }
609}
610
611#[cfg(test)]
612mod tests {
613    use super::*;
614
615    #[test]
616    fn test_ranked_values_merge() {
617        #[derive(PartialEq, Debug, Eq, Clone)]
618        enum Side {
619            Left,
620            Right,
621            Both,
622        }
623        use Side::Both;
624        use Side::Left;
625        use Side::Right;
626
627        let mut left: RankedValues<Side> = [
628            (0..10, Left),
629            (15..20, Left),
630            (30..50, Both),
631            (60..70, Both),
632        ]
633        .into_iter()
634        .collect();
635
636        let right: RankedValues<Side> = [
637            (9..12, Right),
638            (25..30, Right),
639            (30..40, Both),
640            (40..50, Right),
641            (50..60, Both),
642        ]
643        .into_iter()
644        .collect();
645
646        left.merge_from(right);
647        assert_eq!(
648            left.iter().cloned().collect::<Vec<_>>(),
649            vec![
650                (0..9, Left),
651                (9..12, Right),
652                (15..20, Left),
653                (25..30, Right),
654                (30..40, Both),
655                (40..50, Right),
656                // Merge consecutive:
657                (50..70, Both)
658            ]
659        );
660
661        assert_eq!(left.rank(5), 5);
662        assert_eq!(left.rank(10), 10);
663        assert_eq!(left.rank(16), 13);
664        assert_eq!(left.rank(70), 62);
665        assert_eq!(left.rank(100), 62);
666    }
667
668    #[test]
669    fn test_equality() {
670        assert_eq!(
671            RankedValues::from((0..10, 123)),
672            RankedValues::from((0..10, 123))
673        );
674        assert_eq!(
675            RankedValues::from((0..10, Status::Failed("foo".to_string()))),
676            RankedValues::from((0..10, Status::Failed("foo".to_string()))),
677        );
678    }
679
680    #[test]
681    fn test_default_through_merging() {
682        let values: RankedValues<usize> =
683            [(0..10, 1), (15..20, 1), (30..50, 1)].into_iter().collect();
684
685        let mut default = RankedValues::from((0..50, 0));
686        default.merge_from(values);
687
688        assert_eq!(
689            default.iter().cloned().collect::<Vec<_>>(),
690            vec![
691                (0..10, 1),
692                (10..15, 0),
693                (15..20, 1),
694                (20..30, 0),
695                (30..50, 1)
696            ]
697        );
698    }
699}