1use std::collections::HashMap;
12use std::marker::PhantomData;
13use std::sync::OnceLock;
14use std::time::Duration;
15
16use algebra::JoinSemilattice;
17use enum_as_inner::EnumAsInner;
18use serde::Deserialize;
19use serde::Serialize;
20use serde::de::DeserializeOwned;
21use typeuri::Named;
22
23use crate::config;
25use crate::reference;
26
27pub trait Accumulator {
29 type State;
31 type Update;
34
35 fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()>;
37
38 fn reducer_spec(&self) -> Option<ReducerSpec>;
40}
41
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, typeuri::Named)]
44pub struct ReducerSpec {
45 pub typehash: u64,
47 pub builder_params: Option<wirevalue::Any>,
49}
50wirevalue::register_type!(ReducerSpec);
51
52#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Default)]
54pub struct StreamingReducerOpts {
55 pub max_update_interval: Option<Duration>,
58 pub initial_update_interval: Option<Duration>,
62}
63
64#[derive(
66 Debug,
67 Clone,
68 PartialEq,
69 Serialize,
70 Deserialize,
71 EnumAsInner,
72 typeuri::Named
73)]
74pub enum ReducerMode {
75 Streaming(StreamingReducerOpts),
77 Once(usize),
79}
80
81impl Default for ReducerMode {
82 fn default() -> Self {
83 ReducerMode::Streaming(StreamingReducerOpts::default())
84 }
85}
86
87impl ReducerMode {
88 pub(crate) fn max_update_interval(&self) -> Duration {
89 match self {
90 ReducerMode::Streaming(opts) => opts
91 .max_update_interval
92 .unwrap_or(hyperactor_config::global::get(config::SPLIT_MAX_BUFFER_AGE)),
93 ReducerMode::Once(_) => Duration::MAX,
94 }
95 }
96
97 pub(crate) fn initial_update_interval(&self) -> Duration {
98 match self {
99 ReducerMode::Streaming(opts) => opts
100 .initial_update_interval
101 .unwrap_or(Duration::from_millis(1)),
102 ReducerMode::Once(_) => Duration::MAX,
103 }
104 }
105}
106
107pub trait CommReducer {
113 type Update;
115
116 fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update>;
118}
119
120pub trait ErasedCommReducer {
122 fn reduce_erased(
124 &self,
125 left: &wirevalue::Any,
126 right: &wirevalue::Any,
127 ) -> anyhow::Result<wirevalue::Any>;
128
129 fn reduce_updates(
132 &self,
133 updates: Vec<wirevalue::Any>,
134 ) -> Result<wirevalue::Any, (anyhow::Error, Vec<wirevalue::Any>)> {
135 if updates.is_empty() {
136 return Err((anyhow::anyhow!("empty updates"), updates));
137 }
138 if updates.len() == 1 {
139 return Ok(updates.into_iter().next().expect("checked above"));
140 }
141
142 let mut iter = updates.iter();
143 let first = iter.next().unwrap();
144 let second = iter.next().unwrap();
145 let init = match self.reduce_erased(first, second) {
146 Ok(v) => v,
147 Err(e) => return Err((e, updates)),
148 };
149 let reduced = match iter.try_fold(init, |acc, e| self.reduce_erased(&acc, e)) {
150 Ok(v) => v,
151 Err(e) => return Err((e, updates)),
152 };
153 Ok(reduced)
154 }
155
156 fn typehash(&self) -> u64;
158}
159
160impl<R, T> ErasedCommReducer for R
161where
162 R: CommReducer<Update = T> + Named,
163 T: Serialize + DeserializeOwned + Named,
164{
165 fn reduce_erased(
166 &self,
167 left: &wirevalue::Any,
168 right: &wirevalue::Any,
169 ) -> anyhow::Result<wirevalue::Any> {
170 let left = left.deserialized::<T>()?;
171 let right = right.deserialized::<T>()?;
172 let result = self.reduce(left, right)?;
173 Ok(wirevalue::Any::serialize(&result)?)
174 }
175
176 fn typehash(&self) -> u64 {
177 R::typehash()
178 }
179}
180
181pub struct ReducerFactory {
186 pub typehash_f: fn() -> u64,
189 pub builder_f: fn(
191 Option<wirevalue::Any>,
192 ) -> anyhow::Result<Box<dyn ErasedCommReducer + Sync + Send + 'static>>,
193}
194
195inventory::collect!(ReducerFactory);
196
197inventory::submit! {
198 ReducerFactory {
199 typehash_f: <SumReducer<i64> as Named>::typehash,
200 builder_f: |_| Ok(Box::new(SumReducer::<i64>(PhantomData))),
201 }
202}
203inventory::submit! {
204 ReducerFactory {
205 typehash_f: <SumReducer<u64> as Named>::typehash,
206 builder_f: |_| Ok(Box::new(SumReducer::<u64>(PhantomData))),
207 }
208}
209inventory::submit! {
210 ReducerFactory {
211 typehash_f: <SemilatticeReducer<Max<i64>> as Named>::typehash,
212 builder_f: |_| Ok(Box::new(SemilatticeReducer::<Max<i64>>(PhantomData))),
213 }
214}
215inventory::submit! {
216 ReducerFactory {
217 typehash_f: <SemilatticeReducer<Max<u64>> as Named>::typehash,
218 builder_f: |_| Ok(Box::new(SemilatticeReducer::<Max<u64>>(PhantomData))),
219 }
220}
221inventory::submit! {
222 ReducerFactory {
223 typehash_f: <SemilatticeReducer<Min<i64>> as Named>::typehash,
224 builder_f: |_| Ok(Box::new(SemilatticeReducer::<Min<i64>>(PhantomData))),
225 }
226}
227inventory::submit! {
228 ReducerFactory {
229 typehash_f: <SemilatticeReducer<Min<u64>> as Named>::typehash,
230 builder_f: |_| Ok(Box::new(SemilatticeReducer::<Min<u64>>(PhantomData))),
231 }
232}
233inventory::submit! {
234 ReducerFactory {
235 typehash_f: <SemilatticeReducer<WatermarkUpdate<i64>> as Named>::typehash,
236 builder_f: |_| Ok(Box::new(SemilatticeReducer::<WatermarkUpdate<i64>>(PhantomData))),
237 }
238}
239inventory::submit! {
240 ReducerFactory {
241 typehash_f: <SemilatticeReducer<WatermarkUpdate<u64>> as Named>::typehash,
242 builder_f: |_| Ok(Box::new(SemilatticeReducer::<WatermarkUpdate<u64>>(PhantomData))),
243 }
244}
245inventory::submit! {
246 ReducerFactory {
247 typehash_f: <SemilatticeReducer<GCounterUpdate> as Named>::typehash,
248 builder_f: |_| Ok(Box::new(SemilatticeReducer::<GCounterUpdate>(PhantomData))),
249 }
250}
251inventory::submit! {
252 ReducerFactory {
253 typehash_f: <SemilatticeReducer<PNCounterUpdate> as Named>::typehash,
254 builder_f: |_| Ok(Box::new(SemilatticeReducer::<PNCounterUpdate>(PhantomData))),
255 }
256}
257
258pub(crate) fn resolve_reducer(
261 typehash: u64,
262 builder_params: Option<wirevalue::Any>,
263) -> anyhow::Result<Option<Box<dyn ErasedCommReducer + Sync + Send + 'static>>> {
264 static FACTORY_MAP: OnceLock<HashMap<u64, &'static ReducerFactory>> = OnceLock::new();
265 let factories = FACTORY_MAP.get_or_init(|| {
266 let mut map = HashMap::new();
267 for factory in inventory::iter::<ReducerFactory> {
268 map.insert((factory.typehash_f)(), factory);
269 }
270 map
271 });
272
273 factories
274 .get(&typehash)
275 .map(|f| (f.builder_f)(builder_params))
276 .transpose()
277}
278
279#[derive(typeuri::Named)]
280struct SumReducer<T>(PhantomData<T>);
281
282impl<T: std::ops::Add<Output = T> + Copy + 'static> CommReducer for SumReducer<T> {
283 type Update = T;
284
285 fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
286 Ok(left + right)
287 }
288}
289
290struct SumAccumulator<T>(PhantomData<T>);
293
294impl<T: std::ops::Add<Output = T> + Copy + Named + 'static> Accumulator for SumAccumulator<T> {
295 type State = T;
296 type Update = T;
297
298 fn accumulate(&self, state: &mut T, update: T) -> anyhow::Result<()> {
299 *state = *state + update;
300 Ok(())
301 }
302
303 fn reducer_spec(&self) -> Option<ReducerSpec> {
304 Some(ReducerSpec {
305 typehash: <SumReducer<T> as Named>::typehash(),
306 builder_params: None,
307 })
308 }
309}
310
311pub fn sum<T: std::ops::Add<Output = T> + Copy + Named + 'static>()
336-> impl Accumulator<State = T, Update = T> {
337 SumAccumulator(PhantomData)
338}
339
340#[derive(typeuri::Named)]
342struct SemilatticeReducer<L>(PhantomData<L>);
343
344impl<L: JoinSemilattice + Clone> CommReducer for SemilatticeReducer<L> {
345 type Update = L;
346
347 fn reduce(&self, left: L, right: L) -> anyhow::Result<L> {
348 Ok(left.join(&right))
349 }
350}
351
352struct SemilatticeAccumulator<L>(PhantomData<L>);
354
355impl<L: JoinSemilattice + Clone + Named + 'static> Accumulator for SemilatticeAccumulator<L> {
356 type State = L;
357 type Update = L;
358
359 fn accumulate(&self, state: &mut L, update: L) -> anyhow::Result<()> {
360 *state = state.join(&update);
361 Ok(())
362 }
363
364 fn reducer_spec(&self) -> Option<ReducerSpec> {
365 Some(ReducerSpec {
366 typehash: <SemilatticeReducer<L> as Named>::typehash(),
367 builder_params: None,
368 })
369 }
370}
371
372pub fn join_semilattice<L: JoinSemilattice + Clone + Named + 'static>()
386-> impl Accumulator<State = L, Update = L> {
387 SemilatticeAccumulator::<L>(PhantomData)
388}
389
390pub use algebra::Max;
392pub use algebra::Min;
394
395#[derive(Default, Debug, Clone, Serialize, Deserialize, typeuri::Named)]
417pub struct WatermarkUpdate<T>(algebra::LatticeMap<reference::Index, algebra::LWW<T>>);
418
419impl<T: Ord + Clone> WatermarkUpdate<T> {
420 pub fn get(&self) -> &T {
425 self.0
426 .iter()
427 .map(|(_, lww)| &lww.value)
428 .min()
429 .expect("watermark should have been initialized")
430 }
431
432 pub fn get_rank(&self, rank: reference::Index) -> Option<&T> {
434 self.0.get(&rank).map(|lww| &lww.value)
435 }
436
437 pub fn num_ranks(&self) -> usize {
439 self.0.len()
440 }
441}
442
443impl<T> From<(reference::Index, T, u64)> for WatermarkUpdate<T> {
444 fn from((rank, value, timestamp): (reference::Index, T, u64)) -> Self {
450 let mut map = algebra::LatticeMap::new();
451 map.insert(rank, algebra::LWW::new(value, timestamp, rank as u64));
453 Self(map)
454 }
455}
456
457impl<T: Clone + PartialEq> JoinSemilattice for WatermarkUpdate<T> {
458 fn join(&self, other: &Self) -> Self {
459 WatermarkUpdate(self.0.join(&other.0))
460 }
461}
462
463#[derive(Default, Debug, Clone, Serialize, Deserialize, typeuri::Named)]
475pub struct GCounterUpdate(algebra::LatticeMap<reference::Index, Max<u64>>);
476wirevalue::register_type!(GCounterUpdate);
477
478impl GCounterUpdate {
479 pub fn get(&self) -> u64 {
481 self.0.iter().map(|(_, max)| max.0).sum()
482 }
483
484 pub fn get_rank(&self, rank: reference::Index) -> Option<u64> {
486 self.0.get(&rank).map(|max| max.0)
487 }
488
489 pub fn num_ranks(&self) -> usize {
491 self.0.len()
492 }
493}
494
495impl From<(reference::Index, u64)> for GCounterUpdate {
496 fn from((rank, count): (reference::Index, u64)) -> Self {
498 let mut map = algebra::LatticeMap::new();
499 map.insert(rank, Max(count));
500 Self(map)
501 }
502}
503
504impl JoinSemilattice for GCounterUpdate {
505 fn join(&self, other: &Self) -> Self {
506 GCounterUpdate(self.0.join(&other.0))
507 }
508}
509
510#[derive(Default, Debug, Clone, Serialize, Deserialize, typeuri::Named)]
517pub struct PNCounterUpdate {
518 p: algebra::LatticeMap<reference::Index, Max<u64>>,
519 n: algebra::LatticeMap<reference::Index, Max<u64>>,
520}
521wirevalue::register_type!(PNCounterUpdate);
522
523impl PNCounterUpdate {
524 pub fn get(&self) -> i64 {
526 let p: u64 = self.p.iter().map(|(_, m)| m.0).sum();
527 let n: u64 = self.n.iter().map(|(_, m)| m.0).sum();
528 p as i64 - n as i64
529 }
530
531 pub fn inc(rank: reference::Index, delta: u64) -> Self {
533 let mut p = algebra::LatticeMap::new();
534 p.insert(rank, Max(delta));
535 Self {
536 p,
537 n: algebra::LatticeMap::new(),
538 }
539 }
540
541 pub fn dec(rank: reference::Index, delta: u64) -> Self {
543 let mut n = algebra::LatticeMap::new();
544 n.insert(rank, Max(delta));
545 Self {
546 p: algebra::LatticeMap::new(),
547 n,
548 }
549 }
550
551 pub fn num_inc_ranks(&self) -> usize {
553 self.p.len()
554 }
555
556 pub fn num_dec_ranks(&self) -> usize {
558 self.n.len()
559 }
560}
561
562impl JoinSemilattice for PNCounterUpdate {
563 fn join(&self, other: &Self) -> Self {
564 PNCounterUpdate {
565 p: self.p.join(&other.p),
566 n: self.n.join(&other.n),
567 }
568 }
569}
570
571#[cfg(test)]
572mod tests {
573 use std::fmt::Debug;
574
575 use maplit::hashmap;
576 use typeuri::Named;
577
578 use super::*;
579
580 fn serialize<T: Serialize + Named>(values: Vec<T>) -> Vec<wirevalue::Any> {
581 values
582 .into_iter()
583 .map(|n| wirevalue::Any::serialize(&n).unwrap())
584 .collect()
585 }
586
587 #[test]
588 fn test_comm_reducer_numeric() {
589 let u64_numbers_sum: Vec<_> = serialize(vec![1u64, 3u64, 1100u64]);
590 let i64_numbers_sum: Vec<_> = serialize(vec![-123i64, 33i64, 110i64]);
591 let u64_numbers_max: Vec<_> = serialize(vec![Max(1u64), Max(3u64), Max(1100u64)]);
592 let i64_numbers_max: Vec<_> = serialize(vec![Max(-123i64), Max(33i64), Max(110i64)]);
593 let u64_numbers_min: Vec<_> = serialize(vec![Min(1u64), Min(3u64), Min(1100u64)]);
594 let i64_numbers_min: Vec<_> = serialize(vec![Min(-123i64), Min(33i64), Min(110i64)]);
595 {
596 let typehash = <SemilatticeReducer<Max<u64>> as Named>::typehash();
597 assert_eq!(
598 resolve_reducer(typehash, None)
599 .unwrap()
600 .unwrap()
601 .reduce_updates(u64_numbers_max.clone())
602 .unwrap()
603 .deserialized::<Max<u64>>()
604 .unwrap(),
605 Max(1100u64),
606 );
607
608 let typehash = <SemilatticeReducer<Min<u64>> as Named>::typehash();
609 assert_eq!(
610 resolve_reducer(typehash, None)
611 .unwrap()
612 .unwrap()
613 .reduce_updates(u64_numbers_min.clone())
614 .unwrap()
615 .deserialized::<Min<u64>>()
616 .unwrap(),
617 Min(1u64),
618 );
619
620 let typehash = <SumReducer<u64> as Named>::typehash();
621 assert_eq!(
622 resolve_reducer(typehash, None)
623 .unwrap()
624 .unwrap()
625 .reduce_updates(u64_numbers_sum)
626 .unwrap()
627 .deserialized::<u64>()
628 .unwrap(),
629 1104u64,
630 );
631 }
632
633 {
634 let typehash = <SemilatticeReducer<Max<i64>> as Named>::typehash();
635 assert_eq!(
636 resolve_reducer(typehash, None)
637 .unwrap()
638 .unwrap()
639 .reduce_updates(i64_numbers_max.clone())
640 .unwrap()
641 .deserialized::<Max<i64>>()
642 .unwrap(),
643 Max(110i64),
644 );
645
646 let typehash = <SemilatticeReducer<Min<i64>> as Named>::typehash();
647 assert_eq!(
648 resolve_reducer(typehash, None)
649 .unwrap()
650 .unwrap()
651 .reduce_updates(i64_numbers_min.clone())
652 .unwrap()
653 .deserialized::<Min<i64>>()
654 .unwrap(),
655 Min(-123i64),
656 );
657
658 let typehash = <SumReducer<i64> as Named>::typehash();
659 assert_eq!(
660 resolve_reducer(typehash, None)
661 .unwrap()
662 .unwrap()
663 .reduce_updates(i64_numbers_sum)
664 .unwrap()
665 .deserialized::<i64>()
666 .unwrap(),
667 20i64,
668 );
669 }
670 }
671
672 #[test]
673 fn test_comm_reducer_watermark() {
674 let u64_updates = serialize::<WatermarkUpdate<u64>>(
676 vec![
677 (1, 1, 0), (0, 2, 1), (0, 1, 2), (3, 35, 3), (0, 9, 4), (1, 10, 5), (3, 32, 6), (3, 0, 7), (3, 321, 8), ]
687 .into_iter()
688 .map(|(k, v, ts)| WatermarkUpdate::from((k, v, ts)))
689 .collect(),
690 );
691 let i64_updates: Vec<_> = serialize::<WatermarkUpdate<i64>>(
692 vec![
693 (0, 2, 0), (1, 1, 1), (3, 35, 2), (0, 1, 3), (1, -10, 4), (3, 32, 5), (3, 0, 6), (3, -99, 7), (0, -9, 8), ]
703 .into_iter()
704 .map(WatermarkUpdate::from)
705 .collect(),
706 );
707
708 fn verify<T: Ord + Clone + PartialEq + DeserializeOwned + Debug + Named>(
709 updates: Vec<wirevalue::Any>,
710 expected: HashMap<reference::Index, T>,
711 ) {
712 let typehash = <SemilatticeReducer<WatermarkUpdate<T>> as Named>::typehash();
713 let result = resolve_reducer(typehash, None)
714 .unwrap()
715 .unwrap()
716 .reduce_updates(updates)
717 .unwrap()
718 .deserialized::<WatermarkUpdate<T>>()
719 .unwrap();
720
721 for (rank, expected_value) in &expected {
723 assert_eq!(
724 result.get_rank(*rank).unwrap(),
725 expected_value,
726 "Mismatch for rank {rank}"
727 );
728 }
729 assert_eq!(result.num_ranks(), expected.len());
731 }
732
733 verify::<i64>(
734 i64_updates,
735 hashmap! {
736 0 => -9, 1 => -10, 3 => -99, },
740 );
741
742 verify::<u64>(
743 u64_updates,
744 hashmap! {
745 0 => 9, 1 => 10, 3 => 321, },
749 );
750 }
751
752 #[test]
753 fn test_accum_reducer_numeric() {
754 assert_eq!(
755 sum::<u64>().reducer_spec().unwrap().typehash,
756 <SumReducer::<u64> as Named>::typehash(),
757 );
758 assert_eq!(
759 sum::<i64>().reducer_spec().unwrap().typehash,
760 <SumReducer::<i64> as Named>::typehash(),
761 );
762
763 assert_eq!(
764 join_semilattice::<Min<u64>>()
765 .reducer_spec()
766 .unwrap()
767 .typehash,
768 <SemilatticeReducer<Min<u64>> as Named>::typehash(),
769 );
770 assert_eq!(
771 join_semilattice::<Min<i64>>()
772 .reducer_spec()
773 .unwrap()
774 .typehash,
775 <SemilatticeReducer<Min<i64>> as Named>::typehash(),
776 );
777
778 assert_eq!(
779 join_semilattice::<Max<u64>>()
780 .reducer_spec()
781 .unwrap()
782 .typehash,
783 <SemilatticeReducer<Max<u64>> as Named>::typehash(),
784 );
785 assert_eq!(
786 join_semilattice::<Max<i64>>()
787 .reducer_spec()
788 .unwrap()
789 .typehash,
790 <SemilatticeReducer<Max<i64>> as Named>::typehash(),
791 );
792 }
793
794 #[test]
795 fn test_accum_reducer_watermark() {
796 fn verify<T: Clone + PartialEq + Named + 'static>() {
797 assert_eq!(
798 join_semilattice::<WatermarkUpdate<T>>()
799 .reducer_spec()
800 .unwrap()
801 .typehash,
802 <SemilatticeReducer<WatermarkUpdate<T>> as Named>::typehash(),
803 );
804 }
805 verify::<u64>();
806 verify::<i64>();
807 }
808
809 #[test]
810 fn test_watermark_accumulator() {
811 let accumulator = join_semilattice::<WatermarkUpdate<u64>>();
812 let ranks_values_expectations = [
813 (0, 1003, 0, 1003),
815 (1, 1002, 1, 1002),
816 (2, 1001, 2, 1001),
817 (0, 100, 3, 100),
819 (1, 101, 4, 100),
820 (2, 102, 5, 100),
821 (0, 100, 6, 100),
823 (1, 101, 7, 100),
824 (2, 102, 8, 100),
825 (0, 1000, 9, 101),
827 (1, 1100, 10, 102),
829 (2, 1200, 11, 1000),
831 (0, 1001, 12, 1001),
833 (1, 1101, 13, 1001),
834 (2, 1201, 14, 1001),
835 (2, 102, 15, 102),
837 (1, 101, 16, 101),
838 (0, 100, 17, 100),
839 ];
840 let mut state = WatermarkUpdate::default();
841 for (rank, value, ts, expected) in ranks_values_expectations {
842 accumulator
843 .accumulate(&mut state, WatermarkUpdate::from((rank, value, ts)))
844 .unwrap();
845 assert_eq!(
846 state.get(),
847 &expected,
848 "rank is {rank}; value is {value}; ts is {ts}"
849 );
850 }
851 }
852
853 #[test]
854 fn test_comm_reducer_gcounter() {
855 let updates = serialize::<GCounterUpdate>(vec![
857 GCounterUpdate::from((0, 10)),
858 GCounterUpdate::from((1, 20)),
859 GCounterUpdate::from((0, 15)), GCounterUpdate::from((2, 5)),
861 GCounterUpdate::from((1, 25)), ]);
863
864 let typehash = <SemilatticeReducer<GCounterUpdate> as Named>::typehash();
865 let result = resolve_reducer(typehash, None)
866 .unwrap()
867 .unwrap()
868 .reduce_updates(updates)
869 .unwrap()
870 .deserialized::<GCounterUpdate>()
871 .unwrap();
872
873 assert_eq!(result.get_rank(0), Some(15));
875 assert_eq!(result.get_rank(1), Some(25));
876 assert_eq!(result.get_rank(2), Some(5));
877 assert_eq!(result.num_ranks(), 3);
878 assert_eq!(result.get(), 45);
880 }
881
882 #[test]
883 fn test_accum_reducer_gcounter() {
884 assert_eq!(
885 join_semilattice::<GCounterUpdate>()
886 .reducer_spec()
887 .unwrap()
888 .typehash,
889 <SemilatticeReducer<GCounterUpdate> as Named>::typehash(),
890 );
891 }
892
893 #[test]
894 fn test_gcounter_accumulator() {
895 let accumulator = join_semilattice::<GCounterUpdate>();
896 let ranks_counts_expectations: [(reference::Index, u64, u64); 17] = [
898 (0, 1000, 1000),
900 (1, 100, 1100),
901 (2, 10, 1110),
902 (2, 20, 1120),
904 (1, 200, 1220),
905 (0, 2000, 2220),
906 (0, 2000, 2220),
908 (1, 200, 2220),
909 (2, 20, 2220),
910 (0, 1, 2220),
912 (1, 1, 2220),
913 (2, 1, 2220),
914 (2, 5000, 7200), (1, 6000, 13000), (0, 10000, 21000), (0, 10001, 21001),
922 (1, 6001, 21002),
923 ];
924 let mut state = GCounterUpdate::default();
925 for (rank, count, expected) in ranks_counts_expectations {
926 accumulator
927 .accumulate(&mut state, GCounterUpdate::from((rank, count)))
928 .unwrap();
929 assert_eq!(state.get(), expected, "rank is {rank}; count is {count}");
930 }
931 assert_eq!(state.get_rank(0), Some(10001));
933 assert_eq!(state.get_rank(1), Some(6001));
934 assert_eq!(state.get_rank(2), Some(5000));
935 assert_eq!(state.get_rank(3), None);
936 assert_eq!(state.num_ranks(), 3);
937 }
938
939 #[test]
940 fn test_gcounter_commutativity() {
941 let updates = [
943 GCounterUpdate::from((0, 10)),
944 GCounterUpdate::from((1, 20)),
945 GCounterUpdate::from((0, 15)),
946 GCounterUpdate::from((2, 5)),
947 GCounterUpdate::from((1, 25)),
948 ];
949
950 let accumulator = join_semilattice::<GCounterUpdate>();
952 let mut forward = GCounterUpdate::default();
953 for update in updates.iter().cloned() {
954 accumulator.accumulate(&mut forward, update).unwrap();
955 }
956
957 let mut reverse = GCounterUpdate::default();
959 for update in updates.iter().rev().cloned() {
960 accumulator.accumulate(&mut reverse, update).unwrap();
961 }
962
963 assert_eq!(forward.get(), reverse.get());
964 assert_eq!(forward.get(), 45); assert_eq!(forward.get_rank(0), reverse.get_rank(0));
966 assert_eq!(forward.get_rank(1), reverse.get_rank(1));
967 assert_eq!(forward.get_rank(2), reverse.get_rank(2));
968 }
969
970 #[test]
971 fn test_comm_reducer_pncounter() {
972 let updates = serialize::<PNCounterUpdate>(vec![
974 PNCounterUpdate::inc(0, 10),
975 PNCounterUpdate::inc(1, 20),
976 PNCounterUpdate::dec(0, 5),
977 PNCounterUpdate::inc(0, 15), PNCounterUpdate::dec(1, 8),
979 PNCounterUpdate::dec(0, 7), ]);
981
982 let typehash = <SemilatticeReducer<PNCounterUpdate> as Named>::typehash();
983 let result = resolve_reducer(typehash, None)
984 .unwrap()
985 .unwrap()
986 .reduce_updates(updates)
987 .unwrap()
988 .deserialized::<PNCounterUpdate>()
989 .unwrap();
990
991 assert_eq!(result.get(), 20);
996 assert_eq!(result.num_inc_ranks(), 2);
997 assert_eq!(result.num_dec_ranks(), 2);
998 }
999
1000 #[test]
1001 fn test_accum_reducer_pncounter() {
1002 assert_eq!(
1003 join_semilattice::<PNCounterUpdate>()
1004 .reducer_spec()
1005 .unwrap()
1006 .typehash,
1007 <SemilatticeReducer<PNCounterUpdate> as Named>::typehash(),
1008 );
1009 }
1010
1011 #[test]
1012 fn test_pncounter_accumulator() {
1013 let accumulator = join_semilattice::<PNCounterUpdate>();
1014 #[derive(Clone, Copy, Debug)]
1016 enum Op {
1017 Inc(reference::Index, u64),
1018 Dec(reference::Index, u64),
1019 }
1020 use Op::*;
1021
1022 let ops_expectations = [
1026 (Inc(0, 100), 100), (Inc(1, 50), 150), (Inc(2, 25), 175), (Dec(0, 10), 165), (Dec(1, 5), 160), (Dec(2, 2), 158), (Inc(0, 200), 258), (Inc(1, 100), 308), (Inc(2, 50), 333), (Dec(0, 20), 323), (Dec(1, 15), 313), (Dec(2, 5), 310), (Inc(0, 200), 310),
1044 (Dec(1, 15), 310),
1045 (Inc(0, 1), 310),
1047 (Dec(0, 1), 310),
1048 (Dec(2, 60), 255), (Dec(1, 120), 150), (Inc(2, 60), 160), (Inc(0, 1000), 960), (Dec(2, 100), 920), ];
1057
1058 let mut state = PNCounterUpdate::default();
1059 for (i, (op, expected)) in ops_expectations.iter().enumerate() {
1060 let update = match op {
1061 Inc(rank, delta) => PNCounterUpdate::inc(*rank, *delta),
1062 Dec(rank, delta) => PNCounterUpdate::dec(*rank, *delta),
1063 };
1064 accumulator.accumulate(&mut state, update).unwrap();
1065 assert_eq!(state.get(), *expected, "step {i}: {op:?}");
1066 }
1067
1068 assert_eq!(state.num_inc_ranks(), 3);
1070 assert_eq!(state.num_dec_ranks(), 3);
1071 }
1072
1073 #[test]
1074 fn test_pncounter_commutativity() {
1075 let updates = [
1077 PNCounterUpdate::inc(0, 10),
1078 PNCounterUpdate::inc(1, 20),
1079 PNCounterUpdate::dec(0, 5),
1080 PNCounterUpdate::inc(0, 15),
1081 PNCounterUpdate::dec(1, 8),
1082 PNCounterUpdate::dec(2, 3),
1083 PNCounterUpdate::inc(2, 12),
1084 ];
1085
1086 let accumulator = join_semilattice::<PNCounterUpdate>();
1088 let mut forward = PNCounterUpdate::default();
1089 for update in updates.iter().cloned() {
1090 accumulator.accumulate(&mut forward, update).unwrap();
1091 }
1092
1093 let mut reverse = PNCounterUpdate::default();
1095 for update in updates.iter().rev().cloned() {
1096 accumulator.accumulate(&mut reverse, update).unwrap();
1097 }
1098
1099 assert_eq!(forward.get(), reverse.get());
1100 assert_eq!(forward.get(), 31); assert_eq!(forward.num_inc_ranks(), reverse.num_inc_ranks());
1102 assert_eq!(forward.num_dec_ranks(), reverse.num_dec_ranks());
1103 }
1104}