hyperactor/
accum.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Defines the accumulator trait and some common accumulators.
10
11use std::collections::HashMap;
12use std::marker::PhantomData;
13use std::sync::OnceLock;
14
15use serde::Deserialize;
16use serde::Serialize;
17use serde::de::DeserializeOwned;
18
19use crate as hyperactor; // for macros
20use crate::Named;
21use crate::data::Serialized;
22use crate::reference::Index;
23
24/// An accumulator is a object that accumulates updates into a state.
25pub trait Accumulator {
26    /// The type of the accumulated state.
27    type State;
28    /// The type of the updates sent to the accumulator. Updates will be
29    /// accumulated into type [Self::State].
30    type Update;
31
32    /// Accumulate an update into the current state.
33    fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()>;
34
35    /// The specification used to build the reducer.
36    fn reducer_spec(&self) -> Option<ReducerSpec>;
37}
38
39/// Serializable information needed to build a comm reducer.
40#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
41pub struct ReducerSpec {
42    /// The typehash of the underlying [Self::Reducer] type.
43    pub typehash: u64,
44    /// The parameters used to build the reducer.
45    pub builder_params: Option<Serialized>,
46}
47
48/// Commutative reducer for an accumulator. This is used to coallesce updates.
49/// For example, if the accumulator is a sum, its reducer calculates and returns
50/// the sum of 2 updates. This is helpful in split ports, where a large number
51/// of updates can be reduced into a smaller number of updates before being sent
52/// to the parent port.
53pub trait CommReducer {
54    /// The type of updates to be reduced.
55    type Update;
56
57    /// Reduce 2 updates into a single update.
58    fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update>;
59}
60
61/// Type erased version of [CommReducer].
62pub trait ErasedCommReducer {
63    /// Reduce 2 updates into a single update.
64    fn reduce_erased(&self, left: &Serialized, right: &Serialized) -> anyhow::Result<Serialized>;
65
66    /// Reducer an non-empty vector of updates. Return Error if the vector is
67    /// empty.
68    fn reduce_updates(
69        &self,
70        updates: Vec<Serialized>,
71    ) -> Result<Serialized, (anyhow::Error, Vec<Serialized>)> {
72        if updates.is_empty() {
73            return Err((anyhow::anyhow!("empty updates"), updates));
74        }
75        if updates.len() == 1 {
76            return Ok(updates.into_iter().next().expect("checked above"));
77        }
78
79        let mut iter = updates.iter();
80        let first = iter.next().unwrap();
81        let second = iter.next().unwrap();
82        let init = match self.reduce_erased(first, second) {
83            Ok(v) => v,
84            Err(e) => return Err((e, updates)),
85        };
86        let reduced = match iter.try_fold(init, |acc, e| self.reduce_erased(&acc, e)) {
87            Ok(v) => v,
88            Err(e) => return Err((e, updates)),
89        };
90        Ok(reduced)
91    }
92
93    /// Typehash of the underlying [`CommReducer`] type.
94    fn typehash(&self) -> u64;
95}
96
97impl<R, T> ErasedCommReducer for R
98where
99    R: CommReducer<Update = T> + Named,
100    T: Serialize + DeserializeOwned + Named,
101{
102    fn reduce_erased(&self, left: &Serialized, right: &Serialized) -> anyhow::Result<Serialized> {
103        let left = left.deserialized::<T>()?;
104        let right = right.deserialized::<T>()?;
105        let result = self.reduce(left, right)?;
106        Ok(Serialized::serialize(&result)?)
107    }
108
109    fn typehash(&self) -> u64 {
110        R::typehash()
111    }
112}
113
114/// A factory for [`ErasedCommReducer`]s. This is used to register a
115/// [`ErasedCommReducer`] type. We cannot register [`ErasedCommReducer`] trait
116/// object directly because the object could have internal state, and cannot be
117/// shared.
118pub struct ReducerFactory {
119    /// Return the typehash of the [`ErasedCommReducer`] type built by this
120    /// factory.
121    pub typehash_f: fn() -> u64,
122    /// The builder function to build the [`ErasedCommReducer`] type.
123    pub builder_f: fn(
124        Option<Serialized>,
125    ) -> anyhow::Result<Box<dyn ErasedCommReducer + Sync + Send + 'static>>,
126}
127
128inventory::collect!(ReducerFactory);
129
130inventory::submit! {
131    ReducerFactory {
132        typehash_f: <SumReducer<i64> as Named>::typehash,
133        builder_f: |_| Ok(Box::new(SumReducer::<i64>(PhantomData))),
134    }
135}
136inventory::submit! {
137    ReducerFactory {
138        typehash_f: <SumReducer<u64> as Named>::typehash,
139        builder_f: |_| Ok(Box::new(SumReducer::<u64>(PhantomData))),
140    }
141}
142inventory::submit! {
143    ReducerFactory {
144        typehash_f: <MaxReducer::<i64> as Named>::typehash,
145        builder_f: |_| Ok(Box::new(MaxReducer::<i64>(PhantomData))),
146    }
147}
148inventory::submit! {
149    ReducerFactory {
150        typehash_f: <MaxReducer::<u64> as Named>::typehash,
151        builder_f: |_| Ok(Box::new(MaxReducer::<u64>(PhantomData))),
152    }
153}
154inventory::submit! {
155    ReducerFactory {
156        typehash_f: <MinReducer::<i64> as Named>::typehash,
157        builder_f: |_| Ok(Box::new(MinReducer::<i64>(PhantomData))),
158    }
159}
160inventory::submit! {
161    ReducerFactory {
162        typehash_f: <MinReducer::<u64> as Named>::typehash,
163        builder_f: |_| Ok(Box::new(MinReducer::<u64>(PhantomData))),
164    }
165}
166inventory::submit! {
167    ReducerFactory {
168        typehash_f: <WatermarkUpdateReducer::<i64> as Named>::typehash,
169        builder_f: |_| Ok(Box::new(WatermarkUpdateReducer::<i64>(PhantomData))),
170    }
171}
172inventory::submit! {
173    ReducerFactory {
174        typehash_f: <WatermarkUpdateReducer::<u64> as Named>::typehash,
175        builder_f: |_| Ok(Box::new(WatermarkUpdateReducer::<u64>(PhantomData))),
176    }
177}
178
179/// Build a reducer object with the given typehash's [CommReducer] type, and
180/// return the type-erased version of it.
181pub(crate) fn resolve_reducer(
182    typehash: u64,
183    builder_params: Option<Serialized>,
184) -> anyhow::Result<Option<Box<dyn ErasedCommReducer + Sync + Send + 'static>>> {
185    static FACTORY_MAP: OnceLock<HashMap<u64, &'static ReducerFactory>> = OnceLock::new();
186    let factories = FACTORY_MAP.get_or_init(|| {
187        let mut map = HashMap::new();
188        for factory in inventory::iter::<ReducerFactory> {
189            map.insert((factory.typehash_f)(), factory);
190        }
191        map
192    });
193
194    factories
195        .get(&typehash)
196        .map(|f| (f.builder_f)(builder_params))
197        .transpose()
198}
199
200#[derive(Named)]
201struct SumReducer<T>(PhantomData<T>);
202
203impl<T: std::ops::Add<Output = T> + Copy + 'static> CommReducer for SumReducer<T> {
204    type Update = T;
205
206    fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
207        Ok(left + right)
208    }
209}
210
211/// Accumulate the sum of received updates. The inner function performs the
212/// summation between an update and the current state.
213struct SumAccumulator<T>(PhantomData<T>);
214
215impl<T: std::ops::Add<Output = T> + Copy + Named + 'static> Accumulator for SumAccumulator<T> {
216    type State = T;
217    type Update = T;
218
219    fn accumulate(&self, state: &mut T, update: T) -> anyhow::Result<()> {
220        *state = *state + update;
221        Ok(())
222    }
223
224    fn reducer_spec(&self) -> Option<ReducerSpec> {
225        Some(ReducerSpec {
226            typehash: <SumReducer<T> as Named>::typehash(),
227            builder_params: None,
228        })
229    }
230}
231
232/// Accumulate the sum of received updates.
233pub fn sum<T: std::ops::Add<Output = T> + Copy + Named + 'static>()
234-> impl Accumulator<State = T, Update = T> {
235    SumAccumulator(PhantomData)
236}
237
238#[derive(Named)]
239struct MaxReducer<T>(PhantomData<T>);
240
241impl<T: Ord> CommReducer for MaxReducer<T> {
242    type Update = T;
243
244    fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
245        Ok(std::cmp::max(left, right))
246    }
247}
248
249/// The state of a [`Max`] accumulator.
250#[derive(Debug, Clone, Default)]
251pub struct Max<T>(Option<T>);
252
253impl<T> Max<T> {
254    /// Get the accumulated value.
255    pub fn get(&self) -> &T {
256        self.0
257            .as_ref()
258            .expect("accumulator state should have been intialized.")
259    }
260}
261
262/// Accumulate the max of received updates.
263struct MaxAccumulator<T>(PhantomData<T>);
264
265impl<T: Ord + Copy + Named + 'static> Accumulator for MaxAccumulator<T> {
266    type State = Max<T>;
267    type Update = T;
268
269    fn accumulate(&self, state: &mut Self::State, update: T) -> anyhow::Result<()> {
270        match state.0.as_mut() {
271            Some(s) => *s = std::cmp::max(*s, update),
272            None => *state = Max(Some(update)),
273        }
274        Ok(())
275    }
276
277    fn reducer_spec(&self) -> Option<ReducerSpec> {
278        Some(ReducerSpec {
279            typehash: <MaxReducer<T> as Named>::typehash(),
280            builder_params: None,
281        })
282    }
283}
284
285/// Accumulate the max of received updates (i.e. the largest value of all
286/// received updates).
287pub fn max<T: Ord + Copy + Named + 'static>() -> impl Accumulator<State = Max<T>, Update = T> {
288    MaxAccumulator(PhantomData::<T>)
289}
290
291#[derive(Named)]
292struct MinReducer<T>(PhantomData<T>);
293
294impl<T: Ord> CommReducer for MinReducer<T> {
295    type Update = T;
296
297    fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
298        Ok(std::cmp::min(left, right))
299    }
300}
301
302/// The state of a [`Min`] accumulator.
303#[derive(Debug, Clone, Default)]
304pub struct Min<T>(Option<T>);
305
306impl<T> Min<T> {
307    /// Get the accumulated value.
308    pub fn get(&self) -> &T {
309        self.0
310            .as_ref()
311            .expect("accumulator state should have been intialized.")
312    }
313}
314
315/// Accumulate the min of received updates.
316struct MinAccumulator<T>(PhantomData<T>);
317
318impl<T: Ord + Copy + Named + 'static> Accumulator for MinAccumulator<T> {
319    type State = Min<T>;
320    type Update = T;
321
322    fn accumulate(&self, state: &mut Min<T>, update: T) -> anyhow::Result<()> {
323        match state.0.as_mut() {
324            Some(s) => *s = std::cmp::min(*s, update),
325            None => *state = Min(Some(update)),
326        }
327        Ok(())
328    }
329
330    fn reducer_spec(&self) -> Option<ReducerSpec> {
331        Some(ReducerSpec {
332            typehash: <MinReducer<T> as Named>::typehash(),
333            builder_params: None,
334        })
335    }
336}
337
338/// Accumulate the min of received updates (i.e. the smallest value of all
339/// received updates).
340pub fn min<T: Ord + Copy + Named + 'static>() -> impl Accumulator<State = Min<T>, Update = T> {
341    MinAccumulator(PhantomData)
342}
343
344/// Update from ranks for watermark accumulator, where map' key is the rank, and
345/// map's value is the update from that rank.
346#[derive(Default, Debug, Clone, Serialize, Deserialize, Named)]
347pub struct WatermarkUpdate<T>(HashMap<Index, T>);
348
349impl<T: Ord> WatermarkUpdate<T> {
350    /// Get the watermark value. WatermarkUpdate is guarranteed to be initialized by
351    /// accumulator before it is sent to the user.
352    // TODO(pzhang) optimize this and only iterate when there is a new min.
353    pub fn get(&self) -> &T {
354        self.0
355            .values()
356            .min()
357            .expect("watermark should have been intialized.")
358    }
359}
360
361impl<T: PartialEq> WatermarkUpdate<T> {
362    /// See [`WatermarkUpdateReducer`]'s documentation for the merge semantics.
363    fn merge(old: Self, new: Self) -> Self {
364        let mut map = old.0;
365        for (k, v) in new.0 {
366            map.insert(k, v);
367        }
368        Self(map)
369    }
370}
371
372impl<T> From<(Index, T)> for WatermarkUpdate<T> {
373    fn from((rank, value): (Index, T)) -> Self {
374        let mut map = HashMap::with_capacity(1);
375        map.insert(rank, value);
376        Self(map)
377    }
378}
379
380/// Merge an old update and a new update. If a rank exists in boths updates,
381/// only keep its value from the new update.
382#[derive(Named)]
383struct WatermarkUpdateReducer<T>(PhantomData<T>);
384
385impl<T: PartialEq> CommReducer for WatermarkUpdateReducer<T> {
386    type Update = WatermarkUpdate<T>;
387
388    fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update> {
389        Ok(WatermarkUpdate::merge(left, right))
390    }
391}
392
393struct LowWatermarkUpdateAccumulator<T>(PhantomData<T>);
394
395impl<T: Ord + Copy + Named + 'static> Accumulator for LowWatermarkUpdateAccumulator<T> {
396    type State = WatermarkUpdate<T>;
397    type Update = WatermarkUpdate<T>;
398
399    fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()> {
400        let current = std::mem::replace(&mut *state, WatermarkUpdate(HashMap::new()));
401        // TODO(pzhang) optimize this and only iterate when there is a new state.
402        *state = WatermarkUpdate::merge(current, update);
403        Ok(())
404    }
405
406    fn reducer_spec(&self) -> Option<ReducerSpec> {
407        Some(ReducerSpec {
408            typehash: <WatermarkUpdateReducer<T> as Named>::typehash(),
409            builder_params: None,
410        })
411    }
412}
413
414/// Accumulate the min value among the ranks, aka. low watermark, based on the
415/// ranks' latest updates. Ranks' previous updates are discarded, and not used
416/// in the min value calculation.
417///
418/// The main difference bwtween low wartermark accumulator and [`MinAccumulator`]
419/// is, `MinAccumulator` takes previous updates into consideration too, and thus
420/// returns the min of the whole history.
421pub fn low_watermark<T: Ord + Copy + Named + 'static>()
422-> impl Accumulator<State = WatermarkUpdate<T>, Update = WatermarkUpdate<T>> {
423    LowWatermarkUpdateAccumulator(PhantomData)
424}
425
426#[cfg(test)]
427mod tests {
428    use std::fmt::Debug;
429
430    use maplit::hashmap;
431
432    use super::*;
433    use crate::Named;
434
435    fn serialize<T: Serialize + Named>(values: Vec<T>) -> Vec<Serialized> {
436        values
437            .into_iter()
438            .map(|n| Serialized::serialize(&n).unwrap())
439            .collect()
440    }
441
442    #[test]
443    fn test_comm_reducer_numeric() {
444        let u64_numbers: Vec<_> = serialize(vec![1u64, 3u64, 1100u64]);
445        let i64_numbers: Vec<_> = serialize(vec![-123i64, 33i64, 110i64]);
446        {
447            let typehash = <MaxReducer<u64> as Named>::typehash();
448            assert_eq!(
449                resolve_reducer(typehash, None)
450                    .unwrap()
451                    .unwrap()
452                    .reduce_updates(u64_numbers.clone())
453                    .unwrap()
454                    .deserialized::<u64>()
455                    .unwrap(),
456                1100u64,
457            );
458
459            let typehash = <MinReducer<u64> as Named>::typehash();
460            assert_eq!(
461                resolve_reducer(typehash, None)
462                    .unwrap()
463                    .unwrap()
464                    .reduce_updates(u64_numbers.clone())
465                    .unwrap()
466                    .deserialized::<u64>()
467                    .unwrap(),
468                1u64,
469            );
470
471            let typehash = <SumReducer<u64> as Named>::typehash();
472            assert_eq!(
473                resolve_reducer(typehash, None)
474                    .unwrap()
475                    .unwrap()
476                    .reduce_updates(u64_numbers)
477                    .unwrap()
478                    .deserialized::<u64>()
479                    .unwrap(),
480                1104u64,
481            );
482        }
483
484        {
485            let typehash = <MaxReducer<i64> as Named>::typehash();
486            assert_eq!(
487                resolve_reducer(typehash, None)
488                    .unwrap()
489                    .unwrap()
490                    .reduce_updates(i64_numbers.clone())
491                    .unwrap()
492                    .deserialized::<i64>()
493                    .unwrap(),
494                110i64,
495            );
496
497            let typehash = <MinReducer<i64> as Named>::typehash();
498            assert_eq!(
499                resolve_reducer(typehash, None)
500                    .unwrap()
501                    .unwrap()
502                    .reduce_updates(i64_numbers.clone())
503                    .unwrap()
504                    .deserialized::<i64>()
505                    .unwrap(),
506                -123i64,
507            );
508
509            let typehash = <SumReducer<i64> as Named>::typehash();
510            assert_eq!(
511                resolve_reducer(typehash, None)
512                    .unwrap()
513                    .unwrap()
514                    .reduce_updates(i64_numbers)
515                    .unwrap()
516                    .deserialized::<i64>()
517                    .unwrap(),
518                20i64,
519            );
520        }
521    }
522
523    #[test]
524    fn test_comm_reducer_watermark() {
525        let u64_updates = serialize::<WatermarkUpdate<u64>>(
526            vec![
527                (1, 1),
528                (0, 2),
529                (0, 1),
530                (3, 35),
531                (0, 9),
532                (1, 10),
533                (3, 32),
534                (3, 0),
535                (3, 321),
536            ]
537            .into_iter()
538            .map(|(k, v)| WatermarkUpdate::from((k, v)))
539            .collect(),
540        );
541        let i64_updates: Vec<_> = serialize::<WatermarkUpdate<i64>>(
542            vec![
543                (0, 2),
544                (1, 1),
545                (3, 35),
546                (0, 1),
547                (1, -10),
548                (3, 32),
549                (3, 0),
550                (3, -99),
551                (0, -9),
552            ]
553            .into_iter()
554            .map(WatermarkUpdate::from)
555            .collect(),
556        );
557
558        fn verify<T: PartialEq + DeserializeOwned + Debug>(
559            updates: Vec<Serialized>,
560            expected: HashMap<Index, T>,
561        ) {
562            let typehash = <WatermarkUpdateReducer<i64> as Named>::typehash();
563            assert_eq!(
564                resolve_reducer(typehash, None)
565                    .unwrap()
566                    .unwrap()
567                    .reduce_updates(updates)
568                    .unwrap()
569                    .deserialized::<WatermarkUpdate<T>>()
570                    .unwrap()
571                    .0,
572                expected,
573            );
574        }
575
576        verify::<i64>(
577            i64_updates,
578            hashmap! {
579                0 => -9,
580                1 => -10,
581                3 => -99,
582            },
583        );
584
585        verify::<u64>(
586            u64_updates,
587            hashmap! {
588                0 => 9,
589                1 => 10,
590                3 => 321,
591            },
592        );
593    }
594
595    #[test]
596    fn test_accum_reducer_numeric() {
597        assert_eq!(
598            sum::<u64>().reducer_spec().unwrap().typehash,
599            <SumReducer::<u64> as Named>::typehash(),
600        );
601        assert_eq!(
602            sum::<i64>().reducer_spec().unwrap().typehash,
603            <SumReducer::<i64> as Named>::typehash(),
604        );
605
606        assert_eq!(
607            min::<u64>().reducer_spec().unwrap().typehash,
608            <MinReducer::<u64> as Named>::typehash(),
609        );
610        assert_eq!(
611            min::<i64>().reducer_spec().unwrap().typehash,
612            <MinReducer::<i64> as Named>::typehash(),
613        );
614
615        assert_eq!(
616            max::<u64>().reducer_spec().unwrap().typehash,
617            <MaxReducer::<u64> as Named>::typehash(),
618        );
619        assert_eq!(
620            max::<i64>().reducer_spec().unwrap().typehash,
621            <MaxReducer::<i64> as Named>::typehash(),
622        );
623    }
624
625    #[test]
626    fn test_accum_reducer_watermark() {
627        fn verify<T: Ord + Copy + Named>() {
628            assert_eq!(
629                low_watermark::<T>().reducer_spec().unwrap().typehash,
630                <WatermarkUpdateReducer::<T> as Named>::typehash(),
631            );
632        }
633        verify::<u64>();
634        verify::<i64>();
635    }
636
637    #[test]
638    fn test_watermark_accumulator() {
639        let accumulator = low_watermark::<u64>();
640        let ranks_values_expectations = [
641            // send in descending order
642            (0, 1003, 1003),
643            (1, 1002, 1002),
644            (2, 1001, 1001),
645            // send in asscending order
646            (0, 100, 100),
647            (1, 101, 100),
648            (2, 102, 100),
649            // send same as accumulator's cache
650            (0, 100, 100),
651            (1, 101, 100),
652            (2, 102, 100),
653            // shuffle rank 0 to be largest, and make rank 1 smallest
654            (0, 1000, 101),
655            // shuffle rank 1 to be largest, and make rank 2 smallest
656            (1, 1100, 102),
657            // shuffle rank 2 to be largest, and make rank 0 smallest
658            (2, 1200, 1000),
659            // Increase their value, but do not change their order
660            (0, 1001, 1001),
661            (1, 1101, 1001),
662            (2, 1201, 1001),
663            // decrease their values
664            (2, 102, 102),
665            (1, 101, 101),
666            (0, 100, 100),
667        ];
668        let mut state = WatermarkUpdate(HashMap::new());
669        for (rank, value, expected) in ranks_values_expectations {
670            accumulator
671                .accumulate(&mut state, WatermarkUpdate::from((rank, value)))
672                .unwrap();
673            assert_eq!(state.get(), &expected, "rank is {rank}; value is {value}");
674        }
675    }
676}