hyperactor_mesh/
resource.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This modules defines a set of common message types used for managing resources
10//! in hyperactor meshes.
11
12pub mod mesh;
13
14use core::slice::GetDisjointMutIndex as _;
15use std::collections::HashMap;
16use std::fmt;
17use std::fmt::Debug;
18use std::hash::Hash;
19use std::mem::replace;
20use std::mem::take;
21use std::ops::Deref;
22use std::ops::DerefMut;
23use std::ops::Range;
24use std::time::Duration;
25
26use enum_as_inner::EnumAsInner;
27use hyperactor::Bind;
28use hyperactor::HandleClient;
29use hyperactor::Handler;
30use hyperactor::RefClient;
31use hyperactor::RemoteMessage;
32use hyperactor::Unbind;
33use hyperactor::mailbox::PortReceiver;
34use hyperactor::message::Bind;
35use hyperactor::message::Bindings;
36use hyperactor::message::Unbind;
37use hyperactor::reference as hyperactor_reference;
38use hyperactor_config::attrs::Attrs;
39use ndslice::Region;
40use ndslice::ViewExt;
41use serde::Deserialize;
42use serde::Serialize;
43use typeuri::Named;
44
45use crate::Name;
46use crate::StatusOverlay;
47use crate::bootstrap;
48use crate::host_mesh::host_agent::ProcState;
49use crate::proc_agent::ActorSpec;
50use crate::proc_agent::ActorState;
51
52/// The current lifecycle status of a resource.
53#[derive(
54    Clone,
55    Debug,
56    Serialize,
57    Deserialize,
58    Named,
59    PartialOrd,
60    Ord,
61    PartialEq,
62    Eq,
63    Hash,
64    EnumAsInner,
65    strum::Display
66)]
67pub enum Status {
68    /// The resource does not exist.
69    NotExist,
70    /// The resource is being created.
71    Initializing,
72    /// The resource is running.
73    Running,
74    /// The resource is being stopped.
75    Stopping,
76    /// The resource is stopped.
77    Stopped,
78    /// The resource has failed, with an error message.
79    #[strum(to_string = "Failed({0})")]
80    Failed(String),
81    /// The resource has been declared failed after a timeout.
82    #[strum(to_string = "Timeout({0:?})")]
83    Timeout(Duration),
84    /// The resource exists but its status is not known.
85    Unknown,
86}
87
88impl Status {
89    /// Returns whether the status is a terminating status.
90    pub fn is_terminating(&self) -> bool {
91        matches!(
92            self,
93            Status::Stopping | Status::Stopped | Status::Failed(_) | Status::Timeout(_)
94        )
95    }
96
97    /// Tells whether the status represents a failure. A failure is both terminating
98    /// (the resource is not running), but also means abnormal exit (the resource
99    /// did not stop cleanly).
100    pub fn is_failure(&self) -> bool {
101        matches!(self, Self::Failed(_) | Self::Timeout(_))
102    }
103
104    pub fn is_healthy(&self) -> bool {
105        matches!(self, Status::Initializing | Status::Running)
106    }
107}
108
109impl From<bootstrap::ProcStatus> for Status {
110    fn from(status: bootstrap::ProcStatus) -> Self {
111        use bootstrap::ProcStatus;
112        match status {
113            ProcStatus::Starting => Status::Initializing,
114            ProcStatus::Running { .. } | ProcStatus::Ready { .. } => Status::Running,
115            ProcStatus::Stopping { .. } => Status::Stopping,
116            ProcStatus::Stopped { .. } => Status::Stopped,
117            ProcStatus::Failed { reason } => Status::Failed(reason),
118            ProcStatus::Killed { .. } => Status::Failed(format!("{}", status)),
119        }
120    }
121}
122
123/// Data type used to communicate ranks.
124/// Implements [`Bind`] and [`Unbind`]; the comm actor replaces
125/// instances with the delivered rank.
126#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq, Default)]
127pub struct Rank(pub Option<usize>);
128wirevalue::register_type!(Rank);
129
130impl Rank {
131    /// Create a new rank with the provided value.
132    pub fn new(rank: usize) -> Self {
133        Self(Some(rank))
134    }
135
136    /// Unwrap the rank; panics if not set.
137    pub fn unwrap(&self) -> usize {
138        self.0.unwrap()
139    }
140}
141
142impl Unbind for Rank {
143    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
144        bindings.push_back(self)
145    }
146}
147
148impl Bind for Rank {
149    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
150        let bound = bindings.try_pop_front::<Rank>()?;
151        self.0 = bound.0;
152        Ok(())
153    }
154}
155
156/// Get the status of a resource across the mesh.
157///
158/// This message is cast to all ranks; each rank replies with a sparse
159/// status **overlay**. The comm reducer merges overlays (right-wins)
160/// and the accumulator applies them to produce **full StatusMesh
161/// snapshots** on the receiver side.
162#[derive(
163    Clone,
164    Debug,
165    Serialize,
166    Deserialize,
167    Named,
168    Handler,
169    HandleClient,
170    RefClient,
171    Bind,
172    Unbind
173)]
174pub struct GetRankStatus {
175    /// The name of the resource.
176    pub name: Name,
177    /// Sparse status updates (overlays) from a rank.
178    #[binding(include)]
179    pub reply: hyperactor_reference::PortRef<StatusOverlay>,
180}
181
182impl GetRankStatus {
183    pub async fn wait(
184        mut rx: PortReceiver<crate::StatusMesh>,
185        num_ranks: usize,
186        max_idle_time: Duration,
187        region: Region, // used only for fallback
188    ) -> Result<crate::StatusMesh, crate::StatusMesh> {
189        debug_assert_eq!(region.num_ranks(), num_ranks, "region/num_ranks mismatch");
190
191        let mut alarm = hyperactor::time::Alarm::new();
192        alarm.arm(max_idle_time);
193
194        // Fallback snapshot if we time out before receiving anything.
195        let mut snapshot =
196            crate::StatusMesh::from_single(region, crate::resource::Status::NotExist);
197
198        loop {
199            let mut sleeper = alarm.sleeper();
200            tokio::select! {
201                _ = sleeper.sleep() => return Err(snapshot),
202                next = rx.recv() => {
203                    match next {
204                        Ok(mesh) => { snapshot = mesh; }   // latest-wins snapshot
205                        Err(_)   => return Err(snapshot),
206                    }
207                }
208            }
209
210            alarm.arm(max_idle_time);
211
212            // Completion: once every rank (among the first
213            // `num_ranks`) has reported at least something (i.e.
214            // moved off NotExist).
215            if snapshot
216                .values()
217                .take(num_ranks)
218                .all(|s| !matches!(s, crate::resource::Status::NotExist))
219            {
220                break Ok(snapshot);
221            }
222        }
223    }
224}
225
226/// The state of a resource.
227#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq)]
228pub struct State<S> {
229    /// The name of the resource.
230    pub name: Name,
231    /// Its status.
232    pub status: Status,
233    /// Optionally, a resource-defined state.
234    pub state: Option<S>,
235}
236
237impl<S: Serialize> fmt::Display for State<S> {
238    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239        // Use serde_json to serialize the struct to a compact JSON string
240        match serde_json::to_string(self) {
241            Ok(json) => write!(f, "{}", json),
242            Err(e) => write!(f, "<state: serde_json error: {}>", e),
243        }
244    }
245}
246
247/// Create or update a resource according to a spec.
248#[derive(
249    Debug,
250    Clone,
251    Serialize,
252    Deserialize,
253    Named,
254    Handler,
255    HandleClient,
256    RefClient,
257    Bind,
258    Unbind
259)]
260pub struct CreateOrUpdate<S> {
261    /// The name of the resource to create or update.
262    pub name: Name,
263    /// The rank of the resource, when available.
264    #[binding(include)]
265    pub rank: Rank,
266    /// The specification of the resource.
267    pub spec: S,
268}
269wirevalue::register_type!(CreateOrUpdate<ProcSpec>);
270wirevalue::register_type!(CreateOrUpdate<ActorSpec>);
271
272/// Stop a resource according to a spec.
273#[derive(
274    Debug,
275    Clone,
276    Serialize,
277    Deserialize,
278    Named,
279    Handler,
280    HandleClient,
281    RefClient,
282    Bind,
283    Unbind
284)]
285pub struct Stop {
286    /// The name of the resource to stop.
287    pub name: Name,
288    /// The reason for stopping the resource.
289    pub reason: String,
290}
291wirevalue::register_type!(Stop);
292
293/// Stop all resources owned by the receiver of this message.
294/// No reply, this just issues the stop command.
295/// Use GetRankStatus to determine if it has successfully stopped.
296#[derive(
297    Debug,
298    Clone,
299    Serialize,
300    Deserialize,
301    Named,
302    Handler,
303    HandleClient,
304    RefClient,
305    Bind,
306    Unbind
307)]
308pub struct StopAll {
309    /// The reason for stopping.
310    pub reason: String,
311}
312wirevalue::register_type!(StopAll);
313
314/// Retrieve the current state of the resource.
315#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
316pub struct GetState<S> {
317    /// The name of the resource.
318    pub name: Name,
319    /// A reply containing the state.
320    #[reply]
321    pub reply: hyperactor_reference::PortRef<State<S>>,
322}
323wirevalue::register_type!(GetState<ProcState>);
324wirevalue::register_type!(GetState<ActorState>);
325
326// Cannot derive Bind and Unbind for this generic, implement manually.
327impl<S> Unbind for GetState<S>
328where
329    S: RemoteMessage,
330    S: Unbind,
331{
332    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
333        self.reply.unbind(bindings)
334    }
335}
336
337impl<S> Bind for GetState<S>
338where
339    S: RemoteMessage,
340    S: Bind,
341{
342    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
343        self.reply.bind(bindings)
344    }
345}
346
347impl<S> Clone for GetState<S>
348where
349    S: RemoteMessage,
350{
351    fn clone(&self) -> Self {
352        Self {
353            name: self.name.clone(),
354            reply: self.reply.clone(),
355        }
356    }
357}
358
359/// Same as GetState, but additionally tells the receiver that the owner is still alive.
360/// If the receiver does not receive this message for a while, it might assume the owner is dead.
361#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
362pub struct KeepaliveGetState<S> {
363    /// The time at which the actor should be considered expired if no further
364    /// keepalive is received.
365    pub expires_after: std::time::SystemTime,
366    pub get_state: GetState<S>,
367}
368wirevalue::register_type!(KeepaliveGetState<ProcState>);
369wirevalue::register_type!(KeepaliveGetState<ActorState>);
370
371// Cannot derive Bind and Unbind for this generic, implement manually.
372impl<S> Unbind for KeepaliveGetState<S>
373where
374    S: RemoteMessage,
375    S: Unbind,
376{
377    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
378        self.get_state.unbind(bindings)
379    }
380}
381
382impl<S> Bind for KeepaliveGetState<S>
383where
384    S: RemoteMessage,
385    S: Bind,
386{
387    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
388        self.get_state.bind(bindings)
389    }
390}
391
392impl<S> Clone for KeepaliveGetState<S>
393where
394    S: RemoteMessage,
395{
396    fn clone(&self) -> Self {
397        Self {
398            expires_after: self.expires_after.clone(),
399            get_state: self.get_state.clone(),
400        }
401    }
402}
403
404/// List the set of resources managed by the controller.
405#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
406pub struct List {
407    /// List of resource names managed by this controller.
408    #[reply]
409    pub reply: hyperactor_reference::PortRef<Vec<Name>>,
410}
411wirevalue::register_type!(List);
412
413/// A trait that bundles a set of types that together define a resource.
414pub trait Resource {
415    /// The spec specification for this resource.
416    type Spec: typeuri::Named
417        + Serialize
418        + for<'de> Deserialize<'de>
419        + Send
420        + Sync
421        + std::fmt::Debug;
422
423    /// The state for this resource.
424    type State: typeuri::Named
425        + Serialize
426        + for<'de> Deserialize<'de>
427        + Send
428        + Sync
429        + std::fmt::Debug;
430}
431
432// A behavior defining the interface for a mesh controller.
433hyperactor::behavior!(
434    Controller<R: Resource>,
435    CreateOrUpdate<R::Spec>,
436    GetState<R::State>,
437    Stop,
438);
439
440/// RankedValues compactly represents rank-indexed values of type T.
441/// It stores contiguous values in a set of intervals; thus it is
442/// efficient and compact when the cardinality of T-typed values is
443/// low.
444#[derive(Debug, Clone, Named, Serialize, Deserialize)]
445pub struct RankedValues<T> {
446    intervals: Vec<(Range<usize>, T)>,
447}
448
449impl<T: PartialEq> PartialEq for RankedValues<T> {
450    fn eq(&self, other: &Self) -> bool {
451        self.intervals == other.intervals
452    }
453}
454
455impl<T: Eq> Eq for RankedValues<T> {}
456
457impl<T> Default for RankedValues<T> {
458    fn default() -> Self {
459        Self {
460            intervals: Vec::new(),
461        }
462    }
463}
464
465impl<T> RankedValues<T> {
466    /// Iterate over contiguous rank intervals of values.
467    pub fn iter(&self) -> impl Iterator<Item = &(Range<usize>, T)> + '_ {
468        self.intervals.iter()
469    }
470
471    /// The (set) rank of the RankedValues is the number of values stored with
472    /// rank less than `value`.
473    pub fn rank(&self, value: usize) -> usize {
474        self.iter()
475            .take_while(|(ranks, _)| ranks.start <= value)
476            .map(|(ranks, _)| ranks.end.min(value) - ranks.start)
477            .sum()
478    }
479}
480
481impl<T: Clone> RankedValues<T> {
482    pub fn materialized_iter(&self, until: usize) -> impl Iterator<Item = &T> + '_ {
483        assert_eq!(self.rank(until), until, "insufficient rank");
484        self.iter()
485            .flat_map(|(range, value)| std::iter::repeat_n(value, range.end - range.start))
486    }
487}
488
489impl<T: Hash + Eq + Clone> RankedValues<T> {
490    /// Invert this ranked values into a [`ValuesByRank<T>`].
491    pub fn invert(&self) -> ValuesByRank<T> {
492        let mut inverted: HashMap<T, Vec<Range<usize>>> = HashMap::new();
493        for (range, value) in self.iter() {
494            inverted
495                .entry(value.clone())
496                .or_default()
497                .push(range.clone());
498        }
499        ValuesByRank { values: inverted }
500    }
501}
502
503impl<T: Eq + Clone> RankedValues<T> {
504    /// Merge `other` into this set of ranked values. Values in `other` that overlap
505    /// with `self` take prededence.
506    ///
507    /// This currently uses a simple algorithm that merges the full set of RankedValues.
508    /// This remains efficient when the cardinality of T-typed values is low. However,
509    /// it does not efficiently merge high cardinality value sets. Consider using interval
510    /// trees or bitmap techniques like Roaring Bitmaps in these cases.
511    pub fn merge_from(&mut self, other: Self) {
512        let mut left_iter = take(&mut self.intervals).into_iter();
513        let mut right_iter = other.intervals.into_iter();
514
515        let mut left = left_iter.next();
516        let mut right = right_iter.next();
517
518        while left.is_some() && right.is_some() {
519            let (left_ranks, left_value) = left.as_mut().unwrap();
520            let (right_ranks, right_value) = right.as_mut().unwrap();
521
522            if left_ranks.is_overlapping(right_ranks) {
523                if left_value == right_value {
524                    let ranks = left_ranks.start.min(right_ranks.start)..right_ranks.end;
525                    let (_, value) = replace(&mut right, right_iter.next()).unwrap();
526                    left_ranks.start = ranks.end;
527                    if left_ranks.is_empty() {
528                        left = left_iter.next();
529                    }
530                    self.append(ranks, value);
531                } else if left_ranks.start < right_ranks.start {
532                    let ranks = left_ranks.start..right_ranks.start;
533                    left_ranks.start = ranks.end;
534                    // TODO: get rid of clone
535                    self.append(ranks, left_value.clone());
536                } else {
537                    let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
538                    left_ranks.start = ranks.end;
539                    if left_ranks.is_empty() {
540                        left = left_iter.next();
541                    }
542                    self.append(ranks, value);
543                }
544            } else if left_ranks.start < right_ranks.start {
545                let (ranks, value) = replace(&mut left, left_iter.next()).unwrap();
546                self.append(ranks, value);
547            } else {
548                let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
549                self.append(ranks, value);
550            }
551        }
552
553        while let Some((left_ranks, left_value)) = left {
554            self.append(left_ranks, left_value);
555            left = left_iter.next();
556        }
557        while let Some((right_ranks, right_value)) = right {
558            self.append(right_ranks, right_value);
559            right = right_iter.next();
560        }
561    }
562
563    /// Merge the contents of this RankedValues into another RankedValues.
564    pub fn merge_into(self, other: &mut Self) {
565        other.merge_from(self);
566    }
567
568    fn append(&mut self, range: Range<usize>, value: T) {
569        if let Some(last) = self.intervals.last_mut()
570            && last.0.end == range.start
571            && last.1 == value
572        {
573            last.0.end = range.end;
574        } else {
575            self.intervals.push((range, value));
576        }
577    }
578}
579
580impl RankedValues<Status> {
581    pub fn first_terminating(&self) -> Option<(usize, Status)> {
582        self.intervals
583            .iter()
584            .find(|(_, status)| status.is_terminating())
585            .map(|(range, status)| (range.start, status.clone()))
586    }
587
588    pub fn first_failed(&self) -> Option<(usize, Status)> {
589        self.intervals
590            .iter()
591            .find(|(_, status)| matches!(status, Status::Failed(_) | Status::Timeout(_)))
592            .map(|(range, status)| (range.start, status.clone()))
593    }
594}
595
596impl<T> From<(usize, T)> for RankedValues<T> {
597    fn from((rank, value): (usize, T)) -> Self {
598        Self {
599            intervals: vec![(rank..rank + 1, value)],
600        }
601    }
602}
603
604impl<T> From<(Range<usize>, T)> for RankedValues<T> {
605    fn from((range, value): (Range<usize>, T)) -> Self {
606        Self {
607            intervals: vec![(range, value)],
608        }
609    }
610}
611
612/// An inverted index of RankedValues, providing all ranks for
613/// which each unique T-typed value appears.
614#[derive(Clone, Debug)]
615pub struct ValuesByRank<T> {
616    values: HashMap<T, Vec<Range<usize>>>,
617}
618
619impl<T: Eq + Hash> PartialEq for ValuesByRank<T> {
620    fn eq(&self, other: &Self) -> bool {
621        self.values == other.values
622    }
623}
624
625impl<T: Eq + Hash> Eq for ValuesByRank<T> {}
626
627impl<T: fmt::Display> fmt::Display for ValuesByRank<T> {
628    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
629        let mut first_value = true;
630        for (value, ranges) in self.iter() {
631            if first_value {
632                first_value = false;
633            } else {
634                write!(f, ";")?;
635            }
636            write!(f, "{}=", value)?;
637            let mut first_range = true;
638            for range in ranges.iter() {
639                if first_range {
640                    first_range = false;
641                } else {
642                    write!(f, ",")?;
643                }
644                write!(f, "{}..{}", range.start, range.end)?;
645            }
646        }
647        Ok(())
648    }
649}
650
651impl<T> Deref for ValuesByRank<T> {
652    type Target = HashMap<T, Vec<Range<usize>>>;
653
654    fn deref(&self) -> &Self::Target {
655        &self.values
656    }
657}
658
659impl<T> DerefMut for ValuesByRank<T> {
660    fn deref_mut(&mut self) -> &mut Self::Target {
661        &mut self.values
662    }
663}
664
665/// Enabled for test only because we have to guarantee that the input
666/// iterator is well-formed.
667#[cfg(test)]
668impl<T> FromIterator<(Range<usize>, T)> for RankedValues<T> {
669    fn from_iter<I: IntoIterator<Item = (Range<usize>, T)>>(iter: I) -> Self {
670        Self {
671            intervals: iter.into_iter().collect(),
672        }
673    }
674}
675
676/// Spec for a host mesh agent to use when spawning a new proc.
677#[derive(Clone, Debug, Serialize, Deserialize, Named, Default)]
678pub(crate) struct ProcSpec {
679    /// Config values to set on the spawned proc's global config,
680    /// at the `ClientOverride` layer.
681    pub(crate) client_config_override: Attrs,
682}
683wirevalue::register_type!(ProcSpec);
684
685impl ProcSpec {
686    pub(crate) fn new(client_config_override: Attrs) -> Self {
687        Self {
688            client_config_override,
689        }
690    }
691}
692
693#[cfg(test)]
694mod tests {
695    use super::*;
696
697    #[test]
698    fn test_ranked_values_merge() {
699        #[derive(PartialEq, Debug, Eq, Clone)]
700        enum Side {
701            Left,
702            Right,
703            Both,
704        }
705        use Side::Both;
706        use Side::Left;
707        use Side::Right;
708
709        let mut left: RankedValues<Side> = [
710            (0..10, Left),
711            (15..20, Left),
712            (30..50, Both),
713            (60..70, Both),
714        ]
715        .into_iter()
716        .collect();
717
718        let right: RankedValues<Side> = [
719            (9..12, Right),
720            (25..30, Right),
721            (30..40, Both),
722            (40..50, Right),
723            (50..60, Both),
724        ]
725        .into_iter()
726        .collect();
727
728        left.merge_from(right);
729        assert_eq!(
730            left.iter().cloned().collect::<Vec<_>>(),
731            vec![
732                (0..9, Left),
733                (9..12, Right),
734                (15..20, Left),
735                (25..30, Right),
736                (30..40, Both),
737                (40..50, Right),
738                // Merge consecutive:
739                (50..70, Both)
740            ]
741        );
742
743        assert_eq!(left.rank(5), 5);
744        assert_eq!(left.rank(10), 10);
745        assert_eq!(left.rank(16), 13);
746        assert_eq!(left.rank(70), 62);
747        assert_eq!(left.rank(100), 62);
748    }
749
750    #[test]
751    fn test_equality() {
752        assert_eq!(
753            RankedValues::from((0..10, 123)),
754            RankedValues::from((0..10, 123))
755        );
756        assert_eq!(
757            RankedValues::from((0..10, Status::Failed("foo".to_string()))),
758            RankedValues::from((0..10, Status::Failed("foo".to_string()))),
759        );
760    }
761
762    #[test]
763    fn test_default_through_merging() {
764        let values: RankedValues<usize> =
765            [(0..10, 1), (15..20, 1), (30..50, 1)].into_iter().collect();
766
767        let mut default = RankedValues::from((0..50, 0));
768        default.merge_from(values);
769
770        assert_eq!(
771            default.iter().cloned().collect::<Vec<_>>(),
772            vec![
773                (0..10, 1),
774                (10..15, 0),
775                (15..20, 1),
776                (20..30, 0),
777                (30..50, 1)
778            ]
779        );
780    }
781}