hyperactor_mesh/
resource.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This modules defines a set of common message types used for managing resources
10//! in hyperactor meshes.
11
12use core::slice::GetDisjointMutIndex as _;
13use std::fmt::Debug;
14use std::hash::Hash;
15use std::marker::PhantomData;
16use std::mem::replace;
17use std::mem::take;
18use std::ops::Range;
19
20use enum_as_inner::EnumAsInner;
21use hyperactor::Bind;
22use hyperactor::HandleClient;
23use hyperactor::Handler;
24use hyperactor::Named;
25use hyperactor::PortRef;
26use hyperactor::RefClient;
27use hyperactor::RemoteMessage;
28use hyperactor::Unbind;
29use hyperactor::accum::Accumulator;
30use hyperactor::accum::CommReducer;
31use hyperactor::accum::ReducerFactory;
32use hyperactor::accum::ReducerSpec;
33use hyperactor::message::Bind;
34use hyperactor::message::Bindings;
35use hyperactor::message::Unbind;
36use serde::Deserialize;
37use serde::Serialize;
38
39use crate::v1::Name;
40
41/// The current lifecycle status of a resource.
42#[derive(
43    Clone,
44    Debug,
45    Serialize,
46    Deserialize,
47    Named,
48    PartialOrd,
49    Ord,
50    PartialEq,
51    Eq,
52    Hash,
53    EnumAsInner
54)]
55pub enum Status {
56    /// The resource does not exist.
57    NotExist,
58    /// The resource is being created.
59    Initializing,
60    /// The resource is running.
61    Running,
62    /// The resource is being stopped.
63    Stopping,
64    /// The resource is stopped.
65    Stopped,
66    /// The resource has failed, with an error message.
67    Failed(String),
68}
69
70impl Status {
71    /// Returns whether the status is a terminating status.
72    pub fn is_terminating(&self) -> bool {
73        matches!(self, Status::Stopping | Status::Stopped | Status::Failed(_))
74    }
75}
76
77/// Data type used to communicate ranks.
78/// Implements [`Bind`] and [`Unbind`]; the comm actor replaces
79/// instances with the delivered rank.
80#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq, Default)]
81pub struct Rank(pub Option<usize>);
82
83impl Rank {
84    /// Create a new rank with the provided value.
85    pub fn new(rank: usize) -> Self {
86        Self(Some(rank))
87    }
88
89    /// Unwrap the rank; panics if not set.
90    pub fn unwrap(&self) -> usize {
91        self.0.unwrap()
92    }
93}
94
95impl Unbind for Rank {
96    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
97        bindings.push_back(self)
98    }
99}
100
101impl Bind for Rank {
102    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
103        let bound = bindings.try_pop_front::<Rank>()?;
104        self.0 = bound.0;
105        Ok(())
106    }
107}
108
109/// Get the status of a resource at a rank. This message is designed to be
110/// cast and efficiently accumulated.
111#[derive(
112    Clone,
113    Debug,
114    Serialize,
115    Deserialize,
116    Named,
117    Handler,
118    HandleClient,
119    RefClient,
120    Bind,
121    Unbind
122)]
123pub struct GetRankStatus {
124    /// The name of the resource.
125    pub name: Name,
126    /// The status of the rank.
127    #[binding(include)]
128    pub reply: PortRef<RankedValues<Status>>,
129}
130
131/// The state of a resource.
132#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq)]
133pub struct State<S> {
134    /// The name of the resource.
135    pub name: Name,
136    /// Its status.
137    pub status: Status,
138    /// Optionally, a resource-defined state.
139    pub state: Option<S>,
140}
141
142/// Create or update a resource according to a spec.
143#[derive(
144    Debug,
145    Clone,
146    Serialize,
147    Deserialize,
148    Named,
149    Handler,
150    HandleClient,
151    RefClient,
152    Bind,
153    Unbind
154)]
155pub struct CreateOrUpdate<S> {
156    /// The name of the resource to create or update.
157    pub name: Name,
158    /// The rank of the resource, when available.
159    #[binding(include)]
160    pub rank: Rank,
161    /// The specification of the resource.
162    pub spec: S,
163}
164
165/// Retrieve the current state of the resource.
166#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
167pub struct GetState<S> {
168    /// The name of the resource.
169    pub name: Name,
170    /// A reply containing the state.
171    #[reply]
172    pub reply: PortRef<State<S>>,
173}
174
175// Cannot derive Bind and Unbind for this generic, implement manually.
176impl<S> Unbind for GetState<S>
177where
178    S: RemoteMessage,
179    S: Unbind,
180{
181    fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
182        self.reply.unbind(bindings)
183    }
184}
185
186impl<S> Bind for GetState<S>
187where
188    S: RemoteMessage,
189    S: Bind,
190{
191    fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
192        self.reply.bind(bindings)
193    }
194}
195
196impl<S> Clone for GetState<S>
197where
198    S: RemoteMessage,
199{
200    fn clone(&self) -> Self {
201        Self {
202            name: self.name.clone(),
203            reply: self.reply.clone(),
204        }
205    }
206}
207
208/// RankedValues compactly represents rank-indexed values of type T.
209/// It stores contiguous values in a set of intervals; thus it is
210/// efficient and compact when the cardinality of T-typed values is
211/// low.
212#[derive(Debug, Clone, Named, Serialize, Deserialize)]
213pub struct RankedValues<T> {
214    intervals: Vec<(Range<usize>, T)>,
215}
216
217impl<T> Default for RankedValues<T> {
218    fn default() -> Self {
219        Self {
220            intervals: Vec::new(),
221        }
222    }
223}
224
225impl<T> RankedValues<T> {
226    /// Iterate over contiguous rank intervals of values.
227    pub fn iter(&self) -> impl Iterator<Item = &(Range<usize>, T)> + '_ {
228        self.intervals.iter()
229    }
230
231    /// The (set) rank of the RankedValues is the number of values stored with
232    /// rank less than `value`.
233    pub fn rank(&self, value: usize) -> usize {
234        self.iter()
235            .take_while(|(ranks, _)| ranks.start <= value)
236            .map(|(ranks, _)| ranks.end.min(value) - ranks.start)
237            .sum()
238    }
239}
240
241impl<T: Eq + Clone> RankedValues<T> {
242    /// Merge `other` into this set of ranked values. Values in `other` that overlap
243    /// with `self` take prededence.
244    ///
245    /// This currently uses a simple algorithm that merges the full set of RankedValues.
246    /// This remains efficient when the cardinality of T-typed values is low. However,
247    /// it does not efficiently merge high cardinality value sets. Consider using interval
248    /// trees or bitmap techniques like Roaring Bitmaps in these cases.
249    pub fn merge_from(&mut self, other: Self) {
250        let mut left_iter = take(&mut self.intervals).into_iter();
251        let mut right_iter = other.intervals.into_iter();
252
253        let mut left = left_iter.next();
254        let mut right = right_iter.next();
255
256        while left.is_some() && right.is_some() {
257            let (left_ranks, left_value) = left.as_mut().unwrap();
258            let (right_ranks, right_value) = right.as_mut().unwrap();
259
260            if left_ranks.is_overlapping(right_ranks) {
261                if left_value == right_value {
262                    let ranks = left_ranks.start.min(right_ranks.start)..right_ranks.end;
263                    let (_, value) = replace(&mut right, right_iter.next()).unwrap();
264                    left_ranks.start = ranks.end;
265                    if left_ranks.is_empty() {
266                        left = left_iter.next();
267                    }
268                    self.append(ranks, value);
269                } else if left_ranks.start < right_ranks.start {
270                    let ranks = left_ranks.start..right_ranks.start;
271                    left_ranks.start = ranks.end;
272                    // TODO: get rid of clone
273                    self.append(ranks, left_value.clone());
274                } else {
275                    let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
276                    left_ranks.start = ranks.end;
277                    if left_ranks.is_empty() {
278                        left = left_iter.next();
279                    }
280                    self.append(ranks, value);
281                }
282            } else if left_ranks.start < right_ranks.start {
283                let (ranks, value) = replace(&mut left, left_iter.next()).unwrap();
284                self.append(ranks, value);
285            } else {
286                let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
287                self.append(ranks, value);
288            }
289        }
290
291        while let Some((left_ranks, left_value)) = left {
292            self.append(left_ranks, left_value);
293            left = left_iter.next();
294        }
295        while let Some((right_ranks, right_value)) = right {
296            self.append(right_ranks, right_value);
297            right = right_iter.next();
298        }
299    }
300
301    fn append(&mut self, range: Range<usize>, value: T) {
302        if let Some(last) = self.intervals.last_mut()
303            && last.0.end == range.start
304            && last.1 == value
305        {
306            last.0.end = range.end;
307        } else {
308            self.intervals.push((range, value));
309        }
310    }
311}
312
313impl<T> From<(usize, T)> for RankedValues<T> {
314    fn from((rank, value): (usize, T)) -> Self {
315        Self {
316            intervals: vec![(rank..rank + 1, value)],
317        }
318    }
319}
320
321/// Enabled for test only because we have to guarantee that the input
322/// iterator is well-formed.
323#[cfg(test)]
324impl<T> FromIterator<(Range<usize>, T)> for RankedValues<T> {
325    fn from_iter<I: IntoIterator<Item = (Range<usize>, T)>>(iter: I) -> Self {
326        Self {
327            intervals: iter.into_iter().collect(),
328        }
329    }
330}
331
332impl<T: Eq + Clone + Named> Accumulator for RankedValues<T> {
333    type State = Self;
334    type Update = Self;
335
336    fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()> {
337        state.merge_from(update);
338        Ok(())
339    }
340
341    fn reducer_spec(&self) -> Option<ReducerSpec> {
342        None
343        // TODO: make this work. When it is enabled, the comm actor simply halts.
344        // Some(ReducerSpec {
345        //     typehash: <RankedValuesReducer<T> as Named>::typehash(),
346        //     builder_params: None,
347        // })
348    }
349}
350
351#[derive(Named)]
352struct RankedValuesReducer<T>(std::marker::PhantomData<T>);
353
354impl<T: Hash + Eq + Ord + Clone> CommReducer for RankedValuesReducer<T> {
355    type Update = RankedValues<T>;
356
357    fn reduce(&self, mut left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update> {
358        left.merge_from(right);
359        Ok(left)
360    }
361}
362
363// register for concrete types:
364
365hyperactor::submit! {
366    ReducerFactory {
367        typehash_f: <RankedValuesReducer<Status> as Named>::typehash,
368        builder_f: |_| Ok(Box::new(RankedValuesReducer::<Status>(PhantomData))),
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    #[test]
377    fn test_ranked_values_merge() {
378        #[derive(PartialEq, Debug, Eq, Clone)]
379        enum Side {
380            Left,
381            Right,
382            Both,
383        }
384        use Side::Both;
385        use Side::Left;
386        use Side::Right;
387
388        let mut left: RankedValues<Side> = [
389            (0..10, Left),
390            (15..20, Left),
391            (30..50, Both),
392            (60..70, Both),
393        ]
394        .into_iter()
395        .collect();
396
397        let right: RankedValues<Side> = [
398            (9..12, Right),
399            (25..30, Right),
400            (30..40, Both),
401            (40..50, Right),
402            (50..60, Both),
403        ]
404        .into_iter()
405        .collect();
406
407        left.merge_from(right);
408        assert_eq!(
409            left.iter().cloned().collect::<Vec<_>>(),
410            vec![
411                (0..9, Left),
412                (9..12, Right),
413                (15..20, Left),
414                (25..30, Right),
415                (30..40, Both),
416                (40..50, Right),
417                // Merge consecutive:
418                (50..70, Both)
419            ]
420        );
421
422        assert_eq!(left.rank(5), 5);
423        assert_eq!(left.rank(10), 10);
424        assert_eq!(left.rank(16), 13);
425        assert_eq!(left.rank(70), 62);
426        assert_eq!(left.rank(100), 62);
427    }
428}