1use std::any::TypeId;
13use std::collections::HashMap;
14use std::collections::HashSet;
15use std::fmt;
16use std::ops::DerefMut;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::sync::Mutex;
20
21use dashmap::DashMap;
22use hyperactor_config::AttrValue;
23use hyperactor_config::attrs::declare_attrs;
24use serde::Deserialize;
25use serde::Serialize;
26use tokio::sync::mpsc;
27use tokio::sync::mpsc::error::SendError;
28use typeuri::Named;
29use uuid::Uuid;
30
31use crate::ActorAddr;
32use crate::PortAddr;
33use crate::actor::Signal;
34use crate::introspect::IntrospectMessage;
35
36static BYPASS_TYPE_IDS: LazyLock<HashSet<TypeId>> =
52 LazyLock::new(|| HashSet::from([TypeId::of::<Signal>(), TypeId::of::<IntrospectMessage>()]));
53
54static BYPASS_ACTOR_PORTS: LazyLock<HashSet<u64>> =
55 LazyLock::new(|| HashSet::from([Signal::port(), IntrospectMessage::port()]));
56
57pub(crate) fn is_bypass_workq_type_id(id: TypeId) -> bool {
60 BYPASS_TYPE_IDS.contains(&id)
61}
62
63pub(crate) fn is_bypass_workq_actor_port(port: u64) -> bool {
66 BYPASS_ACTOR_PORTS.contains(&port)
67}
68
69struct BufferState<T> {
71 last_seq: u64,
74 buffer: HashMap<u64, T>,
79}
80
81impl<T> Default for BufferState<T> {
82 fn default() -> Self {
83 Self {
84 last_seq: 0,
85 buffer: HashMap::new(),
86 }
87 }
88}
89
90pub(crate) struct OrderedSender<T> {
92 tx: mpsc::UnboundedSender<T>,
93 states: Arc<DashMap<Uuid, Arc<Mutex<BufferState<T>>>>>,
95 pub(crate) enable_buffering: bool,
96 log_id: String,
98}
99
100pub(crate) fn ordered_channel<T>(
102 log_id: String,
103 enable_buffering: bool,
104) -> (OrderedSender<T>, mpsc::UnboundedReceiver<T>) {
105 let (tx, rx) = mpsc::unbounded_channel();
106 (
107 OrderedSender {
108 tx,
109 states: Arc::new(DashMap::new()),
110 enable_buffering,
111 log_id,
112 },
113 rx,
114 )
115}
116
117#[derive(Debug)]
118pub(crate) enum OrderedSenderError<T> {
119 InvalidZeroSeq(T),
120 SendError(SendError<T>),
121 FlushError(anyhow::Error),
122}
123
124impl<T> Clone for OrderedSender<T> {
125 fn clone(&self) -> Self {
126 Self {
127 tx: self.tx.clone(),
128 states: self.states.clone(),
129 enable_buffering: self.enable_buffering,
130 log_id: self.log_id.clone(),
131 }
132 }
133}
134
135impl<T> OrderedSender<T> {
136 pub(crate) fn send(
146 &self,
147 session_id: Uuid,
148 seq_no: u64,
149 msg: T,
150 ) -> Result<(), OrderedSenderError<T>> {
151 use std::cmp::Ordering;
152
153 assert!(self.enable_buffering);
154 if seq_no == 0 {
155 return Err(OrderedSenderError::InvalidZeroSeq(msg));
156 }
157
158 let state = self.states.entry(session_id).or_default().value().clone();
160 let mut state_guard = state.lock().unwrap();
161 let BufferState { last_seq, buffer } = state_guard.deref_mut();
162
163 match seq_no.cmp(&(*last_seq + 1)) {
164 Ordering::Less => {
165 tracing::warn!(
166 "{} duplicate message from session {} with seq no: {}",
167 self.log_id,
168 session_id,
169 seq_no,
170 );
171 }
172 Ordering::Greater => {
173 let old = buffer.insert(seq_no, msg);
175 assert!(
176 old.is_none(),
177 "{}: same seq is insert to buffer twice: {}",
178 self.log_id,
179 seq_no
180 );
181 }
182 Ordering::Equal => {
183 self.tx.send(msg).map_err(OrderedSenderError::SendError)?;
186 *last_seq += 1;
187
188 while let Some(m) = buffer.remove(&(*last_seq + 1)) {
189 match self.tx.send(m) {
190 Ok(()) => *last_seq += 1,
191 Err(err) => {
192 let flush_err = OrderedSenderError::FlushError(anyhow::anyhow!(
193 "failed to flush buffered message: {}",
194 err
195 ));
196 buffer.insert(*last_seq + 1, err.0);
197 return Err(flush_err);
198 }
199 }
200 }
201 }
206 }
207
208 Ok(())
209 }
210
211 pub(crate) fn direct_send(&self, msg: T) -> Result<(), SendError<T>> {
212 self.tx.send(msg)
213 }
214}
215
216#[derive(Clone, Debug, Hash, PartialEq, Eq)]
219enum SeqKey {
220 Actor(ActorAddr),
222 Port(PortAddr),
224}
225
226#[derive(Debug, Serialize, Deserialize, Clone, Named, AttrValue, PartialEq)]
228pub enum SeqInfo {
229 Session {
231 session_id: Uuid,
233 seq: u64,
235 },
236 Direct,
239}
240
241impl fmt::Display for SeqInfo {
242 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243 match self {
244 Self::Direct => write!(f, "direct"),
245 Self::Session { session_id, seq } => write!(f, "{}:{}", session_id, seq),
246 }
247 }
248}
249
250impl std::str::FromStr for SeqInfo {
251 type Err = anyhow::Error;
252
253 fn from_str(s: &str) -> Result<Self, Self::Err> {
254 if s == "direct" {
255 return Ok(SeqInfo::Direct);
256 }
257
258 let parts: Vec<_> = s.split(':').collect();
259 if parts.len() != 2 {
260 return Err(anyhow::anyhow!("invalid SeqInfo: {}", s));
261 }
262 let session_id: Uuid = parts[0].parse()?;
263 let seq: u64 = parts[1].parse()?;
264 Ok(SeqInfo::Session { session_id, seq })
265 }
266}
267
268declare_attrs! {
269 pub attr SEQ_INFO: SeqInfo;
272}
273
274#[derive(Clone, Debug)]
278pub struct Sequencer {
279 session_id: Uuid,
280 last_seqs: Arc<Mutex<HashMap<SeqKey, u64>>>,
282}
283
284impl Sequencer {
285 pub(crate) fn new(session_id: Uuid) -> Self {
286 Self {
287 session_id,
288 last_seqs: Arc::new(Mutex::new(HashMap::new())),
289 }
290 }
291
292 pub fn assign_seq(&self, port_id: &PortAddr) -> SeqInfo {
301 let key = if port_id.is_handler_port() && !is_bypass_workq_actor_port(port_id.index()) {
302 SeqKey::Actor(port_id.actor_addr().clone())
303 } else {
304 SeqKey::Port(port_id.clone())
305 };
306
307 let mut guard = self.last_seqs.lock().unwrap();
308 let entry = guard.entry(key).or_default();
309 *entry += 1;
310 SeqInfo::Session {
311 session_id: self.session_id,
312 seq: *entry,
313 }
314 }
315
316 pub fn session_id(&self) -> Uuid {
318 self.session_id
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use std::sync::Arc;
325
326 use super::*;
327 use crate::port::Port;
328 use crate::testing::ids::test_actor_id;
329
330 #[derive(Named)]
332 struct TestMsg1;
333
334 #[derive(Named)]
336 struct TestMsg2;
337
338 fn drain_try_recv<T: std::fmt::Debug + Clone>(rx: &mut mpsc::UnboundedReceiver<T>) -> Vec<T> {
339 let mut out = Vec::new();
340 while let Ok(m) = rx.try_recv() {
341 out.push(m);
342 }
343 out
344 }
345
346 fn get_seq(seq_info: SeqInfo) -> u64 {
348 match seq_info {
349 SeqInfo::Session { seq, .. } => seq,
350 SeqInfo::Direct => panic!("expected Session variant, got Direct"),
351 }
352 }
353
354 #[test]
355 fn test_ordered_channel_single_client_send_in_order() {
356 let session_id_a = Uuid::now_v7();
357 let (tx, mut rx) = ordered_channel::<u64>("test".to_string(), true);
358 for s in 1..=10 {
359 tx.send(session_id_a, s, s).unwrap();
360 let got = drain_try_recv(&mut rx);
361 assert_eq!(got, vec![s]);
362 }
363 }
364
365 #[test]
366 fn test_ordered_channel_single_client_send_out_of_order() {
367 let session_id_a = Uuid::now_v7();
368 let (tx, mut rx) = ordered_channel::<u64>("test".to_string(), true);
369
370 for s in (2..=4).rev() {
372 tx.send(session_id_a, s, s).unwrap();
373 }
374
375 for s in (7..=9).rev() {
377 tx.send(session_id_a, s, s).unwrap();
378 }
379
380 assert!(
381 drain_try_recv(&mut rx).is_empty(),
382 "nothing should be delivered yet"
383 );
384
385 tx.send(session_id_a, 1, 1).unwrap();
387 assert_eq!(drain_try_recv(&mut rx), vec![1, 2, 3, 4]);
388
389 tx.send(session_id_a, 5, 5).unwrap();
391 assert_eq!(drain_try_recv(&mut rx), vec![5]);
392
393 tx.send(session_id_a, 6, 6).unwrap();
395 assert_eq!(drain_try_recv(&mut rx), vec![6, 7, 8, 9]);
396
397 tx.send(session_id_a, 10, 10).unwrap();
399 let got = drain_try_recv(&mut rx);
400 assert_eq!(got, vec![10]);
401 }
402
403 #[test]
404 fn test_ordered_channel_multi_clients() {
405 let session_id_a = Uuid::now_v7();
406 let session_id_b = Uuid::now_v7();
407 let (tx, mut rx) = ordered_channel::<(Uuid, u64)>("test".to_string(), true);
408
409 tx.send(session_id_a, 1, (session_id_a, 1)).unwrap();
411 assert_eq!(drain_try_recv(&mut rx), vec![(session_id_a, 1)]);
412 tx.send(session_id_b, 1, (session_id_b, 1)).unwrap();
414 assert_eq!(drain_try_recv(&mut rx), vec![(session_id_b, 1)]);
415 for s in (3..=5).rev() {
416 tx.send(session_id_a, s, (session_id_a, s)).unwrap();
418 tx.send(session_id_b, s, (session_id_b, s)).unwrap();
420 }
421 for s in (7..=9).rev() {
422 tx.send(session_id_a, s, (session_id_a, s)).unwrap();
424 tx.send(session_id_b, s, (session_id_b, s)).unwrap();
426 }
427 assert!(
428 drain_try_recv(&mut rx).is_empty(),
429 "nothing should be delivered yet"
430 );
431
432 tx.send(session_id_a, 2, (session_id_a, 2)).unwrap();
434 assert_eq!(
435 drain_try_recv(&mut rx),
436 vec![
437 (session_id_a, 2),
438 (session_id_a, 3),
439 (session_id_a, 4),
440 (session_id_a, 5),
441 ]
442 );
443 tx.send(session_id_b, 2, (session_id_b, 2)).unwrap();
445 assert_eq!(
446 drain_try_recv(&mut rx),
447 vec![
448 (session_id_b, 2),
449 (session_id_b, 3),
450 (session_id_b, 4),
451 (session_id_b, 5),
452 ]
453 );
454
455 tx.send(session_id_a, 6, (session_id_a, 6)).unwrap();
457 assert_eq!(
458 drain_try_recv(&mut rx),
459 vec![
460 (session_id_a, 6),
461 (session_id_a, 7),
462 (session_id_a, 8),
463 (session_id_a, 9)
464 ]
465 );
466 tx.send(session_id_b, 6, (session_id_b, 6)).unwrap();
468 assert_eq!(
469 drain_try_recv(&mut rx),
470 vec![
471 (session_id_b, 6),
472 (session_id_b, 7),
473 (session_id_b, 8),
474 (session_id_b, 9)
475 ]
476 );
477 }
478
479 #[test]
480 fn test_ordered_channel_duplicates() {
481 let session_id_a = Uuid::now_v7();
482 fn verify_empty_buffers<T>(states: &DashMap<Uuid, Arc<Mutex<BufferState<T>>>>) {
483 for entry in states.iter() {
484 assert!(entry.value().lock().unwrap().buffer.is_empty());
485 }
486 }
487
488 let (tx, mut rx) = ordered_channel::<(Uuid, u64)>("test".to_string(), true);
489 tx.send(session_id_a, 1, (session_id_a, 1)).unwrap();
491 assert_eq!(drain_try_recv(&mut rx), vec![(session_id_a, 1)]);
492 verify_empty_buffers(&tx.states);
493 tx.send(session_id_a, 1, (session_id_a, 1_000)).unwrap();
495 assert!(
496 drain_try_recv(&mut rx).is_empty(),
497 "nothing should be delivered yet"
498 );
499 verify_empty_buffers(&tx.states);
500 tx.send(session_id_a, 2, (session_id_a, 2)).unwrap();
502 assert_eq!(drain_try_recv(&mut rx), vec![(session_id_a, 2)]);
503 verify_empty_buffers(&tx.states);
504 tx.send(session_id_a, 1, (session_id_a, 1_001)).unwrap();
506 assert!(
507 drain_try_recv(&mut rx).is_empty(),
508 "nothing should be delivered yet"
509 );
510 verify_empty_buffers(&tx.states);
511 }
512
513 #[test]
514 fn test_sequencer_clone() {
515 let sequencer = Sequencer {
516 session_id: Uuid::now_v7(),
517 last_seqs: Arc::new(Mutex::new(HashMap::new())),
518 };
519
520 let actor_ref: ActorAddr = test_actor_id("test_0", "test");
521 let port_ref = actor_ref.port_addr(Port::from(1));
522
523 sequencer.assign_seq(&port_ref);
525 sequencer.assign_seq(&port_ref);
526
527 let cloned_sequencer = sequencer.clone();
529 assert_eq!(sequencer.session_id(), cloned_sequencer.session_id(),);
530 assert_eq!(get_seq(cloned_sequencer.assign_seq(&port_ref)), 3);
531 }
532
533 #[test]
534 fn test_sequencer_handler_ports_share_sequence() {
535 let sequencer = Sequencer {
536 session_id: Uuid::now_v7(),
537 last_seqs: Arc::new(Mutex::new(HashMap::new())),
538 };
539
540 let actor_ref: ActorAddr = test_actor_id("worker_0", "worker");
541 let handler_port_1 = actor_ref.port_addr(Port::from(TestMsg1::port()));
543 let handler_port_2 = actor_ref.port_addr(Port::from(TestMsg2::port()));
544
545 assert_eq!(get_seq(sequencer.assign_seq(&handler_port_1)), 1);
547 assert_eq!(get_seq(sequencer.assign_seq(&handler_port_2)), 2); assert_eq!(get_seq(sequencer.assign_seq(&handler_port_1)), 3);
549
550 let actor_ref_2: ActorAddr = test_actor_id("worker_1", "worker");
552 let handler_port_3 = actor_ref_2.port_addr(Port::from(TestMsg1::port()));
553 assert_eq!(get_seq(sequencer.assign_seq(&handler_port_3)), 1); }
555
556 #[test]
557 fn test_sequencer_non_handler_ports_have_independent_sequences() {
558 let sequencer = Sequencer {
559 session_id: Uuid::now_v7(),
560 last_seqs: Arc::new(Mutex::new(HashMap::new())),
561 };
562
563 let actor_ref_0: ActorAddr = test_actor_id("worker_0", "worker");
564 let actor_ref_1: ActorAddr = test_actor_id("worker_1", "worker");
565
566 let port_1 = actor_ref_0.port_addr(Port::from(1));
568 let port_2 = actor_ref_0.port_addr(Port::from(2));
569
570 assert_eq!(get_seq(sequencer.assign_seq(&port_1)), 1);
572 assert_eq!(get_seq(sequencer.assign_seq(&port_2)), 1); assert_eq!(get_seq(sequencer.assign_seq(&port_1)), 2);
574 assert_eq!(get_seq(sequencer.assign_seq(&port_2)), 2);
575
576 let port_3 = actor_ref_1.port_addr(Port::from(1));
578 assert_eq!(get_seq(sequencer.assign_seq(&port_3)), 1); assert_eq!(get_seq(sequencer.assign_seq(&port_1)), 3);
580 assert_eq!(get_seq(sequencer.assign_seq(&port_3)), 2);
581 }
582
583 #[test]
584 fn test_sequencer_mixed_handler_and_non_handler_ports() {
585 let sequencer = Sequencer {
586 session_id: Uuid::now_v7(),
587 last_seqs: Arc::new(Mutex::new(HashMap::new())),
588 };
589
590 let actor_ref: ActorAddr = test_actor_id("worker_0", "worker");
591
592 let handler_port_1 = actor_ref.port_addr(Port::from(TestMsg1::port()));
594 let handler_port_2 = actor_ref.port_addr(Port::from(TestMsg2::port()));
595
596 let non_handler_port_1 = actor_ref.port_addr(Port::from(1));
598 let non_handler_port_2 = actor_ref.port_addr(Port::from(2));
599
600 assert_eq!(get_seq(sequencer.assign_seq(&handler_port_1)), 1);
602 assert_eq!(get_seq(sequencer.assign_seq(&non_handler_port_1)), 1); assert_eq!(get_seq(sequencer.assign_seq(&handler_port_2)), 2); assert_eq!(get_seq(sequencer.assign_seq(&non_handler_port_2)), 1); assert_eq!(get_seq(sequencer.assign_seq(&non_handler_port_1)), 2); assert_eq!(get_seq(sequencer.assign_seq(&handler_port_1)), 3); assert_eq!(get_seq(sequencer.assign_seq(&non_handler_port_2)), 2); }
609
610 #[test]
611 fn bypass_registry_introspect_message() {
612 assert!(is_bypass_workq_type_id(TypeId::of::<IntrospectMessage>()));
613 assert!(is_bypass_workq_actor_port(IntrospectMessage::port()));
614 }
615
616 #[test]
617 fn bypass_registry_signal() {
618 assert!(is_bypass_workq_type_id(TypeId::of::<Signal>()));
619 assert!(is_bypass_workq_actor_port(Signal::port()));
620 }
621
622 #[test]
623 fn bypass_registry_lists_have_matching_lengths() {
624 assert_eq!(BYPASS_TYPE_IDS.len(), BYPASS_ACTOR_PORTS.len());
626 }
627
628 #[test]
629 fn bypass_actor_port_uses_per_port_seq_counter() {
630 let sequencer = Sequencer::new(Uuid::now_v7());
635 let actor_ref: ActorAddr = test_actor_id("agent_0", "proc_agent");
636
637 let introspect_port = actor_ref.port_addr(Port::from(IntrospectMessage::port()));
638 let regular_actor_port = actor_ref.port_addr(Port::from(TestMsg1::port()));
639
640 assert_eq!(get_seq(sequencer.assign_seq(&introspect_port)), 1);
642 assert_eq!(get_seq(sequencer.assign_seq(®ular_actor_port)), 1);
643 assert_eq!(get_seq(sequencer.assign_seq(&introspect_port)), 2);
645 assert_eq!(get_seq(sequencer.assign_seq(®ular_actor_port)), 2);
646 }
647}