1use std::collections::HashMap;
12use std::marker::PhantomData;
13use std::sync::OnceLock;
14
15use serde::Deserialize;
16use serde::Serialize;
17use serde::de::DeserializeOwned;
18
19use crate as hyperactor; use crate::Named;
21use crate::data::Serialized;
22use crate::reference::Index;
23
24pub trait Accumulator {
26 type State;
28 type Update;
31
32 fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()>;
34
35 fn reducer_spec(&self) -> Option<ReducerSpec>;
37}
38
39#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
41pub struct ReducerSpec {
42 pub typehash: u64,
44 pub builder_params: Option<Serialized>,
46}
47
48pub trait CommReducer {
54 type Update;
56
57 fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update>;
59}
60
61pub trait ErasedCommReducer {
63 fn reduce_erased(&self, left: &Serialized, right: &Serialized) -> anyhow::Result<Serialized>;
65
66 fn reduce_updates(
69 &self,
70 updates: Vec<Serialized>,
71 ) -> Result<Serialized, (anyhow::Error, Vec<Serialized>)> {
72 if updates.is_empty() {
73 return Err((anyhow::anyhow!("empty updates"), updates));
74 }
75 if updates.len() == 1 {
76 return Ok(updates.into_iter().next().expect("checked above"));
77 }
78
79 let mut iter = updates.iter();
80 let first = iter.next().unwrap();
81 let second = iter.next().unwrap();
82 let init = match self.reduce_erased(first, second) {
83 Ok(v) => v,
84 Err(e) => return Err((e, updates)),
85 };
86 let reduced = match iter.try_fold(init, |acc, e| self.reduce_erased(&acc, e)) {
87 Ok(v) => v,
88 Err(e) => return Err((e, updates)),
89 };
90 Ok(reduced)
91 }
92
93 fn typehash(&self) -> u64;
95}
96
97impl<R, T> ErasedCommReducer for R
98where
99 R: CommReducer<Update = T> + Named,
100 T: Serialize + DeserializeOwned + Named,
101{
102 fn reduce_erased(&self, left: &Serialized, right: &Serialized) -> anyhow::Result<Serialized> {
103 let left = left.deserialized::<T>()?;
104 let right = right.deserialized::<T>()?;
105 let result = self.reduce(left, right)?;
106 Ok(Serialized::serialize(&result)?)
107 }
108
109 fn typehash(&self) -> u64 {
110 R::typehash()
111 }
112}
113
114pub struct ReducerFactory {
119 pub typehash_f: fn() -> u64,
122 pub builder_f: fn(
124 Option<Serialized>,
125 ) -> anyhow::Result<Box<dyn ErasedCommReducer + Sync + Send + 'static>>,
126}
127
128inventory::collect!(ReducerFactory);
129
130inventory::submit! {
131 ReducerFactory {
132 typehash_f: <SumReducer<i64> as Named>::typehash,
133 builder_f: |_| Ok(Box::new(SumReducer::<i64>(PhantomData))),
134 }
135}
136inventory::submit! {
137 ReducerFactory {
138 typehash_f: <SumReducer<u64> as Named>::typehash,
139 builder_f: |_| Ok(Box::new(SumReducer::<u64>(PhantomData))),
140 }
141}
142inventory::submit! {
143 ReducerFactory {
144 typehash_f: <MaxReducer::<i64> as Named>::typehash,
145 builder_f: |_| Ok(Box::new(MaxReducer::<i64>(PhantomData))),
146 }
147}
148inventory::submit! {
149 ReducerFactory {
150 typehash_f: <MaxReducer::<u64> as Named>::typehash,
151 builder_f: |_| Ok(Box::new(MaxReducer::<u64>(PhantomData))),
152 }
153}
154inventory::submit! {
155 ReducerFactory {
156 typehash_f: <MinReducer::<i64> as Named>::typehash,
157 builder_f: |_| Ok(Box::new(MinReducer::<i64>(PhantomData))),
158 }
159}
160inventory::submit! {
161 ReducerFactory {
162 typehash_f: <MinReducer::<u64> as Named>::typehash,
163 builder_f: |_| Ok(Box::new(MinReducer::<u64>(PhantomData))),
164 }
165}
166inventory::submit! {
167 ReducerFactory {
168 typehash_f: <WatermarkUpdateReducer::<i64> as Named>::typehash,
169 builder_f: |_| Ok(Box::new(WatermarkUpdateReducer::<i64>(PhantomData))),
170 }
171}
172inventory::submit! {
173 ReducerFactory {
174 typehash_f: <WatermarkUpdateReducer::<u64> as Named>::typehash,
175 builder_f: |_| Ok(Box::new(WatermarkUpdateReducer::<u64>(PhantomData))),
176 }
177}
178
179pub(crate) fn resolve_reducer(
182 typehash: u64,
183 builder_params: Option<Serialized>,
184) -> anyhow::Result<Option<Box<dyn ErasedCommReducer + Sync + Send + 'static>>> {
185 static FACTORY_MAP: OnceLock<HashMap<u64, &'static ReducerFactory>> = OnceLock::new();
186 let factories = FACTORY_MAP.get_or_init(|| {
187 let mut map = HashMap::new();
188 for factory in inventory::iter::<ReducerFactory> {
189 map.insert((factory.typehash_f)(), factory);
190 }
191 map
192 });
193
194 factories
195 .get(&typehash)
196 .map(|f| (f.builder_f)(builder_params))
197 .transpose()
198}
199
200#[derive(Named)]
201struct SumReducer<T>(PhantomData<T>);
202
203impl<T: std::ops::Add<Output = T> + Copy + 'static> CommReducer for SumReducer<T> {
204 type Update = T;
205
206 fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
207 Ok(left + right)
208 }
209}
210
211struct SumAccumulator<T>(PhantomData<T>);
214
215impl<T: std::ops::Add<Output = T> + Copy + Named + 'static> Accumulator for SumAccumulator<T> {
216 type State = T;
217 type Update = T;
218
219 fn accumulate(&self, state: &mut T, update: T) -> anyhow::Result<()> {
220 *state = *state + update;
221 Ok(())
222 }
223
224 fn reducer_spec(&self) -> Option<ReducerSpec> {
225 Some(ReducerSpec {
226 typehash: <SumReducer<T> as Named>::typehash(),
227 builder_params: None,
228 })
229 }
230}
231
232pub fn sum<T: std::ops::Add<Output = T> + Copy + Named + 'static>()
234-> impl Accumulator<State = T, Update = T> {
235 SumAccumulator(PhantomData)
236}
237
238#[derive(Named)]
239struct MaxReducer<T>(PhantomData<T>);
240
241impl<T: Ord> CommReducer for MaxReducer<T> {
242 type Update = T;
243
244 fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
245 Ok(std::cmp::max(left, right))
246 }
247}
248
249#[derive(Debug, Clone, Default)]
251pub struct Max<T>(Option<T>);
252
253impl<T> Max<T> {
254 pub fn get(&self) -> &T {
256 self.0
257 .as_ref()
258 .expect("accumulator state should have been intialized.")
259 }
260}
261
262struct MaxAccumulator<T>(PhantomData<T>);
264
265impl<T: Ord + Copy + Named + 'static> Accumulator for MaxAccumulator<T> {
266 type State = Max<T>;
267 type Update = T;
268
269 fn accumulate(&self, state: &mut Self::State, update: T) -> anyhow::Result<()> {
270 match state.0.as_mut() {
271 Some(s) => *s = std::cmp::max(*s, update),
272 None => *state = Max(Some(update)),
273 }
274 Ok(())
275 }
276
277 fn reducer_spec(&self) -> Option<ReducerSpec> {
278 Some(ReducerSpec {
279 typehash: <MaxReducer<T> as Named>::typehash(),
280 builder_params: None,
281 })
282 }
283}
284
285pub fn max<T: Ord + Copy + Named + 'static>() -> impl Accumulator<State = Max<T>, Update = T> {
288 MaxAccumulator(PhantomData::<T>)
289}
290
291#[derive(Named)]
292struct MinReducer<T>(PhantomData<T>);
293
294impl<T: Ord> CommReducer for MinReducer<T> {
295 type Update = T;
296
297 fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
298 Ok(std::cmp::min(left, right))
299 }
300}
301
302#[derive(Debug, Clone, Default)]
304pub struct Min<T>(Option<T>);
305
306impl<T> Min<T> {
307 pub fn get(&self) -> &T {
309 self.0
310 .as_ref()
311 .expect("accumulator state should have been intialized.")
312 }
313}
314
315struct MinAccumulator<T>(PhantomData<T>);
317
318impl<T: Ord + Copy + Named + 'static> Accumulator for MinAccumulator<T> {
319 type State = Min<T>;
320 type Update = T;
321
322 fn accumulate(&self, state: &mut Min<T>, update: T) -> anyhow::Result<()> {
323 match state.0.as_mut() {
324 Some(s) => *s = std::cmp::min(*s, update),
325 None => *state = Min(Some(update)),
326 }
327 Ok(())
328 }
329
330 fn reducer_spec(&self) -> Option<ReducerSpec> {
331 Some(ReducerSpec {
332 typehash: <MinReducer<T> as Named>::typehash(),
333 builder_params: None,
334 })
335 }
336}
337
338pub fn min<T: Ord + Copy + Named + 'static>() -> impl Accumulator<State = Min<T>, Update = T> {
341 MinAccumulator(PhantomData)
342}
343
344#[derive(Default, Debug, Clone, Serialize, Deserialize, Named)]
347pub struct WatermarkUpdate<T>(HashMap<Index, T>);
348
349impl<T: Ord> WatermarkUpdate<T> {
350 pub fn get(&self) -> &T {
354 self.0
355 .values()
356 .min()
357 .expect("watermark should have been intialized.")
358 }
359}
360
361impl<T: PartialEq> WatermarkUpdate<T> {
362 fn merge(old: Self, new: Self) -> Self {
364 let mut map = old.0;
365 for (k, v) in new.0 {
366 map.insert(k, v);
367 }
368 Self(map)
369 }
370}
371
372impl<T> From<(Index, T)> for WatermarkUpdate<T> {
373 fn from((rank, value): (Index, T)) -> Self {
374 let mut map = HashMap::with_capacity(1);
375 map.insert(rank, value);
376 Self(map)
377 }
378}
379
380#[derive(Named)]
383struct WatermarkUpdateReducer<T>(PhantomData<T>);
384
385impl<T: PartialEq> CommReducer for WatermarkUpdateReducer<T> {
386 type Update = WatermarkUpdate<T>;
387
388 fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update> {
389 Ok(WatermarkUpdate::merge(left, right))
390 }
391}
392
393struct LowWatermarkUpdateAccumulator<T>(PhantomData<T>);
394
395impl<T: Ord + Copy + Named + 'static> Accumulator for LowWatermarkUpdateAccumulator<T> {
396 type State = WatermarkUpdate<T>;
397 type Update = WatermarkUpdate<T>;
398
399 fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()> {
400 let current = std::mem::replace(&mut *state, WatermarkUpdate(HashMap::new()));
401 *state = WatermarkUpdate::merge(current, update);
403 Ok(())
404 }
405
406 fn reducer_spec(&self) -> Option<ReducerSpec> {
407 Some(ReducerSpec {
408 typehash: <WatermarkUpdateReducer<T> as Named>::typehash(),
409 builder_params: None,
410 })
411 }
412}
413
414pub fn low_watermark<T: Ord + Copy + Named + 'static>()
422-> impl Accumulator<State = WatermarkUpdate<T>, Update = WatermarkUpdate<T>> {
423 LowWatermarkUpdateAccumulator(PhantomData)
424}
425
426#[cfg(test)]
427mod tests {
428 use std::fmt::Debug;
429
430 use maplit::hashmap;
431
432 use super::*;
433 use crate::Named;
434
435 fn serialize<T: Serialize + Named>(values: Vec<T>) -> Vec<Serialized> {
436 values
437 .into_iter()
438 .map(|n| Serialized::serialize(&n).unwrap())
439 .collect()
440 }
441
442 #[test]
443 fn test_comm_reducer_numeric() {
444 let u64_numbers: Vec<_> = serialize(vec![1u64, 3u64, 1100u64]);
445 let i64_numbers: Vec<_> = serialize(vec![-123i64, 33i64, 110i64]);
446 {
447 let typehash = <MaxReducer<u64> as Named>::typehash();
448 assert_eq!(
449 resolve_reducer(typehash, None)
450 .unwrap()
451 .unwrap()
452 .reduce_updates(u64_numbers.clone())
453 .unwrap()
454 .deserialized::<u64>()
455 .unwrap(),
456 1100u64,
457 );
458
459 let typehash = <MinReducer<u64> as Named>::typehash();
460 assert_eq!(
461 resolve_reducer(typehash, None)
462 .unwrap()
463 .unwrap()
464 .reduce_updates(u64_numbers.clone())
465 .unwrap()
466 .deserialized::<u64>()
467 .unwrap(),
468 1u64,
469 );
470
471 let typehash = <SumReducer<u64> as Named>::typehash();
472 assert_eq!(
473 resolve_reducer(typehash, None)
474 .unwrap()
475 .unwrap()
476 .reduce_updates(u64_numbers)
477 .unwrap()
478 .deserialized::<u64>()
479 .unwrap(),
480 1104u64,
481 );
482 }
483
484 {
485 let typehash = <MaxReducer<i64> as Named>::typehash();
486 assert_eq!(
487 resolve_reducer(typehash, None)
488 .unwrap()
489 .unwrap()
490 .reduce_updates(i64_numbers.clone())
491 .unwrap()
492 .deserialized::<i64>()
493 .unwrap(),
494 110i64,
495 );
496
497 let typehash = <MinReducer<i64> as Named>::typehash();
498 assert_eq!(
499 resolve_reducer(typehash, None)
500 .unwrap()
501 .unwrap()
502 .reduce_updates(i64_numbers.clone())
503 .unwrap()
504 .deserialized::<i64>()
505 .unwrap(),
506 -123i64,
507 );
508
509 let typehash = <SumReducer<i64> as Named>::typehash();
510 assert_eq!(
511 resolve_reducer(typehash, None)
512 .unwrap()
513 .unwrap()
514 .reduce_updates(i64_numbers)
515 .unwrap()
516 .deserialized::<i64>()
517 .unwrap(),
518 20i64,
519 );
520 }
521 }
522
523 #[test]
524 fn test_comm_reducer_watermark() {
525 let u64_updates = serialize::<WatermarkUpdate<u64>>(
526 vec![
527 (1, 1),
528 (0, 2),
529 (0, 1),
530 (3, 35),
531 (0, 9),
532 (1, 10),
533 (3, 32),
534 (3, 0),
535 (3, 321),
536 ]
537 .into_iter()
538 .map(|(k, v)| WatermarkUpdate::from((k, v)))
539 .collect(),
540 );
541 let i64_updates: Vec<_> = serialize::<WatermarkUpdate<i64>>(
542 vec![
543 (0, 2),
544 (1, 1),
545 (3, 35),
546 (0, 1),
547 (1, -10),
548 (3, 32),
549 (3, 0),
550 (3, -99),
551 (0, -9),
552 ]
553 .into_iter()
554 .map(WatermarkUpdate::from)
555 .collect(),
556 );
557
558 fn verify<T: PartialEq + DeserializeOwned + Debug>(
559 updates: Vec<Serialized>,
560 expected: HashMap<Index, T>,
561 ) {
562 let typehash = <WatermarkUpdateReducer<i64> as Named>::typehash();
563 assert_eq!(
564 resolve_reducer(typehash, None)
565 .unwrap()
566 .unwrap()
567 .reduce_updates(updates)
568 .unwrap()
569 .deserialized::<WatermarkUpdate<T>>()
570 .unwrap()
571 .0,
572 expected,
573 );
574 }
575
576 verify::<i64>(
577 i64_updates,
578 hashmap! {
579 0 => -9,
580 1 => -10,
581 3 => -99,
582 },
583 );
584
585 verify::<u64>(
586 u64_updates,
587 hashmap! {
588 0 => 9,
589 1 => 10,
590 3 => 321,
591 },
592 );
593 }
594
595 #[test]
596 fn test_accum_reducer_numeric() {
597 assert_eq!(
598 sum::<u64>().reducer_spec().unwrap().typehash,
599 <SumReducer::<u64> as Named>::typehash(),
600 );
601 assert_eq!(
602 sum::<i64>().reducer_spec().unwrap().typehash,
603 <SumReducer::<i64> as Named>::typehash(),
604 );
605
606 assert_eq!(
607 min::<u64>().reducer_spec().unwrap().typehash,
608 <MinReducer::<u64> as Named>::typehash(),
609 );
610 assert_eq!(
611 min::<i64>().reducer_spec().unwrap().typehash,
612 <MinReducer::<i64> as Named>::typehash(),
613 );
614
615 assert_eq!(
616 max::<u64>().reducer_spec().unwrap().typehash,
617 <MaxReducer::<u64> as Named>::typehash(),
618 );
619 assert_eq!(
620 max::<i64>().reducer_spec().unwrap().typehash,
621 <MaxReducer::<i64> as Named>::typehash(),
622 );
623 }
624
625 #[test]
626 fn test_accum_reducer_watermark() {
627 fn verify<T: Ord + Copy + Named>() {
628 assert_eq!(
629 low_watermark::<T>().reducer_spec().unwrap().typehash,
630 <WatermarkUpdateReducer::<T> as Named>::typehash(),
631 );
632 }
633 verify::<u64>();
634 verify::<i64>();
635 }
636
637 #[test]
638 fn test_watermark_accumulator() {
639 let accumulator = low_watermark::<u64>();
640 let ranks_values_expectations = [
641 (0, 1003, 1003),
643 (1, 1002, 1002),
644 (2, 1001, 1001),
645 (0, 100, 100),
647 (1, 101, 100),
648 (2, 102, 100),
649 (0, 100, 100),
651 (1, 101, 100),
652 (2, 102, 100),
653 (0, 1000, 101),
655 (1, 1100, 102),
657 (2, 1200, 1000),
659 (0, 1001, 1001),
661 (1, 1101, 1001),
662 (2, 1201, 1001),
663 (2, 102, 102),
665 (1, 101, 101),
666 (0, 100, 100),
667 ];
668 let mut state = WatermarkUpdate(HashMap::new());
669 for (rank, value, expected) in ranks_values_expectations {
670 accumulator
671 .accumulate(&mut state, WatermarkUpdate::from((rank, value)))
672 .unwrap();
673 assert_eq!(state.get(), &expected, "rank is {rank}; value is {value}");
674 }
675 }
676}