Skip to main content

hyperactor/
accum.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Defines the accumulator trait and some common accumulators.
10
11use std::collections::HashMap;
12use std::marker::PhantomData;
13use std::sync::OnceLock;
14use std::time::Duration;
15
16use algebra::JoinSemilattice;
17use enum_as_inner::EnumAsInner;
18use serde::Deserialize;
19use serde::Serialize;
20use serde::de::DeserializeOwned;
21use typeuri::Named;
22
23// for macros
24use crate::config;
25
26/// An accumulator is a object that accumulates updates into a state.
27pub trait Accumulator {
28    /// The type of the accumulated state.
29    type State;
30    /// The type of the updates sent to the accumulator. Updates will be
31    /// accumulated into type [Self::State].
32    type Update;
33
34    /// Accumulate an update into the current state.
35    fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()>;
36
37    /// The specification used to build the reducer.
38    fn reducer_spec(&self) -> Option<ReducerSpec>;
39}
40
41/// Serializable information needed to build a comm reducer.
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, typeuri::Named)]
43pub struct ReducerSpec {
44    /// The typehash of the underlying [Self::Reducer] type.
45    pub typehash: u64,
46    /// The parameters used to build the reducer.
47    pub builder_params: Option<wirevalue::Any>,
48}
49wirevalue::register_type!(ReducerSpec);
50
51/// Options for streaming reducer mode.
52#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Default)]
53pub struct StreamingReducerOpts {
54    /// The maximum interval between updates. When unspecified, a default
55    /// interval is used.
56    pub max_update_interval: Option<Duration>,
57    /// The initial interval for the first update. When unspecified, defaults to 1ms.
58    /// This allows quick flushing of single messages while using exponential backoff
59    /// to reach max_update_interval for batched messages.
60    pub initial_update_interval: Option<Duration>,
61}
62
63/// The mode in which a reducer operates.
64#[derive(
65    Debug,
66    Clone,
67    PartialEq,
68    Serialize,
69    Deserialize,
70    EnumAsInner,
71    typeuri::Named
72)]
73pub enum ReducerMode {
74    /// Streaming mode: continuously reduce and emit updates based on buffer size/timeout.
75    Streaming(StreamingReducerOpts),
76    /// Once mode: accumulate exactly `n` values, emit a single reduced update, then tear down.
77    Once(usize),
78}
79
80impl Default for ReducerMode {
81    fn default() -> Self {
82        ReducerMode::Streaming(StreamingReducerOpts::default())
83    }
84}
85
86impl ReducerMode {
87    pub(crate) fn max_update_interval(&self) -> Duration {
88        match self {
89            ReducerMode::Streaming(opts) => opts
90                .max_update_interval
91                .unwrap_or(hyperactor_config::global::get(config::SPLIT_MAX_BUFFER_AGE)),
92            ReducerMode::Once(_) => Duration::MAX,
93        }
94    }
95
96    pub(crate) fn initial_update_interval(&self) -> Duration {
97        match self {
98            ReducerMode::Streaming(opts) => opts
99                .initial_update_interval
100                .unwrap_or(Duration::from_millis(1)),
101            ReducerMode::Once(_) => Duration::MAX,
102        }
103    }
104}
105
106/// Commutative reducer for an accumulator. This is used to coallesce updates.
107/// For example, if the accumulator is a sum, its reducer calculates and returns
108/// the sum of 2 updates. This is helpful in split ports, where a large number
109/// of updates can be reduced into a smaller number of updates before being sent
110/// to the parent port.
111pub trait CommReducer {
112    /// The type of updates to be reduced.
113    type Update;
114
115    /// Reduce 2 updates into a single update.
116    fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update>;
117}
118
119/// Type erased version of [CommReducer].
120pub trait ErasedCommReducer {
121    /// Reduce 2 updates into a single update.
122    fn reduce_erased(
123        &self,
124        left: &wirevalue::Any,
125        right: &wirevalue::Any,
126    ) -> anyhow::Result<wirevalue::Any>;
127
128    /// Reducer an non-empty vector of updates. Return Error if the vector is
129    /// empty.
130    fn reduce_updates(
131        &self,
132        updates: Vec<wirevalue::Any>,
133    ) -> Result<wirevalue::Any, (anyhow::Error, Vec<wirevalue::Any>)> {
134        if updates.is_empty() {
135            return Err((anyhow::anyhow!("empty updates"), updates));
136        }
137        if updates.len() == 1 {
138            return Ok(updates.into_iter().next().expect("checked above"));
139        }
140
141        let mut iter = updates.iter();
142        let first = iter.next().unwrap();
143        let second = iter.next().unwrap();
144        let init = match self.reduce_erased(first, second) {
145            Ok(v) => v,
146            Err(e) => return Err((e, updates)),
147        };
148        let reduced = match iter.try_fold(init, |acc, e| self.reduce_erased(&acc, e)) {
149            Ok(v) => v,
150            Err(e) => return Err((e, updates)),
151        };
152        Ok(reduced)
153    }
154
155    /// Typehash of the underlying [`CommReducer`] type.
156    fn typehash(&self) -> u64;
157}
158
159impl<R, T> ErasedCommReducer for R
160where
161    R: CommReducer<Update = T> + Named,
162    T: Serialize + DeserializeOwned + Named,
163{
164    fn reduce_erased(
165        &self,
166        left: &wirevalue::Any,
167        right: &wirevalue::Any,
168    ) -> anyhow::Result<wirevalue::Any> {
169        let left = left.deserialized::<T>()?;
170        let right = right.deserialized::<T>()?;
171        let result = self.reduce(left, right)?;
172        Ok(wirevalue::Any::serialize(&result)?)
173    }
174
175    fn typehash(&self) -> u64 {
176        R::typehash()
177    }
178}
179
180/// A factory for [`ErasedCommReducer`]s. This is used to register a
181/// [`ErasedCommReducer`] type. We cannot register [`ErasedCommReducer`] trait
182/// object directly because the object could have internal state, and cannot be
183/// shared.
184pub struct ReducerFactory {
185    /// Return the typehash of the [`ErasedCommReducer`] type built by this
186    /// factory.
187    pub typehash_f: fn() -> u64,
188    /// The builder function to build the [`ErasedCommReducer`] type.
189    pub builder_f: fn(
190        Option<wirevalue::Any>,
191    ) -> anyhow::Result<Box<dyn ErasedCommReducer + Sync + Send + 'static>>,
192}
193
194inventory::collect!(ReducerFactory);
195
196inventory::submit! {
197    ReducerFactory {
198        typehash_f: <SumReducer<i64> as Named>::typehash,
199        builder_f: |_| Ok(Box::new(SumReducer::<i64>(PhantomData))),
200    }
201}
202inventory::submit! {
203    ReducerFactory {
204        typehash_f: <SumReducer<u64> as Named>::typehash,
205        builder_f: |_| Ok(Box::new(SumReducer::<u64>(PhantomData))),
206    }
207}
208inventory::submit! {
209    ReducerFactory {
210        typehash_f: <SemilatticeReducer<Max<i64>> as Named>::typehash,
211        builder_f: |_| Ok(Box::new(SemilatticeReducer::<Max<i64>>(PhantomData))),
212    }
213}
214inventory::submit! {
215    ReducerFactory {
216        typehash_f: <SemilatticeReducer<Max<u64>> as Named>::typehash,
217        builder_f: |_| Ok(Box::new(SemilatticeReducer::<Max<u64>>(PhantomData))),
218    }
219}
220inventory::submit! {
221    ReducerFactory {
222        typehash_f: <SemilatticeReducer<Min<i64>> as Named>::typehash,
223        builder_f: |_| Ok(Box::new(SemilatticeReducer::<Min<i64>>(PhantomData))),
224    }
225}
226inventory::submit! {
227    ReducerFactory {
228        typehash_f: <SemilatticeReducer<Min<u64>> as Named>::typehash,
229        builder_f: |_| Ok(Box::new(SemilatticeReducer::<Min<u64>>(PhantomData))),
230    }
231}
232inventory::submit! {
233    ReducerFactory {
234        typehash_f: <SemilatticeReducer<WatermarkUpdate<i64>> as Named>::typehash,
235        builder_f: |_| Ok(Box::new(SemilatticeReducer::<WatermarkUpdate<i64>>(PhantomData))),
236    }
237}
238inventory::submit! {
239    ReducerFactory {
240        typehash_f: <SemilatticeReducer<WatermarkUpdate<u64>> as Named>::typehash,
241        builder_f: |_| Ok(Box::new(SemilatticeReducer::<WatermarkUpdate<u64>>(PhantomData))),
242    }
243}
244inventory::submit! {
245    ReducerFactory {
246        typehash_f: <SemilatticeReducer<GCounterUpdate> as Named>::typehash,
247        builder_f: |_| Ok(Box::new(SemilatticeReducer::<GCounterUpdate>(PhantomData))),
248    }
249}
250inventory::submit! {
251    ReducerFactory {
252        typehash_f: <SemilatticeReducer<PNCounterUpdate> as Named>::typehash,
253        builder_f: |_| Ok(Box::new(SemilatticeReducer::<PNCounterUpdate>(PhantomData))),
254    }
255}
256
257/// Build a reducer object with the given typehash's [CommReducer] type, and
258/// return the type-erased version of it.
259pub(crate) fn resolve_reducer(
260    typehash: u64,
261    builder_params: Option<wirevalue::Any>,
262) -> anyhow::Result<Option<Box<dyn ErasedCommReducer + Sync + Send + 'static>>> {
263    static FACTORY_MAP: OnceLock<HashMap<u64, &'static ReducerFactory>> = OnceLock::new();
264    let factories = FACTORY_MAP.get_or_init(|| {
265        let mut map = HashMap::new();
266        for factory in inventory::iter::<ReducerFactory> {
267            map.insert((factory.typehash_f)(), factory);
268        }
269        map
270    });
271
272    factories
273        .get(&typehash)
274        .map(|f| (f.builder_f)(builder_params))
275        .transpose()
276}
277
278#[derive(typeuri::Named)]
279struct SumReducer<T>(PhantomData<T>);
280
281impl<T: std::ops::Add<Output = T> + Copy + 'static> CommReducer for SumReducer<T> {
282    type Update = T;
283
284    fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
285        Ok(left + right)
286    }
287}
288
289/// Accumulate the sum of received updates. The inner function performs the
290/// summation between an update and the current state.
291struct SumAccumulator<T>(PhantomData<T>);
292
293impl<T: std::ops::Add<Output = T> + Copy + Named + 'static> Accumulator for SumAccumulator<T> {
294    type State = T;
295    type Update = T;
296
297    fn accumulate(&self, state: &mut T, update: T) -> anyhow::Result<()> {
298        *state = *state + update;
299        Ok(())
300    }
301
302    fn reducer_spec(&self) -> Option<ReducerSpec> {
303        Some(ReducerSpec {
304            typehash: <SumReducer<T> as Named>::typehash(),
305            builder_params: None,
306        })
307    }
308}
309
310/// Accumulate the sum of received updates.
311///
312/// # Note: Not a CRDT
313///
314/// This accumulator is *not idempotent* and is therefore *not
315/// suitable* for distributed scatter/gather patterns with
316/// at-least-once delivery semantics. Duplicate updates will be
317/// counted multiple times:
318///
319/// ```text
320/// sum(1, 2, 2, 3) = 8  (expected 6 if second 2 is duplicate)
321/// ```
322///
323/// ## When to use:
324/// - Single-source accumulation with exactly-once delivery
325/// - Local (non-distributed) aggregation
326/// - When upstream deduplication is guaranteed
327///
328/// ## CRDT Alternative:
329/// For distributed use cases, consider using a GCounter CRDT instead,
330/// which tracks per-replica increments and uses pointwise-max for
331/// merging (commutative, associative, and idempotent).
332///
333/// *See also*: [`Max`], [`Min`] (proper lattice-based CRDTs)
334pub fn sum<T: std::ops::Add<Output = T> + Copy + Named + 'static>()
335-> impl Accumulator<State = T, Update = T> {
336    SumAccumulator(PhantomData)
337}
338
339/// Generic reducer for any JoinSemilattice type.
340#[derive(typeuri::Named)]
341struct SemilatticeReducer<L>(PhantomData<L>);
342
343impl<L: JoinSemilattice + Clone> CommReducer for SemilatticeReducer<L> {
344    type Update = L;
345
346    fn reduce(&self, left: L, right: L) -> anyhow::Result<L> {
347        Ok(left.join(&right))
348    }
349}
350
351/// Generic accumulator for any JoinSemilattice type.
352struct SemilatticeAccumulator<L>(PhantomData<L>);
353
354impl<L: JoinSemilattice + Clone + Named + 'static> Accumulator for SemilatticeAccumulator<L> {
355    type State = L;
356    type Update = L;
357
358    fn accumulate(&self, state: &mut L, update: L) -> anyhow::Result<()> {
359        *state = state.join(&update);
360        Ok(())
361    }
362
363    fn reducer_spec(&self) -> Option<ReducerSpec> {
364        Some(ReducerSpec {
365            typehash: <SemilatticeReducer<L> as Named>::typehash(),
366            builder_params: None,
367        })
368    }
369}
370
371/// Create an accumulator for any JoinSemilattice type.
372///
373/// This is the primary way to create accumulators for lattice-based
374/// types like `Max<T>`, `Min<T>`, `GCounterUpdate`, `PNCounterUpdate`,
375/// and `WatermarkUpdate<T>`.
376///
377/// # Example
378///
379/// ```ignore
380/// use hyperactor::accum::{join_semilattice, Max};
381///
382/// let max_acc = join_semilattice::<Max<u64>>();
383/// ```
384pub fn join_semilattice<L: JoinSemilattice + Clone + Named + 'static>()
385-> impl Accumulator<State = L, Update = L> {
386    SemilatticeAccumulator::<L>(PhantomData)
387}
388
389/// Re-export Max from algebra.
390pub use algebra::Max;
391/// Re-export Min from algebra.
392pub use algebra::Min;
393
394/// Update from ranks for watermark accumulator using Last-Writer-Wins
395/// CRDT.
396///
397/// This is a proper CRDT that tracks the latest value from each rank
398/// using logical timestamps. When updates from the same rank are
399/// merged, the one with the higher timestamp wins. This allows ranks
400/// to report values that may decrease (e.g., during failure recovery)
401/// while maintaining proper commutativity and idempotence.
402///
403/// # CRDT Properties
404///
405/// - *Commutative*: Merge order doesn't matter (timestamps resolve
406///   conflicts)
407/// - *Idempotent*: Merging duplicate updates has no effect
408/// - *Convergent*: All replicas converge to the same state
409///
410/// # Watermark Semantics
411///
412/// The watermark is the minimum value across all ranks' *latest*
413/// reports. "Latest" is determined by logical timestamp, not arrival
414/// order.
415#[derive(Default, Debug, Clone, Serialize, Deserialize, typeuri::Named)]
416pub struct WatermarkUpdate<T>(algebra::LatticeMap<usize, algebra::LWW<T>>);
417
418impl<T: Ord + Clone> WatermarkUpdate<T> {
419    /// Get the watermark value (minimum of all ranks' current values).
420    ///
421    /// WatermarkUpdate is guaranteed to be initialized by the accumulator
422    /// before it is sent to the user.
423    pub fn get(&self) -> &T {
424        self.0
425            .iter()
426            .map(|(_, lww)| &lww.value)
427            .min()
428            .expect("watermark should have been initialized")
429    }
430
431    /// Get the current value for a specific rank, if present.
432    pub fn get_rank(&self, rank: usize) -> Option<&T> {
433        self.0.get(&rank).map(|lww| &lww.value)
434    }
435
436    /// Get the number of ranks currently tracked.
437    pub fn num_ranks(&self) -> usize {
438        self.0.len()
439    }
440}
441
442impl<T> From<(usize, T, u64)> for WatermarkUpdate<T> {
443    /// Create a watermark update from (rank, value, timestamp).
444    ///
445    /// The timestamp should be a logical clock value (Lamport clock, sequence
446    /// number, or monotonic counter) that increases with each update from
447    /// the same rank.
448    fn from((rank, value, timestamp): (usize, T, u64)) -> Self {
449        let mut map = algebra::LatticeMap::new();
450        // Use rank as replica ID - each rank is a unique writer
451        map.insert(rank, algebra::LWW::new(value, timestamp, rank as u64));
452        Self(map)
453    }
454}
455
456impl<T: Clone + PartialEq> JoinSemilattice for WatermarkUpdate<T> {
457    fn join(&self, other: &Self) -> Self {
458        WatermarkUpdate(self.0.join(&other.0))
459    }
460}
461
462/// State for a grow-only distributed counter (GCounter CRDT).
463///
464/// Each rank maintains its own count. The total value is the sum of
465/// all ranks' counts. Merge takes pointwise max.
466///
467/// # CRDT Properties
468///
469/// - *Commutative*: Merge order doesn't matter
470/// - *Associative*: Grouping doesn't matter
471/// - *Idempotent*: Merging duplicate updates has no effect
472/// - *Convergent*: All replicas converge to the same state
473#[derive(Default, Debug, Clone, Serialize, Deserialize, typeuri::Named)]
474pub struct GCounterUpdate(algebra::LatticeMap<usize, Max<u64>>);
475wirevalue::register_type!(GCounterUpdate);
476
477impl GCounterUpdate {
478    /// Total counter value (sum of all ranks' counts).
479    pub fn get(&self) -> u64 {
480        self.0.iter().map(|(_, max)| max.0).sum()
481    }
482
483    /// Get count for a specific rank.
484    pub fn get_rank(&self, rank: usize) -> Option<u64> {
485        self.0.get(&rank).map(|max| max.0)
486    }
487
488    /// Number of ranks that have contributed.
489    pub fn num_ranks(&self) -> usize {
490        self.0.len()
491    }
492}
493
494impl From<(usize, u64)> for GCounterUpdate {
495    /// Create a GCounter update from (rank, count).
496    fn from((rank, count): (usize, u64)) -> Self {
497        let mut map = algebra::LatticeMap::new();
498        map.insert(rank, Max(count));
499        Self(map)
500    }
501}
502
503impl JoinSemilattice for GCounterUpdate {
504    fn join(&self, other: &Self) -> Self {
505        GCounterUpdate(self.0.join(&other.0))
506    }
507}
508
509/// State for an increment/decrement distributed counter (PNCounter
510/// CRDT).
511///
512/// Internally uses two GCounters: one for increments (P), one for
513/// decrements (N). The value is P - N. Each is merged independently
514/// via pointwise max.
515#[derive(Default, Debug, Clone, Serialize, Deserialize, typeuri::Named)]
516pub struct PNCounterUpdate {
517    p: algebra::LatticeMap<usize, Max<u64>>,
518    n: algebra::LatticeMap<usize, Max<u64>>,
519}
520wirevalue::register_type!(PNCounterUpdate);
521
522impl PNCounterUpdate {
523    /// Counter value (sum of increments minus sum of decrements).
524    pub fn get(&self) -> i64 {
525        let p: u64 = self.p.iter().map(|(_, m)| m.0).sum();
526        let n: u64 = self.n.iter().map(|(_, m)| m.0).sum();
527        p as i64 - n as i64
528    }
529
530    /// Create an increment update for a rank.
531    pub fn inc(rank: usize, delta: u64) -> Self {
532        let mut p = algebra::LatticeMap::new();
533        p.insert(rank, Max(delta));
534        Self {
535            p,
536            n: algebra::LatticeMap::new(),
537        }
538    }
539
540    /// Create a decrement update for a rank.
541    pub fn dec(rank: usize, delta: u64) -> Self {
542        let mut n = algebra::LatticeMap::new();
543        n.insert(rank, Max(delta));
544        Self {
545            p: algebra::LatticeMap::new(),
546            n,
547        }
548    }
549
550    /// Number of ranks that have contributed increments.
551    pub fn num_inc_ranks(&self) -> usize {
552        self.p.len()
553    }
554
555    /// Number of ranks that have contributed decrements.
556    pub fn num_dec_ranks(&self) -> usize {
557        self.n.len()
558    }
559}
560
561impl JoinSemilattice for PNCounterUpdate {
562    fn join(&self, other: &Self) -> Self {
563        PNCounterUpdate {
564            p: self.p.join(&other.p),
565            n: self.n.join(&other.n),
566        }
567    }
568}
569
570#[cfg(test)]
571mod tests {
572    use std::fmt::Debug;
573
574    use maplit::hashmap;
575    use typeuri::Named;
576
577    use super::*;
578
579    fn serialize<T: Serialize + Named>(values: Vec<T>) -> Vec<wirevalue::Any> {
580        values
581            .into_iter()
582            .map(|n| wirevalue::Any::serialize(&n).unwrap())
583            .collect()
584    }
585
586    #[test]
587    fn test_comm_reducer_numeric() {
588        let u64_numbers_sum: Vec<_> = serialize(vec![1u64, 3u64, 1100u64]);
589        let i64_numbers_sum: Vec<_> = serialize(vec![-123i64, 33i64, 110i64]);
590        let u64_numbers_max: Vec<_> = serialize(vec![Max(1u64), Max(3u64), Max(1100u64)]);
591        let i64_numbers_max: Vec<_> = serialize(vec![Max(-123i64), Max(33i64), Max(110i64)]);
592        let u64_numbers_min: Vec<_> = serialize(vec![Min(1u64), Min(3u64), Min(1100u64)]);
593        let i64_numbers_min: Vec<_> = serialize(vec![Min(-123i64), Min(33i64), Min(110i64)]);
594        {
595            let typehash = <SemilatticeReducer<Max<u64>> as Named>::typehash();
596            assert_eq!(
597                resolve_reducer(typehash, None)
598                    .unwrap()
599                    .unwrap()
600                    .reduce_updates(u64_numbers_max.clone())
601                    .unwrap()
602                    .deserialized::<Max<u64>>()
603                    .unwrap(),
604                Max(1100u64),
605            );
606
607            let typehash = <SemilatticeReducer<Min<u64>> as Named>::typehash();
608            assert_eq!(
609                resolve_reducer(typehash, None)
610                    .unwrap()
611                    .unwrap()
612                    .reduce_updates(u64_numbers_min.clone())
613                    .unwrap()
614                    .deserialized::<Min<u64>>()
615                    .unwrap(),
616                Min(1u64),
617            );
618
619            let typehash = <SumReducer<u64> as Named>::typehash();
620            assert_eq!(
621                resolve_reducer(typehash, None)
622                    .unwrap()
623                    .unwrap()
624                    .reduce_updates(u64_numbers_sum)
625                    .unwrap()
626                    .deserialized::<u64>()
627                    .unwrap(),
628                1104u64,
629            );
630        }
631
632        {
633            let typehash = <SemilatticeReducer<Max<i64>> as Named>::typehash();
634            assert_eq!(
635                resolve_reducer(typehash, None)
636                    .unwrap()
637                    .unwrap()
638                    .reduce_updates(i64_numbers_max.clone())
639                    .unwrap()
640                    .deserialized::<Max<i64>>()
641                    .unwrap(),
642                Max(110i64),
643            );
644
645            let typehash = <SemilatticeReducer<Min<i64>> as Named>::typehash();
646            assert_eq!(
647                resolve_reducer(typehash, None)
648                    .unwrap()
649                    .unwrap()
650                    .reduce_updates(i64_numbers_min.clone())
651                    .unwrap()
652                    .deserialized::<Min<i64>>()
653                    .unwrap(),
654                Min(-123i64),
655            );
656
657            let typehash = <SumReducer<i64> as Named>::typehash();
658            assert_eq!(
659                resolve_reducer(typehash, None)
660                    .unwrap()
661                    .unwrap()
662                    .reduce_updates(i64_numbers_sum)
663                    .unwrap()
664                    .deserialized::<i64>()
665                    .unwrap(),
666                20i64,
667            );
668        }
669    }
670
671    #[test]
672    fn test_comm_reducer_watermark() {
673        // With LWW, we need timestamps. Assign in order of appearance.
674        let u64_updates = serialize::<WatermarkUpdate<u64>>(
675            vec![
676                (1, 1, 0),   // rank 1: value 1, ts 0
677                (0, 2, 1),   // rank 0: value 2, ts 1
678                (0, 1, 2),   // rank 0: value 1, ts 2 (later ts, wins over value 2)
679                (3, 35, 3),  // rank 3: value 35, ts 3
680                (0, 9, 4),   // rank 0: value 9, ts 4 (latest for rank 0)
681                (1, 10, 5),  // rank 1: value 10, ts 5 (latest for rank 1)
682                (3, 32, 6),  // rank 3: value 32, ts 6
683                (3, 0, 7),   // rank 3: value 0, ts 7
684                (3, 321, 8), // rank 3: value 321, ts 8 (latest for rank 3)
685            ]
686            .into_iter()
687            .map(|(k, v, ts)| WatermarkUpdate::from((k, v, ts)))
688            .collect(),
689        );
690        let i64_updates: Vec<_> = serialize::<WatermarkUpdate<i64>>(
691            vec![
692                (0, 2, 0),   // rank 0: value 2, ts 0
693                (1, 1, 1),   // rank 1: value 1, ts 1
694                (3, 35, 2),  // rank 3: value 35, ts 2
695                (0, 1, 3),   // rank 0: value 1, ts 3
696                (1, -10, 4), // rank 1: value -10, ts 4
697                (3, 32, 5),  // rank 3: value 32, ts 5
698                (3, 0, 6),   // rank 3: value 0, ts 6
699                (3, -99, 7), // rank 3: value -99, ts 7 (latest for rank 3)
700                (0, -9, 8),  // rank 0: value -9, ts 8 (latest for rank 0)
701            ]
702            .into_iter()
703            .map(WatermarkUpdate::from)
704            .collect(),
705        );
706
707        fn verify<T: Ord + Clone + PartialEq + DeserializeOwned + Debug + Named>(
708            updates: Vec<wirevalue::Any>,
709            expected: HashMap<usize, T>,
710        ) {
711            let typehash = <SemilatticeReducer<WatermarkUpdate<T>> as Named>::typehash();
712            let result = resolve_reducer(typehash, None)
713                .unwrap()
714                .unwrap()
715                .reduce_updates(updates)
716                .unwrap()
717                .deserialized::<WatermarkUpdate<T>>()
718                .unwrap();
719
720            // Check each expected rank value
721            for (rank, expected_value) in &expected {
722                assert_eq!(
723                    result.get_rank(*rank).unwrap(),
724                    expected_value,
725                    "Mismatch for rank {rank}"
726                );
727            }
728            // Also verify no extra ranks
729            assert_eq!(result.num_ranks(), expected.len());
730        }
731
732        verify::<i64>(
733            i64_updates,
734            hashmap! {
735                0 => -9,   // latest ts for rank 0
736                1 => -10,  // latest ts for rank 1
737                3 => -99,  // latest ts for rank 3
738            },
739        );
740
741        verify::<u64>(
742            u64_updates,
743            hashmap! {
744                0 => 9,    // latest ts for rank 0
745                1 => 10,   // latest ts for rank 1
746                3 => 321,  // latest ts for rank 3
747            },
748        );
749    }
750
751    #[test]
752    fn test_accum_reducer_numeric() {
753        assert_eq!(
754            sum::<u64>().reducer_spec().unwrap().typehash,
755            <SumReducer::<u64> as Named>::typehash(),
756        );
757        assert_eq!(
758            sum::<i64>().reducer_spec().unwrap().typehash,
759            <SumReducer::<i64> as Named>::typehash(),
760        );
761
762        assert_eq!(
763            join_semilattice::<Min<u64>>()
764                .reducer_spec()
765                .unwrap()
766                .typehash,
767            <SemilatticeReducer<Min<u64>> as Named>::typehash(),
768        );
769        assert_eq!(
770            join_semilattice::<Min<i64>>()
771                .reducer_spec()
772                .unwrap()
773                .typehash,
774            <SemilatticeReducer<Min<i64>> as Named>::typehash(),
775        );
776
777        assert_eq!(
778            join_semilattice::<Max<u64>>()
779                .reducer_spec()
780                .unwrap()
781                .typehash,
782            <SemilatticeReducer<Max<u64>> as Named>::typehash(),
783        );
784        assert_eq!(
785            join_semilattice::<Max<i64>>()
786                .reducer_spec()
787                .unwrap()
788                .typehash,
789            <SemilatticeReducer<Max<i64>> as Named>::typehash(),
790        );
791    }
792
793    #[test]
794    fn test_accum_reducer_watermark() {
795        fn verify<T: Clone + PartialEq + Named + 'static>() {
796            assert_eq!(
797                join_semilattice::<WatermarkUpdate<T>>()
798                    .reducer_spec()
799                    .unwrap()
800                    .typehash,
801                <SemilatticeReducer<WatermarkUpdate<T>> as Named>::typehash(),
802            );
803        }
804        verify::<u64>();
805        verify::<i64>();
806    }
807
808    #[test]
809    fn test_watermark_accumulator() {
810        let accumulator = join_semilattice::<WatermarkUpdate<u64>>();
811        let ranks_values_expectations = [
812            // send in descending order (with timestamps 0, 1, 2)
813            (0, 1003, 0, 1003),
814            (1, 1002, 1, 1002),
815            (2, 1001, 2, 1001),
816            // send in ascending order (timestamps 3, 4, 5)
817            (0, 100, 3, 100),
818            (1, 101, 4, 100),
819            (2, 102, 5, 100),
820            // send same values (timestamps 6, 7, 8)
821            (0, 100, 6, 100),
822            (1, 101, 7, 100),
823            (2, 102, 8, 100),
824            // shuffle rank 0 to be largest, and make rank 1 smallest (timestamps 9, 10, 11)
825            (0, 1000, 9, 101),
826            // shuffle rank 1 to be largest, and make rank 2 smallest
827            (1, 1100, 10, 102),
828            // shuffle rank 2 to be largest, and make rank 0 smallest
829            (2, 1200, 11, 1000),
830            // Increase their value, but do not change their order (timestamps 12, 13, 14)
831            (0, 1001, 12, 1001),
832            (1, 1101, 13, 1001),
833            (2, 1201, 14, 1001),
834            // decrease their values (timestamps 15, 16, 17)
835            (2, 102, 15, 102),
836            (1, 101, 16, 101),
837            (0, 100, 17, 100),
838        ];
839        let mut state = WatermarkUpdate::default();
840        for (rank, value, ts, expected) in ranks_values_expectations {
841            accumulator
842                .accumulate(&mut state, WatermarkUpdate::from((rank, value, ts)))
843                .unwrap();
844            assert_eq!(
845                state.get(),
846                &expected,
847                "rank is {rank}; value is {value}; ts is {ts}"
848            );
849        }
850    }
851
852    #[test]
853    fn test_comm_reducer_gcounter() {
854        // Updates from different ranks
855        let updates = serialize::<GCounterUpdate>(vec![
856            GCounterUpdate::from((0, 10)),
857            GCounterUpdate::from((1, 20)),
858            GCounterUpdate::from((0, 15)), // rank 0 increases to 15
859            GCounterUpdate::from((2, 5)),
860            GCounterUpdate::from((1, 25)), // rank 1 increases to 25
861        ]);
862
863        let typehash = <SemilatticeReducer<GCounterUpdate> as Named>::typehash();
864        let result = resolve_reducer(typehash, None)
865            .unwrap()
866            .unwrap()
867            .reduce_updates(updates)
868            .unwrap()
869            .deserialized::<GCounterUpdate>()
870            .unwrap();
871
872        // Each rank should have its max value
873        assert_eq!(result.get_rank(0), Some(15));
874        assert_eq!(result.get_rank(1), Some(25));
875        assert_eq!(result.get_rank(2), Some(5));
876        assert_eq!(result.num_ranks(), 3);
877        // Total is sum of max values: 15 + 25 + 5 = 45
878        assert_eq!(result.get(), 45);
879    }
880
881    #[test]
882    fn test_accum_reducer_gcounter() {
883        assert_eq!(
884            join_semilattice::<GCounterUpdate>()
885                .reducer_spec()
886                .unwrap()
887                .typehash,
888            <SemilatticeReducer<GCounterUpdate> as Named>::typehash(),
889        );
890    }
891
892    #[test]
893    fn test_gcounter_accumulator() {
894        let accumulator = join_semilattice::<GCounterUpdate>();
895        // (rank, count, expected_total)
896        let ranks_counts_expectations: [(usize, u64, u64); 17] = [
897            // initialize all 3 ranks in descending order
898            (0, 1000, 1000),
899            (1, 100, 1100),
900            (2, 10, 1110),
901            // increase in ascending order
902            (2, 20, 1120),
903            (1, 200, 1220),
904            (0, 2000, 2220),
905            // same values (idempotent - no change)
906            (0, 2000, 2220),
907            (1, 200, 2220),
908            (2, 20, 2220),
909            // lower values (ignored - max wins)
910            (0, 1, 2220),
911            (1, 1, 2220),
912            (2, 1, 2220),
913            // shuffle which rank has max: make rank 2 largest
914            (2, 5000, 7200), // 2000 + 200 + 5000
915            // make rank 1 largest
916            (1, 6000, 13000), // 2000 + 6000 + 5000
917            // make rank 0 largest again
918            (0, 10000, 21000), // 10000 + 6000 + 5000
919            // all ranks increase together
920            (0, 10001, 21001),
921            (1, 6001, 21002),
922        ];
923        let mut state = GCounterUpdate::default();
924        for (rank, count, expected) in ranks_counts_expectations {
925            accumulator
926                .accumulate(&mut state, GCounterUpdate::from((rank, count)))
927                .unwrap();
928            assert_eq!(state.get(), expected, "rank is {rank}; count is {count}");
929        }
930        // Verify final per-rank values
931        assert_eq!(state.get_rank(0), Some(10001));
932        assert_eq!(state.get_rank(1), Some(6001));
933        assert_eq!(state.get_rank(2), Some(5000));
934        assert_eq!(state.get_rank(3), None);
935        assert_eq!(state.num_ranks(), 3);
936    }
937
938    #[test]
939    fn test_gcounter_commutativity() {
940        // Verify that order of accumulation doesn't matter
941        let updates = [
942            GCounterUpdate::from((0, 10)),
943            GCounterUpdate::from((1, 20)),
944            GCounterUpdate::from((0, 15)),
945            GCounterUpdate::from((2, 5)),
946            GCounterUpdate::from((1, 25)),
947        ];
948
949        // Forward order
950        let accumulator = join_semilattice::<GCounterUpdate>();
951        let mut forward = GCounterUpdate::default();
952        for update in updates.iter().cloned() {
953            accumulator.accumulate(&mut forward, update).unwrap();
954        }
955
956        // Reverse order
957        let mut reverse = GCounterUpdate::default();
958        for update in updates.iter().rev().cloned() {
959            accumulator.accumulate(&mut reverse, update).unwrap();
960        }
961
962        assert_eq!(forward.get(), reverse.get());
963        assert_eq!(forward.get(), 45); // 15 + 25 + 5
964        assert_eq!(forward.get_rank(0), reverse.get_rank(0));
965        assert_eq!(forward.get_rank(1), reverse.get_rank(1));
966        assert_eq!(forward.get_rank(2), reverse.get_rank(2));
967    }
968
969    #[test]
970    fn test_comm_reducer_pncounter() {
971        // Updates from different ranks with increments and decrements
972        let updates = serialize::<PNCounterUpdate>(vec![
973            PNCounterUpdate::inc(0, 10),
974            PNCounterUpdate::inc(1, 20),
975            PNCounterUpdate::dec(0, 5),
976            PNCounterUpdate::inc(0, 15), // rank 0 inc increases to 15
977            PNCounterUpdate::dec(1, 8),
978            PNCounterUpdate::dec(0, 7), // rank 0 dec increases to 7
979        ]);
980
981        let typehash = <SemilatticeReducer<PNCounterUpdate> as Named>::typehash();
982        let result = resolve_reducer(typehash, None)
983            .unwrap()
984            .unwrap()
985            .reduce_updates(updates)
986            .unwrap()
987            .deserialized::<PNCounterUpdate>()
988            .unwrap();
989
990        // Each rank should have its max values for both inc and dec
991        // rank 0: inc=15, dec=7 -> contribution = 15-7 = 8
992        // rank 1: inc=20, dec=8 -> contribution = 20-8 = 12
993        // Total: 8 + 12 = 20
994        assert_eq!(result.get(), 20);
995        assert_eq!(result.num_inc_ranks(), 2);
996        assert_eq!(result.num_dec_ranks(), 2);
997    }
998
999    #[test]
1000    fn test_accum_reducer_pncounter() {
1001        assert_eq!(
1002            join_semilattice::<PNCounterUpdate>()
1003                .reducer_spec()
1004                .unwrap()
1005                .typehash,
1006            <SemilatticeReducer<PNCounterUpdate> as Named>::typehash(),
1007        );
1008    }
1009
1010    #[test]
1011    fn test_pncounter_accumulator() {
1012        let accumulator = join_semilattice::<PNCounterUpdate>();
1013        // Helper to make updates clearer
1014        #[derive(Clone, Copy, Debug)]
1015        enum Op {
1016            Inc(usize, u64),
1017            Dec(usize, u64),
1018        }
1019        use Op::*;
1020
1021        // (operation, expected_total)
1022        // State tracked: p0, p1, p2 (increments), n0, n1, n2 (decrements)
1023        // Total = (p0 + p1 + p2) - (n0 + n1 + n2)
1024        let ops_expectations = [
1025            // initialize all 3 ranks with increments
1026            (Inc(0, 100), 100), // p: 100,0,0 n: 0,0,0 = 100
1027            (Inc(1, 50), 150),  // p: 100,50,0 n: 0,0,0 = 150
1028            (Inc(2, 25), 175),  // p: 100,50,25 n: 0,0,0 = 175
1029            // add decrements
1030            (Dec(0, 10), 165), // p: 100,50,25 n: 10,0,0 = 175-10 = 165
1031            (Dec(1, 5), 160),  // p: 100,50,25 n: 10,5,0 = 175-15 = 160
1032            (Dec(2, 2), 158),  // p: 100,50,25 n: 10,5,2 = 175-17 = 158
1033            // increase increments
1034            (Inc(0, 200), 258), // p: 200,50,25 n: 10,5,2 = 275-17 = 258
1035            (Inc(1, 100), 308), // p: 200,100,25 n: 10,5,2 = 325-17 = 308
1036            (Inc(2, 50), 333),  // p: 200,100,50 n: 10,5,2 = 350-17 = 333
1037            // increase decrements
1038            (Dec(0, 20), 323), // p: 200,100,50 n: 20,5,2 = 350-27 = 323
1039            (Dec(1, 15), 313), // p: 200,100,50 n: 20,15,2 = 350-37 = 313
1040            (Dec(2, 5), 310),  // p: 200,100,50 n: 20,15,5 = 350-40 = 310
1041            // duplicate updates (idempotent - no change)
1042            (Inc(0, 200), 310),
1043            (Dec(1, 15), 310),
1044            // lower values (ignored - max wins)
1045            (Inc(0, 1), 310),
1046            (Dec(0, 1), 310),
1047            // make decrements larger than increments for some ranks
1048            (Dec(2, 60), 255),  // p: 200,100,50 n: 20,15,60 = 350-95 = 255
1049            (Dec(1, 120), 150), // p: 200,100,50 n: 20,120,60 = 350-200 = 150
1050            // rank 1 now contributes negatively: 100 - 120 = -20
1051            (Inc(2, 60), 160), // p: 200,100,60 n: 20,120,60 = 360-200 = 160
1052            // shuffle: make rank 0 contribute most
1053            (Inc(0, 1000), 960), // p: 1000,100,60 n: 20,120,60 = 1160-200 = 960
1054            (Dec(2, 100), 920),  // p: 1000,100,60 n: 20,120,100 = 1160-240 = 920
1055        ];
1056
1057        let mut state = PNCounterUpdate::default();
1058        for (i, (op, expected)) in ops_expectations.iter().enumerate() {
1059            let update = match op {
1060                Inc(rank, delta) => PNCounterUpdate::inc(*rank, *delta),
1061                Dec(rank, delta) => PNCounterUpdate::dec(*rank, *delta),
1062            };
1063            accumulator.accumulate(&mut state, update).unwrap();
1064            assert_eq!(state.get(), *expected, "step {i}: {op:?}");
1065        }
1066
1067        // Verify final state
1068        assert_eq!(state.num_inc_ranks(), 3);
1069        assert_eq!(state.num_dec_ranks(), 3);
1070    }
1071
1072    #[test]
1073    fn test_pncounter_commutativity() {
1074        // Verify that order of accumulation doesn't matter
1075        let updates = [
1076            PNCounterUpdate::inc(0, 10),
1077            PNCounterUpdate::inc(1, 20),
1078            PNCounterUpdate::dec(0, 5),
1079            PNCounterUpdate::inc(0, 15),
1080            PNCounterUpdate::dec(1, 8),
1081            PNCounterUpdate::dec(2, 3),
1082            PNCounterUpdate::inc(2, 12),
1083        ];
1084
1085        // Forward order
1086        let accumulator = join_semilattice::<PNCounterUpdate>();
1087        let mut forward = PNCounterUpdate::default();
1088        for update in updates.iter().cloned() {
1089            accumulator.accumulate(&mut forward, update).unwrap();
1090        }
1091
1092        // Reverse order
1093        let mut reverse = PNCounterUpdate::default();
1094        for update in updates.iter().rev().cloned() {
1095            accumulator.accumulate(&mut reverse, update).unwrap();
1096        }
1097
1098        assert_eq!(forward.get(), reverse.get());
1099        assert_eq!(forward.get(), 31); // (15 + 20 + 12) - (5 + 8 + 3) = 47 - 16 = 31
1100        assert_eq!(forward.num_inc_ranks(), reverse.num_inc_ranks());
1101        assert_eq!(forward.num_dec_ranks(), reverse.num_dec_ranks());
1102    }
1103}