hyperactor/
ordering.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module contains utilities used to help messages are delivered in order
10//! for any given sender and receiver actor pair.
11
12use std::collections::HashMap;
13use std::fmt;
14use std::ops::DerefMut;
15use std::sync::Arc;
16use std::sync::Mutex;
17
18use dashmap::DashMap;
19use hyperactor_config::AttrValue;
20use hyperactor_config::attrs::declare_attrs;
21use serde::Deserialize;
22use serde::Serialize;
23use tokio::sync::mpsc;
24use tokio::sync::mpsc::error::SendError;
25use typeuri::Named;
26use uuid::Uuid;
27
28use crate::reference;
29
30/// A client's re-ordering buffer state.
31struct BufferState<T> {
32    /// the last sequence number sent to receiver for this client. seq starts
33    /// with 1 and 0 mean no message has been sent.
34    last_seq: u64,
35    /// Buffer out-of-order messages in order to ensures messages are delivered
36    /// strictly in per-client sequence order.
37    ///
38    /// Map's key is seq_no, value is msg.
39    buffer: HashMap<u64, T>,
40}
41
42impl<T> Default for BufferState<T> {
43    fn default() -> Self {
44        Self {
45            last_seq: 0,
46            buffer: HashMap::new(),
47        }
48    }
49}
50
51/// A sender that ensures messages are delivered in per-client sequence order.
52pub(crate) struct OrderedSender<T> {
53    tx: mpsc::UnboundedSender<T>,
54    /// Map's key is session ID, and value is the buffer state of that session.
55    states: Arc<DashMap<Uuid, Arc<Mutex<BufferState<T>>>>>,
56    pub(crate) enable_buffering: bool,
57    /// The identify of this object, which is used to distiguish it in debugging.
58    log_id: String,
59}
60
61/// A receiver that receives messages in per-client sequence order.
62pub(crate) fn ordered_channel<T>(
63    log_id: String,
64    enable_buffering: bool,
65) -> (OrderedSender<T>, mpsc::UnboundedReceiver<T>) {
66    let (tx, rx) = mpsc::unbounded_channel();
67    (
68        OrderedSender {
69            tx,
70            states: Arc::new(DashMap::new()),
71            enable_buffering,
72            log_id,
73        },
74        rx,
75    )
76}
77
78#[derive(Debug)]
79pub(crate) enum OrderedSenderError<T> {
80    InvalidZeroSeq(T),
81    SendError(SendError<T>),
82    FlushError(anyhow::Error),
83}
84
85impl<T> Clone for OrderedSender<T> {
86    fn clone(&self) -> Self {
87        Self {
88            tx: self.tx.clone(),
89            states: self.states.clone(),
90            enable_buffering: self.enable_buffering,
91            log_id: self.log_id.clone(),
92        }
93    }
94}
95
96impl<T> OrderedSender<T> {
97    /// Buffer msgs if necessary, and deliver them to receiver based on their
98    /// seqs in monotonically increasing order. Note seq is scoped by `sender`
99    /// so the ordering is also scoped by it.
100    ///
101    /// Locking behavior:
102    ///
103    /// For the same channel,
104    /// * Calls from the same client will be serialized with a lock.
105    /// * calls from different clients will be executed concurrently.
106    pub(crate) fn send(
107        &self,
108        session_id: Uuid,
109        seq_no: u64,
110        msg: T,
111    ) -> Result<(), OrderedSenderError<T>> {
112        use std::cmp::Ordering;
113
114        assert!(self.enable_buffering);
115        if seq_no == 0 {
116            return Err(OrderedSenderError::InvalidZeroSeq(msg));
117        }
118
119        // Make sure only this session's state is locked, not all states.
120        let state = self.states.entry(session_id).or_default().value().clone();
121        let mut state_guard = state.lock().unwrap();
122        let BufferState { last_seq, buffer } = state_guard.deref_mut();
123
124        match seq_no.cmp(&(*last_seq + 1)) {
125            Ordering::Less => {
126                tracing::warn!(
127                    "{} duplicate message from session {} with seq no: {}",
128                    self.log_id,
129                    session_id,
130                    seq_no,
131                );
132            }
133            Ordering::Greater => {
134                // Future message: buffer until the gap is filled.
135                let old = buffer.insert(seq_no, msg);
136                assert!(
137                    old.is_none(),
138                    "{}: same seq is insert to buffer twice: {}",
139                    self.log_id,
140                    seq_no
141                );
142            }
143            Ordering::Equal => {
144                // In-order: deliver, then flush consecutives from buffer until
145                // it reaches a gap.
146                self.tx.send(msg).map_err(OrderedSenderError::SendError)?;
147                *last_seq += 1;
148
149                while let Some(m) = buffer.remove(&(*last_seq + 1)) {
150                    match self.tx.send(m) {
151                        Ok(()) => *last_seq += 1,
152                        Err(err) => {
153                            let flush_err = OrderedSenderError::FlushError(anyhow::anyhow!(
154                                "failed to flush buffered message: {}",
155                                err
156                            ));
157                            buffer.insert(*last_seq + 1, err.0);
158                            return Err(flush_err);
159                        }
160                    }
161                }
162                // We do not remove a client's state even if its buffer becomes
163                // empty. This is because a duplicate message might arrive after
164                // the buffer became empty. Removing the state would cause the
165                // duplicate message to be delivered.
166            }
167        }
168
169        Ok(())
170    }
171
172    pub(crate) fn direct_send(&self, msg: T) -> Result<(), SendError<T>> {
173        self.tx.send(msg)
174    }
175}
176
177/// Key for sequence assignment.
178/// Actor ports share a sequence per actor; non-actor ports get individual sequences.
179#[derive(Clone, Debug, Hash, PartialEq, Eq)]
180enum SeqKey {
181    /// Shared sequence for all actor ports of an actor
182    Actor(reference::ActorId),
183    /// Individual sequence for a specific non-actor port
184    Port(reference::PortId),
185}
186
187/// A message's sequencer number infomation.
188#[derive(Debug, Serialize, Deserialize, Clone, Named, AttrValue, PartialEq)]
189pub enum SeqInfo {
190    /// Messages with the same session ID should be delivered in order.
191    Session {
192        /// Message's session ID
193        session_id: Uuid,
194        /// Message's sequence number in the given session.
195        seq: u64,
196    },
197    /// This message does not have a seq number and should be delivered
198    /// immediately.
199    Direct,
200}
201
202impl fmt::Display for SeqInfo {
203    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204        match self {
205            Self::Direct => write!(f, "direct"),
206            Self::Session { session_id, seq } => write!(f, "{}:{}", session_id, seq),
207        }
208    }
209}
210
211impl std::str::FromStr for SeqInfo {
212    type Err = anyhow::Error;
213
214    fn from_str(s: &str) -> Result<Self, Self::Err> {
215        if s == "direct" {
216            return Ok(SeqInfo::Direct);
217        }
218
219        let parts: Vec<_> = s.split(':').collect();
220        if parts.len() != 2 {
221            return Err(anyhow::anyhow!("invalid SeqInfo: {}", s));
222        }
223        let session_id: Uuid = parts[0].parse()?;
224        let seq: u64 = parts[1].parse()?;
225        Ok(SeqInfo::Session { session_id, seq })
226    }
227}
228
229declare_attrs! {
230    /// The sender of this message, the session ID, and the message's sequence
231    /// number assigned by this session.
232    pub attr SEQ_INFO: SeqInfo;
233}
234
235/// Used by sender to track the message sequence numbers it sends to each destination.
236/// Each [Sequencer] object has a session id, sequence numbers are scoped by
237/// the (session_id, SeqKey) pair.
238#[derive(Clone, Debug)]
239pub struct Sequencer {
240    session_id: Uuid,
241    // Map's key is the sequence key (actor or port), value is the last seq number.
242    last_seqs: Arc<Mutex<HashMap<SeqKey, u64>>>,
243}
244
245impl Sequencer {
246    pub(crate) fn new(session_id: Uuid) -> Self {
247        Self {
248            session_id,
249            last_seqs: Arc::new(Mutex::new(HashMap::new())),
250        }
251    }
252
253    /// Assign the next seq for a port, mutate the sequencer with the new seq,
254    /// and return the new seq.
255    ///
256    /// - Actor ports: share the same sequence scheme per actor (keyed by ActorId)
257    /// - Non-actor ports: get individual sequence schemes (keyed by PortId)
258    pub fn assign_seq(&self, port_id: &reference::PortId) -> SeqInfo {
259        let key = if port_id.is_actor_port() {
260            SeqKey::Actor(port_id.actor_id().clone())
261        } else {
262            SeqKey::Port(port_id.clone())
263        };
264
265        let mut guard = self.last_seqs.lock().unwrap();
266        let entry = guard.entry(key).or_default();
267        *entry += 1;
268        SeqInfo::Session {
269            session_id: self.session_id,
270            seq: *entry,
271        }
272    }
273
274    /// Id of the session this sequencer belongs to.
275    pub fn session_id(&self) -> Uuid {
276        self.session_id
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use std::sync::Arc;
283
284    use super::*;
285    use crate::testing::ids::test_actor_id;
286
287    /// Test message type 1 for actor port sequencing tests.
288    #[derive(Named)]
289    struct TestMsg1;
290
291    /// Test message type 2 for actor port sequencing tests.
292    #[derive(Named)]
293    struct TestMsg2;
294
295    fn drain_try_recv<T: std::fmt::Debug + Clone>(rx: &mut mpsc::UnboundedReceiver<T>) -> Vec<T> {
296        let mut out = Vec::new();
297        while let Ok(m) = rx.try_recv() {
298            out.push(m);
299        }
300        out
301    }
302
303    /// Helper to extract seq from SeqInfo::Session variant (for tests only)
304    fn get_seq(seq_info: SeqInfo) -> u64 {
305        match seq_info {
306            SeqInfo::Session { seq, .. } => seq,
307            SeqInfo::Direct => panic!("expected Session variant, got Direct"),
308        }
309    }
310
311    #[test]
312    fn test_ordered_channel_single_client_send_in_order() {
313        let session_id_a = Uuid::now_v7();
314        let (tx, mut rx) = ordered_channel::<u64>("test".to_string(), true);
315        for s in 1..=10 {
316            tx.send(session_id_a, s, s).unwrap();
317            let got = drain_try_recv(&mut rx);
318            assert_eq!(got, vec![s]);
319        }
320    }
321
322    #[test]
323    fn test_ordered_channel_single_client_send_out_of_order() {
324        let session_id_a = Uuid::now_v7();
325        let (tx, mut rx) = ordered_channel::<u64>("test".to_string(), true);
326
327        // Send 2 to 4 in descending order: all should buffer until 1 arrives.
328        for s in (2..=4).rev() {
329            tx.send(session_id_a, s, s).unwrap();
330        }
331
332        // Send 7 to 9 in descending order: all should buffer until 1 - 6 arrives.
333        for s in (7..=9).rev() {
334            tx.send(session_id_a, s, s).unwrap();
335        }
336
337        assert!(
338            drain_try_recv(&mut rx).is_empty(),
339            "nothing should be delivered yet"
340        );
341
342        // Now send 1: should deliver 1 then flush 2 - 4.
343        tx.send(session_id_a, 1, 1).unwrap();
344        assert_eq!(drain_try_recv(&mut rx), vec![1, 2, 3, 4]);
345
346        // Now send 5: should deliver immediately but not flush 7 - 9.
347        tx.send(session_id_a, 5, 5).unwrap();
348        assert_eq!(drain_try_recv(&mut rx), vec![5]);
349
350        // Now send 6: should deliver 6 then flush 7 - 9.
351        tx.send(session_id_a, 6, 6).unwrap();
352        assert_eq!(drain_try_recv(&mut rx), vec![6, 7, 8, 9]);
353
354        // Send 10: should deliver immediately.
355        tx.send(session_id_a, 10, 10).unwrap();
356        let got = drain_try_recv(&mut rx);
357        assert_eq!(got, vec![10]);
358    }
359
360    #[test]
361    fn test_ordered_channel_multi_clients() {
362        let session_id_a = Uuid::now_v7();
363        let session_id_b = Uuid::now_v7();
364        let (tx, mut rx) = ordered_channel::<(Uuid, u64)>("test".to_string(), true);
365
366        // A1 -> deliver
367        tx.send(session_id_a, 1, (session_id_a, 1)).unwrap();
368        assert_eq!(drain_try_recv(&mut rx), vec![(session_id_a, 1)]);
369        // B1 -> deliver
370        tx.send(session_id_b, 1, (session_id_b, 1)).unwrap();
371        assert_eq!(drain_try_recv(&mut rx), vec![(session_id_b, 1)]);
372        for s in (3..=5).rev() {
373            // A3-5 -> buffer (waiting for A2)
374            tx.send(session_id_a, s, (session_id_a, s)).unwrap();
375            // B3-5 -> buffer (waiting for B2)
376            tx.send(session_id_b, s, (session_id_b, s)).unwrap();
377        }
378        for s in (7..=9).rev() {
379            // A7-9 -> buffer (waiting for A1-6)
380            tx.send(session_id_a, s, (session_id_a, s)).unwrap();
381            // B7-9 -> buffer (waiting for B1-6)
382            tx.send(session_id_b, s, (session_id_b, s)).unwrap();
383        }
384        assert!(
385            drain_try_recv(&mut rx).is_empty(),
386            "nothing should be delivered yet"
387        );
388
389        // A2 -> deliver A2 then flush A3
390        tx.send(session_id_a, 2, (session_id_a, 2)).unwrap();
391        assert_eq!(
392            drain_try_recv(&mut rx),
393            vec![
394                (session_id_a, 2),
395                (session_id_a, 3),
396                (session_id_a, 4),
397                (session_id_a, 5),
398            ]
399        );
400        // B2 -> deliver B2 then flush B3
401        tx.send(session_id_b, 2, (session_id_b, 2)).unwrap();
402        assert_eq!(
403            drain_try_recv(&mut rx),
404            vec![
405                (session_id_b, 2),
406                (session_id_b, 3),
407                (session_id_b, 4),
408                (session_id_b, 5),
409            ]
410        );
411
412        // A6 -> should deliver immediately and flush A7-9
413        tx.send(session_id_a, 6, (session_id_a, 6)).unwrap();
414        assert_eq!(
415            drain_try_recv(&mut rx),
416            vec![
417                (session_id_a, 6),
418                (session_id_a, 7),
419                (session_id_a, 8),
420                (session_id_a, 9)
421            ]
422        );
423        // B6 -> should deliver immediately and flush B7-9
424        tx.send(session_id_b, 6, (session_id_b, 6)).unwrap();
425        assert_eq!(
426            drain_try_recv(&mut rx),
427            vec![
428                (session_id_b, 6),
429                (session_id_b, 7),
430                (session_id_b, 8),
431                (session_id_b, 9)
432            ]
433        );
434    }
435
436    #[test]
437    fn test_ordered_channel_duplicates() {
438        let session_id_a = Uuid::now_v7();
439        fn verify_empty_buffers<T>(states: &DashMap<Uuid, Arc<Mutex<BufferState<T>>>>) {
440            for entry in states.iter() {
441                assert!(entry.value().lock().unwrap().buffer.is_empty());
442            }
443        }
444
445        let (tx, mut rx) = ordered_channel::<(Uuid, u64)>("test".to_string(), true);
446        // A1 -> deliver
447        tx.send(session_id_a, 1, (session_id_a, 1)).unwrap();
448        assert_eq!(drain_try_recv(&mut rx), vec![(session_id_a, 1)]);
449        verify_empty_buffers(&tx.states);
450        // duplicate A1 -> drop even if the message is different.
451        tx.send(session_id_a, 1, (session_id_a, 1_000)).unwrap();
452        assert!(
453            drain_try_recv(&mut rx).is_empty(),
454            "nothing should be delivered yet"
455        );
456        verify_empty_buffers(&tx.states);
457        // A2 -> deliver
458        tx.send(session_id_a, 2, (session_id_a, 2)).unwrap();
459        assert_eq!(drain_try_recv(&mut rx), vec![(session_id_a, 2)]);
460        verify_empty_buffers(&tx.states);
461        // late A1 duplicate -> drop
462        tx.send(session_id_a, 1, (session_id_a, 1_001)).unwrap();
463        assert!(
464            drain_try_recv(&mut rx).is_empty(),
465            "nothing should be delivered yet"
466        );
467        verify_empty_buffers(&tx.states);
468    }
469
470    #[test]
471    fn test_sequencer_clone() {
472        let sequencer = Sequencer {
473            session_id: Uuid::now_v7(),
474            last_seqs: Arc::new(Mutex::new(HashMap::new())),
475        };
476
477        let actor_id = test_actor_id("test_0", "test");
478        let port_id = actor_id.port_id(1);
479
480        // Modify original sequencer
481        sequencer.assign_seq(&port_id);
482        sequencer.assign_seq(&port_id);
483
484        // Clone should share the same state
485        let cloned_sequencer = sequencer.clone();
486        assert_eq!(sequencer.session_id(), cloned_sequencer.session_id(),);
487        assert_eq!(get_seq(cloned_sequencer.assign_seq(&port_id)), 3);
488    }
489
490    #[test]
491    fn test_sequencer_actor_ports_share_sequence() {
492        let sequencer = Sequencer {
493            session_id: Uuid::now_v7(),
494            last_seqs: Arc::new(Mutex::new(HashMap::new())),
495        };
496
497        let actor_id = test_actor_id("worker_0", "worker");
498        // Two different actor ports for the same actor (using Named::port())
499        let actor_port_1 = actor_id.port_id(TestMsg1::port());
500        let actor_port_2 = actor_id.port_id(TestMsg2::port());
501
502        // Actor ports should share a sequence (keyed by ActorId)
503        assert_eq!(get_seq(sequencer.assign_seq(&actor_port_1)), 1);
504        assert_eq!(get_seq(sequencer.assign_seq(&actor_port_2)), 2); // continues from 1
505        assert_eq!(get_seq(sequencer.assign_seq(&actor_port_1)), 3);
506
507        // Actor ports from a different actor get their own shared sequence
508        let actor_id_2 = test_actor_id("worker_1", "worker");
509        let actor_port_3 = actor_id_2.port_id(TestMsg1::port());
510        assert_eq!(get_seq(sequencer.assign_seq(&actor_port_3)), 1); // independent from actor_id
511    }
512
513    #[test]
514    fn test_sequencer_non_actor_ports_have_independent_sequences() {
515        let sequencer = Sequencer {
516            session_id: Uuid::now_v7(),
517            last_seqs: Arc::new(Mutex::new(HashMap::new())),
518        };
519
520        let actor_id_0 = test_actor_id("worker_0", "worker");
521        let actor_id_1 = test_actor_id("worker_1", "worker");
522
523        // Non-actor ports from the same actor (without ACTOR_PORT_BIT)
524        let port_1 = actor_id_0.port_id(1);
525        let port_2 = actor_id_0.port_id(2);
526
527        // Non-actor ports should have independent sequences (keyed by PortId)
528        assert_eq!(get_seq(sequencer.assign_seq(&port_1)), 1);
529        assert_eq!(get_seq(sequencer.assign_seq(&port_2)), 1); // independent, starts at 1
530        assert_eq!(get_seq(sequencer.assign_seq(&port_1)), 2);
531        assert_eq!(get_seq(sequencer.assign_seq(&port_2)), 2);
532
533        // Non-actor ports from different actors are also independent
534        let port_3 = actor_id_1.port_id(1);
535        assert_eq!(get_seq(sequencer.assign_seq(&port_3)), 1); // independent from port_1
536        assert_eq!(get_seq(sequencer.assign_seq(&port_1)), 3);
537        assert_eq!(get_seq(sequencer.assign_seq(&port_3)), 2);
538    }
539
540    #[test]
541    fn test_sequencer_mixed_actor_and_non_actor_ports() {
542        let sequencer = Sequencer {
543            session_id: Uuid::now_v7(),
544            last_seqs: Arc::new(Mutex::new(HashMap::new())),
545        };
546
547        let actor_id = test_actor_id("worker_0", "worker");
548
549        // Actor ports (share sequence per actor)
550        let actor_port_1 = actor_id.port_id(TestMsg1::port());
551        let actor_port_2 = actor_id.port_id(TestMsg2::port());
552
553        // Non-actor ports (independent sequences per port)
554        let non_actor_port_1 = actor_id.port_id(1);
555        let non_actor_port_2 = actor_id.port_id(2);
556
557        // Interleave sends to all port types
558        assert_eq!(get_seq(sequencer.assign_seq(&actor_port_1)), 1);
559        assert_eq!(get_seq(sequencer.assign_seq(&non_actor_port_1)), 1); // independent
560        assert_eq!(get_seq(sequencer.assign_seq(&actor_port_2)), 2); // continues actor sequence
561        assert_eq!(get_seq(sequencer.assign_seq(&non_actor_port_2)), 1); // independent
562        assert_eq!(get_seq(sequencer.assign_seq(&non_actor_port_1)), 2); // continues its own
563        assert_eq!(get_seq(sequencer.assign_seq(&actor_port_1)), 3); // continues actor sequence
564        assert_eq!(get_seq(sequencer.assign_seq(&non_actor_port_2)), 2); // continues its own
565    }
566}