hyperactor/
accum.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Defines the accumulator trait and some common accumulators.
10
11use std::collections::HashMap;
12use std::marker::PhantomData;
13use std::sync::OnceLock;
14use std::time::Duration;
15
16use serde::Deserialize;
17use serde::Serialize;
18use serde::de::DeserializeOwned;
19
20use crate as hyperactor; // for macros
21use crate::Named;
22use crate::config;
23use crate::data::Serialized;
24use crate::reference::Index;
25
26/// An accumulator is a object that accumulates updates into a state.
27pub trait Accumulator {
28    /// The type of the accumulated state.
29    type State;
30    /// The type of the updates sent to the accumulator. Updates will be
31    /// accumulated into type [Self::State].
32    type Update;
33
34    /// Accumulate an update into the current state.
35    fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()>;
36
37    /// The specification used to build the reducer.
38    fn reducer_spec(&self) -> Option<ReducerSpec>;
39}
40
41/// Serializable information needed to build a comm reducer.
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
43pub struct ReducerSpec {
44    /// The typehash of the underlying [Self::Reducer] type.
45    pub typehash: u64,
46    /// The parameters used to build the reducer.
47    pub builder_params: Option<Serialized>,
48}
49
50/// Runtime behavior of reducers.
51#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Default)]
52pub struct ReducerOpts {
53    /// The maximum interval between updates. When unspecified, a default
54    /// interval is used.
55    pub max_update_interval: Option<Duration>,
56}
57
58impl ReducerOpts {
59    pub(crate) fn max_update_interval(&self) -> Duration {
60        self.max_update_interval
61            .unwrap_or(config::global::get(config::SPLIT_MAX_BUFFER_AGE))
62    }
63}
64
65/// Commutative reducer for an accumulator. This is used to coallesce updates.
66/// For example, if the accumulator is a sum, its reducer calculates and returns
67/// the sum of 2 updates. This is helpful in split ports, where a large number
68/// of updates can be reduced into a smaller number of updates before being sent
69/// to the parent port.
70pub trait CommReducer {
71    /// The type of updates to be reduced.
72    type Update;
73
74    /// Reduce 2 updates into a single update.
75    fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update>;
76}
77
78/// Type erased version of [CommReducer].
79pub trait ErasedCommReducer {
80    /// Reduce 2 updates into a single update.
81    fn reduce_erased(&self, left: &Serialized, right: &Serialized) -> anyhow::Result<Serialized>;
82
83    /// Reducer an non-empty vector of updates. Return Error if the vector is
84    /// empty.
85    fn reduce_updates(
86        &self,
87        updates: Vec<Serialized>,
88    ) -> Result<Serialized, (anyhow::Error, Vec<Serialized>)> {
89        if updates.is_empty() {
90            return Err((anyhow::anyhow!("empty updates"), updates));
91        }
92        if updates.len() == 1 {
93            return Ok(updates.into_iter().next().expect("checked above"));
94        }
95
96        let mut iter = updates.iter();
97        let first = iter.next().unwrap();
98        let second = iter.next().unwrap();
99        let init = match self.reduce_erased(first, second) {
100            Ok(v) => v,
101            Err(e) => return Err((e, updates)),
102        };
103        let reduced = match iter.try_fold(init, |acc, e| self.reduce_erased(&acc, e)) {
104            Ok(v) => v,
105            Err(e) => return Err((e, updates)),
106        };
107        Ok(reduced)
108    }
109
110    /// Typehash of the underlying [`CommReducer`] type.
111    fn typehash(&self) -> u64;
112}
113
114impl<R, T> ErasedCommReducer for R
115where
116    R: CommReducer<Update = T> + Named,
117    T: Serialize + DeserializeOwned + Named,
118{
119    fn reduce_erased(&self, left: &Serialized, right: &Serialized) -> anyhow::Result<Serialized> {
120        let left = left.deserialized::<T>()?;
121        let right = right.deserialized::<T>()?;
122        let result = self.reduce(left, right)?;
123        Ok(Serialized::serialize(&result)?)
124    }
125
126    fn typehash(&self) -> u64 {
127        R::typehash()
128    }
129}
130
131/// A factory for [`ErasedCommReducer`]s. This is used to register a
132/// [`ErasedCommReducer`] type. We cannot register [`ErasedCommReducer`] trait
133/// object directly because the object could have internal state, and cannot be
134/// shared.
135pub struct ReducerFactory {
136    /// Return the typehash of the [`ErasedCommReducer`] type built by this
137    /// factory.
138    pub typehash_f: fn() -> u64,
139    /// The builder function to build the [`ErasedCommReducer`] type.
140    pub builder_f: fn(
141        Option<Serialized>,
142    ) -> anyhow::Result<Box<dyn ErasedCommReducer + Sync + Send + 'static>>,
143}
144
145inventory::collect!(ReducerFactory);
146
147inventory::submit! {
148    ReducerFactory {
149        typehash_f: <SumReducer<i64> as Named>::typehash,
150        builder_f: |_| Ok(Box::new(SumReducer::<i64>(PhantomData))),
151    }
152}
153inventory::submit! {
154    ReducerFactory {
155        typehash_f: <SumReducer<u64> as Named>::typehash,
156        builder_f: |_| Ok(Box::new(SumReducer::<u64>(PhantomData))),
157    }
158}
159inventory::submit! {
160    ReducerFactory {
161        typehash_f: <MaxReducer::<i64> as Named>::typehash,
162        builder_f: |_| Ok(Box::new(MaxReducer::<i64>(PhantomData))),
163    }
164}
165inventory::submit! {
166    ReducerFactory {
167        typehash_f: <MaxReducer::<u64> as Named>::typehash,
168        builder_f: |_| Ok(Box::new(MaxReducer::<u64>(PhantomData))),
169    }
170}
171inventory::submit! {
172    ReducerFactory {
173        typehash_f: <MinReducer::<i64> as Named>::typehash,
174        builder_f: |_| Ok(Box::new(MinReducer::<i64>(PhantomData))),
175    }
176}
177inventory::submit! {
178    ReducerFactory {
179        typehash_f: <MinReducer::<u64> as Named>::typehash,
180        builder_f: |_| Ok(Box::new(MinReducer::<u64>(PhantomData))),
181    }
182}
183inventory::submit! {
184    ReducerFactory {
185        typehash_f: <WatermarkUpdateReducer::<i64> as Named>::typehash,
186        builder_f: |_| Ok(Box::new(WatermarkUpdateReducer::<i64>(PhantomData))),
187    }
188}
189inventory::submit! {
190    ReducerFactory {
191        typehash_f: <WatermarkUpdateReducer::<u64> as Named>::typehash,
192        builder_f: |_| Ok(Box::new(WatermarkUpdateReducer::<u64>(PhantomData))),
193    }
194}
195
196/// Build a reducer object with the given typehash's [CommReducer] type, and
197/// return the type-erased version of it.
198pub(crate) fn resolve_reducer(
199    typehash: u64,
200    builder_params: Option<Serialized>,
201) -> anyhow::Result<Option<Box<dyn ErasedCommReducer + Sync + Send + 'static>>> {
202    static FACTORY_MAP: OnceLock<HashMap<u64, &'static ReducerFactory>> = OnceLock::new();
203    let factories = FACTORY_MAP.get_or_init(|| {
204        let mut map = HashMap::new();
205        for factory in inventory::iter::<ReducerFactory> {
206            map.insert((factory.typehash_f)(), factory);
207        }
208        map
209    });
210
211    factories
212        .get(&typehash)
213        .map(|f| (f.builder_f)(builder_params))
214        .transpose()
215}
216
217#[derive(Named)]
218struct SumReducer<T>(PhantomData<T>);
219
220impl<T: std::ops::Add<Output = T> + Copy + 'static> CommReducer for SumReducer<T> {
221    type Update = T;
222
223    fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
224        Ok(left + right)
225    }
226}
227
228/// Accumulate the sum of received updates. The inner function performs the
229/// summation between an update and the current state.
230struct SumAccumulator<T>(PhantomData<T>);
231
232impl<T: std::ops::Add<Output = T> + Copy + Named + 'static> Accumulator for SumAccumulator<T> {
233    type State = T;
234    type Update = T;
235
236    fn accumulate(&self, state: &mut T, update: T) -> anyhow::Result<()> {
237        *state = *state + update;
238        Ok(())
239    }
240
241    fn reducer_spec(&self) -> Option<ReducerSpec> {
242        Some(ReducerSpec {
243            typehash: <SumReducer<T> as Named>::typehash(),
244            builder_params: None,
245        })
246    }
247}
248
249/// Accumulate the sum of received updates.
250pub fn sum<T: std::ops::Add<Output = T> + Copy + Named + 'static>()
251-> impl Accumulator<State = T, Update = T> {
252    SumAccumulator(PhantomData)
253}
254
255#[derive(Named)]
256struct MaxReducer<T>(PhantomData<T>);
257
258impl<T: Ord> CommReducer for MaxReducer<T> {
259    type Update = T;
260
261    fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
262        Ok(std::cmp::max(left, right))
263    }
264}
265
266/// The state of a [`Max`] accumulator.
267#[derive(Debug, Clone, Default)]
268pub struct Max<T>(Option<T>);
269
270impl<T> Max<T> {
271    /// Get the accumulated value.
272    pub fn get(&self) -> &T {
273        self.0
274            .as_ref()
275            .expect("accumulator state should have been intialized.")
276    }
277}
278
279/// Accumulate the max of received updates.
280struct MaxAccumulator<T>(PhantomData<T>);
281
282impl<T: Ord + Copy + Named + 'static> Accumulator for MaxAccumulator<T> {
283    type State = Max<T>;
284    type Update = T;
285
286    fn accumulate(&self, state: &mut Self::State, update: T) -> anyhow::Result<()> {
287        match state.0.as_mut() {
288            Some(s) => *s = std::cmp::max(*s, update),
289            None => *state = Max(Some(update)),
290        }
291        Ok(())
292    }
293
294    fn reducer_spec(&self) -> Option<ReducerSpec> {
295        Some(ReducerSpec {
296            typehash: <MaxReducer<T> as Named>::typehash(),
297            builder_params: None,
298        })
299    }
300}
301
302/// Accumulate the max of received updates (i.e. the largest value of all
303/// received updates).
304pub fn max<T: Ord + Copy + Named + 'static>() -> impl Accumulator<State = Max<T>, Update = T> {
305    MaxAccumulator(PhantomData::<T>)
306}
307
308#[derive(Named)]
309struct MinReducer<T>(PhantomData<T>);
310
311impl<T: Ord> CommReducer for MinReducer<T> {
312    type Update = T;
313
314    fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
315        Ok(std::cmp::min(left, right))
316    }
317}
318
319/// The state of a [`Min`] accumulator.
320#[derive(Debug, Clone, Default)]
321pub struct Min<T>(Option<T>);
322
323impl<T> Min<T> {
324    /// Get the accumulated value.
325    pub fn get(&self) -> &T {
326        self.0
327            .as_ref()
328            .expect("accumulator state should have been intialized.")
329    }
330}
331
332/// Accumulate the min of received updates.
333struct MinAccumulator<T>(PhantomData<T>);
334
335impl<T: Ord + Copy + Named + 'static> Accumulator for MinAccumulator<T> {
336    type State = Min<T>;
337    type Update = T;
338
339    fn accumulate(&self, state: &mut Min<T>, update: T) -> anyhow::Result<()> {
340        match state.0.as_mut() {
341            Some(s) => *s = std::cmp::min(*s, update),
342            None => *state = Min(Some(update)),
343        }
344        Ok(())
345    }
346
347    fn reducer_spec(&self) -> Option<ReducerSpec> {
348        Some(ReducerSpec {
349            typehash: <MinReducer<T> as Named>::typehash(),
350            builder_params: None,
351        })
352    }
353}
354
355/// Accumulate the min of received updates (i.e. the smallest value of all
356/// received updates).
357pub fn min<T: Ord + Copy + Named + 'static>() -> impl Accumulator<State = Min<T>, Update = T> {
358    MinAccumulator(PhantomData)
359}
360
361/// Update from ranks for watermark accumulator, where map' key is the rank, and
362/// map's value is the update from that rank.
363#[derive(Default, Debug, Clone, Serialize, Deserialize, Named)]
364pub struct WatermarkUpdate<T>(HashMap<Index, T>);
365
366impl<T: Ord> WatermarkUpdate<T> {
367    /// Get the watermark value. WatermarkUpdate is guarranteed to be initialized by
368    /// accumulator before it is sent to the user.
369    // TODO(pzhang) optimize this and only iterate when there is a new min.
370    pub fn get(&self) -> &T {
371        self.0
372            .values()
373            .min()
374            .expect("watermark should have been intialized.")
375    }
376}
377
378impl<T: PartialEq> WatermarkUpdate<T> {
379    /// See [`WatermarkUpdateReducer`]'s documentation for the merge semantics.
380    fn merge(old: Self, new: Self) -> Self {
381        let mut map = old.0;
382        for (k, v) in new.0 {
383            map.insert(k, v);
384        }
385        Self(map)
386    }
387}
388
389impl<T> From<(Index, T)> for WatermarkUpdate<T> {
390    fn from((rank, value): (Index, T)) -> Self {
391        let mut map = HashMap::with_capacity(1);
392        map.insert(rank, value);
393        Self(map)
394    }
395}
396
397/// Merge an old update and a new update. If a rank exists in boths updates,
398/// only keep its value from the new update.
399#[derive(Named)]
400struct WatermarkUpdateReducer<T>(PhantomData<T>);
401
402impl<T: PartialEq> CommReducer for WatermarkUpdateReducer<T> {
403    type Update = WatermarkUpdate<T>;
404
405    fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update> {
406        Ok(WatermarkUpdate::merge(left, right))
407    }
408}
409
410struct LowWatermarkUpdateAccumulator<T>(PhantomData<T>);
411
412impl<T: Ord + Copy + Named + 'static> Accumulator for LowWatermarkUpdateAccumulator<T> {
413    type State = WatermarkUpdate<T>;
414    type Update = WatermarkUpdate<T>;
415
416    fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()> {
417        let current = std::mem::replace(&mut *state, WatermarkUpdate(HashMap::new()));
418        // TODO(pzhang) optimize this and only iterate when there is a new state.
419        *state = WatermarkUpdate::merge(current, update);
420        Ok(())
421    }
422
423    fn reducer_spec(&self) -> Option<ReducerSpec> {
424        Some(ReducerSpec {
425            typehash: <WatermarkUpdateReducer<T> as Named>::typehash(),
426            builder_params: None,
427        })
428    }
429}
430
431/// Accumulate the min value among the ranks, aka. low watermark, based on the
432/// ranks' latest updates. Ranks' previous updates are discarded, and not used
433/// in the min value calculation.
434///
435/// The main difference bwtween low wartermark accumulator and [`MinAccumulator`]
436/// is, `MinAccumulator` takes previous updates into consideration too, and thus
437/// returns the min of the whole history.
438pub fn low_watermark<T: Ord + Copy + Named + 'static>()
439-> impl Accumulator<State = WatermarkUpdate<T>, Update = WatermarkUpdate<T>> {
440    LowWatermarkUpdateAccumulator(PhantomData)
441}
442
443#[cfg(test)]
444mod tests {
445    use std::fmt::Debug;
446
447    use maplit::hashmap;
448
449    use super::*;
450    use crate::Named;
451
452    fn serialize<T: Serialize + Named>(values: Vec<T>) -> Vec<Serialized> {
453        values
454            .into_iter()
455            .map(|n| Serialized::serialize(&n).unwrap())
456            .collect()
457    }
458
459    #[test]
460    fn test_comm_reducer_numeric() {
461        let u64_numbers: Vec<_> = serialize(vec![1u64, 3u64, 1100u64]);
462        let i64_numbers: Vec<_> = serialize(vec![-123i64, 33i64, 110i64]);
463        {
464            let typehash = <MaxReducer<u64> as Named>::typehash();
465            assert_eq!(
466                resolve_reducer(typehash, None)
467                    .unwrap()
468                    .unwrap()
469                    .reduce_updates(u64_numbers.clone())
470                    .unwrap()
471                    .deserialized::<u64>()
472                    .unwrap(),
473                1100u64,
474            );
475
476            let typehash = <MinReducer<u64> as Named>::typehash();
477            assert_eq!(
478                resolve_reducer(typehash, None)
479                    .unwrap()
480                    .unwrap()
481                    .reduce_updates(u64_numbers.clone())
482                    .unwrap()
483                    .deserialized::<u64>()
484                    .unwrap(),
485                1u64,
486            );
487
488            let typehash = <SumReducer<u64> as Named>::typehash();
489            assert_eq!(
490                resolve_reducer(typehash, None)
491                    .unwrap()
492                    .unwrap()
493                    .reduce_updates(u64_numbers)
494                    .unwrap()
495                    .deserialized::<u64>()
496                    .unwrap(),
497                1104u64,
498            );
499        }
500
501        {
502            let typehash = <MaxReducer<i64> as Named>::typehash();
503            assert_eq!(
504                resolve_reducer(typehash, None)
505                    .unwrap()
506                    .unwrap()
507                    .reduce_updates(i64_numbers.clone())
508                    .unwrap()
509                    .deserialized::<i64>()
510                    .unwrap(),
511                110i64,
512            );
513
514            let typehash = <MinReducer<i64> as Named>::typehash();
515            assert_eq!(
516                resolve_reducer(typehash, None)
517                    .unwrap()
518                    .unwrap()
519                    .reduce_updates(i64_numbers.clone())
520                    .unwrap()
521                    .deserialized::<i64>()
522                    .unwrap(),
523                -123i64,
524            );
525
526            let typehash = <SumReducer<i64> as Named>::typehash();
527            assert_eq!(
528                resolve_reducer(typehash, None)
529                    .unwrap()
530                    .unwrap()
531                    .reduce_updates(i64_numbers)
532                    .unwrap()
533                    .deserialized::<i64>()
534                    .unwrap(),
535                20i64,
536            );
537        }
538    }
539
540    #[test]
541    fn test_comm_reducer_watermark() {
542        let u64_updates = serialize::<WatermarkUpdate<u64>>(
543            vec![
544                (1, 1),
545                (0, 2),
546                (0, 1),
547                (3, 35),
548                (0, 9),
549                (1, 10),
550                (3, 32),
551                (3, 0),
552                (3, 321),
553            ]
554            .into_iter()
555            .map(|(k, v)| WatermarkUpdate::from((k, v)))
556            .collect(),
557        );
558        let i64_updates: Vec<_> = serialize::<WatermarkUpdate<i64>>(
559            vec![
560                (0, 2),
561                (1, 1),
562                (3, 35),
563                (0, 1),
564                (1, -10),
565                (3, 32),
566                (3, 0),
567                (3, -99),
568                (0, -9),
569            ]
570            .into_iter()
571            .map(WatermarkUpdate::from)
572            .collect(),
573        );
574
575        fn verify<T: PartialEq + DeserializeOwned + Debug + Named>(
576            updates: Vec<Serialized>,
577            expected: HashMap<Index, T>,
578        ) {
579            let typehash = <WatermarkUpdateReducer<T> as Named>::typehash();
580            assert_eq!(
581                resolve_reducer(typehash, None)
582                    .unwrap()
583                    .unwrap()
584                    .reduce_updates(updates)
585                    .unwrap()
586                    .deserialized::<WatermarkUpdate<T>>()
587                    .unwrap()
588                    .0,
589                expected,
590            );
591        }
592
593        verify::<i64>(
594            i64_updates,
595            hashmap! {
596                0 => -9,
597                1 => -10,
598                3 => -99,
599            },
600        );
601
602        verify::<u64>(
603            u64_updates,
604            hashmap! {
605                0 => 9,
606                1 => 10,
607                3 => 321,
608            },
609        );
610    }
611
612    #[test]
613    fn test_accum_reducer_numeric() {
614        assert_eq!(
615            sum::<u64>().reducer_spec().unwrap().typehash,
616            <SumReducer::<u64> as Named>::typehash(),
617        );
618        assert_eq!(
619            sum::<i64>().reducer_spec().unwrap().typehash,
620            <SumReducer::<i64> as Named>::typehash(),
621        );
622
623        assert_eq!(
624            min::<u64>().reducer_spec().unwrap().typehash,
625            <MinReducer::<u64> as Named>::typehash(),
626        );
627        assert_eq!(
628            min::<i64>().reducer_spec().unwrap().typehash,
629            <MinReducer::<i64> as Named>::typehash(),
630        );
631
632        assert_eq!(
633            max::<u64>().reducer_spec().unwrap().typehash,
634            <MaxReducer::<u64> as Named>::typehash(),
635        );
636        assert_eq!(
637            max::<i64>().reducer_spec().unwrap().typehash,
638            <MaxReducer::<i64> as Named>::typehash(),
639        );
640    }
641
642    #[test]
643    fn test_accum_reducer_watermark() {
644        fn verify<T: Ord + Copy + Named>() {
645            assert_eq!(
646                low_watermark::<T>().reducer_spec().unwrap().typehash,
647                <WatermarkUpdateReducer::<T> as Named>::typehash(),
648            );
649        }
650        verify::<u64>();
651        verify::<i64>();
652    }
653
654    #[test]
655    fn test_watermark_accumulator() {
656        let accumulator = low_watermark::<u64>();
657        let ranks_values_expectations = [
658            // send in descending order
659            (0, 1003, 1003),
660            (1, 1002, 1002),
661            (2, 1001, 1001),
662            // send in asscending order
663            (0, 100, 100),
664            (1, 101, 100),
665            (2, 102, 100),
666            // send same as accumulator's cache
667            (0, 100, 100),
668            (1, 101, 100),
669            (2, 102, 100),
670            // shuffle rank 0 to be largest, and make rank 1 smallest
671            (0, 1000, 101),
672            // shuffle rank 1 to be largest, and make rank 2 smallest
673            (1, 1100, 102),
674            // shuffle rank 2 to be largest, and make rank 0 smallest
675            (2, 1200, 1000),
676            // Increase their value, but do not change their order
677            (0, 1001, 1001),
678            (1, 1101, 1001),
679            (2, 1201, 1001),
680            // decrease their values
681            (2, 102, 102),
682            (1, 101, 101),
683            (0, 100, 100),
684        ];
685        let mut state = WatermarkUpdate(HashMap::new());
686        for (rank, value, expected) in ranks_values_expectations {
687            accumulator
688                .accumulate(&mut state, WatermarkUpdate::from((rank, value)))
689                .unwrap();
690            assert_eq!(state.get(), &expected, "rank is {rank}; value is {value}");
691        }
692    }
693}