1use std::collections::HashMap;
12use std::marker::PhantomData;
13use std::sync::OnceLock;
14use std::time::Duration;
15
16use serde::Deserialize;
17use serde::Serialize;
18use serde::de::DeserializeOwned;
19
20use crate as hyperactor; use crate::Named;
22use crate::config;
23use crate::data::Serialized;
24use crate::reference::Index;
25
26pub trait Accumulator {
28 type State;
30 type Update;
33
34 fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()>;
36
37 fn reducer_spec(&self) -> Option<ReducerSpec>;
39}
40
41#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
43pub struct ReducerSpec {
44 pub typehash: u64,
46 pub builder_params: Option<Serialized>,
48}
49
50#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Default)]
52pub struct ReducerOpts {
53 pub max_update_interval: Option<Duration>,
56}
57
58impl ReducerOpts {
59 pub(crate) fn max_update_interval(&self) -> Duration {
60 self.max_update_interval
61 .unwrap_or(config::global::get(config::SPLIT_MAX_BUFFER_AGE))
62 }
63}
64
65pub trait CommReducer {
71 type Update;
73
74 fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update>;
76}
77
78pub trait ErasedCommReducer {
80 fn reduce_erased(&self, left: &Serialized, right: &Serialized) -> anyhow::Result<Serialized>;
82
83 fn reduce_updates(
86 &self,
87 updates: Vec<Serialized>,
88 ) -> Result<Serialized, (anyhow::Error, Vec<Serialized>)> {
89 if updates.is_empty() {
90 return Err((anyhow::anyhow!("empty updates"), updates));
91 }
92 if updates.len() == 1 {
93 return Ok(updates.into_iter().next().expect("checked above"));
94 }
95
96 let mut iter = updates.iter();
97 let first = iter.next().unwrap();
98 let second = iter.next().unwrap();
99 let init = match self.reduce_erased(first, second) {
100 Ok(v) => v,
101 Err(e) => return Err((e, updates)),
102 };
103 let reduced = match iter.try_fold(init, |acc, e| self.reduce_erased(&acc, e)) {
104 Ok(v) => v,
105 Err(e) => return Err((e, updates)),
106 };
107 Ok(reduced)
108 }
109
110 fn typehash(&self) -> u64;
112}
113
114impl<R, T> ErasedCommReducer for R
115where
116 R: CommReducer<Update = T> + Named,
117 T: Serialize + DeserializeOwned + Named,
118{
119 fn reduce_erased(&self, left: &Serialized, right: &Serialized) -> anyhow::Result<Serialized> {
120 let left = left.deserialized::<T>()?;
121 let right = right.deserialized::<T>()?;
122 let result = self.reduce(left, right)?;
123 Ok(Serialized::serialize(&result)?)
124 }
125
126 fn typehash(&self) -> u64 {
127 R::typehash()
128 }
129}
130
131pub struct ReducerFactory {
136 pub typehash_f: fn() -> u64,
139 pub builder_f: fn(
141 Option<Serialized>,
142 ) -> anyhow::Result<Box<dyn ErasedCommReducer + Sync + Send + 'static>>,
143}
144
145inventory::collect!(ReducerFactory);
146
147inventory::submit! {
148 ReducerFactory {
149 typehash_f: <SumReducer<i64> as Named>::typehash,
150 builder_f: |_| Ok(Box::new(SumReducer::<i64>(PhantomData))),
151 }
152}
153inventory::submit! {
154 ReducerFactory {
155 typehash_f: <SumReducer<u64> as Named>::typehash,
156 builder_f: |_| Ok(Box::new(SumReducer::<u64>(PhantomData))),
157 }
158}
159inventory::submit! {
160 ReducerFactory {
161 typehash_f: <MaxReducer::<i64> as Named>::typehash,
162 builder_f: |_| Ok(Box::new(MaxReducer::<i64>(PhantomData))),
163 }
164}
165inventory::submit! {
166 ReducerFactory {
167 typehash_f: <MaxReducer::<u64> as Named>::typehash,
168 builder_f: |_| Ok(Box::new(MaxReducer::<u64>(PhantomData))),
169 }
170}
171inventory::submit! {
172 ReducerFactory {
173 typehash_f: <MinReducer::<i64> as Named>::typehash,
174 builder_f: |_| Ok(Box::new(MinReducer::<i64>(PhantomData))),
175 }
176}
177inventory::submit! {
178 ReducerFactory {
179 typehash_f: <MinReducer::<u64> as Named>::typehash,
180 builder_f: |_| Ok(Box::new(MinReducer::<u64>(PhantomData))),
181 }
182}
183inventory::submit! {
184 ReducerFactory {
185 typehash_f: <WatermarkUpdateReducer::<i64> as Named>::typehash,
186 builder_f: |_| Ok(Box::new(WatermarkUpdateReducer::<i64>(PhantomData))),
187 }
188}
189inventory::submit! {
190 ReducerFactory {
191 typehash_f: <WatermarkUpdateReducer::<u64> as Named>::typehash,
192 builder_f: |_| Ok(Box::new(WatermarkUpdateReducer::<u64>(PhantomData))),
193 }
194}
195
196pub(crate) fn resolve_reducer(
199 typehash: u64,
200 builder_params: Option<Serialized>,
201) -> anyhow::Result<Option<Box<dyn ErasedCommReducer + Sync + Send + 'static>>> {
202 static FACTORY_MAP: OnceLock<HashMap<u64, &'static ReducerFactory>> = OnceLock::new();
203 let factories = FACTORY_MAP.get_or_init(|| {
204 let mut map = HashMap::new();
205 for factory in inventory::iter::<ReducerFactory> {
206 map.insert((factory.typehash_f)(), factory);
207 }
208 map
209 });
210
211 factories
212 .get(&typehash)
213 .map(|f| (f.builder_f)(builder_params))
214 .transpose()
215}
216
217#[derive(Named)]
218struct SumReducer<T>(PhantomData<T>);
219
220impl<T: std::ops::Add<Output = T> + Copy + 'static> CommReducer for SumReducer<T> {
221 type Update = T;
222
223 fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
224 Ok(left + right)
225 }
226}
227
228struct SumAccumulator<T>(PhantomData<T>);
231
232impl<T: std::ops::Add<Output = T> + Copy + Named + 'static> Accumulator for SumAccumulator<T> {
233 type State = T;
234 type Update = T;
235
236 fn accumulate(&self, state: &mut T, update: T) -> anyhow::Result<()> {
237 *state = *state + update;
238 Ok(())
239 }
240
241 fn reducer_spec(&self) -> Option<ReducerSpec> {
242 Some(ReducerSpec {
243 typehash: <SumReducer<T> as Named>::typehash(),
244 builder_params: None,
245 })
246 }
247}
248
249pub fn sum<T: std::ops::Add<Output = T> + Copy + Named + 'static>()
251-> impl Accumulator<State = T, Update = T> {
252 SumAccumulator(PhantomData)
253}
254
255#[derive(Named)]
256struct MaxReducer<T>(PhantomData<T>);
257
258impl<T: Ord> CommReducer for MaxReducer<T> {
259 type Update = T;
260
261 fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
262 Ok(std::cmp::max(left, right))
263 }
264}
265
266#[derive(Debug, Clone, Default)]
268pub struct Max<T>(Option<T>);
269
270impl<T> Max<T> {
271 pub fn get(&self) -> &T {
273 self.0
274 .as_ref()
275 .expect("accumulator state should have been intialized.")
276 }
277}
278
279struct MaxAccumulator<T>(PhantomData<T>);
281
282impl<T: Ord + Copy + Named + 'static> Accumulator for MaxAccumulator<T> {
283 type State = Max<T>;
284 type Update = T;
285
286 fn accumulate(&self, state: &mut Self::State, update: T) -> anyhow::Result<()> {
287 match state.0.as_mut() {
288 Some(s) => *s = std::cmp::max(*s, update),
289 None => *state = Max(Some(update)),
290 }
291 Ok(())
292 }
293
294 fn reducer_spec(&self) -> Option<ReducerSpec> {
295 Some(ReducerSpec {
296 typehash: <MaxReducer<T> as Named>::typehash(),
297 builder_params: None,
298 })
299 }
300}
301
302pub fn max<T: Ord + Copy + Named + 'static>() -> impl Accumulator<State = Max<T>, Update = T> {
305 MaxAccumulator(PhantomData::<T>)
306}
307
308#[derive(Named)]
309struct MinReducer<T>(PhantomData<T>);
310
311impl<T: Ord> CommReducer for MinReducer<T> {
312 type Update = T;
313
314 fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
315 Ok(std::cmp::min(left, right))
316 }
317}
318
319#[derive(Debug, Clone, Default)]
321pub struct Min<T>(Option<T>);
322
323impl<T> Min<T> {
324 pub fn get(&self) -> &T {
326 self.0
327 .as_ref()
328 .expect("accumulator state should have been intialized.")
329 }
330}
331
332struct MinAccumulator<T>(PhantomData<T>);
334
335impl<T: Ord + Copy + Named + 'static> Accumulator for MinAccumulator<T> {
336 type State = Min<T>;
337 type Update = T;
338
339 fn accumulate(&self, state: &mut Min<T>, update: T) -> anyhow::Result<()> {
340 match state.0.as_mut() {
341 Some(s) => *s = std::cmp::min(*s, update),
342 None => *state = Min(Some(update)),
343 }
344 Ok(())
345 }
346
347 fn reducer_spec(&self) -> Option<ReducerSpec> {
348 Some(ReducerSpec {
349 typehash: <MinReducer<T> as Named>::typehash(),
350 builder_params: None,
351 })
352 }
353}
354
355pub fn min<T: Ord + Copy + Named + 'static>() -> impl Accumulator<State = Min<T>, Update = T> {
358 MinAccumulator(PhantomData)
359}
360
361#[derive(Default, Debug, Clone, Serialize, Deserialize, Named)]
364pub struct WatermarkUpdate<T>(HashMap<Index, T>);
365
366impl<T: Ord> WatermarkUpdate<T> {
367 pub fn get(&self) -> &T {
371 self.0
372 .values()
373 .min()
374 .expect("watermark should have been intialized.")
375 }
376}
377
378impl<T: PartialEq> WatermarkUpdate<T> {
379 fn merge(old: Self, new: Self) -> Self {
381 let mut map = old.0;
382 for (k, v) in new.0 {
383 map.insert(k, v);
384 }
385 Self(map)
386 }
387}
388
389impl<T> From<(Index, T)> for WatermarkUpdate<T> {
390 fn from((rank, value): (Index, T)) -> Self {
391 let mut map = HashMap::with_capacity(1);
392 map.insert(rank, value);
393 Self(map)
394 }
395}
396
397#[derive(Named)]
400struct WatermarkUpdateReducer<T>(PhantomData<T>);
401
402impl<T: PartialEq> CommReducer for WatermarkUpdateReducer<T> {
403 type Update = WatermarkUpdate<T>;
404
405 fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update> {
406 Ok(WatermarkUpdate::merge(left, right))
407 }
408}
409
410struct LowWatermarkUpdateAccumulator<T>(PhantomData<T>);
411
412impl<T: Ord + Copy + Named + 'static> Accumulator for LowWatermarkUpdateAccumulator<T> {
413 type State = WatermarkUpdate<T>;
414 type Update = WatermarkUpdate<T>;
415
416 fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()> {
417 let current = std::mem::replace(&mut *state, WatermarkUpdate(HashMap::new()));
418 *state = WatermarkUpdate::merge(current, update);
420 Ok(())
421 }
422
423 fn reducer_spec(&self) -> Option<ReducerSpec> {
424 Some(ReducerSpec {
425 typehash: <WatermarkUpdateReducer<T> as Named>::typehash(),
426 builder_params: None,
427 })
428 }
429}
430
431pub fn low_watermark<T: Ord + Copy + Named + 'static>()
439-> impl Accumulator<State = WatermarkUpdate<T>, Update = WatermarkUpdate<T>> {
440 LowWatermarkUpdateAccumulator(PhantomData)
441}
442
443#[cfg(test)]
444mod tests {
445 use std::fmt::Debug;
446
447 use maplit::hashmap;
448
449 use super::*;
450 use crate::Named;
451
452 fn serialize<T: Serialize + Named>(values: Vec<T>) -> Vec<Serialized> {
453 values
454 .into_iter()
455 .map(|n| Serialized::serialize(&n).unwrap())
456 .collect()
457 }
458
459 #[test]
460 fn test_comm_reducer_numeric() {
461 let u64_numbers: Vec<_> = serialize(vec![1u64, 3u64, 1100u64]);
462 let i64_numbers: Vec<_> = serialize(vec![-123i64, 33i64, 110i64]);
463 {
464 let typehash = <MaxReducer<u64> as Named>::typehash();
465 assert_eq!(
466 resolve_reducer(typehash, None)
467 .unwrap()
468 .unwrap()
469 .reduce_updates(u64_numbers.clone())
470 .unwrap()
471 .deserialized::<u64>()
472 .unwrap(),
473 1100u64,
474 );
475
476 let typehash = <MinReducer<u64> as Named>::typehash();
477 assert_eq!(
478 resolve_reducer(typehash, None)
479 .unwrap()
480 .unwrap()
481 .reduce_updates(u64_numbers.clone())
482 .unwrap()
483 .deserialized::<u64>()
484 .unwrap(),
485 1u64,
486 );
487
488 let typehash = <SumReducer<u64> as Named>::typehash();
489 assert_eq!(
490 resolve_reducer(typehash, None)
491 .unwrap()
492 .unwrap()
493 .reduce_updates(u64_numbers)
494 .unwrap()
495 .deserialized::<u64>()
496 .unwrap(),
497 1104u64,
498 );
499 }
500
501 {
502 let typehash = <MaxReducer<i64> as Named>::typehash();
503 assert_eq!(
504 resolve_reducer(typehash, None)
505 .unwrap()
506 .unwrap()
507 .reduce_updates(i64_numbers.clone())
508 .unwrap()
509 .deserialized::<i64>()
510 .unwrap(),
511 110i64,
512 );
513
514 let typehash = <MinReducer<i64> as Named>::typehash();
515 assert_eq!(
516 resolve_reducer(typehash, None)
517 .unwrap()
518 .unwrap()
519 .reduce_updates(i64_numbers.clone())
520 .unwrap()
521 .deserialized::<i64>()
522 .unwrap(),
523 -123i64,
524 );
525
526 let typehash = <SumReducer<i64> as Named>::typehash();
527 assert_eq!(
528 resolve_reducer(typehash, None)
529 .unwrap()
530 .unwrap()
531 .reduce_updates(i64_numbers)
532 .unwrap()
533 .deserialized::<i64>()
534 .unwrap(),
535 20i64,
536 );
537 }
538 }
539
540 #[test]
541 fn test_comm_reducer_watermark() {
542 let u64_updates = serialize::<WatermarkUpdate<u64>>(
543 vec![
544 (1, 1),
545 (0, 2),
546 (0, 1),
547 (3, 35),
548 (0, 9),
549 (1, 10),
550 (3, 32),
551 (3, 0),
552 (3, 321),
553 ]
554 .into_iter()
555 .map(|(k, v)| WatermarkUpdate::from((k, v)))
556 .collect(),
557 );
558 let i64_updates: Vec<_> = serialize::<WatermarkUpdate<i64>>(
559 vec![
560 (0, 2),
561 (1, 1),
562 (3, 35),
563 (0, 1),
564 (1, -10),
565 (3, 32),
566 (3, 0),
567 (3, -99),
568 (0, -9),
569 ]
570 .into_iter()
571 .map(WatermarkUpdate::from)
572 .collect(),
573 );
574
575 fn verify<T: PartialEq + DeserializeOwned + Debug + Named>(
576 updates: Vec<Serialized>,
577 expected: HashMap<Index, T>,
578 ) {
579 let typehash = <WatermarkUpdateReducer<T> as Named>::typehash();
580 assert_eq!(
581 resolve_reducer(typehash, None)
582 .unwrap()
583 .unwrap()
584 .reduce_updates(updates)
585 .unwrap()
586 .deserialized::<WatermarkUpdate<T>>()
587 .unwrap()
588 .0,
589 expected,
590 );
591 }
592
593 verify::<i64>(
594 i64_updates,
595 hashmap! {
596 0 => -9,
597 1 => -10,
598 3 => -99,
599 },
600 );
601
602 verify::<u64>(
603 u64_updates,
604 hashmap! {
605 0 => 9,
606 1 => 10,
607 3 => 321,
608 },
609 );
610 }
611
612 #[test]
613 fn test_accum_reducer_numeric() {
614 assert_eq!(
615 sum::<u64>().reducer_spec().unwrap().typehash,
616 <SumReducer::<u64> as Named>::typehash(),
617 );
618 assert_eq!(
619 sum::<i64>().reducer_spec().unwrap().typehash,
620 <SumReducer::<i64> as Named>::typehash(),
621 );
622
623 assert_eq!(
624 min::<u64>().reducer_spec().unwrap().typehash,
625 <MinReducer::<u64> as Named>::typehash(),
626 );
627 assert_eq!(
628 min::<i64>().reducer_spec().unwrap().typehash,
629 <MinReducer::<i64> as Named>::typehash(),
630 );
631
632 assert_eq!(
633 max::<u64>().reducer_spec().unwrap().typehash,
634 <MaxReducer::<u64> as Named>::typehash(),
635 );
636 assert_eq!(
637 max::<i64>().reducer_spec().unwrap().typehash,
638 <MaxReducer::<i64> as Named>::typehash(),
639 );
640 }
641
642 #[test]
643 fn test_accum_reducer_watermark() {
644 fn verify<T: Ord + Copy + Named>() {
645 assert_eq!(
646 low_watermark::<T>().reducer_spec().unwrap().typehash,
647 <WatermarkUpdateReducer::<T> as Named>::typehash(),
648 );
649 }
650 verify::<u64>();
651 verify::<i64>();
652 }
653
654 #[test]
655 fn test_watermark_accumulator() {
656 let accumulator = low_watermark::<u64>();
657 let ranks_values_expectations = [
658 (0, 1003, 1003),
660 (1, 1002, 1002),
661 (2, 1001, 1001),
662 (0, 100, 100),
664 (1, 101, 100),
665 (2, 102, 100),
666 (0, 100, 100),
668 (1, 101, 100),
669 (2, 102, 100),
670 (0, 1000, 101),
672 (1, 1100, 102),
674 (2, 1200, 1000),
676 (0, 1001, 1001),
678 (1, 1101, 1001),
679 (2, 1201, 1001),
680 (2, 102, 102),
682 (1, 101, 101),
683 (0, 100, 100),
684 ];
685 let mut state = WatermarkUpdate(HashMap::new());
686 for (rank, value, expected) in ranks_values_expectations {
687 accumulator
688 .accumulate(&mut state, WatermarkUpdate::from((rank, value)))
689 .unwrap();
690 assert_eq!(state.get(), &expected, "rank is {rank}; value is {value}");
691 }
692 }
693}