1use core::slice::GetDisjointMutIndex as _;
13use std::fmt::Debug;
14use std::hash::Hash;
15use std::marker::PhantomData;
16use std::mem::replace;
17use std::mem::take;
18use std::ops::Range;
19
20use enum_as_inner::EnumAsInner;
21use hyperactor::Bind;
22use hyperactor::HandleClient;
23use hyperactor::Handler;
24use hyperactor::Named;
25use hyperactor::PortRef;
26use hyperactor::RefClient;
27use hyperactor::RemoteMessage;
28use hyperactor::Unbind;
29use hyperactor::accum::Accumulator;
30use hyperactor::accum::CommReducer;
31use hyperactor::accum::ReducerFactory;
32use hyperactor::accum::ReducerSpec;
33use hyperactor::message::Bind;
34use hyperactor::message::Bindings;
35use hyperactor::message::Unbind;
36use serde::Deserialize;
37use serde::Serialize;
38
39use crate::v1::Name;
40
41#[derive(
43 Clone,
44 Debug,
45 Serialize,
46 Deserialize,
47 Named,
48 PartialOrd,
49 Ord,
50 PartialEq,
51 Eq,
52 Hash,
53 EnumAsInner
54)]
55pub enum Status {
56 NotExist,
58 Initializing,
60 Running,
62 Stopping,
64 Stopped,
66 Failed(String),
68}
69
70impl Status {
71 pub fn is_terminating(&self) -> bool {
73 matches!(self, Status::Stopping | Status::Stopped | Status::Failed(_))
74 }
75}
76
77#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq, Default)]
81pub struct Rank(pub Option<usize>);
82
83impl Rank {
84 pub fn new(rank: usize) -> Self {
86 Self(Some(rank))
87 }
88
89 pub fn unwrap(&self) -> usize {
91 self.0.unwrap()
92 }
93}
94
95impl Unbind for Rank {
96 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
97 bindings.push_back(self)
98 }
99}
100
101impl Bind for Rank {
102 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
103 let bound = bindings.try_pop_front::<Rank>()?;
104 self.0 = bound.0;
105 Ok(())
106 }
107}
108
109#[derive(
112 Clone,
113 Debug,
114 Serialize,
115 Deserialize,
116 Named,
117 Handler,
118 HandleClient,
119 RefClient,
120 Bind,
121 Unbind
122)]
123pub struct GetRankStatus {
124 pub name: Name,
126 #[binding(include)]
128 pub reply: PortRef<RankedValues<Status>>,
129}
130
131#[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq, Eq)]
133pub struct State<S> {
134 pub name: Name,
136 pub status: Status,
138 pub state: Option<S>,
140}
141
142#[derive(
144 Debug,
145 Clone,
146 Serialize,
147 Deserialize,
148 Named,
149 Handler,
150 HandleClient,
151 RefClient,
152 Bind,
153 Unbind
154)]
155pub struct CreateOrUpdate<S> {
156 pub name: Name,
158 #[binding(include)]
160 pub rank: Rank,
161 pub spec: S,
163}
164
165#[derive(Debug, Serialize, Deserialize, Named, Handler, HandleClient, RefClient)]
167pub struct GetState<S> {
168 pub name: Name,
170 #[reply]
172 pub reply: PortRef<State<S>>,
173}
174
175impl<S> Unbind for GetState<S>
177where
178 S: RemoteMessage,
179 S: Unbind,
180{
181 fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> {
182 self.reply.unbind(bindings)
183 }
184}
185
186impl<S> Bind for GetState<S>
187where
188 S: RemoteMessage,
189 S: Bind,
190{
191 fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> {
192 self.reply.bind(bindings)
193 }
194}
195
196impl<S> Clone for GetState<S>
197where
198 S: RemoteMessage,
199{
200 fn clone(&self) -> Self {
201 Self {
202 name: self.name.clone(),
203 reply: self.reply.clone(),
204 }
205 }
206}
207
208#[derive(Debug, Clone, Named, Serialize, Deserialize)]
213pub struct RankedValues<T> {
214 intervals: Vec<(Range<usize>, T)>,
215}
216
217impl<T> Default for RankedValues<T> {
218 fn default() -> Self {
219 Self {
220 intervals: Vec::new(),
221 }
222 }
223}
224
225impl<T> RankedValues<T> {
226 pub fn iter(&self) -> impl Iterator<Item = &(Range<usize>, T)> + '_ {
228 self.intervals.iter()
229 }
230
231 pub fn rank(&self, value: usize) -> usize {
234 self.iter()
235 .take_while(|(ranks, _)| ranks.start <= value)
236 .map(|(ranks, _)| ranks.end.min(value) - ranks.start)
237 .sum()
238 }
239}
240
241impl<T: Eq + Clone> RankedValues<T> {
242 pub fn merge_from(&mut self, other: Self) {
250 let mut left_iter = take(&mut self.intervals).into_iter();
251 let mut right_iter = other.intervals.into_iter();
252
253 let mut left = left_iter.next();
254 let mut right = right_iter.next();
255
256 while left.is_some() && right.is_some() {
257 let (left_ranks, left_value) = left.as_mut().unwrap();
258 let (right_ranks, right_value) = right.as_mut().unwrap();
259
260 if left_ranks.is_overlapping(right_ranks) {
261 if left_value == right_value {
262 let ranks = left_ranks.start.min(right_ranks.start)..right_ranks.end;
263 let (_, value) = replace(&mut right, right_iter.next()).unwrap();
264 left_ranks.start = ranks.end;
265 if left_ranks.is_empty() {
266 left = left_iter.next();
267 }
268 self.append(ranks, value);
269 } else if left_ranks.start < right_ranks.start {
270 let ranks = left_ranks.start..right_ranks.start;
271 left_ranks.start = ranks.end;
272 self.append(ranks, left_value.clone());
274 } else {
275 let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
276 left_ranks.start = ranks.end;
277 if left_ranks.is_empty() {
278 left = left_iter.next();
279 }
280 self.append(ranks, value);
281 }
282 } else if left_ranks.start < right_ranks.start {
283 let (ranks, value) = replace(&mut left, left_iter.next()).unwrap();
284 self.append(ranks, value);
285 } else {
286 let (ranks, value) = replace(&mut right, right_iter.next()).unwrap();
287 self.append(ranks, value);
288 }
289 }
290
291 while let Some((left_ranks, left_value)) = left {
292 self.append(left_ranks, left_value);
293 left = left_iter.next();
294 }
295 while let Some((right_ranks, right_value)) = right {
296 self.append(right_ranks, right_value);
297 right = right_iter.next();
298 }
299 }
300
301 fn append(&mut self, range: Range<usize>, value: T) {
302 if let Some(last) = self.intervals.last_mut()
303 && last.0.end == range.start
304 && last.1 == value
305 {
306 last.0.end = range.end;
307 } else {
308 self.intervals.push((range, value));
309 }
310 }
311}
312
313impl<T> From<(usize, T)> for RankedValues<T> {
314 fn from((rank, value): (usize, T)) -> Self {
315 Self {
316 intervals: vec![(rank..rank + 1, value)],
317 }
318 }
319}
320
321#[cfg(test)]
324impl<T> FromIterator<(Range<usize>, T)> for RankedValues<T> {
325 fn from_iter<I: IntoIterator<Item = (Range<usize>, T)>>(iter: I) -> Self {
326 Self {
327 intervals: iter.into_iter().collect(),
328 }
329 }
330}
331
332impl<T: Eq + Clone + Named> Accumulator for RankedValues<T> {
333 type State = Self;
334 type Update = Self;
335
336 fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()> {
337 state.merge_from(update);
338 Ok(())
339 }
340
341 fn reducer_spec(&self) -> Option<ReducerSpec> {
342 None
343 }
349}
350
351#[derive(Named)]
352struct RankedValuesReducer<T>(std::marker::PhantomData<T>);
353
354impl<T: Hash + Eq + Ord + Clone> CommReducer for RankedValuesReducer<T> {
355 type Update = RankedValues<T>;
356
357 fn reduce(&self, mut left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update> {
358 left.merge_from(right);
359 Ok(left)
360 }
361}
362
363hyperactor::submit! {
366 ReducerFactory {
367 typehash_f: <RankedValuesReducer<Status> as Named>::typehash,
368 builder_f: |_| Ok(Box::new(RankedValuesReducer::<Status>(PhantomData))),
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn test_ranked_values_merge() {
378 #[derive(PartialEq, Debug, Eq, Clone)]
379 enum Side {
380 Left,
381 Right,
382 Both,
383 }
384 use Side::Both;
385 use Side::Left;
386 use Side::Right;
387
388 let mut left: RankedValues<Side> = [
389 (0..10, Left),
390 (15..20, Left),
391 (30..50, Both),
392 (60..70, Both),
393 ]
394 .into_iter()
395 .collect();
396
397 let right: RankedValues<Side> = [
398 (9..12, Right),
399 (25..30, Right),
400 (30..40, Both),
401 (40..50, Right),
402 (50..60, Both),
403 ]
404 .into_iter()
405 .collect();
406
407 left.merge_from(right);
408 assert_eq!(
409 left.iter().cloned().collect::<Vec<_>>(),
410 vec![
411 (0..9, Left),
412 (9..12, Right),
413 (15..20, Left),
414 (25..30, Right),
415 (30..40, Both),
416 (40..50, Right),
417 (50..70, Both)
419 ]
420 );
421
422 assert_eq!(left.rank(5), 5);
423 assert_eq!(left.rank(10), 10);
424 assert_eq!(left.rank(16), 13);
425 assert_eq!(left.rank(70), 62);
426 assert_eq!(left.rank(100), 62);
427 }
428}