1use std::collections::HashMap;
12use std::marker::PhantomData;
13use std::sync::OnceLock;
14use std::time::Duration;
15
16use serde::Deserialize;
17use serde::Serialize;
18use serde::de::DeserializeOwned;
19use typeuri::Named;
20
21use crate::config;
23use crate::reference::Index;
24
25pub trait Accumulator {
27 type State;
29 type Update;
32
33 fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()>;
35
36 fn reducer_spec(&self) -> Option<ReducerSpec>;
38}
39
40#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, typeuri::Named)]
42pub struct ReducerSpec {
43 pub typehash: u64,
45 pub builder_params: Option<wirevalue::Any>,
47}
48wirevalue::register_type!(ReducerSpec);
49
50#[derive(
52 Debug,
53 Clone,
54 PartialEq,
55 Serialize,
56 Deserialize,
57 typeuri::Named,
58 Default
59)]
60pub struct ReducerOpts {
61 pub max_update_interval: Option<Duration>,
64 pub initial_update_interval: Option<Duration>,
68}
69
70impl ReducerOpts {
71 pub(crate) fn max_update_interval(&self) -> Duration {
72 self.max_update_interval
73 .unwrap_or(hyperactor_config::global::get(config::SPLIT_MAX_BUFFER_AGE))
74 }
75
76 pub(crate) fn initial_update_interval(&self) -> Duration {
77 self.initial_update_interval
78 .unwrap_or(Duration::from_millis(1))
79 }
80}
81
82pub trait CommReducer {
88 type Update;
90
91 fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update>;
93}
94
95pub trait ErasedCommReducer {
97 fn reduce_erased(
99 &self,
100 left: &wirevalue::Any,
101 right: &wirevalue::Any,
102 ) -> anyhow::Result<wirevalue::Any>;
103
104 fn reduce_updates(
107 &self,
108 updates: Vec<wirevalue::Any>,
109 ) -> Result<wirevalue::Any, (anyhow::Error, Vec<wirevalue::Any>)> {
110 if updates.is_empty() {
111 return Err((anyhow::anyhow!("empty updates"), updates));
112 }
113 if updates.len() == 1 {
114 return Ok(updates.into_iter().next().expect("checked above"));
115 }
116
117 let mut iter = updates.iter();
118 let first = iter.next().unwrap();
119 let second = iter.next().unwrap();
120 let init = match self.reduce_erased(first, second) {
121 Ok(v) => v,
122 Err(e) => return Err((e, updates)),
123 };
124 let reduced = match iter.try_fold(init, |acc, e| self.reduce_erased(&acc, e)) {
125 Ok(v) => v,
126 Err(e) => return Err((e, updates)),
127 };
128 Ok(reduced)
129 }
130
131 fn typehash(&self) -> u64;
133}
134
135impl<R, T> ErasedCommReducer for R
136where
137 R: CommReducer<Update = T> + Named,
138 T: Serialize + DeserializeOwned + Named,
139{
140 fn reduce_erased(
141 &self,
142 left: &wirevalue::Any,
143 right: &wirevalue::Any,
144 ) -> anyhow::Result<wirevalue::Any> {
145 let left = left.deserialized::<T>()?;
146 let right = right.deserialized::<T>()?;
147 let result = self.reduce(left, right)?;
148 Ok(wirevalue::Any::serialize(&result)?)
149 }
150
151 fn typehash(&self) -> u64 {
152 R::typehash()
153 }
154}
155
156pub struct ReducerFactory {
161 pub typehash_f: fn() -> u64,
164 pub builder_f: fn(
166 Option<wirevalue::Any>,
167 ) -> anyhow::Result<Box<dyn ErasedCommReducer + Sync + Send + 'static>>,
168}
169
170inventory::collect!(ReducerFactory);
171
172inventory::submit! {
173 ReducerFactory {
174 typehash_f: <SumReducer<i64> as Named>::typehash,
175 builder_f: |_| Ok(Box::new(SumReducer::<i64>(PhantomData))),
176 }
177}
178inventory::submit! {
179 ReducerFactory {
180 typehash_f: <SumReducer<u64> as Named>::typehash,
181 builder_f: |_| Ok(Box::new(SumReducer::<u64>(PhantomData))),
182 }
183}
184inventory::submit! {
185 ReducerFactory {
186 typehash_f: <MaxReducer::<i64> as Named>::typehash,
187 builder_f: |_| Ok(Box::new(MaxReducer::<i64>(PhantomData))),
188 }
189}
190inventory::submit! {
191 ReducerFactory {
192 typehash_f: <MaxReducer::<u64> as Named>::typehash,
193 builder_f: |_| Ok(Box::new(MaxReducer::<u64>(PhantomData))),
194 }
195}
196inventory::submit! {
197 ReducerFactory {
198 typehash_f: <MinReducer::<i64> as Named>::typehash,
199 builder_f: |_| Ok(Box::new(MinReducer::<i64>(PhantomData))),
200 }
201}
202inventory::submit! {
203 ReducerFactory {
204 typehash_f: <MinReducer::<u64> as Named>::typehash,
205 builder_f: |_| Ok(Box::new(MinReducer::<u64>(PhantomData))),
206 }
207}
208inventory::submit! {
209 ReducerFactory {
210 typehash_f: <WatermarkUpdateReducer::<i64> as Named>::typehash,
211 builder_f: |_| Ok(Box::new(WatermarkUpdateReducer::<i64>(PhantomData))),
212 }
213}
214inventory::submit! {
215 ReducerFactory {
216 typehash_f: <WatermarkUpdateReducer::<u64> as Named>::typehash,
217 builder_f: |_| Ok(Box::new(WatermarkUpdateReducer::<u64>(PhantomData))),
218 }
219}
220
221pub(crate) fn resolve_reducer(
224 typehash: u64,
225 builder_params: Option<wirevalue::Any>,
226) -> anyhow::Result<Option<Box<dyn ErasedCommReducer + Sync + Send + 'static>>> {
227 static FACTORY_MAP: OnceLock<HashMap<u64, &'static ReducerFactory>> = OnceLock::new();
228 let factories = FACTORY_MAP.get_or_init(|| {
229 let mut map = HashMap::new();
230 for factory in inventory::iter::<ReducerFactory> {
231 map.insert((factory.typehash_f)(), factory);
232 }
233 map
234 });
235
236 factories
237 .get(&typehash)
238 .map(|f| (f.builder_f)(builder_params))
239 .transpose()
240}
241
242#[derive(typeuri::Named)]
243struct SumReducer<T>(PhantomData<T>);
244
245impl<T: std::ops::Add<Output = T> + Copy + 'static> CommReducer for SumReducer<T> {
246 type Update = T;
247
248 fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
249 Ok(left + right)
250 }
251}
252
253struct SumAccumulator<T>(PhantomData<T>);
256
257impl<T: std::ops::Add<Output = T> + Copy + Named + 'static> Accumulator for SumAccumulator<T> {
258 type State = T;
259 type Update = T;
260
261 fn accumulate(&self, state: &mut T, update: T) -> anyhow::Result<()> {
262 *state = *state + update;
263 Ok(())
264 }
265
266 fn reducer_spec(&self) -> Option<ReducerSpec> {
267 Some(ReducerSpec {
268 typehash: <SumReducer<T> as Named>::typehash(),
269 builder_params: None,
270 })
271 }
272}
273
274pub fn sum<T: std::ops::Add<Output = T> + Copy + Named + 'static>()
276-> impl Accumulator<State = T, Update = T> {
277 SumAccumulator(PhantomData)
278}
279
280#[derive(typeuri::Named)]
281struct MaxReducer<T>(PhantomData<T>);
282
283impl<T: Ord> CommReducer for MaxReducer<T> {
284 type Update = T;
285
286 fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
287 Ok(std::cmp::max(left, right))
288 }
289}
290
291#[derive(Debug, Clone, Default)]
293pub struct Max<T>(Option<T>);
294
295impl<T> Max<T> {
296 pub fn get(&self) -> &T {
298 self.0
299 .as_ref()
300 .expect("accumulator state should have been intialized.")
301 }
302}
303
304struct MaxAccumulator<T>(PhantomData<T>);
306
307impl<T: Ord + Copy + Named + 'static> Accumulator for MaxAccumulator<T> {
308 type State = Max<T>;
309 type Update = T;
310
311 fn accumulate(&self, state: &mut Self::State, update: T) -> anyhow::Result<()> {
312 match state.0.as_mut() {
313 Some(s) => *s = std::cmp::max(*s, update),
314 None => *state = Max(Some(update)),
315 }
316 Ok(())
317 }
318
319 fn reducer_spec(&self) -> Option<ReducerSpec> {
320 Some(ReducerSpec {
321 typehash: <MaxReducer<T> as Named>::typehash(),
322 builder_params: None,
323 })
324 }
325}
326
327pub fn max<T: Ord + Copy + Named + 'static>() -> impl Accumulator<State = Max<T>, Update = T> {
330 MaxAccumulator(PhantomData::<T>)
331}
332
333#[derive(typeuri::Named)]
334struct MinReducer<T>(PhantomData<T>);
335
336impl<T: Ord> CommReducer for MinReducer<T> {
337 type Update = T;
338
339 fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
340 Ok(std::cmp::min(left, right))
341 }
342}
343
344#[derive(Debug, Clone, Default)]
346pub struct Min<T>(Option<T>);
347
348impl<T> Min<T> {
349 pub fn get(&self) -> &T {
351 self.0
352 .as_ref()
353 .expect("accumulator state should have been intialized.")
354 }
355}
356
357struct MinAccumulator<T>(PhantomData<T>);
359
360impl<T: Ord + Copy + Named + 'static> Accumulator for MinAccumulator<T> {
361 type State = Min<T>;
362 type Update = T;
363
364 fn accumulate(&self, state: &mut Min<T>, update: T) -> anyhow::Result<()> {
365 match state.0.as_mut() {
366 Some(s) => *s = std::cmp::min(*s, update),
367 None => *state = Min(Some(update)),
368 }
369 Ok(())
370 }
371
372 fn reducer_spec(&self) -> Option<ReducerSpec> {
373 Some(ReducerSpec {
374 typehash: <MinReducer<T> as Named>::typehash(),
375 builder_params: None,
376 })
377 }
378}
379
380pub fn min<T: Ord + Copy + Named + 'static>() -> impl Accumulator<State = Min<T>, Update = T> {
383 MinAccumulator(PhantomData)
384}
385
386#[derive(Default, Debug, Clone, Serialize, Deserialize, typeuri::Named)]
389pub struct WatermarkUpdate<T>(HashMap<Index, T>);
390
391impl<T: Ord> WatermarkUpdate<T> {
392 pub fn get(&self) -> &T {
396 self.0
397 .values()
398 .min()
399 .expect("watermark should have been intialized.")
400 }
401}
402
403impl<T: PartialEq> WatermarkUpdate<T> {
404 fn merge(old: Self, new: Self) -> Self {
406 let mut map = old.0;
407 for (k, v) in new.0 {
408 map.insert(k, v);
409 }
410 Self(map)
411 }
412}
413
414impl<T> From<(Index, T)> for WatermarkUpdate<T> {
415 fn from((rank, value): (Index, T)) -> Self {
416 let mut map = HashMap::with_capacity(1);
417 map.insert(rank, value);
418 Self(map)
419 }
420}
421
422#[derive(typeuri::Named)]
425struct WatermarkUpdateReducer<T>(PhantomData<T>);
426
427impl<T: PartialEq> CommReducer for WatermarkUpdateReducer<T> {
428 type Update = WatermarkUpdate<T>;
429
430 fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update> {
431 Ok(WatermarkUpdate::merge(left, right))
432 }
433}
434
435struct LowWatermarkUpdateAccumulator<T>(PhantomData<T>);
436
437impl<T: Ord + Copy + Named + 'static> Accumulator for LowWatermarkUpdateAccumulator<T> {
438 type State = WatermarkUpdate<T>;
439 type Update = WatermarkUpdate<T>;
440
441 fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()> {
442 let current = std::mem::replace(&mut *state, WatermarkUpdate(HashMap::new()));
443 *state = WatermarkUpdate::merge(current, update);
445 Ok(())
446 }
447
448 fn reducer_spec(&self) -> Option<ReducerSpec> {
449 Some(ReducerSpec {
450 typehash: <WatermarkUpdateReducer<T> as Named>::typehash(),
451 builder_params: None,
452 })
453 }
454}
455
456pub fn low_watermark<T: Ord + Copy + Named + 'static>()
464-> impl Accumulator<State = WatermarkUpdate<T>, Update = WatermarkUpdate<T>> {
465 LowWatermarkUpdateAccumulator(PhantomData)
466}
467
468#[cfg(test)]
469mod tests {
470 use std::fmt::Debug;
471
472 use maplit::hashmap;
473 use typeuri::Named;
474
475 use super::*;
476
477 fn serialize<T: Serialize + Named>(values: Vec<T>) -> Vec<wirevalue::Any> {
478 values
479 .into_iter()
480 .map(|n| wirevalue::Any::serialize(&n).unwrap())
481 .collect()
482 }
483
484 #[test]
485 fn test_comm_reducer_numeric() {
486 let u64_numbers: Vec<_> = serialize(vec![1u64, 3u64, 1100u64]);
487 let i64_numbers: Vec<_> = serialize(vec![-123i64, 33i64, 110i64]);
488 {
489 let typehash = <MaxReducer<u64> as Named>::typehash();
490 assert_eq!(
491 resolve_reducer(typehash, None)
492 .unwrap()
493 .unwrap()
494 .reduce_updates(u64_numbers.clone())
495 .unwrap()
496 .deserialized::<u64>()
497 .unwrap(),
498 1100u64,
499 );
500
501 let typehash = <MinReducer<u64> as Named>::typehash();
502 assert_eq!(
503 resolve_reducer(typehash, None)
504 .unwrap()
505 .unwrap()
506 .reduce_updates(u64_numbers.clone())
507 .unwrap()
508 .deserialized::<u64>()
509 .unwrap(),
510 1u64,
511 );
512
513 let typehash = <SumReducer<u64> as Named>::typehash();
514 assert_eq!(
515 resolve_reducer(typehash, None)
516 .unwrap()
517 .unwrap()
518 .reduce_updates(u64_numbers)
519 .unwrap()
520 .deserialized::<u64>()
521 .unwrap(),
522 1104u64,
523 );
524 }
525
526 {
527 let typehash = <MaxReducer<i64> as Named>::typehash();
528 assert_eq!(
529 resolve_reducer(typehash, None)
530 .unwrap()
531 .unwrap()
532 .reduce_updates(i64_numbers.clone())
533 .unwrap()
534 .deserialized::<i64>()
535 .unwrap(),
536 110i64,
537 );
538
539 let typehash = <MinReducer<i64> as Named>::typehash();
540 assert_eq!(
541 resolve_reducer(typehash, None)
542 .unwrap()
543 .unwrap()
544 .reduce_updates(i64_numbers.clone())
545 .unwrap()
546 .deserialized::<i64>()
547 .unwrap(),
548 -123i64,
549 );
550
551 let typehash = <SumReducer<i64> as Named>::typehash();
552 assert_eq!(
553 resolve_reducer(typehash, None)
554 .unwrap()
555 .unwrap()
556 .reduce_updates(i64_numbers)
557 .unwrap()
558 .deserialized::<i64>()
559 .unwrap(),
560 20i64,
561 );
562 }
563 }
564
565 #[test]
566 fn test_comm_reducer_watermark() {
567 let u64_updates = serialize::<WatermarkUpdate<u64>>(
568 vec![
569 (1, 1),
570 (0, 2),
571 (0, 1),
572 (3, 35),
573 (0, 9),
574 (1, 10),
575 (3, 32),
576 (3, 0),
577 (3, 321),
578 ]
579 .into_iter()
580 .map(|(k, v)| WatermarkUpdate::from((k, v)))
581 .collect(),
582 );
583 let i64_updates: Vec<_> = serialize::<WatermarkUpdate<i64>>(
584 vec![
585 (0, 2),
586 (1, 1),
587 (3, 35),
588 (0, 1),
589 (1, -10),
590 (3, 32),
591 (3, 0),
592 (3, -99),
593 (0, -9),
594 ]
595 .into_iter()
596 .map(WatermarkUpdate::from)
597 .collect(),
598 );
599
600 fn verify<T: PartialEq + DeserializeOwned + Debug + Named>(
601 updates: Vec<wirevalue::Any>,
602 expected: HashMap<Index, T>,
603 ) {
604 let typehash = <WatermarkUpdateReducer<T> as Named>::typehash();
605 assert_eq!(
606 resolve_reducer(typehash, None)
607 .unwrap()
608 .unwrap()
609 .reduce_updates(updates)
610 .unwrap()
611 .deserialized::<WatermarkUpdate<T>>()
612 .unwrap()
613 .0,
614 expected,
615 );
616 }
617
618 verify::<i64>(
619 i64_updates,
620 hashmap! {
621 0 => -9,
622 1 => -10,
623 3 => -99,
624 },
625 );
626
627 verify::<u64>(
628 u64_updates,
629 hashmap! {
630 0 => 9,
631 1 => 10,
632 3 => 321,
633 },
634 );
635 }
636
637 #[test]
638 fn test_accum_reducer_numeric() {
639 assert_eq!(
640 sum::<u64>().reducer_spec().unwrap().typehash,
641 <SumReducer::<u64> as Named>::typehash(),
642 );
643 assert_eq!(
644 sum::<i64>().reducer_spec().unwrap().typehash,
645 <SumReducer::<i64> as Named>::typehash(),
646 );
647
648 assert_eq!(
649 min::<u64>().reducer_spec().unwrap().typehash,
650 <MinReducer::<u64> as Named>::typehash(),
651 );
652 assert_eq!(
653 min::<i64>().reducer_spec().unwrap().typehash,
654 <MinReducer::<i64> as Named>::typehash(),
655 );
656
657 assert_eq!(
658 max::<u64>().reducer_spec().unwrap().typehash,
659 <MaxReducer::<u64> as Named>::typehash(),
660 );
661 assert_eq!(
662 max::<i64>().reducer_spec().unwrap().typehash,
663 <MaxReducer::<i64> as Named>::typehash(),
664 );
665 }
666
667 #[test]
668 fn test_accum_reducer_watermark() {
669 fn verify<T: Ord + Copy + Named>() {
670 assert_eq!(
671 low_watermark::<T>().reducer_spec().unwrap().typehash,
672 <WatermarkUpdateReducer::<T> as Named>::typehash(),
673 );
674 }
675 verify::<u64>();
676 verify::<i64>();
677 }
678
679 #[test]
680 fn test_watermark_accumulator() {
681 let accumulator = low_watermark::<u64>();
682 let ranks_values_expectations = [
683 (0, 1003, 1003),
685 (1, 1002, 1002),
686 (2, 1001, 1001),
687 (0, 100, 100),
689 (1, 101, 100),
690 (2, 102, 100),
691 (0, 100, 100),
693 (1, 101, 100),
694 (2, 102, 100),
695 (0, 1000, 101),
697 (1, 1100, 102),
699 (2, 1200, 1000),
701 (0, 1001, 1001),
703 (1, 1101, 1001),
704 (2, 1201, 1001),
705 (2, 102, 102),
707 (1, 101, 101),
708 (0, 100, 100),
709 ];
710 let mut state = WatermarkUpdate(HashMap::new());
711 for (rank, value, expected) in ranks_values_expectations {
712 accumulator
713 .accumulate(&mut state, WatermarkUpdate::from((rank, value)))
714 .unwrap();
715 assert_eq!(state.get(), &expected, "rank is {rank}; value is {value}");
716 }
717 }
718}