hyperactor/
accum.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Defines the accumulator trait and some common accumulators.
10
11use std::collections::HashMap;
12use std::marker::PhantomData;
13use std::sync::OnceLock;
14use std::time::Duration;
15
16use serde::Deserialize;
17use serde::Serialize;
18use serde::de::DeserializeOwned;
19use typeuri::Named;
20
21// for macros
22use crate::config;
23use crate::reference::Index;
24
25/// An accumulator is a object that accumulates updates into a state.
26pub trait Accumulator {
27    /// The type of the accumulated state.
28    type State;
29    /// The type of the updates sent to the accumulator. Updates will be
30    /// accumulated into type [Self::State].
31    type Update;
32
33    /// Accumulate an update into the current state.
34    fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()>;
35
36    /// The specification used to build the reducer.
37    fn reducer_spec(&self) -> Option<ReducerSpec>;
38}
39
40/// Serializable information needed to build a comm reducer.
41#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, typeuri::Named)]
42pub struct ReducerSpec {
43    /// The typehash of the underlying [Self::Reducer] type.
44    pub typehash: u64,
45    /// The parameters used to build the reducer.
46    pub builder_params: Option<wirevalue::Any>,
47}
48wirevalue::register_type!(ReducerSpec);
49
50/// Runtime behavior of reducers.
51#[derive(
52    Debug,
53    Clone,
54    PartialEq,
55    Serialize,
56    Deserialize,
57    typeuri::Named,
58    Default
59)]
60pub struct ReducerOpts {
61    /// The maximum interval between updates. When unspecified, a default
62    /// interval is used.
63    pub max_update_interval: Option<Duration>,
64    /// The initial interval for the first update. When unspecified, defaults to 1ms.
65    /// This allows quick flushing of single messages while using exponential backoff
66    /// to reach max_update_interval for batched messages.
67    pub initial_update_interval: Option<Duration>,
68}
69
70impl ReducerOpts {
71    pub(crate) fn max_update_interval(&self) -> Duration {
72        self.max_update_interval
73            .unwrap_or(hyperactor_config::global::get(config::SPLIT_MAX_BUFFER_AGE))
74    }
75
76    pub(crate) fn initial_update_interval(&self) -> Duration {
77        self.initial_update_interval
78            .unwrap_or(Duration::from_millis(1))
79    }
80}
81
82/// Commutative reducer for an accumulator. This is used to coallesce updates.
83/// For example, if the accumulator is a sum, its reducer calculates and returns
84/// the sum of 2 updates. This is helpful in split ports, where a large number
85/// of updates can be reduced into a smaller number of updates before being sent
86/// to the parent port.
87pub trait CommReducer {
88    /// The type of updates to be reduced.
89    type Update;
90
91    /// Reduce 2 updates into a single update.
92    fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update>;
93}
94
95/// Type erased version of [CommReducer].
96pub trait ErasedCommReducer {
97    /// Reduce 2 updates into a single update.
98    fn reduce_erased(
99        &self,
100        left: &wirevalue::Any,
101        right: &wirevalue::Any,
102    ) -> anyhow::Result<wirevalue::Any>;
103
104    /// Reducer an non-empty vector of updates. Return Error if the vector is
105    /// empty.
106    fn reduce_updates(
107        &self,
108        updates: Vec<wirevalue::Any>,
109    ) -> Result<wirevalue::Any, (anyhow::Error, Vec<wirevalue::Any>)> {
110        if updates.is_empty() {
111            return Err((anyhow::anyhow!("empty updates"), updates));
112        }
113        if updates.len() == 1 {
114            return Ok(updates.into_iter().next().expect("checked above"));
115        }
116
117        let mut iter = updates.iter();
118        let first = iter.next().unwrap();
119        let second = iter.next().unwrap();
120        let init = match self.reduce_erased(first, second) {
121            Ok(v) => v,
122            Err(e) => return Err((e, updates)),
123        };
124        let reduced = match iter.try_fold(init, |acc, e| self.reduce_erased(&acc, e)) {
125            Ok(v) => v,
126            Err(e) => return Err((e, updates)),
127        };
128        Ok(reduced)
129    }
130
131    /// Typehash of the underlying [`CommReducer`] type.
132    fn typehash(&self) -> u64;
133}
134
135impl<R, T> ErasedCommReducer for R
136where
137    R: CommReducer<Update = T> + Named,
138    T: Serialize + DeserializeOwned + Named,
139{
140    fn reduce_erased(
141        &self,
142        left: &wirevalue::Any,
143        right: &wirevalue::Any,
144    ) -> anyhow::Result<wirevalue::Any> {
145        let left = left.deserialized::<T>()?;
146        let right = right.deserialized::<T>()?;
147        let result = self.reduce(left, right)?;
148        Ok(wirevalue::Any::serialize(&result)?)
149    }
150
151    fn typehash(&self) -> u64 {
152        R::typehash()
153    }
154}
155
156/// A factory for [`ErasedCommReducer`]s. This is used to register a
157/// [`ErasedCommReducer`] type. We cannot register [`ErasedCommReducer`] trait
158/// object directly because the object could have internal state, and cannot be
159/// shared.
160pub struct ReducerFactory {
161    /// Return the typehash of the [`ErasedCommReducer`] type built by this
162    /// factory.
163    pub typehash_f: fn() -> u64,
164    /// The builder function to build the [`ErasedCommReducer`] type.
165    pub builder_f: fn(
166        Option<wirevalue::Any>,
167    ) -> anyhow::Result<Box<dyn ErasedCommReducer + Sync + Send + 'static>>,
168}
169
170inventory::collect!(ReducerFactory);
171
172inventory::submit! {
173    ReducerFactory {
174        typehash_f: <SumReducer<i64> as Named>::typehash,
175        builder_f: |_| Ok(Box::new(SumReducer::<i64>(PhantomData))),
176    }
177}
178inventory::submit! {
179    ReducerFactory {
180        typehash_f: <SumReducer<u64> as Named>::typehash,
181        builder_f: |_| Ok(Box::new(SumReducer::<u64>(PhantomData))),
182    }
183}
184inventory::submit! {
185    ReducerFactory {
186        typehash_f: <MaxReducer::<i64> as Named>::typehash,
187        builder_f: |_| Ok(Box::new(MaxReducer::<i64>(PhantomData))),
188    }
189}
190inventory::submit! {
191    ReducerFactory {
192        typehash_f: <MaxReducer::<u64> as Named>::typehash,
193        builder_f: |_| Ok(Box::new(MaxReducer::<u64>(PhantomData))),
194    }
195}
196inventory::submit! {
197    ReducerFactory {
198        typehash_f: <MinReducer::<i64> as Named>::typehash,
199        builder_f: |_| Ok(Box::new(MinReducer::<i64>(PhantomData))),
200    }
201}
202inventory::submit! {
203    ReducerFactory {
204        typehash_f: <MinReducer::<u64> as Named>::typehash,
205        builder_f: |_| Ok(Box::new(MinReducer::<u64>(PhantomData))),
206    }
207}
208inventory::submit! {
209    ReducerFactory {
210        typehash_f: <WatermarkUpdateReducer::<i64> as Named>::typehash,
211        builder_f: |_| Ok(Box::new(WatermarkUpdateReducer::<i64>(PhantomData))),
212    }
213}
214inventory::submit! {
215    ReducerFactory {
216        typehash_f: <WatermarkUpdateReducer::<u64> as Named>::typehash,
217        builder_f: |_| Ok(Box::new(WatermarkUpdateReducer::<u64>(PhantomData))),
218    }
219}
220
221/// Build a reducer object with the given typehash's [CommReducer] type, and
222/// return the type-erased version of it.
223pub(crate) fn resolve_reducer(
224    typehash: u64,
225    builder_params: Option<wirevalue::Any>,
226) -> anyhow::Result<Option<Box<dyn ErasedCommReducer + Sync + Send + 'static>>> {
227    static FACTORY_MAP: OnceLock<HashMap<u64, &'static ReducerFactory>> = OnceLock::new();
228    let factories = FACTORY_MAP.get_or_init(|| {
229        let mut map = HashMap::new();
230        for factory in inventory::iter::<ReducerFactory> {
231            map.insert((factory.typehash_f)(), factory);
232        }
233        map
234    });
235
236    factories
237        .get(&typehash)
238        .map(|f| (f.builder_f)(builder_params))
239        .transpose()
240}
241
242#[derive(typeuri::Named)]
243struct SumReducer<T>(PhantomData<T>);
244
245impl<T: std::ops::Add<Output = T> + Copy + 'static> CommReducer for SumReducer<T> {
246    type Update = T;
247
248    fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
249        Ok(left + right)
250    }
251}
252
253/// Accumulate the sum of received updates. The inner function performs the
254/// summation between an update and the current state.
255struct SumAccumulator<T>(PhantomData<T>);
256
257impl<T: std::ops::Add<Output = T> + Copy + Named + 'static> Accumulator for SumAccumulator<T> {
258    type State = T;
259    type Update = T;
260
261    fn accumulate(&self, state: &mut T, update: T) -> anyhow::Result<()> {
262        *state = *state + update;
263        Ok(())
264    }
265
266    fn reducer_spec(&self) -> Option<ReducerSpec> {
267        Some(ReducerSpec {
268            typehash: <SumReducer<T> as Named>::typehash(),
269            builder_params: None,
270        })
271    }
272}
273
274/// Accumulate the sum of received updates.
275pub fn sum<T: std::ops::Add<Output = T> + Copy + Named + 'static>()
276-> impl Accumulator<State = T, Update = T> {
277    SumAccumulator(PhantomData)
278}
279
280#[derive(typeuri::Named)]
281struct MaxReducer<T>(PhantomData<T>);
282
283impl<T: Ord> CommReducer for MaxReducer<T> {
284    type Update = T;
285
286    fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
287        Ok(std::cmp::max(left, right))
288    }
289}
290
291/// The state of a [`Max`] accumulator.
292#[derive(Debug, Clone, Default)]
293pub struct Max<T>(Option<T>);
294
295impl<T> Max<T> {
296    /// Get the accumulated value.
297    pub fn get(&self) -> &T {
298        self.0
299            .as_ref()
300            .expect("accumulator state should have been intialized.")
301    }
302}
303
304/// Accumulate the max of received updates.
305struct MaxAccumulator<T>(PhantomData<T>);
306
307impl<T: Ord + Copy + Named + 'static> Accumulator for MaxAccumulator<T> {
308    type State = Max<T>;
309    type Update = T;
310
311    fn accumulate(&self, state: &mut Self::State, update: T) -> anyhow::Result<()> {
312        match state.0.as_mut() {
313            Some(s) => *s = std::cmp::max(*s, update),
314            None => *state = Max(Some(update)),
315        }
316        Ok(())
317    }
318
319    fn reducer_spec(&self) -> Option<ReducerSpec> {
320        Some(ReducerSpec {
321            typehash: <MaxReducer<T> as Named>::typehash(),
322            builder_params: None,
323        })
324    }
325}
326
327/// Accumulate the max of received updates (i.e. the largest value of all
328/// received updates).
329pub fn max<T: Ord + Copy + Named + 'static>() -> impl Accumulator<State = Max<T>, Update = T> {
330    MaxAccumulator(PhantomData::<T>)
331}
332
333#[derive(typeuri::Named)]
334struct MinReducer<T>(PhantomData<T>);
335
336impl<T: Ord> CommReducer for MinReducer<T> {
337    type Update = T;
338
339    fn reduce(&self, left: T, right: T) -> anyhow::Result<T> {
340        Ok(std::cmp::min(left, right))
341    }
342}
343
344/// The state of a [`Min`] accumulator.
345#[derive(Debug, Clone, Default)]
346pub struct Min<T>(Option<T>);
347
348impl<T> Min<T> {
349    /// Get the accumulated value.
350    pub fn get(&self) -> &T {
351        self.0
352            .as_ref()
353            .expect("accumulator state should have been intialized.")
354    }
355}
356
357/// Accumulate the min of received updates.
358struct MinAccumulator<T>(PhantomData<T>);
359
360impl<T: Ord + Copy + Named + 'static> Accumulator for MinAccumulator<T> {
361    type State = Min<T>;
362    type Update = T;
363
364    fn accumulate(&self, state: &mut Min<T>, update: T) -> anyhow::Result<()> {
365        match state.0.as_mut() {
366            Some(s) => *s = std::cmp::min(*s, update),
367            None => *state = Min(Some(update)),
368        }
369        Ok(())
370    }
371
372    fn reducer_spec(&self) -> Option<ReducerSpec> {
373        Some(ReducerSpec {
374            typehash: <MinReducer<T> as Named>::typehash(),
375            builder_params: None,
376        })
377    }
378}
379
380/// Accumulate the min of received updates (i.e. the smallest value of all
381/// received updates).
382pub fn min<T: Ord + Copy + Named + 'static>() -> impl Accumulator<State = Min<T>, Update = T> {
383    MinAccumulator(PhantomData)
384}
385
386/// Update from ranks for watermark accumulator, where map' key is the rank, and
387/// map's value is the update from that rank.
388#[derive(Default, Debug, Clone, Serialize, Deserialize, typeuri::Named)]
389pub struct WatermarkUpdate<T>(HashMap<Index, T>);
390
391impl<T: Ord> WatermarkUpdate<T> {
392    /// Get the watermark value. WatermarkUpdate is guarranteed to be initialized by
393    /// accumulator before it is sent to the user.
394    // TODO(pzhang) optimize this and only iterate when there is a new min.
395    pub fn get(&self) -> &T {
396        self.0
397            .values()
398            .min()
399            .expect("watermark should have been intialized.")
400    }
401}
402
403impl<T: PartialEq> WatermarkUpdate<T> {
404    /// See [`WatermarkUpdateReducer`]'s documentation for the merge semantics.
405    fn merge(old: Self, new: Self) -> Self {
406        let mut map = old.0;
407        for (k, v) in new.0 {
408            map.insert(k, v);
409        }
410        Self(map)
411    }
412}
413
414impl<T> From<(Index, T)> for WatermarkUpdate<T> {
415    fn from((rank, value): (Index, T)) -> Self {
416        let mut map = HashMap::with_capacity(1);
417        map.insert(rank, value);
418        Self(map)
419    }
420}
421
422/// Merge an old update and a new update. If a rank exists in boths updates,
423/// only keep its value from the new update.
424#[derive(typeuri::Named)]
425struct WatermarkUpdateReducer<T>(PhantomData<T>);
426
427impl<T: PartialEq> CommReducer for WatermarkUpdateReducer<T> {
428    type Update = WatermarkUpdate<T>;
429
430    fn reduce(&self, left: Self::Update, right: Self::Update) -> anyhow::Result<Self::Update> {
431        Ok(WatermarkUpdate::merge(left, right))
432    }
433}
434
435struct LowWatermarkUpdateAccumulator<T>(PhantomData<T>);
436
437impl<T: Ord + Copy + Named + 'static> Accumulator for LowWatermarkUpdateAccumulator<T> {
438    type State = WatermarkUpdate<T>;
439    type Update = WatermarkUpdate<T>;
440
441    fn accumulate(&self, state: &mut Self::State, update: Self::Update) -> anyhow::Result<()> {
442        let current = std::mem::replace(&mut *state, WatermarkUpdate(HashMap::new()));
443        // TODO(pzhang) optimize this and only iterate when there is a new state.
444        *state = WatermarkUpdate::merge(current, update);
445        Ok(())
446    }
447
448    fn reducer_spec(&self) -> Option<ReducerSpec> {
449        Some(ReducerSpec {
450            typehash: <WatermarkUpdateReducer<T> as Named>::typehash(),
451            builder_params: None,
452        })
453    }
454}
455
456/// Accumulate the min value among the ranks, aka. low watermark, based on the
457/// ranks' latest updates. Ranks' previous updates are discarded, and not used
458/// in the min value calculation.
459///
460/// The main difference bwtween low wartermark accumulator and [`MinAccumulator`]
461/// is, `MinAccumulator` takes previous updates into consideration too, and thus
462/// returns the min of the whole history.
463pub fn low_watermark<T: Ord + Copy + Named + 'static>()
464-> impl Accumulator<State = WatermarkUpdate<T>, Update = WatermarkUpdate<T>> {
465    LowWatermarkUpdateAccumulator(PhantomData)
466}
467
468#[cfg(test)]
469mod tests {
470    use std::fmt::Debug;
471
472    use maplit::hashmap;
473    use typeuri::Named;
474
475    use super::*;
476
477    fn serialize<T: Serialize + Named>(values: Vec<T>) -> Vec<wirevalue::Any> {
478        values
479            .into_iter()
480            .map(|n| wirevalue::Any::serialize(&n).unwrap())
481            .collect()
482    }
483
484    #[test]
485    fn test_comm_reducer_numeric() {
486        let u64_numbers: Vec<_> = serialize(vec![1u64, 3u64, 1100u64]);
487        let i64_numbers: Vec<_> = serialize(vec![-123i64, 33i64, 110i64]);
488        {
489            let typehash = <MaxReducer<u64> as Named>::typehash();
490            assert_eq!(
491                resolve_reducer(typehash, None)
492                    .unwrap()
493                    .unwrap()
494                    .reduce_updates(u64_numbers.clone())
495                    .unwrap()
496                    .deserialized::<u64>()
497                    .unwrap(),
498                1100u64,
499            );
500
501            let typehash = <MinReducer<u64> as Named>::typehash();
502            assert_eq!(
503                resolve_reducer(typehash, None)
504                    .unwrap()
505                    .unwrap()
506                    .reduce_updates(u64_numbers.clone())
507                    .unwrap()
508                    .deserialized::<u64>()
509                    .unwrap(),
510                1u64,
511            );
512
513            let typehash = <SumReducer<u64> as Named>::typehash();
514            assert_eq!(
515                resolve_reducer(typehash, None)
516                    .unwrap()
517                    .unwrap()
518                    .reduce_updates(u64_numbers)
519                    .unwrap()
520                    .deserialized::<u64>()
521                    .unwrap(),
522                1104u64,
523            );
524        }
525
526        {
527            let typehash = <MaxReducer<i64> as Named>::typehash();
528            assert_eq!(
529                resolve_reducer(typehash, None)
530                    .unwrap()
531                    .unwrap()
532                    .reduce_updates(i64_numbers.clone())
533                    .unwrap()
534                    .deserialized::<i64>()
535                    .unwrap(),
536                110i64,
537            );
538
539            let typehash = <MinReducer<i64> as Named>::typehash();
540            assert_eq!(
541                resolve_reducer(typehash, None)
542                    .unwrap()
543                    .unwrap()
544                    .reduce_updates(i64_numbers.clone())
545                    .unwrap()
546                    .deserialized::<i64>()
547                    .unwrap(),
548                -123i64,
549            );
550
551            let typehash = <SumReducer<i64> as Named>::typehash();
552            assert_eq!(
553                resolve_reducer(typehash, None)
554                    .unwrap()
555                    .unwrap()
556                    .reduce_updates(i64_numbers)
557                    .unwrap()
558                    .deserialized::<i64>()
559                    .unwrap(),
560                20i64,
561            );
562        }
563    }
564
565    #[test]
566    fn test_comm_reducer_watermark() {
567        let u64_updates = serialize::<WatermarkUpdate<u64>>(
568            vec![
569                (1, 1),
570                (0, 2),
571                (0, 1),
572                (3, 35),
573                (0, 9),
574                (1, 10),
575                (3, 32),
576                (3, 0),
577                (3, 321),
578            ]
579            .into_iter()
580            .map(|(k, v)| WatermarkUpdate::from((k, v)))
581            .collect(),
582        );
583        let i64_updates: Vec<_> = serialize::<WatermarkUpdate<i64>>(
584            vec![
585                (0, 2),
586                (1, 1),
587                (3, 35),
588                (0, 1),
589                (1, -10),
590                (3, 32),
591                (3, 0),
592                (3, -99),
593                (0, -9),
594            ]
595            .into_iter()
596            .map(WatermarkUpdate::from)
597            .collect(),
598        );
599
600        fn verify<T: PartialEq + DeserializeOwned + Debug + Named>(
601            updates: Vec<wirevalue::Any>,
602            expected: HashMap<Index, T>,
603        ) {
604            let typehash = <WatermarkUpdateReducer<T> as Named>::typehash();
605            assert_eq!(
606                resolve_reducer(typehash, None)
607                    .unwrap()
608                    .unwrap()
609                    .reduce_updates(updates)
610                    .unwrap()
611                    .deserialized::<WatermarkUpdate<T>>()
612                    .unwrap()
613                    .0,
614                expected,
615            );
616        }
617
618        verify::<i64>(
619            i64_updates,
620            hashmap! {
621                0 => -9,
622                1 => -10,
623                3 => -99,
624            },
625        );
626
627        verify::<u64>(
628            u64_updates,
629            hashmap! {
630                0 => 9,
631                1 => 10,
632                3 => 321,
633            },
634        );
635    }
636
637    #[test]
638    fn test_accum_reducer_numeric() {
639        assert_eq!(
640            sum::<u64>().reducer_spec().unwrap().typehash,
641            <SumReducer::<u64> as Named>::typehash(),
642        );
643        assert_eq!(
644            sum::<i64>().reducer_spec().unwrap().typehash,
645            <SumReducer::<i64> as Named>::typehash(),
646        );
647
648        assert_eq!(
649            min::<u64>().reducer_spec().unwrap().typehash,
650            <MinReducer::<u64> as Named>::typehash(),
651        );
652        assert_eq!(
653            min::<i64>().reducer_spec().unwrap().typehash,
654            <MinReducer::<i64> as Named>::typehash(),
655        );
656
657        assert_eq!(
658            max::<u64>().reducer_spec().unwrap().typehash,
659            <MaxReducer::<u64> as Named>::typehash(),
660        );
661        assert_eq!(
662            max::<i64>().reducer_spec().unwrap().typehash,
663            <MaxReducer::<i64> as Named>::typehash(),
664        );
665    }
666
667    #[test]
668    fn test_accum_reducer_watermark() {
669        fn verify<T: Ord + Copy + Named>() {
670            assert_eq!(
671                low_watermark::<T>().reducer_spec().unwrap().typehash,
672                <WatermarkUpdateReducer::<T> as Named>::typehash(),
673            );
674        }
675        verify::<u64>();
676        verify::<i64>();
677    }
678
679    #[test]
680    fn test_watermark_accumulator() {
681        let accumulator = low_watermark::<u64>();
682        let ranks_values_expectations = [
683            // send in descending order
684            (0, 1003, 1003),
685            (1, 1002, 1002),
686            (2, 1001, 1001),
687            // send in asscending order
688            (0, 100, 100),
689            (1, 101, 100),
690            (2, 102, 100),
691            // send same as accumulator's cache
692            (0, 100, 100),
693            (1, 101, 100),
694            (2, 102, 100),
695            // shuffle rank 0 to be largest, and make rank 1 smallest
696            (0, 1000, 101),
697            // shuffle rank 1 to be largest, and make rank 2 smallest
698            (1, 1100, 102),
699            // shuffle rank 2 to be largest, and make rank 0 smallest
700            (2, 1200, 1000),
701            // Increase their value, but do not change their order
702            (0, 1001, 1001),
703            (1, 1101, 1001),
704            (2, 1201, 1001),
705            // decrease their values
706            (2, 102, 102),
707            (1, 101, 101),
708            (0, 100, 100),
709        ];
710        let mut state = WatermarkUpdate(HashMap::new());
711        for (rank, value, expected) in ranks_values_expectations {
712            accumulator
713                .accumulate(&mut state, WatermarkUpdate::from((rank, value)))
714                .unwrap();
715            assert_eq!(state.get(), &expected, "rank is {rank}; value is {value}");
716        }
717    }
718}