1use std::collections::HashMap;
13use std::fmt;
14use std::ops::DerefMut;
15use std::sync::Arc;
16use std::sync::Mutex;
17
18use dashmap::DashMap;
19use hyperactor_config::AttrValue;
20use hyperactor_config::attrs::declare_attrs;
21use serde::Deserialize;
22use serde::Serialize;
23use tokio::sync::mpsc;
24use tokio::sync::mpsc::error::SendError;
25use typeuri::Named;
26use uuid::Uuid;
27
28use crate::reference;
29
30struct BufferState<T> {
32 last_seq: u64,
35 buffer: HashMap<u64, T>,
40}
41
42impl<T> Default for BufferState<T> {
43 fn default() -> Self {
44 Self {
45 last_seq: 0,
46 buffer: HashMap::new(),
47 }
48 }
49}
50
51pub(crate) struct OrderedSender<T> {
53 tx: mpsc::UnboundedSender<T>,
54 states: Arc<DashMap<Uuid, Arc<Mutex<BufferState<T>>>>>,
56 pub(crate) enable_buffering: bool,
57 log_id: String,
59}
60
61pub(crate) fn ordered_channel<T>(
63 log_id: String,
64 enable_buffering: bool,
65) -> (OrderedSender<T>, mpsc::UnboundedReceiver<T>) {
66 let (tx, rx) = mpsc::unbounded_channel();
67 (
68 OrderedSender {
69 tx,
70 states: Arc::new(DashMap::new()),
71 enable_buffering,
72 log_id,
73 },
74 rx,
75 )
76}
77
78#[derive(Debug)]
79pub(crate) enum OrderedSenderError<T> {
80 InvalidZeroSeq(T),
81 SendError(SendError<T>),
82 FlushError(anyhow::Error),
83}
84
85impl<T> Clone for OrderedSender<T> {
86 fn clone(&self) -> Self {
87 Self {
88 tx: self.tx.clone(),
89 states: self.states.clone(),
90 enable_buffering: self.enable_buffering,
91 log_id: self.log_id.clone(),
92 }
93 }
94}
95
96impl<T> OrderedSender<T> {
97 pub(crate) fn send(
107 &self,
108 session_id: Uuid,
109 seq_no: u64,
110 msg: T,
111 ) -> Result<(), OrderedSenderError<T>> {
112 use std::cmp::Ordering;
113
114 assert!(self.enable_buffering);
115 if seq_no == 0 {
116 return Err(OrderedSenderError::InvalidZeroSeq(msg));
117 }
118
119 let state = self.states.entry(session_id).or_default().value().clone();
121 let mut state_guard = state.lock().unwrap();
122 let BufferState { last_seq, buffer } = state_guard.deref_mut();
123
124 match seq_no.cmp(&(*last_seq + 1)) {
125 Ordering::Less => {
126 tracing::warn!(
127 "{} duplicate message from session {} with seq no: {}",
128 self.log_id,
129 session_id,
130 seq_no,
131 );
132 }
133 Ordering::Greater => {
134 let old = buffer.insert(seq_no, msg);
136 assert!(
137 old.is_none(),
138 "{}: same seq is insert to buffer twice: {}",
139 self.log_id,
140 seq_no
141 );
142 }
143 Ordering::Equal => {
144 self.tx.send(msg).map_err(OrderedSenderError::SendError)?;
147 *last_seq += 1;
148
149 while let Some(m) = buffer.remove(&(*last_seq + 1)) {
150 match self.tx.send(m) {
151 Ok(()) => *last_seq += 1,
152 Err(err) => {
153 let flush_err = OrderedSenderError::FlushError(anyhow::anyhow!(
154 "failed to flush buffered message: {}",
155 err
156 ));
157 buffer.insert(*last_seq + 1, err.0);
158 return Err(flush_err);
159 }
160 }
161 }
162 }
167 }
168
169 Ok(())
170 }
171
172 pub(crate) fn direct_send(&self, msg: T) -> Result<(), SendError<T>> {
173 self.tx.send(msg)
174 }
175}
176
177#[derive(Clone, Debug, Hash, PartialEq, Eq)]
180enum SeqKey {
181 Actor(reference::ActorId),
183 Port(reference::PortId),
185}
186
187#[derive(Debug, Serialize, Deserialize, Clone, Named, AttrValue, PartialEq)]
189pub enum SeqInfo {
190 Session {
192 session_id: Uuid,
194 seq: u64,
196 },
197 Direct,
200}
201
202impl fmt::Display for SeqInfo {
203 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204 match self {
205 Self::Direct => write!(f, "direct"),
206 Self::Session { session_id, seq } => write!(f, "{}:{}", session_id, seq),
207 }
208 }
209}
210
211impl std::str::FromStr for SeqInfo {
212 type Err = anyhow::Error;
213
214 fn from_str(s: &str) -> Result<Self, Self::Err> {
215 if s == "direct" {
216 return Ok(SeqInfo::Direct);
217 }
218
219 let parts: Vec<_> = s.split(':').collect();
220 if parts.len() != 2 {
221 return Err(anyhow::anyhow!("invalid SeqInfo: {}", s));
222 }
223 let session_id: Uuid = parts[0].parse()?;
224 let seq: u64 = parts[1].parse()?;
225 Ok(SeqInfo::Session { session_id, seq })
226 }
227}
228
229declare_attrs! {
230 pub attr SEQ_INFO: SeqInfo;
233}
234
235#[derive(Clone, Debug)]
239pub struct Sequencer {
240 session_id: Uuid,
241 last_seqs: Arc<Mutex<HashMap<SeqKey, u64>>>,
243}
244
245impl Sequencer {
246 pub(crate) fn new(session_id: Uuid) -> Self {
247 Self {
248 session_id,
249 last_seqs: Arc::new(Mutex::new(HashMap::new())),
250 }
251 }
252
253 pub fn assign_seq(&self, port_id: &reference::PortId) -> SeqInfo {
259 let key = if port_id.is_actor_port() {
260 SeqKey::Actor(port_id.actor_id().clone())
261 } else {
262 SeqKey::Port(port_id.clone())
263 };
264
265 let mut guard = self.last_seqs.lock().unwrap();
266 let entry = guard.entry(key).or_default();
267 *entry += 1;
268 SeqInfo::Session {
269 session_id: self.session_id,
270 seq: *entry,
271 }
272 }
273
274 pub fn session_id(&self) -> Uuid {
276 self.session_id
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use std::sync::Arc;
283
284 use super::*;
285 use crate::testing::ids::test_actor_id;
286
287 #[derive(Named)]
289 struct TestMsg1;
290
291 #[derive(Named)]
293 struct TestMsg2;
294
295 fn drain_try_recv<T: std::fmt::Debug + Clone>(rx: &mut mpsc::UnboundedReceiver<T>) -> Vec<T> {
296 let mut out = Vec::new();
297 while let Ok(m) = rx.try_recv() {
298 out.push(m);
299 }
300 out
301 }
302
303 fn get_seq(seq_info: SeqInfo) -> u64 {
305 match seq_info {
306 SeqInfo::Session { seq, .. } => seq,
307 SeqInfo::Direct => panic!("expected Session variant, got Direct"),
308 }
309 }
310
311 #[test]
312 fn test_ordered_channel_single_client_send_in_order() {
313 let session_id_a = Uuid::now_v7();
314 let (tx, mut rx) = ordered_channel::<u64>("test".to_string(), true);
315 for s in 1..=10 {
316 tx.send(session_id_a, s, s).unwrap();
317 let got = drain_try_recv(&mut rx);
318 assert_eq!(got, vec![s]);
319 }
320 }
321
322 #[test]
323 fn test_ordered_channel_single_client_send_out_of_order() {
324 let session_id_a = Uuid::now_v7();
325 let (tx, mut rx) = ordered_channel::<u64>("test".to_string(), true);
326
327 for s in (2..=4).rev() {
329 tx.send(session_id_a, s, s).unwrap();
330 }
331
332 for s in (7..=9).rev() {
334 tx.send(session_id_a, s, s).unwrap();
335 }
336
337 assert!(
338 drain_try_recv(&mut rx).is_empty(),
339 "nothing should be delivered yet"
340 );
341
342 tx.send(session_id_a, 1, 1).unwrap();
344 assert_eq!(drain_try_recv(&mut rx), vec![1, 2, 3, 4]);
345
346 tx.send(session_id_a, 5, 5).unwrap();
348 assert_eq!(drain_try_recv(&mut rx), vec![5]);
349
350 tx.send(session_id_a, 6, 6).unwrap();
352 assert_eq!(drain_try_recv(&mut rx), vec![6, 7, 8, 9]);
353
354 tx.send(session_id_a, 10, 10).unwrap();
356 let got = drain_try_recv(&mut rx);
357 assert_eq!(got, vec![10]);
358 }
359
360 #[test]
361 fn test_ordered_channel_multi_clients() {
362 let session_id_a = Uuid::now_v7();
363 let session_id_b = Uuid::now_v7();
364 let (tx, mut rx) = ordered_channel::<(Uuid, u64)>("test".to_string(), true);
365
366 tx.send(session_id_a, 1, (session_id_a, 1)).unwrap();
368 assert_eq!(drain_try_recv(&mut rx), vec![(session_id_a, 1)]);
369 tx.send(session_id_b, 1, (session_id_b, 1)).unwrap();
371 assert_eq!(drain_try_recv(&mut rx), vec![(session_id_b, 1)]);
372 for s in (3..=5).rev() {
373 tx.send(session_id_a, s, (session_id_a, s)).unwrap();
375 tx.send(session_id_b, s, (session_id_b, s)).unwrap();
377 }
378 for s in (7..=9).rev() {
379 tx.send(session_id_a, s, (session_id_a, s)).unwrap();
381 tx.send(session_id_b, s, (session_id_b, s)).unwrap();
383 }
384 assert!(
385 drain_try_recv(&mut rx).is_empty(),
386 "nothing should be delivered yet"
387 );
388
389 tx.send(session_id_a, 2, (session_id_a, 2)).unwrap();
391 assert_eq!(
392 drain_try_recv(&mut rx),
393 vec![
394 (session_id_a, 2),
395 (session_id_a, 3),
396 (session_id_a, 4),
397 (session_id_a, 5),
398 ]
399 );
400 tx.send(session_id_b, 2, (session_id_b, 2)).unwrap();
402 assert_eq!(
403 drain_try_recv(&mut rx),
404 vec![
405 (session_id_b, 2),
406 (session_id_b, 3),
407 (session_id_b, 4),
408 (session_id_b, 5),
409 ]
410 );
411
412 tx.send(session_id_a, 6, (session_id_a, 6)).unwrap();
414 assert_eq!(
415 drain_try_recv(&mut rx),
416 vec![
417 (session_id_a, 6),
418 (session_id_a, 7),
419 (session_id_a, 8),
420 (session_id_a, 9)
421 ]
422 );
423 tx.send(session_id_b, 6, (session_id_b, 6)).unwrap();
425 assert_eq!(
426 drain_try_recv(&mut rx),
427 vec![
428 (session_id_b, 6),
429 (session_id_b, 7),
430 (session_id_b, 8),
431 (session_id_b, 9)
432 ]
433 );
434 }
435
436 #[test]
437 fn test_ordered_channel_duplicates() {
438 let session_id_a = Uuid::now_v7();
439 fn verify_empty_buffers<T>(states: &DashMap<Uuid, Arc<Mutex<BufferState<T>>>>) {
440 for entry in states.iter() {
441 assert!(entry.value().lock().unwrap().buffer.is_empty());
442 }
443 }
444
445 let (tx, mut rx) = ordered_channel::<(Uuid, u64)>("test".to_string(), true);
446 tx.send(session_id_a, 1, (session_id_a, 1)).unwrap();
448 assert_eq!(drain_try_recv(&mut rx), vec![(session_id_a, 1)]);
449 verify_empty_buffers(&tx.states);
450 tx.send(session_id_a, 1, (session_id_a, 1_000)).unwrap();
452 assert!(
453 drain_try_recv(&mut rx).is_empty(),
454 "nothing should be delivered yet"
455 );
456 verify_empty_buffers(&tx.states);
457 tx.send(session_id_a, 2, (session_id_a, 2)).unwrap();
459 assert_eq!(drain_try_recv(&mut rx), vec![(session_id_a, 2)]);
460 verify_empty_buffers(&tx.states);
461 tx.send(session_id_a, 1, (session_id_a, 1_001)).unwrap();
463 assert!(
464 drain_try_recv(&mut rx).is_empty(),
465 "nothing should be delivered yet"
466 );
467 verify_empty_buffers(&tx.states);
468 }
469
470 #[test]
471 fn test_sequencer_clone() {
472 let sequencer = Sequencer {
473 session_id: Uuid::now_v7(),
474 last_seqs: Arc::new(Mutex::new(HashMap::new())),
475 };
476
477 let actor_id = test_actor_id("test_0", "test");
478 let port_id = actor_id.port_id(1);
479
480 sequencer.assign_seq(&port_id);
482 sequencer.assign_seq(&port_id);
483
484 let cloned_sequencer = sequencer.clone();
486 assert_eq!(sequencer.session_id(), cloned_sequencer.session_id(),);
487 assert_eq!(get_seq(cloned_sequencer.assign_seq(&port_id)), 3);
488 }
489
490 #[test]
491 fn test_sequencer_actor_ports_share_sequence() {
492 let sequencer = Sequencer {
493 session_id: Uuid::now_v7(),
494 last_seqs: Arc::new(Mutex::new(HashMap::new())),
495 };
496
497 let actor_id = test_actor_id("worker_0", "worker");
498 let actor_port_1 = actor_id.port_id(TestMsg1::port());
500 let actor_port_2 = actor_id.port_id(TestMsg2::port());
501
502 assert_eq!(get_seq(sequencer.assign_seq(&actor_port_1)), 1);
504 assert_eq!(get_seq(sequencer.assign_seq(&actor_port_2)), 2); assert_eq!(get_seq(sequencer.assign_seq(&actor_port_1)), 3);
506
507 let actor_id_2 = test_actor_id("worker_1", "worker");
509 let actor_port_3 = actor_id_2.port_id(TestMsg1::port());
510 assert_eq!(get_seq(sequencer.assign_seq(&actor_port_3)), 1); }
512
513 #[test]
514 fn test_sequencer_non_actor_ports_have_independent_sequences() {
515 let sequencer = Sequencer {
516 session_id: Uuid::now_v7(),
517 last_seqs: Arc::new(Mutex::new(HashMap::new())),
518 };
519
520 let actor_id_0 = test_actor_id("worker_0", "worker");
521 let actor_id_1 = test_actor_id("worker_1", "worker");
522
523 let port_1 = actor_id_0.port_id(1);
525 let port_2 = actor_id_0.port_id(2);
526
527 assert_eq!(get_seq(sequencer.assign_seq(&port_1)), 1);
529 assert_eq!(get_seq(sequencer.assign_seq(&port_2)), 1); assert_eq!(get_seq(sequencer.assign_seq(&port_1)), 2);
531 assert_eq!(get_seq(sequencer.assign_seq(&port_2)), 2);
532
533 let port_3 = actor_id_1.port_id(1);
535 assert_eq!(get_seq(sequencer.assign_seq(&port_3)), 1); assert_eq!(get_seq(sequencer.assign_seq(&port_1)), 3);
537 assert_eq!(get_seq(sequencer.assign_seq(&port_3)), 2);
538 }
539
540 #[test]
541 fn test_sequencer_mixed_actor_and_non_actor_ports() {
542 let sequencer = Sequencer {
543 session_id: Uuid::now_v7(),
544 last_seqs: Arc::new(Mutex::new(HashMap::new())),
545 };
546
547 let actor_id = test_actor_id("worker_0", "worker");
548
549 let actor_port_1 = actor_id.port_id(TestMsg1::port());
551 let actor_port_2 = actor_id.port_id(TestMsg2::port());
552
553 let non_actor_port_1 = actor_id.port_id(1);
555 let non_actor_port_2 = actor_id.port_id(2);
556
557 assert_eq!(get_seq(sequencer.assign_seq(&actor_port_1)), 1);
559 assert_eq!(get_seq(sequencer.assign_seq(&non_actor_port_1)), 1); assert_eq!(get_seq(sequencer.assign_seq(&actor_port_2)), 2); assert_eq!(get_seq(sequencer.assign_seq(&non_actor_port_2)), 1); assert_eq!(get_seq(sequencer.assign_seq(&non_actor_port_1)), 2); assert_eq!(get_seq(sequencer.assign_seq(&actor_port_1)), 3); assert_eq!(get_seq(sequencer.assign_seq(&non_actor_port_2)), 2); }
566}