hyperactor/
accum.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Defines the accumulator trait and some common accumulators.
10
11use std::collections::HashMap;
12use std::marker::PhantomData;
13use std::sync::OnceLock;
14use std::time::Duration;
15
16use algebra::JoinSemilattice;
17use enum_as_inner::EnumAsInner;
18use serde::Deserialize;
19use serde::Serialize;
20use serde::de::DeserializeOwned;
21use typeuri::Named;
22
23// for macros
24use crate::config;
25use crate::reference;
26
27/// An accumulator is a object that accumulates updates into a state.
28pub trait Accumulator {
29    /// The type of the accumulated state.
30    type State;
31    /// The type of the updates sent to the accumulator. Updates will be
32    /// accumulated into type [Self::State].
33    type Update;
34
35    /// Accumulate an update into the current state.
36    fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()>;
37
38    /// The specification used to build the reducer.
39    fn reducer_spec(&self) -> Option<ReducerSpec>;
40}
41
42/// Serializable information needed to build a comm reducer.
43#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, typeuri::Named)]
44pub struct ReducerSpec {
45    /// The typehash of the underlying [Self::Reducer] type.
46    pub typehash: u64,
47    /// The parameters used to build the reducer.
48    pub builder_params: Option<wirevalue::Any>,
49}
50wirevalue::register_type!(ReducerSpec);
51
52/// Options for streaming reducer mode.
53#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Default)]
54pub struct StreamingReducerOpts {
55    /// The maximum interval between updates. When unspecified, a default
56    /// interval is used.
57    pub max_update_interval: Option<Duration>,
58    /// The initial interval for the first update. When unspecified, defaults to 1ms.
59    /// This allows quick flushing of single messages while using exponential backoff
60    /// to reach max_update_interval for batched messages.
61    pub initial_update_interval: Option<Duration>,
62}
63
64/// The mode in which a reducer operates.
65#[derive(
66    Debug,
67    Clone,
68    PartialEq,
69    Serialize,
70    Deserialize,
71    EnumAsInner,
72    typeuri::Named
73)]
74pub enum ReducerMode {
75    /// Streaming mode: continuously reduce and emit updates based on buffer size/timeout.
76    Streaming(StreamingReducerOpts),
77    /// Once mode: accumulate exactly `n` values, emit a single reduced update, then tear down.
78    Once(usize),
79}
80
81impl Default for ReducerMode {
82    fn default() -> Self {
83        ReducerMode::Streaming(StreamingReducerOpts::default())
84    }
85}
86
87impl ReducerMode {
88    pub(crate) fn max_update_interval(&self) -> Duration {
89        match self {
90            ReducerMode::Streaming(opts) => opts
91                .max_update_interval
92                .unwrap_or(hyperactor_config::global::get(config::SPLIT_MAX_BUFFER_AGE)),
93            ReducerMode::Once(_) => Duration::MAX,
94        }
95    }
96
97    pub(crate) fn initial_update_interval(&self) -> Duration {
98        match self {
99            ReducerMode::Streaming(opts) => opts
100                .initial_update_interval
101                .unwrap_or(Duration::from_millis(1)),
102            ReducerMode::Once(_) => Duration::MAX,
103        }
104    }
105}
106
107/// Commutative reducer for an accumulator. This is used to coallesce updates.
108/// For example, if the accumulator is a sum, its reducer calculates and returns
109/// the sum of 2 updates. This is helpful in split ports, where a large number
110/// of updates can be reduced into a smaller number of updates before being sent
111/// to the parent port.
112pub trait CommReducer {
113    /// The type of updates to be reduced.
114    type Update;
115
116    /// Reduce 2 updates into a single update.
117    fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update>;
118}
119
120/// Type erased version of [CommReducer].
121pub trait ErasedCommReducer {
122    /// Reduce 2 updates into a single update.
123    fn reduce_erased(
124        &self,
125        left: &wirevalue::Any,
126        right: &wirevalue::Any,
127    ) -> anyhow::Result<wirevalue::Any>;
128
129    /// Reducer an non-empty vector of updates. Return Error if the vector is
130    /// empty.
131    fn reduce_updates(
132        &self,
133        updates: Vec<wirevalue::Any>,
134    ) -> Result<wirevalue::Any, (anyhow::Error, Vec<wirevalue::Any>)> {
135        if updates.is_empty() {
136            return Err((anyhow::anyhow!("empty updates"), updates));
137        }
138        if updates.len() == 1 {
139            return Ok(updates.into_iter().next().expect("checked above"));
140        }
141
142        let mut iter = updates.iter();
143        let first = iter.next().unwrap();
144        let second = iter.next().unwrap();
145        let init = match self.reduce_erased(first, second) {
146            Ok(v) => v,
147            Err(e) => return Err((e, updates)),
148        };
149        let reduced = match iter.try_fold(init, |acc, e| self.reduce_erased(&acc, e)) {
150            Ok(v) => v,
151            Err(e) => return Err((e, updates)),
152        };
153        Ok(reduced)
154    }
155
156    /// Typehash of the underlying [`CommReducer`] type.
157    fn typehash(&self) -> u64;
158}
159
160impl<R, T> ErasedCommReducer for R
161where
162    R: CommReducer<Update = T> + Named,
163    T: Serialize + DeserializeOwned + Named,
164{
165    fn reduce_erased(
166        &self,
167        left: &wirevalue::Any,
168        right: &wirevalue::Any,
169    ) -> anyhow::Result<wirevalue::Any> {
170        let left = left.deserialized::<T>()?;
171        let right = right.deserialized::<T>()?;
172        let result = self.reduce(left, right)?;
173        Ok(wirevalue::Any::serialize(&result)?)
174    }
175
176    fn typehash(&self) -> u64 {
177        R::typehash()
178    }
179}
180
181/// A factory for [`ErasedCommReducer`]s. This is used to register a
182/// [`ErasedCommReducer`] type. We cannot register [`ErasedCommReducer`] trait
183/// object directly because the object could have internal state, and cannot be
184/// shared.
185pub struct ReducerFactory {
186    /// Return the typehash of the [`ErasedCommReducer`] type built by this
187    /// factory.
188    pub typehash_f: fn() -> u64,
189    /// The builder function to build the [`ErasedCommReducer`] type.
190    pub builder_f: fn(
191        Option<wirevalue::Any>,
192    ) -> anyhow::Result<Box<dyn ErasedCommReducer + Sync + Send + 'static>>,
193}
194
195inventory::collect!(ReducerFactory);
196
197inventory::submit! {
198    ReducerFactory {
199        typehash_f: <SumReducer<i64> as Named>::typehash,
200        builder_f: |_| Ok(Box::new(SumReducer::<i64>(PhantomData))),
201    }
202}
203inventory::submit! {
204    ReducerFactory {
205        typehash_f: <SumReducer<u64> as Named>::typehash,
206        builder_f: |_| Ok(Box::new(SumReducer::<u64>(PhantomData))),
207    }
208}
209inventory::submit! {
210    ReducerFactory {
211        typehash_f: <SemilatticeReducer<Max<i64>> as Named>::typehash,
212        builder_f: |_| Ok(Box::new(SemilatticeReducer::<Max<i64>>(PhantomData))),
213    }
214}
215inventory::submit! {
216    ReducerFactory {
217        typehash_f: <SemilatticeReducer<Max<u64>> as Named>::typehash,
218        builder_f: |_| Ok(Box::new(SemilatticeReducer::<Max<u64>>(PhantomData))),
219    }
220}
221inventory::submit! {
222    ReducerFactory {
223        typehash_f: <SemilatticeReducer<Min<i64>> as Named>::typehash,
224        builder_f: |_| Ok(Box::new(SemilatticeReducer::<Min<i64>>(PhantomData))),
225    }
226}
227inventory::submit! {
228    ReducerFactory {
229        typehash_f: <SemilatticeReducer<Min<u64>> as Named>::typehash,
230        builder_f: |_| Ok(Box::new(SemilatticeReducer::<Min<u64>>(PhantomData))),
231    }
232}
233inventory::submit! {
234    ReducerFactory {
235        typehash_f: <SemilatticeReducer<WatermarkUpdate<i64>> as Named>::typehash,
236        builder_f: |_| Ok(Box::new(SemilatticeReducer::<WatermarkUpdate<i64>>(PhantomData))),
237    }
238}
239inventory::submit! {
240    ReducerFactory {
241        typehash_f: <SemilatticeReducer<WatermarkUpdate<u64>> as Named>::typehash,
242        builder_f: |_| Ok(Box::new(SemilatticeReducer::<WatermarkUpdate<u64>>(PhantomData))),
243    }
244}
245inventory::submit! {
246    ReducerFactory {
247        typehash_f: <SemilatticeReducer<GCounterUpdate> as Named>::typehash,
248        builder_f: |_| Ok(Box::new(SemilatticeReducer::<GCounterUpdate>(PhantomData))),
249    }
250}
251inventory::submit! {
252    ReducerFactory {
253        typehash_f: <SemilatticeReducer<PNCounterUpdate> as Named>::typehash,
254        builder_f: |_| Ok(Box::new(SemilatticeReducer::<PNCounterUpdate>(PhantomData))),
255    }
256}
257
258/// Build a reducer object with the given typehash's [CommReducer] type, and
259/// return the type-erased version of it.
260pub(crate) fn resolve_reducer(
261    typehash: u64,
262    builder_params: Option<wirevalue::Any>,
263) -> anyhow::Result<Option<Box<dyn ErasedCommReducer + Sync + Send + 'static>>> {
264    static FACTORY_MAP: OnceLock<HashMap<u64, &'static ReducerFactory>> = OnceLock::new();
265    let factories = FACTORY_MAP.get_or_init(|| {
266        let mut map = HashMap::new();
267        for factory in inventory::iter::<ReducerFactory> {
268            map.insert((factory.typehash_f)(), factory);
269        }
270        map
271    });
272
273    factories
274        .get(&typehash)
275        .map(|f| (f.builder_f)(builder_params))
276        .transpose()
277}
278
279#[derive(typeuri::Named)]
280struct SumReducer<T>(PhantomData<T>);
281
282impl<T: std::ops::Add<Output = T> + Copy + 'static> CommReducer for SumReducer<T> {
283    type Update = T;
284
285    fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
286        Ok(left + right)
287    }
288}
289
290/// Accumulate the sum of received updates. The inner function performs the
291/// summation between an update and the current state.
292struct SumAccumulator<T>(PhantomData<T>);
293
294impl<T: std::ops::Add<Output = T> + Copy + Named + 'static> Accumulator for SumAccumulator<T> {
295    type State = T;
296    type Update = T;
297
298    fn accumulate(&self, state: &mut T, update: T) -> anyhow::Result<()> {
299        *state = *state + update;
300        Ok(())
301    }
302
303    fn reducer_spec(&self) -> Option<ReducerSpec> {
304        Some(ReducerSpec {
305            typehash: <SumReducer<T> as Named>::typehash(),
306            builder_params: None,
307        })
308    }
309}
310
311/// Accumulate the sum of received updates.
312///
313/// # Note: Not a CRDT
314///
315/// This accumulator is *not idempotent* and is therefore *not
316/// suitable* for distributed scatter/gather patterns with
317/// at-least-once delivery semantics. Duplicate updates will be
318/// counted multiple times:
319///
320/// ```text
321/// sum(1, 2, 2, 3) = 8  (expected 6 if second 2 is duplicate)
322/// ```
323///
324/// ## When to use:
325/// - Single-source accumulation with exactly-once delivery
326/// - Local (non-distributed) aggregation
327/// - When upstream deduplication is guaranteed
328///
329/// ## CRDT Alternative:
330/// For distributed use cases, consider using a GCounter CRDT instead,
331/// which tracks per-replica increments and uses pointwise-max for
332/// merging (commutative, associative, and idempotent).
333///
334/// *See also*: [`Max`], [`Min`] (proper lattice-based CRDTs)
335pub fn sum<T: std::ops::Add<Output = T> + Copy + Named + 'static>()
336-> impl Accumulator<State = T, Update = T> {
337    SumAccumulator(PhantomData)
338}
339
340/// Generic reducer for any JoinSemilattice type.
341#[derive(typeuri::Named)]
342struct SemilatticeReducer<L>(PhantomData<L>);
343
344impl<L: JoinSemilattice + Clone> CommReducer for SemilatticeReducer<L> {
345    type Update = L;
346
347    fn reduce(&self, left: L, right: L) -> anyhow::Result<L> {
348        Ok(left.join(&right))
349    }
350}
351
352/// Generic accumulator for any JoinSemilattice type.
353struct SemilatticeAccumulator<L>(PhantomData<L>);
354
355impl<L: JoinSemilattice + Clone + Named + 'static> Accumulator for SemilatticeAccumulator<L> {
356    type State = L;
357    type Update = L;
358
359    fn accumulate(&self, state: &mut L, update: L) -> anyhow::Result<()> {
360        *state = state.join(&update);
361        Ok(())
362    }
363
364    fn reducer_spec(&self) -> Option<ReducerSpec> {
365        Some(ReducerSpec {
366            typehash: <SemilatticeReducer<L> as Named>::typehash(),
367            builder_params: None,
368        })
369    }
370}
371
372/// Create an accumulator for any JoinSemilattice type.
373///
374/// This is the primary way to create accumulators for lattice-based
375/// types like `Max<T>`, `Min<T>`, `GCounterUpdate`, `PNCounterUpdate`,
376/// and `WatermarkUpdate<T>`.
377///
378/// # Example
379///
380/// ```ignore
381/// use hyperactor::accum::{join_semilattice, Max};
382///
383/// let max_acc = join_semilattice::<Max<u64>>();
384/// ```
385pub fn join_semilattice<L: JoinSemilattice + Clone + Named + 'static>()
386-> impl Accumulator<State = L, Update = L> {
387    SemilatticeAccumulator::<L>(PhantomData)
388}
389
390/// Re-export Max from algebra.
391pub use algebra::Max;
392/// Re-export Min from algebra.
393pub use algebra::Min;
394
395/// Update from ranks for watermark accumulator using Last-Writer-Wins
396/// CRDT.
397///
398/// This is a proper CRDT that tracks the latest value from each rank
399/// using logical timestamps. When updates from the same rank are
400/// merged, the one with the higher timestamp wins. This allows ranks
401/// to report values that may decrease (e.g., during failure recovery)
402/// while maintaining proper commutativity and idempotence.
403///
404/// # CRDT Properties
405///
406/// - *Commutative*: Merge order doesn't matter (timestamps resolve
407///   conflicts)
408/// - *Idempotent*: Merging duplicate updates has no effect
409/// - *Convergent*: All replicas converge to the same state
410///
411/// # Watermark Semantics
412///
413/// The watermark is the minimum value across all ranks' *latest*
414/// reports. "Latest" is determined by logical timestamp, not arrival
415/// order.
416#[derive(Default, Debug, Clone, Serialize, Deserialize, typeuri::Named)]
417pub struct WatermarkUpdate<T>(algebra::LatticeMap<reference::Index, algebra::LWW<T>>);
418
419impl<T: Ord + Clone> WatermarkUpdate<T> {
420    /// Get the watermark value (minimum of all ranks' current values).
421    ///
422    /// WatermarkUpdate is guaranteed to be initialized by the accumulator
423    /// before it is sent to the user.
424    pub fn get(&self) -> &T {
425        self.0
426            .iter()
427            .map(|(_, lww)| &lww.value)
428            .min()
429            .expect("watermark should have been initialized")
430    }
431
432    /// Get the current value for a specific rank, if present.
433    pub fn get_rank(&self, rank: reference::Index) -> Option<&T> {
434        self.0.get(&rank).map(|lww| &lww.value)
435    }
436
437    /// Get the number of ranks currently tracked.
438    pub fn num_ranks(&self) -> usize {
439        self.0.len()
440    }
441}
442
443impl<T> From<(reference::Index, T, u64)> for WatermarkUpdate<T> {
444    /// Create a watermark update from (rank, value, timestamp).
445    ///
446    /// The timestamp should be a logical clock value (Lamport clock, sequence
447    /// number, or monotonic counter) that increases with each update from
448    /// the same rank.
449    fn from((rank, value, timestamp): (reference::Index, T, u64)) -> Self {
450        let mut map = algebra::LatticeMap::new();
451        // Use rank as replica ID - each rank is a unique writer
452        map.insert(rank, algebra::LWW::new(value, timestamp, rank as u64));
453        Self(map)
454    }
455}
456
457impl<T: Clone + PartialEq> JoinSemilattice for WatermarkUpdate<T> {
458    fn join(&self, other: &Self) -> Self {
459        WatermarkUpdate(self.0.join(&other.0))
460    }
461}
462
463/// State for a grow-only distributed counter (GCounter CRDT).
464///
465/// Each rank maintains its own count. The total value is the sum of
466/// all ranks' counts. Merge takes pointwise max.
467///
468/// # CRDT Properties
469///
470/// - *Commutative*: Merge order doesn't matter
471/// - *Associative*: Grouping doesn't matter
472/// - *Idempotent*: Merging duplicate updates has no effect
473/// - *Convergent*: All replicas converge to the same state
474#[derive(Default, Debug, Clone, Serialize, Deserialize, typeuri::Named)]
475pub struct GCounterUpdate(algebra::LatticeMap<reference::Index, Max<u64>>);
476wirevalue::register_type!(GCounterUpdate);
477
478impl GCounterUpdate {
479    /// Total counter value (sum of all ranks' counts).
480    pub fn get(&self) -> u64 {
481        self.0.iter().map(|(_, max)| max.0).sum()
482    }
483
484    /// Get count for a specific rank.
485    pub fn get_rank(&self, rank: reference::Index) -> Option<u64> {
486        self.0.get(&rank).map(|max| max.0)
487    }
488
489    /// Number of ranks that have contributed.
490    pub fn num_ranks(&self) -> usize {
491        self.0.len()
492    }
493}
494
495impl From<(reference::Index, u64)> for GCounterUpdate {
496    /// Create a GCounter update from (rank, count).
497    fn from((rank, count): (reference::Index, u64)) -> Self {
498        let mut map = algebra::LatticeMap::new();
499        map.insert(rank, Max(count));
500        Self(map)
501    }
502}
503
504impl JoinSemilattice for GCounterUpdate {
505    fn join(&self, other: &Self) -> Self {
506        GCounterUpdate(self.0.join(&other.0))
507    }
508}
509
510/// State for an increment/decrement distributed counter (PNCounter
511/// CRDT).
512///
513/// Internally uses two GCounters: one for increments (P), one for
514/// decrements (N). The value is P - N. Each is merged independently
515/// via pointwise max.
516#[derive(Default, Debug, Clone, Serialize, Deserialize, typeuri::Named)]
517pub struct PNCounterUpdate {
518    p: algebra::LatticeMap<reference::Index, Max<u64>>,
519    n: algebra::LatticeMap<reference::Index, Max<u64>>,
520}
521wirevalue::register_type!(PNCounterUpdate);
522
523impl PNCounterUpdate {
524    /// Counter value (sum of increments minus sum of decrements).
525    pub fn get(&self) -> i64 {
526        let p: u64 = self.p.iter().map(|(_, m)| m.0).sum();
527        let n: u64 = self.n.iter().map(|(_, m)| m.0).sum();
528        p as i64 - n as i64
529    }
530
531    /// Create an increment update for a rank.
532    pub fn inc(rank: reference::Index, delta: u64) -> Self {
533        let mut p = algebra::LatticeMap::new();
534        p.insert(rank, Max(delta));
535        Self {
536            p,
537            n: algebra::LatticeMap::new(),
538        }
539    }
540
541    /// Create a decrement update for a rank.
542    pub fn dec(rank: reference::Index, delta: u64) -> Self {
543        let mut n = algebra::LatticeMap::new();
544        n.insert(rank, Max(delta));
545        Self {
546            p: algebra::LatticeMap::new(),
547            n,
548        }
549    }
550
551    /// Number of ranks that have contributed increments.
552    pub fn num_inc_ranks(&self) -> usize {
553        self.p.len()
554    }
555
556    /// Number of ranks that have contributed decrements.
557    pub fn num_dec_ranks(&self) -> usize {
558        self.n.len()
559    }
560}
561
562impl JoinSemilattice for PNCounterUpdate {
563    fn join(&self, other: &Self) -> Self {
564        PNCounterUpdate {
565            p: self.p.join(&other.p),
566            n: self.n.join(&other.n),
567        }
568    }
569}
570
571#[cfg(test)]
572mod tests {
573    use std::fmt::Debug;
574
575    use maplit::hashmap;
576    use typeuri::Named;
577
578    use super::*;
579
580    fn serialize<T: Serialize + Named>(values: Vec<T>) -> Vec<wirevalue::Any> {
581        values
582            .into_iter()
583            .map(|n| wirevalue::Any::serialize(&n).unwrap())
584            .collect()
585    }
586
587    #[test]
588    fn test_comm_reducer_numeric() {
589        let u64_numbers_sum: Vec<_> = serialize(vec![1u64, 3u64, 1100u64]);
590        let i64_numbers_sum: Vec<_> = serialize(vec![-123i64, 33i64, 110i64]);
591        let u64_numbers_max: Vec<_> = serialize(vec![Max(1u64), Max(3u64), Max(1100u64)]);
592        let i64_numbers_max: Vec<_> = serialize(vec![Max(-123i64), Max(33i64), Max(110i64)]);
593        let u64_numbers_min: Vec<_> = serialize(vec![Min(1u64), Min(3u64), Min(1100u64)]);
594        let i64_numbers_min: Vec<_> = serialize(vec![Min(-123i64), Min(33i64), Min(110i64)]);
595        {
596            let typehash = <SemilatticeReducer<Max<u64>> as Named>::typehash();
597            assert_eq!(
598                resolve_reducer(typehash, None)
599                    .unwrap()
600                    .unwrap()
601                    .reduce_updates(u64_numbers_max.clone())
602                    .unwrap()
603                    .deserialized::<Max<u64>>()
604                    .unwrap(),
605                Max(1100u64),
606            );
607
608            let typehash = <SemilatticeReducer<Min<u64>> as Named>::typehash();
609            assert_eq!(
610                resolve_reducer(typehash, None)
611                    .unwrap()
612                    .unwrap()
613                    .reduce_updates(u64_numbers_min.clone())
614                    .unwrap()
615                    .deserialized::<Min<u64>>()
616                    .unwrap(),
617                Min(1u64),
618            );
619
620            let typehash = <SumReducer<u64> as Named>::typehash();
621            assert_eq!(
622                resolve_reducer(typehash, None)
623                    .unwrap()
624                    .unwrap()
625                    .reduce_updates(u64_numbers_sum)
626                    .unwrap()
627                    .deserialized::<u64>()
628                    .unwrap(),
629                1104u64,
630            );
631        }
632
633        {
634            let typehash = <SemilatticeReducer<Max<i64>> as Named>::typehash();
635            assert_eq!(
636                resolve_reducer(typehash, None)
637                    .unwrap()
638                    .unwrap()
639                    .reduce_updates(i64_numbers_max.clone())
640                    .unwrap()
641                    .deserialized::<Max<i64>>()
642                    .unwrap(),
643                Max(110i64),
644            );
645
646            let typehash = <SemilatticeReducer<Min<i64>> as Named>::typehash();
647            assert_eq!(
648                resolve_reducer(typehash, None)
649                    .unwrap()
650                    .unwrap()
651                    .reduce_updates(i64_numbers_min.clone())
652                    .unwrap()
653                    .deserialized::<Min<i64>>()
654                    .unwrap(),
655                Min(-123i64),
656            );
657
658            let typehash = <SumReducer<i64> as Named>::typehash();
659            assert_eq!(
660                resolve_reducer(typehash, None)
661                    .unwrap()
662                    .unwrap()
663                    .reduce_updates(i64_numbers_sum)
664                    .unwrap()
665                    .deserialized::<i64>()
666                    .unwrap(),
667                20i64,
668            );
669        }
670    }
671
672    #[test]
673    fn test_comm_reducer_watermark() {
674        // With LWW, we need timestamps. Assign in order of appearance.
675        let u64_updates = serialize::<WatermarkUpdate<u64>>(
676            vec![
677                (1, 1, 0),   // rank 1: value 1, ts 0
678                (0, 2, 1),   // rank 0: value 2, ts 1
679                (0, 1, 2),   // rank 0: value 1, ts 2 (later ts, wins over value 2)
680                (3, 35, 3),  // rank 3: value 35, ts 3
681                (0, 9, 4),   // rank 0: value 9, ts 4 (latest for rank 0)
682                (1, 10, 5),  // rank 1: value 10, ts 5 (latest for rank 1)
683                (3, 32, 6),  // rank 3: value 32, ts 6
684                (3, 0, 7),   // rank 3: value 0, ts 7
685                (3, 321, 8), // rank 3: value 321, ts 8 (latest for rank 3)
686            ]
687            .into_iter()
688            .map(|(k, v, ts)| WatermarkUpdate::from((k, v, ts)))
689            .collect(),
690        );
691        let i64_updates: Vec<_> = serialize::<WatermarkUpdate<i64>>(
692            vec![
693                (0, 2, 0),   // rank 0: value 2, ts 0
694                (1, 1, 1),   // rank 1: value 1, ts 1
695                (3, 35, 2),  // rank 3: value 35, ts 2
696                (0, 1, 3),   // rank 0: value 1, ts 3
697                (1, -10, 4), // rank 1: value -10, ts 4
698                (3, 32, 5),  // rank 3: value 32, ts 5
699                (3, 0, 6),   // rank 3: value 0, ts 6
700                (3, -99, 7), // rank 3: value -99, ts 7 (latest for rank 3)
701                (0, -9, 8),  // rank 0: value -9, ts 8 (latest for rank 0)
702            ]
703            .into_iter()
704            .map(WatermarkUpdate::from)
705            .collect(),
706        );
707
708        fn verify<T: Ord + Clone + PartialEq + DeserializeOwned + Debug + Named>(
709            updates: Vec<wirevalue::Any>,
710            expected: HashMap<reference::Index, T>,
711        ) {
712            let typehash = <SemilatticeReducer<WatermarkUpdate<T>> as Named>::typehash();
713            let result = resolve_reducer(typehash, None)
714                .unwrap()
715                .unwrap()
716                .reduce_updates(updates)
717                .unwrap()
718                .deserialized::<WatermarkUpdate<T>>()
719                .unwrap();
720
721            // Check each expected rank value
722            for (rank, expected_value) in &expected {
723                assert_eq!(
724                    result.get_rank(*rank).unwrap(),
725                    expected_value,
726                    "Mismatch for rank {rank}"
727                );
728            }
729            // Also verify no extra ranks
730            assert_eq!(result.num_ranks(), expected.len());
731        }
732
733        verify::<i64>(
734            i64_updates,
735            hashmap! {
736                0 => -9,   // latest ts for rank 0
737                1 => -10,  // latest ts for rank 1
738                3 => -99,  // latest ts for rank 3
739            },
740        );
741
742        verify::<u64>(
743            u64_updates,
744            hashmap! {
745                0 => 9,    // latest ts for rank 0
746                1 => 10,   // latest ts for rank 1
747                3 => 321,  // latest ts for rank 3
748            },
749        );
750    }
751
752    #[test]
753    fn test_accum_reducer_numeric() {
754        assert_eq!(
755            sum::<u64>().reducer_spec().unwrap().typehash,
756            <SumReducer::<u64> as Named>::typehash(),
757        );
758        assert_eq!(
759            sum::<i64>().reducer_spec().unwrap().typehash,
760            <SumReducer::<i64> as Named>::typehash(),
761        );
762
763        assert_eq!(
764            join_semilattice::<Min<u64>>()
765                .reducer_spec()
766                .unwrap()
767                .typehash,
768            <SemilatticeReducer<Min<u64>> as Named>::typehash(),
769        );
770        assert_eq!(
771            join_semilattice::<Min<i64>>()
772                .reducer_spec()
773                .unwrap()
774                .typehash,
775            <SemilatticeReducer<Min<i64>> as Named>::typehash(),
776        );
777
778        assert_eq!(
779            join_semilattice::<Max<u64>>()
780                .reducer_spec()
781                .unwrap()
782                .typehash,
783            <SemilatticeReducer<Max<u64>> as Named>::typehash(),
784        );
785        assert_eq!(
786            join_semilattice::<Max<i64>>()
787                .reducer_spec()
788                .unwrap()
789                .typehash,
790            <SemilatticeReducer<Max<i64>> as Named>::typehash(),
791        );
792    }
793
794    #[test]
795    fn test_accum_reducer_watermark() {
796        fn verify<T: Clone + PartialEq + Named + 'static>() {
797            assert_eq!(
798                join_semilattice::<WatermarkUpdate<T>>()
799                    .reducer_spec()
800                    .unwrap()
801                    .typehash,
802                <SemilatticeReducer<WatermarkUpdate<T>> as Named>::typehash(),
803            );
804        }
805        verify::<u64>();
806        verify::<i64>();
807    }
808
809    #[test]
810    fn test_watermark_accumulator() {
811        let accumulator = join_semilattice::<WatermarkUpdate<u64>>();
812        let ranks_values_expectations = [
813            // send in descending order (with timestamps 0, 1, 2)
814            (0, 1003, 0, 1003),
815            (1, 1002, 1, 1002),
816            (2, 1001, 2, 1001),
817            // send in ascending order (timestamps 3, 4, 5)
818            (0, 100, 3, 100),
819            (1, 101, 4, 100),
820            (2, 102, 5, 100),
821            // send same values (timestamps 6, 7, 8)
822            (0, 100, 6, 100),
823            (1, 101, 7, 100),
824            (2, 102, 8, 100),
825            // shuffle rank 0 to be largest, and make rank 1 smallest (timestamps 9, 10, 11)
826            (0, 1000, 9, 101),
827            // shuffle rank 1 to be largest, and make rank 2 smallest
828            (1, 1100, 10, 102),
829            // shuffle rank 2 to be largest, and make rank 0 smallest
830            (2, 1200, 11, 1000),
831            // Increase their value, but do not change their order (timestamps 12, 13, 14)
832            (0, 1001, 12, 1001),
833            (1, 1101, 13, 1001),
834            (2, 1201, 14, 1001),
835            // decrease their values (timestamps 15, 16, 17)
836            (2, 102, 15, 102),
837            (1, 101, 16, 101),
838            (0, 100, 17, 100),
839        ];
840        let mut state = WatermarkUpdate::default();
841        for (rank, value, ts, expected) in ranks_values_expectations {
842            accumulator
843                .accumulate(&mut state, WatermarkUpdate::from((rank, value, ts)))
844                .unwrap();
845            assert_eq!(
846                state.get(),
847                &expected,
848                "rank is {rank}; value is {value}; ts is {ts}"
849            );
850        }
851    }
852
853    #[test]
854    fn test_comm_reducer_gcounter() {
855        // Updates from different ranks
856        let updates = serialize::<GCounterUpdate>(vec![
857            GCounterUpdate::from((0, 10)),
858            GCounterUpdate::from((1, 20)),
859            GCounterUpdate::from((0, 15)), // rank 0 increases to 15
860            GCounterUpdate::from((2, 5)),
861            GCounterUpdate::from((1, 25)), // rank 1 increases to 25
862        ]);
863
864        let typehash = <SemilatticeReducer<GCounterUpdate> as Named>::typehash();
865        let result = resolve_reducer(typehash, None)
866            .unwrap()
867            .unwrap()
868            .reduce_updates(updates)
869            .unwrap()
870            .deserialized::<GCounterUpdate>()
871            .unwrap();
872
873        // Each rank should have its max value
874        assert_eq!(result.get_rank(0), Some(15));
875        assert_eq!(result.get_rank(1), Some(25));
876        assert_eq!(result.get_rank(2), Some(5));
877        assert_eq!(result.num_ranks(), 3);
878        // Total is sum of max values: 15 + 25 + 5 = 45
879        assert_eq!(result.get(), 45);
880    }
881
882    #[test]
883    fn test_accum_reducer_gcounter() {
884        assert_eq!(
885            join_semilattice::<GCounterUpdate>()
886                .reducer_spec()
887                .unwrap()
888                .typehash,
889            <SemilatticeReducer<GCounterUpdate> as Named>::typehash(),
890        );
891    }
892
893    #[test]
894    fn test_gcounter_accumulator() {
895        let accumulator = join_semilattice::<GCounterUpdate>();
896        // (rank, count, expected_total)
897        let ranks_counts_expectations: [(reference::Index, u64, u64); 17] = [
898            // initialize all 3 ranks in descending order
899            (0, 1000, 1000),
900            (1, 100, 1100),
901            (2, 10, 1110),
902            // increase in ascending order
903            (2, 20, 1120),
904            (1, 200, 1220),
905            (0, 2000, 2220),
906            // same values (idempotent - no change)
907            (0, 2000, 2220),
908            (1, 200, 2220),
909            (2, 20, 2220),
910            // lower values (ignored - max wins)
911            (0, 1, 2220),
912            (1, 1, 2220),
913            (2, 1, 2220),
914            // shuffle which rank has max: make rank 2 largest
915            (2, 5000, 7200), // 2000 + 200 + 5000
916            // make rank 1 largest
917            (1, 6000, 13000), // 2000 + 6000 + 5000
918            // make rank 0 largest again
919            (0, 10000, 21000), // 10000 + 6000 + 5000
920            // all ranks increase together
921            (0, 10001, 21001),
922            (1, 6001, 21002),
923        ];
924        let mut state = GCounterUpdate::default();
925        for (rank, count, expected) in ranks_counts_expectations {
926            accumulator
927                .accumulate(&mut state, GCounterUpdate::from((rank, count)))
928                .unwrap();
929            assert_eq!(state.get(), expected, "rank is {rank}; count is {count}");
930        }
931        // Verify final per-rank values
932        assert_eq!(state.get_rank(0), Some(10001));
933        assert_eq!(state.get_rank(1), Some(6001));
934        assert_eq!(state.get_rank(2), Some(5000));
935        assert_eq!(state.get_rank(3), None);
936        assert_eq!(state.num_ranks(), 3);
937    }
938
939    #[test]
940    fn test_gcounter_commutativity() {
941        // Verify that order of accumulation doesn't matter
942        let updates = [
943            GCounterUpdate::from((0, 10)),
944            GCounterUpdate::from((1, 20)),
945            GCounterUpdate::from((0, 15)),
946            GCounterUpdate::from((2, 5)),
947            GCounterUpdate::from((1, 25)),
948        ];
949
950        // Forward order
951        let accumulator = join_semilattice::<GCounterUpdate>();
952        let mut forward = GCounterUpdate::default();
953        for update in updates.iter().cloned() {
954            accumulator.accumulate(&mut forward, update).unwrap();
955        }
956
957        // Reverse order
958        let mut reverse = GCounterUpdate::default();
959        for update in updates.iter().rev().cloned() {
960            accumulator.accumulate(&mut reverse, update).unwrap();
961        }
962
963        assert_eq!(forward.get(), reverse.get());
964        assert_eq!(forward.get(), 45); // 15 + 25 + 5
965        assert_eq!(forward.get_rank(0), reverse.get_rank(0));
966        assert_eq!(forward.get_rank(1), reverse.get_rank(1));
967        assert_eq!(forward.get_rank(2), reverse.get_rank(2));
968    }
969
970    #[test]
971    fn test_comm_reducer_pncounter() {
972        // Updates from different ranks with increments and decrements
973        let updates = serialize::<PNCounterUpdate>(vec![
974            PNCounterUpdate::inc(0, 10),
975            PNCounterUpdate::inc(1, 20),
976            PNCounterUpdate::dec(0, 5),
977            PNCounterUpdate::inc(0, 15), // rank 0 inc increases to 15
978            PNCounterUpdate::dec(1, 8),
979            PNCounterUpdate::dec(0, 7), // rank 0 dec increases to 7
980        ]);
981
982        let typehash = <SemilatticeReducer<PNCounterUpdate> as Named>::typehash();
983        let result = resolve_reducer(typehash, None)
984            .unwrap()
985            .unwrap()
986            .reduce_updates(updates)
987            .unwrap()
988            .deserialized::<PNCounterUpdate>()
989            .unwrap();
990
991        // Each rank should have its max values for both inc and dec
992        // rank 0: inc=15, dec=7 -> contribution = 15-7 = 8
993        // rank 1: inc=20, dec=8 -> contribution = 20-8 = 12
994        // Total: 8 + 12 = 20
995        assert_eq!(result.get(), 20);
996        assert_eq!(result.num_inc_ranks(), 2);
997        assert_eq!(result.num_dec_ranks(), 2);
998    }
999
1000    #[test]
1001    fn test_accum_reducer_pncounter() {
1002        assert_eq!(
1003            join_semilattice::<PNCounterUpdate>()
1004                .reducer_spec()
1005                .unwrap()
1006                .typehash,
1007            <SemilatticeReducer<PNCounterUpdate> as Named>::typehash(),
1008        );
1009    }
1010
1011    #[test]
1012    fn test_pncounter_accumulator() {
1013        let accumulator = join_semilattice::<PNCounterUpdate>();
1014        // Helper to make updates clearer
1015        #[derive(Clone, Copy, Debug)]
1016        enum Op {
1017            Inc(reference::Index, u64),
1018            Dec(reference::Index, u64),
1019        }
1020        use Op::*;
1021
1022        // (operation, expected_total)
1023        // State tracked: p0, p1, p2 (increments), n0, n1, n2 (decrements)
1024        // Total = (p0 + p1 + p2) - (n0 + n1 + n2)
1025        let ops_expectations = [
1026            // initialize all 3 ranks with increments
1027            (Inc(0, 100), 100), // p: 100,0,0 n: 0,0,0 = 100
1028            (Inc(1, 50), 150),  // p: 100,50,0 n: 0,0,0 = 150
1029            (Inc(2, 25), 175),  // p: 100,50,25 n: 0,0,0 = 175
1030            // add decrements
1031            (Dec(0, 10), 165), // p: 100,50,25 n: 10,0,0 = 175-10 = 165
1032            (Dec(1, 5), 160),  // p: 100,50,25 n: 10,5,0 = 175-15 = 160
1033            (Dec(2, 2), 158),  // p: 100,50,25 n: 10,5,2 = 175-17 = 158
1034            // increase increments
1035            (Inc(0, 200), 258), // p: 200,50,25 n: 10,5,2 = 275-17 = 258
1036            (Inc(1, 100), 308), // p: 200,100,25 n: 10,5,2 = 325-17 = 308
1037            (Inc(2, 50), 333),  // p: 200,100,50 n: 10,5,2 = 350-17 = 333
1038            // increase decrements
1039            (Dec(0, 20), 323), // p: 200,100,50 n: 20,5,2 = 350-27 = 323
1040            (Dec(1, 15), 313), // p: 200,100,50 n: 20,15,2 = 350-37 = 313
1041            (Dec(2, 5), 310),  // p: 200,100,50 n: 20,15,5 = 350-40 = 310
1042            // duplicate updates (idempotent - no change)
1043            (Inc(0, 200), 310),
1044            (Dec(1, 15), 310),
1045            // lower values (ignored - max wins)
1046            (Inc(0, 1), 310),
1047            (Dec(0, 1), 310),
1048            // make decrements larger than increments for some ranks
1049            (Dec(2, 60), 255),  // p: 200,100,50 n: 20,15,60 = 350-95 = 255
1050            (Dec(1, 120), 150), // p: 200,100,50 n: 20,120,60 = 350-200 = 150
1051            // rank 1 now contributes negatively: 100 - 120 = -20
1052            (Inc(2, 60), 160), // p: 200,100,60 n: 20,120,60 = 360-200 = 160
1053            // shuffle: make rank 0 contribute most
1054            (Inc(0, 1000), 960), // p: 1000,100,60 n: 20,120,60 = 1160-200 = 960
1055            (Dec(2, 100), 920),  // p: 1000,100,60 n: 20,120,100 = 1160-240 = 920
1056        ];
1057
1058        let mut state = PNCounterUpdate::default();
1059        for (i, (op, expected)) in ops_expectations.iter().enumerate() {
1060            let update = match op {
1061                Inc(rank, delta) => PNCounterUpdate::inc(*rank, *delta),
1062                Dec(rank, delta) => PNCounterUpdate::dec(*rank, *delta),
1063            };
1064            accumulator.accumulate(&mut state, update).unwrap();
1065            assert_eq!(state.get(), *expected, "step {i}: {op:?}");
1066        }
1067
1068        // Verify final state
1069        assert_eq!(state.num_inc_ranks(), 3);
1070        assert_eq!(state.num_dec_ranks(), 3);
1071    }
1072
1073    #[test]
1074    fn test_pncounter_commutativity() {
1075        // Verify that order of accumulation doesn't matter
1076        let updates = [
1077            PNCounterUpdate::inc(0, 10),
1078            PNCounterUpdate::inc(1, 20),
1079            PNCounterUpdate::dec(0, 5),
1080            PNCounterUpdate::inc(0, 15),
1081            PNCounterUpdate::dec(1, 8),
1082            PNCounterUpdate::dec(2, 3),
1083            PNCounterUpdate::inc(2, 12),
1084        ];
1085
1086        // Forward order
1087        let accumulator = join_semilattice::<PNCounterUpdate>();
1088        let mut forward = PNCounterUpdate::default();
1089        for update in updates.iter().cloned() {
1090            accumulator.accumulate(&mut forward, update).unwrap();
1091        }
1092
1093        // Reverse order
1094        let mut reverse = PNCounterUpdate::default();
1095        for update in updates.iter().rev().cloned() {
1096            accumulator.accumulate(&mut reverse, update).unwrap();
1097        }
1098
1099        assert_eq!(forward.get(), reverse.get());
1100        assert_eq!(forward.get(), 31); // (15 + 20 + 12) - (5 + 8 + 3) = 47 - 16 = 31
1101        assert_eq!(forward.num_inc_ranks(), reverse.num_inc_ranks());
1102        assert_eq!(forward.num_dec_ranks(), reverse.num_dec_ranks());
1103    }
1104}