1use std::collections::HashMap;
12use std::marker::PhantomData;
13use std::sync::OnceLock;
14use std::time::Duration;
15
16use algebra::JoinSemilattice;
17use enum_as_inner::EnumAsInner;
18use serde::Deserialize;
19use serde::Serialize;
20use serde::de::DeserializeOwned;
21use typeuri::Named;
22
23use crate::config;
25
26pub trait Accumulator {
28 type State;
30 type Update;
33
34 fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()>;
36
37 fn reducer_spec(&self) -> Option<ReducerSpec>;
39}
40
41#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, typeuri::Named)]
43pub struct ReducerSpec {
44 pub typehash: u64,
46 pub builder_params: Option<wirevalue::Any>,
48}
49wirevalue::register_type!(ReducerSpec);
50
51#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Default)]
53pub struct StreamingReducerOpts {
54 pub max_update_interval: Option<Duration>,
57 pub initial_update_interval: Option<Duration>,
61}
62
63#[derive(
65 Debug,
66 Clone,
67 PartialEq,
68 Serialize,
69 Deserialize,
70 EnumAsInner,
71 typeuri::Named
72)]
73pub enum ReducerMode {
74 Streaming(StreamingReducerOpts),
76 Once(usize),
78}
79
80impl Default for ReducerMode {
81 fn default() -> Self {
82 ReducerMode::Streaming(StreamingReducerOpts::default())
83 }
84}
85
86impl ReducerMode {
87 pub(crate) fn max_update_interval(&self) -> Duration {
88 match self {
89 ReducerMode::Streaming(opts) => opts
90 .max_update_interval
91 .unwrap_or(hyperactor_config::global::get(config::SPLIT_MAX_BUFFER_AGE)),
92 ReducerMode::Once(_) => Duration::MAX,
93 }
94 }
95
96 pub(crate) fn initial_update_interval(&self) -> Duration {
97 match self {
98 ReducerMode::Streaming(opts) => opts
99 .initial_update_interval
100 .unwrap_or(Duration::from_millis(1)),
101 ReducerMode::Once(_) => Duration::MAX,
102 }
103 }
104}
105
106pub trait CommReducer {
112 type Update;
114
115 fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update>;
117}
118
119pub trait ErasedCommReducer {
121 fn reduce_erased(
123 &self,
124 left: &wirevalue::Any,
125 right: &wirevalue::Any,
126 ) -> anyhow::Result<wirevalue::Any>;
127
128 fn reduce_updates(
131 &self,
132 updates: Vec<wirevalue::Any>,
133 ) -> Result<wirevalue::Any, (anyhow::Error, Vec<wirevalue::Any>)> {
134 if updates.is_empty() {
135 return Err((anyhow::anyhow!("empty updates"), updates));
136 }
137 if updates.len() == 1 {
138 return Ok(updates.into_iter().next().expect("checked above"));
139 }
140
141 let mut iter = updates.iter();
142 let first = iter.next().unwrap();
143 let second = iter.next().unwrap();
144 let init = match self.reduce_erased(first, second) {
145 Ok(v) => v,
146 Err(e) => return Err((e, updates)),
147 };
148 let reduced = match iter.try_fold(init, |acc, e| self.reduce_erased(&acc, e)) {
149 Ok(v) => v,
150 Err(e) => return Err((e, updates)),
151 };
152 Ok(reduced)
153 }
154
155 fn typehash(&self) -> u64;
157}
158
159impl<R, T> ErasedCommReducer for R
160where
161 R: CommReducer<Update = T> + Named,
162 T: Serialize + DeserializeOwned + Named,
163{
164 fn reduce_erased(
165 &self,
166 left: &wirevalue::Any,
167 right: &wirevalue::Any,
168 ) -> anyhow::Result<wirevalue::Any> {
169 let left = left.deserialized::<T>()?;
170 let right = right.deserialized::<T>()?;
171 let result = self.reduce(left, right)?;
172 Ok(wirevalue::Any::serialize(&result)?)
173 }
174
175 fn typehash(&self) -> u64 {
176 R::typehash()
177 }
178}
179
180pub struct ReducerFactory {
185 pub typehash_f: fn() -> u64,
188 pub builder_f: fn(
190 Option<wirevalue::Any>,
191 ) -> anyhow::Result<Box<dyn ErasedCommReducer + Sync + Send + 'static>>,
192}
193
194inventory::collect!(ReducerFactory);
195
196inventory::submit! {
197 ReducerFactory {
198 typehash_f: <SumReducer<i64> as Named>::typehash,
199 builder_f: |_| Ok(Box::new(SumReducer::<i64>(PhantomData))),
200 }
201}
202inventory::submit! {
203 ReducerFactory {
204 typehash_f: <SumReducer<u64> as Named>::typehash,
205 builder_f: |_| Ok(Box::new(SumReducer::<u64>(PhantomData))),
206 }
207}
208inventory::submit! {
209 ReducerFactory {
210 typehash_f: <SemilatticeReducer<Max<i64>> as Named>::typehash,
211 builder_f: |_| Ok(Box::new(SemilatticeReducer::<Max<i64>>(PhantomData))),
212 }
213}
214inventory::submit! {
215 ReducerFactory {
216 typehash_f: <SemilatticeReducer<Max<u64>> as Named>::typehash,
217 builder_f: |_| Ok(Box::new(SemilatticeReducer::<Max<u64>>(PhantomData))),
218 }
219}
220inventory::submit! {
221 ReducerFactory {
222 typehash_f: <SemilatticeReducer<Min<i64>> as Named>::typehash,
223 builder_f: |_| Ok(Box::new(SemilatticeReducer::<Min<i64>>(PhantomData))),
224 }
225}
226inventory::submit! {
227 ReducerFactory {
228 typehash_f: <SemilatticeReducer<Min<u64>> as Named>::typehash,
229 builder_f: |_| Ok(Box::new(SemilatticeReducer::<Min<u64>>(PhantomData))),
230 }
231}
232inventory::submit! {
233 ReducerFactory {
234 typehash_f: <SemilatticeReducer<WatermarkUpdate<i64>> as Named>::typehash,
235 builder_f: |_| Ok(Box::new(SemilatticeReducer::<WatermarkUpdate<i64>>(PhantomData))),
236 }
237}
238inventory::submit! {
239 ReducerFactory {
240 typehash_f: <SemilatticeReducer<WatermarkUpdate<u64>> as Named>::typehash,
241 builder_f: |_| Ok(Box::new(SemilatticeReducer::<WatermarkUpdate<u64>>(PhantomData))),
242 }
243}
244inventory::submit! {
245 ReducerFactory {
246 typehash_f: <SemilatticeReducer<GCounterUpdate> as Named>::typehash,
247 builder_f: |_| Ok(Box::new(SemilatticeReducer::<GCounterUpdate>(PhantomData))),
248 }
249}
250inventory::submit! {
251 ReducerFactory {
252 typehash_f: <SemilatticeReducer<PNCounterUpdate> as Named>::typehash,
253 builder_f: |_| Ok(Box::new(SemilatticeReducer::<PNCounterUpdate>(PhantomData))),
254 }
255}
256
257pub(crate) fn resolve_reducer(
260 typehash: u64,
261 builder_params: Option<wirevalue::Any>,
262) -> anyhow::Result<Option<Box<dyn ErasedCommReducer + Sync + Send + 'static>>> {
263 static FACTORY_MAP: OnceLock<HashMap<u64, &'static ReducerFactory>> = OnceLock::new();
264 let factories = FACTORY_MAP.get_or_init(|| {
265 let mut map = HashMap::new();
266 for factory in inventory::iter::<ReducerFactory> {
267 map.insert((factory.typehash_f)(), factory);
268 }
269 map
270 });
271
272 factories
273 .get(&typehash)
274 .map(|f| (f.builder_f)(builder_params))
275 .transpose()
276}
277
278#[derive(typeuri::Named)]
279struct SumReducer<T>(PhantomData<T>);
280
281impl<T: std::ops::Add<Output = T> + Copy + 'static> CommReducer for SumReducer<T> {
282 type Update = T;
283
284 fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
285 Ok(left + right)
286 }
287}
288
289struct SumAccumulator<T>(PhantomData<T>);
292
293impl<T: std::ops::Add<Output = T> + Copy + Named + 'static> Accumulator for SumAccumulator<T> {
294 type State = T;
295 type Update = T;
296
297 fn accumulate(&self, state: &mut T, update: T) -> anyhow::Result<()> {
298 *state = *state + update;
299 Ok(())
300 }
301
302 fn reducer_spec(&self) -> Option<ReducerSpec> {
303 Some(ReducerSpec {
304 typehash: <SumReducer<T> as Named>::typehash(),
305 builder_params: None,
306 })
307 }
308}
309
310pub fn sum<T: std::ops::Add<Output = T> + Copy + Named + 'static>()
335-> impl Accumulator<State = T, Update = T> {
336 SumAccumulator(PhantomData)
337}
338
339#[derive(typeuri::Named)]
341struct SemilatticeReducer<L>(PhantomData<L>);
342
343impl<L: JoinSemilattice + Clone> CommReducer for SemilatticeReducer<L> {
344 type Update = L;
345
346 fn reduce(&self, left: L, right: L) -> anyhow::Result<L> {
347 Ok(left.join(&right))
348 }
349}
350
351struct SemilatticeAccumulator<L>(PhantomData<L>);
353
354impl<L: JoinSemilattice + Clone + Named + 'static> Accumulator for SemilatticeAccumulator<L> {
355 type State = L;
356 type Update = L;
357
358 fn accumulate(&self, state: &mut L, update: L) -> anyhow::Result<()> {
359 *state = state.join(&update);
360 Ok(())
361 }
362
363 fn reducer_spec(&self) -> Option<ReducerSpec> {
364 Some(ReducerSpec {
365 typehash: <SemilatticeReducer<L> as Named>::typehash(),
366 builder_params: None,
367 })
368 }
369}
370
371pub fn join_semilattice<L: JoinSemilattice + Clone + Named + 'static>()
385-> impl Accumulator<State = L, Update = L> {
386 SemilatticeAccumulator::<L>(PhantomData)
387}
388
389pub use algebra::Max;
391pub use algebra::Min;
393
394#[derive(Default, Debug, Clone, Serialize, Deserialize, typeuri::Named)]
416pub struct WatermarkUpdate<T>(algebra::LatticeMap<usize, algebra::LWW<T>>);
417
418impl<T: Ord + Clone> WatermarkUpdate<T> {
419 pub fn get(&self) -> &T {
424 self.0
425 .iter()
426 .map(|(_, lww)| &lww.value)
427 .min()
428 .expect("watermark should have been initialized")
429 }
430
431 pub fn get_rank(&self, rank: usize) -> Option<&T> {
433 self.0.get(&rank).map(|lww| &lww.value)
434 }
435
436 pub fn num_ranks(&self) -> usize {
438 self.0.len()
439 }
440}
441
442impl<T> From<(usize, T, u64)> for WatermarkUpdate<T> {
443 fn from((rank, value, timestamp): (usize, T, u64)) -> Self {
449 let mut map = algebra::LatticeMap::new();
450 map.insert(rank, algebra::LWW::new(value, timestamp, rank as u64));
452 Self(map)
453 }
454}
455
456impl<T: Clone + PartialEq> JoinSemilattice for WatermarkUpdate<T> {
457 fn join(&self, other: &Self) -> Self {
458 WatermarkUpdate(self.0.join(&other.0))
459 }
460}
461
462#[derive(Default, Debug, Clone, Serialize, Deserialize, typeuri::Named)]
474pub struct GCounterUpdate(algebra::LatticeMap<usize, Max<u64>>);
475wirevalue::register_type!(GCounterUpdate);
476
477impl GCounterUpdate {
478 pub fn get(&self) -> u64 {
480 self.0.iter().map(|(_, max)| max.0).sum()
481 }
482
483 pub fn get_rank(&self, rank: usize) -> Option<u64> {
485 self.0.get(&rank).map(|max| max.0)
486 }
487
488 pub fn num_ranks(&self) -> usize {
490 self.0.len()
491 }
492}
493
494impl From<(usize, u64)> for GCounterUpdate {
495 fn from((rank, count): (usize, u64)) -> Self {
497 let mut map = algebra::LatticeMap::new();
498 map.insert(rank, Max(count));
499 Self(map)
500 }
501}
502
503impl JoinSemilattice for GCounterUpdate {
504 fn join(&self, other: &Self) -> Self {
505 GCounterUpdate(self.0.join(&other.0))
506 }
507}
508
509#[derive(Default, Debug, Clone, Serialize, Deserialize, typeuri::Named)]
516pub struct PNCounterUpdate {
517 p: algebra::LatticeMap<usize, Max<u64>>,
518 n: algebra::LatticeMap<usize, Max<u64>>,
519}
520wirevalue::register_type!(PNCounterUpdate);
521
522impl PNCounterUpdate {
523 pub fn get(&self) -> i64 {
525 let p: u64 = self.p.iter().map(|(_, m)| m.0).sum();
526 let n: u64 = self.n.iter().map(|(_, m)| m.0).sum();
527 p as i64 - n as i64
528 }
529
530 pub fn inc(rank: usize, delta: u64) -> Self {
532 let mut p = algebra::LatticeMap::new();
533 p.insert(rank, Max(delta));
534 Self {
535 p,
536 n: algebra::LatticeMap::new(),
537 }
538 }
539
540 pub fn dec(rank: usize, delta: u64) -> Self {
542 let mut n = algebra::LatticeMap::new();
543 n.insert(rank, Max(delta));
544 Self {
545 p: algebra::LatticeMap::new(),
546 n,
547 }
548 }
549
550 pub fn num_inc_ranks(&self) -> usize {
552 self.p.len()
553 }
554
555 pub fn num_dec_ranks(&self) -> usize {
557 self.n.len()
558 }
559}
560
561impl JoinSemilattice for PNCounterUpdate {
562 fn join(&self, other: &Self) -> Self {
563 PNCounterUpdate {
564 p: self.p.join(&other.p),
565 n: self.n.join(&other.n),
566 }
567 }
568}
569
570#[cfg(test)]
571mod tests {
572 use std::fmt::Debug;
573
574 use maplit::hashmap;
575 use typeuri::Named;
576
577 use super::*;
578
579 fn serialize<T: Serialize + Named>(values: Vec<T>) -> Vec<wirevalue::Any> {
580 values
581 .into_iter()
582 .map(|n| wirevalue::Any::serialize(&n).unwrap())
583 .collect()
584 }
585
586 #[test]
587 fn test_comm_reducer_numeric() {
588 let u64_numbers_sum: Vec<_> = serialize(vec![1u64, 3u64, 1100u64]);
589 let i64_numbers_sum: Vec<_> = serialize(vec![-123i64, 33i64, 110i64]);
590 let u64_numbers_max: Vec<_> = serialize(vec![Max(1u64), Max(3u64), Max(1100u64)]);
591 let i64_numbers_max: Vec<_> = serialize(vec![Max(-123i64), Max(33i64), Max(110i64)]);
592 let u64_numbers_min: Vec<_> = serialize(vec![Min(1u64), Min(3u64), Min(1100u64)]);
593 let i64_numbers_min: Vec<_> = serialize(vec![Min(-123i64), Min(33i64), Min(110i64)]);
594 {
595 let typehash = <SemilatticeReducer<Max<u64>> as Named>::typehash();
596 assert_eq!(
597 resolve_reducer(typehash, None)
598 .unwrap()
599 .unwrap()
600 .reduce_updates(u64_numbers_max.clone())
601 .unwrap()
602 .deserialized::<Max<u64>>()
603 .unwrap(),
604 Max(1100u64),
605 );
606
607 let typehash = <SemilatticeReducer<Min<u64>> as Named>::typehash();
608 assert_eq!(
609 resolve_reducer(typehash, None)
610 .unwrap()
611 .unwrap()
612 .reduce_updates(u64_numbers_min.clone())
613 .unwrap()
614 .deserialized::<Min<u64>>()
615 .unwrap(),
616 Min(1u64),
617 );
618
619 let typehash = <SumReducer<u64> as Named>::typehash();
620 assert_eq!(
621 resolve_reducer(typehash, None)
622 .unwrap()
623 .unwrap()
624 .reduce_updates(u64_numbers_sum)
625 .unwrap()
626 .deserialized::<u64>()
627 .unwrap(),
628 1104u64,
629 );
630 }
631
632 {
633 let typehash = <SemilatticeReducer<Max<i64>> as Named>::typehash();
634 assert_eq!(
635 resolve_reducer(typehash, None)
636 .unwrap()
637 .unwrap()
638 .reduce_updates(i64_numbers_max.clone())
639 .unwrap()
640 .deserialized::<Max<i64>>()
641 .unwrap(),
642 Max(110i64),
643 );
644
645 let typehash = <SemilatticeReducer<Min<i64>> as Named>::typehash();
646 assert_eq!(
647 resolve_reducer(typehash, None)
648 .unwrap()
649 .unwrap()
650 .reduce_updates(i64_numbers_min.clone())
651 .unwrap()
652 .deserialized::<Min<i64>>()
653 .unwrap(),
654 Min(-123i64),
655 );
656
657 let typehash = <SumReducer<i64> as Named>::typehash();
658 assert_eq!(
659 resolve_reducer(typehash, None)
660 .unwrap()
661 .unwrap()
662 .reduce_updates(i64_numbers_sum)
663 .unwrap()
664 .deserialized::<i64>()
665 .unwrap(),
666 20i64,
667 );
668 }
669 }
670
671 #[test]
672 fn test_comm_reducer_watermark() {
673 let u64_updates = serialize::<WatermarkUpdate<u64>>(
675 vec![
676 (1, 1, 0), (0, 2, 1), (0, 1, 2), (3, 35, 3), (0, 9, 4), (1, 10, 5), (3, 32, 6), (3, 0, 7), (3, 321, 8), ]
686 .into_iter()
687 .map(|(k, v, ts)| WatermarkUpdate::from((k, v, ts)))
688 .collect(),
689 );
690 let i64_updates: Vec<_> = serialize::<WatermarkUpdate<i64>>(
691 vec![
692 (0, 2, 0), (1, 1, 1), (3, 35, 2), (0, 1, 3), (1, -10, 4), (3, 32, 5), (3, 0, 6), (3, -99, 7), (0, -9, 8), ]
702 .into_iter()
703 .map(WatermarkUpdate::from)
704 .collect(),
705 );
706
707 fn verify<T: Ord + Clone + PartialEq + DeserializeOwned + Debug + Named>(
708 updates: Vec<wirevalue::Any>,
709 expected: HashMap<usize, T>,
710 ) {
711 let typehash = <SemilatticeReducer<WatermarkUpdate<T>> as Named>::typehash();
712 let result = resolve_reducer(typehash, None)
713 .unwrap()
714 .unwrap()
715 .reduce_updates(updates)
716 .unwrap()
717 .deserialized::<WatermarkUpdate<T>>()
718 .unwrap();
719
720 for (rank, expected_value) in &expected {
722 assert_eq!(
723 result.get_rank(*rank).unwrap(),
724 expected_value,
725 "Mismatch for rank {rank}"
726 );
727 }
728 assert_eq!(result.num_ranks(), expected.len());
730 }
731
732 verify::<i64>(
733 i64_updates,
734 hashmap! {
735 0 => -9, 1 => -10, 3 => -99, },
739 );
740
741 verify::<u64>(
742 u64_updates,
743 hashmap! {
744 0 => 9, 1 => 10, 3 => 321, },
748 );
749 }
750
751 #[test]
752 fn test_accum_reducer_numeric() {
753 assert_eq!(
754 sum::<u64>().reducer_spec().unwrap().typehash,
755 <SumReducer::<u64> as Named>::typehash(),
756 );
757 assert_eq!(
758 sum::<i64>().reducer_spec().unwrap().typehash,
759 <SumReducer::<i64> as Named>::typehash(),
760 );
761
762 assert_eq!(
763 join_semilattice::<Min<u64>>()
764 .reducer_spec()
765 .unwrap()
766 .typehash,
767 <SemilatticeReducer<Min<u64>> as Named>::typehash(),
768 );
769 assert_eq!(
770 join_semilattice::<Min<i64>>()
771 .reducer_spec()
772 .unwrap()
773 .typehash,
774 <SemilatticeReducer<Min<i64>> as Named>::typehash(),
775 );
776
777 assert_eq!(
778 join_semilattice::<Max<u64>>()
779 .reducer_spec()
780 .unwrap()
781 .typehash,
782 <SemilatticeReducer<Max<u64>> as Named>::typehash(),
783 );
784 assert_eq!(
785 join_semilattice::<Max<i64>>()
786 .reducer_spec()
787 .unwrap()
788 .typehash,
789 <SemilatticeReducer<Max<i64>> as Named>::typehash(),
790 );
791 }
792
793 #[test]
794 fn test_accum_reducer_watermark() {
795 fn verify<T: Clone + PartialEq + Named + 'static>() {
796 assert_eq!(
797 join_semilattice::<WatermarkUpdate<T>>()
798 .reducer_spec()
799 .unwrap()
800 .typehash,
801 <SemilatticeReducer<WatermarkUpdate<T>> as Named>::typehash(),
802 );
803 }
804 verify::<u64>();
805 verify::<i64>();
806 }
807
808 #[test]
809 fn test_watermark_accumulator() {
810 let accumulator = join_semilattice::<WatermarkUpdate<u64>>();
811 let ranks_values_expectations = [
812 (0, 1003, 0, 1003),
814 (1, 1002, 1, 1002),
815 (2, 1001, 2, 1001),
816 (0, 100, 3, 100),
818 (1, 101, 4, 100),
819 (2, 102, 5, 100),
820 (0, 100, 6, 100),
822 (1, 101, 7, 100),
823 (2, 102, 8, 100),
824 (0, 1000, 9, 101),
826 (1, 1100, 10, 102),
828 (2, 1200, 11, 1000),
830 (0, 1001, 12, 1001),
832 (1, 1101, 13, 1001),
833 (2, 1201, 14, 1001),
834 (2, 102, 15, 102),
836 (1, 101, 16, 101),
837 (0, 100, 17, 100),
838 ];
839 let mut state = WatermarkUpdate::default();
840 for (rank, value, ts, expected) in ranks_values_expectations {
841 accumulator
842 .accumulate(&mut state, WatermarkUpdate::from((rank, value, ts)))
843 .unwrap();
844 assert_eq!(
845 state.get(),
846 &expected,
847 "rank is {rank}; value is {value}; ts is {ts}"
848 );
849 }
850 }
851
852 #[test]
853 fn test_comm_reducer_gcounter() {
854 let updates = serialize::<GCounterUpdate>(vec![
856 GCounterUpdate::from((0, 10)),
857 GCounterUpdate::from((1, 20)),
858 GCounterUpdate::from((0, 15)), GCounterUpdate::from((2, 5)),
860 GCounterUpdate::from((1, 25)), ]);
862
863 let typehash = <SemilatticeReducer<GCounterUpdate> as Named>::typehash();
864 let result = resolve_reducer(typehash, None)
865 .unwrap()
866 .unwrap()
867 .reduce_updates(updates)
868 .unwrap()
869 .deserialized::<GCounterUpdate>()
870 .unwrap();
871
872 assert_eq!(result.get_rank(0), Some(15));
874 assert_eq!(result.get_rank(1), Some(25));
875 assert_eq!(result.get_rank(2), Some(5));
876 assert_eq!(result.num_ranks(), 3);
877 assert_eq!(result.get(), 45);
879 }
880
881 #[test]
882 fn test_accum_reducer_gcounter() {
883 assert_eq!(
884 join_semilattice::<GCounterUpdate>()
885 .reducer_spec()
886 .unwrap()
887 .typehash,
888 <SemilatticeReducer<GCounterUpdate> as Named>::typehash(),
889 );
890 }
891
892 #[test]
893 fn test_gcounter_accumulator() {
894 let accumulator = join_semilattice::<GCounterUpdate>();
895 let ranks_counts_expectations: [(usize, u64, u64); 17] = [
897 (0, 1000, 1000),
899 (1, 100, 1100),
900 (2, 10, 1110),
901 (2, 20, 1120),
903 (1, 200, 1220),
904 (0, 2000, 2220),
905 (0, 2000, 2220),
907 (1, 200, 2220),
908 (2, 20, 2220),
909 (0, 1, 2220),
911 (1, 1, 2220),
912 (2, 1, 2220),
913 (2, 5000, 7200), (1, 6000, 13000), (0, 10000, 21000), (0, 10001, 21001),
921 (1, 6001, 21002),
922 ];
923 let mut state = GCounterUpdate::default();
924 for (rank, count, expected) in ranks_counts_expectations {
925 accumulator
926 .accumulate(&mut state, GCounterUpdate::from((rank, count)))
927 .unwrap();
928 assert_eq!(state.get(), expected, "rank is {rank}; count is {count}");
929 }
930 assert_eq!(state.get_rank(0), Some(10001));
932 assert_eq!(state.get_rank(1), Some(6001));
933 assert_eq!(state.get_rank(2), Some(5000));
934 assert_eq!(state.get_rank(3), None);
935 assert_eq!(state.num_ranks(), 3);
936 }
937
938 #[test]
939 fn test_gcounter_commutativity() {
940 let updates = [
942 GCounterUpdate::from((0, 10)),
943 GCounterUpdate::from((1, 20)),
944 GCounterUpdate::from((0, 15)),
945 GCounterUpdate::from((2, 5)),
946 GCounterUpdate::from((1, 25)),
947 ];
948
949 let accumulator = join_semilattice::<GCounterUpdate>();
951 let mut forward = GCounterUpdate::default();
952 for update in updates.iter().cloned() {
953 accumulator.accumulate(&mut forward, update).unwrap();
954 }
955
956 let mut reverse = GCounterUpdate::default();
958 for update in updates.iter().rev().cloned() {
959 accumulator.accumulate(&mut reverse, update).unwrap();
960 }
961
962 assert_eq!(forward.get(), reverse.get());
963 assert_eq!(forward.get(), 45); assert_eq!(forward.get_rank(0), reverse.get_rank(0));
965 assert_eq!(forward.get_rank(1), reverse.get_rank(1));
966 assert_eq!(forward.get_rank(2), reverse.get_rank(2));
967 }
968
969 #[test]
970 fn test_comm_reducer_pncounter() {
971 let updates = serialize::<PNCounterUpdate>(vec![
973 PNCounterUpdate::inc(0, 10),
974 PNCounterUpdate::inc(1, 20),
975 PNCounterUpdate::dec(0, 5),
976 PNCounterUpdate::inc(0, 15), PNCounterUpdate::dec(1, 8),
978 PNCounterUpdate::dec(0, 7), ]);
980
981 let typehash = <SemilatticeReducer<PNCounterUpdate> as Named>::typehash();
982 let result = resolve_reducer(typehash, None)
983 .unwrap()
984 .unwrap()
985 .reduce_updates(updates)
986 .unwrap()
987 .deserialized::<PNCounterUpdate>()
988 .unwrap();
989
990 assert_eq!(result.get(), 20);
995 assert_eq!(result.num_inc_ranks(), 2);
996 assert_eq!(result.num_dec_ranks(), 2);
997 }
998
999 #[test]
1000 fn test_accum_reducer_pncounter() {
1001 assert_eq!(
1002 join_semilattice::<PNCounterUpdate>()
1003 .reducer_spec()
1004 .unwrap()
1005 .typehash,
1006 <SemilatticeReducer<PNCounterUpdate> as Named>::typehash(),
1007 );
1008 }
1009
1010 #[test]
1011 fn test_pncounter_accumulator() {
1012 let accumulator = join_semilattice::<PNCounterUpdate>();
1013 #[derive(Clone, Copy, Debug)]
1015 enum Op {
1016 Inc(usize, u64),
1017 Dec(usize, u64),
1018 }
1019 use Op::*;
1020
1021 let ops_expectations = [
1025 (Inc(0, 100), 100), (Inc(1, 50), 150), (Inc(2, 25), 175), (Dec(0, 10), 165), (Dec(1, 5), 160), (Dec(2, 2), 158), (Inc(0, 200), 258), (Inc(1, 100), 308), (Inc(2, 50), 333), (Dec(0, 20), 323), (Dec(1, 15), 313), (Dec(2, 5), 310), (Inc(0, 200), 310),
1043 (Dec(1, 15), 310),
1044 (Inc(0, 1), 310),
1046 (Dec(0, 1), 310),
1047 (Dec(2, 60), 255), (Dec(1, 120), 150), (Inc(2, 60), 160), (Inc(0, 1000), 960), (Dec(2, 100), 920), ];
1056
1057 let mut state = PNCounterUpdate::default();
1058 for (i, (op, expected)) in ops_expectations.iter().enumerate() {
1059 let update = match op {
1060 Inc(rank, delta) => PNCounterUpdate::inc(*rank, *delta),
1061 Dec(rank, delta) => PNCounterUpdate::dec(*rank, *delta),
1062 };
1063 accumulator.accumulate(&mut state, update).unwrap();
1064 assert_eq!(state.get(), *expected, "step {i}: {op:?}");
1065 }
1066
1067 assert_eq!(state.num_inc_ranks(), 3);
1069 assert_eq!(state.num_dec_ranks(), 3);
1070 }
1071
1072 #[test]
1073 fn test_pncounter_commutativity() {
1074 let updates = [
1076 PNCounterUpdate::inc(0, 10),
1077 PNCounterUpdate::inc(1, 20),
1078 PNCounterUpdate::dec(0, 5),
1079 PNCounterUpdate::inc(0, 15),
1080 PNCounterUpdate::dec(1, 8),
1081 PNCounterUpdate::dec(2, 3),
1082 PNCounterUpdate::inc(2, 12),
1083 ];
1084
1085 let accumulator = join_semilattice::<PNCounterUpdate>();
1087 let mut forward = PNCounterUpdate::default();
1088 for update in updates.iter().cloned() {
1089 accumulator.accumulate(&mut forward, update).unwrap();
1090 }
1091
1092 let mut reverse = PNCounterUpdate::default();
1094 for update in updates.iter().rev().cloned() {
1095 accumulator.accumulate(&mut reverse, update).unwrap();
1096 }
1097
1098 assert_eq!(forward.get(), reverse.get());
1099 assert_eq!(forward.get(), 31); assert_eq!(forward.num_inc_ranks(), reverse.num_inc_ranks());
1101 assert_eq!(forward.num_dec_ranks(), reverse.num_dec_ranks());
1102 }
1103}