1use std::sync::Arc;
30
31use async_trait::async_trait;
32use backoff::ExponentialBackoffBuilder;
33use backoff::backoff::Backoff;
34use dashmap::DashMap;
35use tokio::sync::mpsc;
36use tokio::sync::oneshot;
37use tokio::sync::watch;
38use tokio::time::Instant;
39use tokio_util::sync::CancellationToken;
40
41use super::ClientError;
42use super::Link;
43use super::LinkStatus;
44use super::ServerError;
45use super::SessionId;
46use super::log_send_error;
47use super::read_link_init;
48use super::server::AcceptorLink;
49use super::server::ServerHandle;
50use super::session;
51use super::session::Next;
52use super::session::Session;
53use crate::RemoteMessage;
54use crate::channel::ChannelAddr;
55use crate::channel::ChannelError;
56use crate::channel::ChannelTransport;
57use crate::channel::Rx;
58use crate::channel::SendError;
59use crate::channel::Tx;
60use crate::channel::TxStatus;
61use crate::channel::net::Stream;
62use crate::channel::net::meta;
63use crate::channel::net::tls;
64use crate::metrics;
65use crate::sync::mvar::MVar;
66
67pub struct DuplexServer<In: RemoteMessage, Out: RemoteMessage> {
69 accept_rx: mpsc::Receiver<(DuplexRx<In>, DuplexTx<Out>)>,
70 _handle: ServerHandle,
71 addr: ChannelAddr,
72}
73
74impl<In: RemoteMessage, Out: RemoteMessage> DuplexServer<In, Out> {
75 pub async fn accept(&mut self) -> Result<(DuplexRx<In>, DuplexTx<Out>), ChannelError> {
77 self.accept_rx.recv().await.ok_or(ChannelError::Closed)
78 }
79
80 pub fn addr(&self) -> &ChannelAddr {
82 &self.addr
83 }
84}
85
86pub struct DuplexRx<M: RemoteMessage>(mpsc::Receiver<M>, ChannelAddr);
88
89impl<M: RemoteMessage> DuplexRx<M> {
90 pub(super) fn new(rx: mpsc::Receiver<M>, addr: ChannelAddr) -> Self {
91 Self(rx, addr)
92 }
93}
94
95#[async_trait]
96impl<M: RemoteMessage> Rx<M> for DuplexRx<M> {
97 async fn recv(&mut self) -> Result<M, ChannelError> {
98 self.0.recv().await.ok_or(ChannelError::Closed)
99 }
100
101 fn addr(&self) -> ChannelAddr {
102 self.1.clone()
103 }
104
105 async fn join(self) {}
106}
107
108pub struct DuplexTx<M: RemoteMessage> {
110 tx: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
111 addr: ChannelAddr,
112 status: watch::Receiver<TxStatus>,
113}
114
115impl<M: RemoteMessage> DuplexTx<M> {
116 pub(super) fn new(
117 tx: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
118 addr: ChannelAddr,
119 status: watch::Receiver<TxStatus>,
120 ) -> Self {
121 Self { tx, addr, status }
122 }
123}
124
125#[async_trait]
126impl<M: RemoteMessage> Tx<M> for DuplexTx<M> {
127 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
128 let return_channel = return_channel.unwrap_or_else(|| oneshot::channel().0);
129 if let Err(mpsc::error::SendError((message, return_channel, _))) =
130 self.tx
131 .send((message, return_channel, tokio::time::Instant::now()))
132 {
133 let reason = self.status.borrow().as_closed().map(|r| r.to_string());
134 let _ = return_channel.send(SendError {
135 error: ChannelError::Closed,
136 message,
137 reason,
138 });
139 }
140 }
141
142 fn addr(&self) -> ChannelAddr {
143 self.addr.clone()
144 }
145
146 fn status(&self) -> &watch::Receiver<TxStatus> {
147 &self.status
148 }
149}
150
151impl<M: RemoteMessage> Clone for DuplexTx<M> {
152 fn clone(&self) -> Self {
153 Self {
154 tx: self.tx.clone(),
155 addr: self.addr.clone(),
156 status: self.status.clone(),
157 }
158 }
159}
160
161pub fn serve<In: RemoteMessage, Out: RemoteMessage>(
163 addr: ChannelAddr,
164) -> Result<DuplexServer<In, Out>, ServerError> {
165 let (mut listener, channel_addr) = super::listen(addr)?;
166
167 let (accept_tx, accept_rx) = mpsc::channel(16);
168 let cancel_token = CancellationToken::new();
169 let child_token = cancel_token.child_token();
170
171 let is_tls = matches!(
172 channel_addr.transport(),
173 ChannelTransport::Tls | ChannelTransport::MetaTls(_)
174 );
175 let dest = channel_addr.clone();
176 let prepare = move |stream: Box<dyn Stream>, source: ChannelAddr| {
177 let dest = dest.clone();
178 async move {
179 if is_tls {
180 let tls_acceptor = match dest.transport() {
181 ChannelTransport::Tls => tls::tls_acceptor()?,
182 _ => meta::tls_acceptor(true)?,
183 };
184 let mut tls_stream = tls_acceptor.accept(stream).await?;
185 let session_id = read_link_init(&mut tls_stream)
186 .await
187 .map_err(|e| anyhow::anyhow!("LinkInit read failed from {}: {}", source, e))?;
188 Ok((session_id, Box::new(tls_stream) as Box<dyn Stream>))
189 } else {
190 let mut stream = stream;
191 let session_id = read_link_init(&mut stream)
192 .await
193 .map_err(|e| anyhow::anyhow!("LinkInit read failed from {}: {}", source, e))?;
194 Ok((session_id, stream))
195 }
196 }
197 };
198
199 let sessions: Arc<DashMap<SessionId, MVar<Box<dyn Stream>>>> = Arc::new(DashMap::new());
200 let child_cancel = CancellationToken::new();
201 let dispatch_dest = channel_addr.clone();
202 let dispatch = {
203 let sessions = Arc::clone(&sessions);
204 let accept_tx = accept_tx.clone();
205 let child_cancel = child_cancel.clone();
206 let dest = dispatch_dest;
207 move |session_id: SessionId, stream: Box<dyn Stream>| {
208 let sessions = Arc::clone(&sessions);
209 let accept_tx = accept_tx.clone();
210 let cancel = child_cancel.child_token();
211 let dest = dest.clone();
212 async move {
213 dispatch_duplex_stream::<In, Out>(
214 session_id, stream, &sessions, dest, &accept_tx, cancel,
215 )
216 .await;
217 }
218 }
219 };
220
221 let ca = channel_addr.clone();
222 let join_handle = tokio::spawn(async move {
223 let result =
224 super::server::accept_loop(&mut listener, &ca, &child_token, prepare, dispatch).await;
225 child_cancel.cancel();
226 result
227 });
228
229 let server_handle = ServerHandle::new(join_handle, cancel_token, channel_addr.clone());
230
231 Ok(DuplexServer {
232 accept_rx,
233 _handle: server_handle,
234 addr: channel_addr,
235 })
236}
237
238enum Either {
240 Send(session::SendLoopError),
241 Recv(session::RecvLoopError),
242}
243
244async fn dispatch_duplex_stream<In: RemoteMessage, Out: RemoteMessage>(
247 session_id: SessionId,
248 stream: Box<dyn Stream>,
249 sessions: &DashMap<SessionId, MVar<Box<dyn Stream>>>,
250 addr: ChannelAddr,
251 accept_tx: &mpsc::Sender<(DuplexRx<In>, DuplexTx<Out>)>,
252 cancel: CancellationToken,
253) {
254 let mvar = {
255 let entry = sessions.entry(session_id);
256 match entry {
257 dashmap::mapref::entry::Entry::Occupied(e) => e.get().clone(),
258 dashmap::mapref::entry::Entry::Vacant(e) => {
259 let mvar: MVar<Box<dyn Stream>> = MVar::empty();
260 let link = AcceptorLink {
261 dest: addr.clone(),
262 session_id,
263 stream: mvar.clone(),
264 cancel: cancel.clone(),
265 };
266
267 let (inbound_tx, inbound_rx) = mpsc::channel::<In>(1024);
268 let (outbound_tx, outbound_rx) =
269 mpsc::unbounded_channel::<(Out, oneshot::Sender<SendError<Out>>, Instant)>();
270 let (notify, status) = watch::channel(TxStatus::Active);
271 let net_rx = DuplexRx(inbound_rx, addr.clone());
272 let net_tx = DuplexTx {
273 tx: outbound_tx,
274 addr: addr.clone(),
275 status,
276 };
277 let _ = accept_tx.send((net_rx, net_tx)).await;
278
279 let session_ct = cancel.clone();
280 let dest = addr.clone();
281 tokio::spawn(async move {
282 let mut session = Session::new(link);
283 let mut recv_next = Next { seq: 0, ack: 0 };
284 let log_id = format!("duplex server {:016x}", session_id.0);
285 let mut deliveries = session::Deliveries {
286 outbox: session::Outbox::new(log_id.clone(), dest, session_id.0),
287 unacked: session::Unacked::new(None, log_id),
288 };
289 let mut outbound_rx = outbound_rx;
290
291 loop {
292 let connected = match session.connect().await {
293 Ok(s) => s,
294 Err(_) => break,
295 };
296 deliveries.requeue_unacked();
297 let result = {
298 let recv_stream = connected.stream(super::INITIATOR_TO_ACCEPTOR);
299 let send_stream = connected.stream(super::ACCEPTOR_TO_INITIATOR);
300 tokio::select! {
301 r = session::recv_connected::<In, _, _>(
302 &recv_stream,
303 &inbound_tx,
304 &mut recv_next,
305 ) => r.map_err(Either::Recv),
306 r = session::send_connected(
307 &send_stream,
308 &mut deliveries,
309 &mut outbound_rx,
310 ) => r.map_err(Either::Send),
311 _ = session_ct.cancelled() => Err(Either::Recv(session::RecvLoopError::Cancelled)),
312 }
313 };
314
315 let terminal = match &result {
316 Ok(()) => {
317 tracing::info!(
318 session_id = session_id.0,
319 "duplex recv_connected returned EOF, awaiting reconnect"
320 );
321 false
322 }
323 Err(Either::Send(session::SendLoopError::Io(err))) => {
324 tracing::info!(
325 session_id = session_id.0,
326 error = %err,
327 "duplex send error (recoverable)",
328 );
329 false
330 }
331 Err(Either::Recv(session::RecvLoopError::Io(err))) => {
332 tracing::info!(
333 session_id = session_id.0,
334 error = %err,
335 "duplex recv error (recoverable)",
336 );
337 false
338 }
339 Err(Either::Send(e)) => {
340 tracing::info!(
341 session_id = session_id.0,
342 error = %e,
343 "duplex send terminal error"
344 );
345 true
346 }
347 Err(Either::Recv(e)) => {
348 tracing::info!(
349 session_id = session_id.0,
350 error = %e,
351 "duplex recv terminal error"
352 );
353 true
354 }
355 };
356 session = connected.release();
357 if terminal {
358 break;
359 }
360 }
361
362 let _ = notify.send(TxStatus::Closed("duplex session ended".into()));
363 });
364
365 e.insert(mvar.clone());
366 mvar
367 }
368 }
369 };
370
371 mvar.put(stream).await;
372}
373
374pub(crate) fn spawn<Out: RemoteMessage, In: RemoteMessage>(
377 link: impl Link,
378) -> (DuplexTx<Out>, DuplexRx<In>) {
379 let addr = link.dest();
380 let session_id = link.link_id();
381 let (outbound_tx, outbound_rx) = tokio::sync::mpsc::unbounded_channel();
382 let (inbound_tx, inbound_rx) = tokio::sync::mpsc::channel::<In>(1024);
383 let (notify, status) = watch::channel(TxStatus::Active);
384 let dest = addr.clone();
385 crate::init::get_runtime().spawn(async move {
386 let mut session = Session::new(link);
387 let log_id = format!("session {}.{:016x}", dest, session_id.0);
388 let mut deliveries = session::Deliveries {
389 outbox: session::Outbox::new(log_id.clone(), dest.clone(), session_id.0),
390 unacked: session::Unacked::new(None, log_id),
391 };
392 let mut outbound_rx = outbound_rx;
393 let mut recv_next = Next { seq: 0, ack: 0 };
394 let mut reconnect_backoff = ExponentialBackoffBuilder::new()
395 .with_initial_interval(std::time::Duration::from_millis(10))
396 .with_multiplier(2.0)
397 .with_randomization_factor(0.1)
398 .with_max_interval(std::time::Duration::from_secs(5))
399 .with_max_elapsed_time(None)
400 .build();
401
402 let mut link_status = LinkStatus::NeverConnected;
403
404 loop {
405 let connected = match session.connect().await {
406 Ok(s) => s,
407 Err(_) => break,
408 };
409
410 metrics::CHANNEL_CONNECTIONS.add(
411 1,
412 hyperactor_telemetry::kv_pairs!(
413 "transport" => dest.transport().to_string(),
414 "mode" => "duplex",
415 "reason" => "link connected",
416 ),
417 );
418
419 if !deliveries.unacked.is_empty() {
420 metrics::CHANNEL_RECONNECTIONS.add(
421 1,
422 hyperactor_telemetry::kv_pairs!(
423 "dest" => dest.to_string(),
424 "transport" => dest.transport().to_string(),
425 "mode" => "duplex",
426 "reason" => "reconnect_with_unacked",
427 ),
428 );
429 }
430 deliveries.requeue_unacked();
431
432 link_status.connected();
433 let connected_at = tokio::time::Instant::now();
434
435 let result = {
436 let send_stream = connected.stream(super::INITIATOR_TO_ACCEPTOR);
437 let recv_stream = connected.stream(super::ACCEPTOR_TO_INITIATOR);
438 tokio::select! {
439 r = session::send_connected(
440 &send_stream, &mut deliveries, &mut outbound_rx,
441 ) => r.map_err(Either::Send),
442 r = session::recv_connected::<In, _, _>(
443 &recv_stream, &inbound_tx, &mut recv_next,
444 ) => r.map_err(Either::Recv),
445 }
446 };
447
448 link_status.disconnected();
449
450 if connected_at.elapsed() > tokio::time::Duration::from_secs(1) {
451 reconnect_backoff.reset();
452 }
453
454 let terminal = match &result {
455 Ok(()) => {
456 if let Some(delay) = reconnect_backoff.next_backoff() {
457 tracing::info!(
458 dest = %dest,
459 session_id = session_id.0,
460 delay_ms = delay.as_millis() as u64,
461 "duplex send_connected returned EOF, reconnecting after backoff; {link_status}"
462 );
463 tokio::time::sleep(delay).await;
464 }
465 false
466 }
467 Err(Either::Send(e)) => {
468 let terminal = log_send_error(e, &dest, session_id.0, "duplex", &link_status);
469 if !terminal {
470 if let Some(delay) = reconnect_backoff.next_backoff() {
472 tracing::info!(
473 dest = %dest,
474 session_id = session_id.0,
475 error = %e,
476 delay_ms = delay.as_millis() as u64,
477 mode = "duplex",
478 "send error (recoverable), reconnecting after backoff; {link_status}",
479 );
480 tokio::time::sleep(delay).await;
481 }
482 }
483 terminal
484 }
485 Err(Either::Recv(session::RecvLoopError::Io(err))) => {
486 if let Some(delay) = reconnect_backoff.next_backoff() {
487 tracing::info!(
488 dest = %dest,
489 session_id = session_id.0,
490 error = %err,
491 delay_ms = delay.as_millis() as u64,
492 mode = "duplex",
493 "recv error (recoverable), reconnecting after backoff; {link_status}",
494 );
495 tokio::time::sleep(delay).await;
496 }
497 metrics::CHANNEL_ERRORS.add(
498 1,
499 hyperactor_telemetry::kv_pairs!(
500 "dest" => dest.to_string(),
501 "session_id" => session_id.0.to_string(),
502 "error_type" => metrics::ChannelErrorType::SendError.as_str(),
503 "mode" => "duplex",
504 ),
505 );
506 false
507 }
508 Err(Either::Recv(e)) => {
509 tracing::info!(
510 dest = %dest,
511 session_id = session_id.0,
512 error = %e,
513 "duplex recv terminal error; {link_status}"
514 );
515 true
516 }
517 };
518 session = connected.release();
519 if terminal {
520 break;
521 }
522 }
523
524 let _ = notify.send(TxStatus::Closed("duplex session ended".into()));
525 });
526 (
527 DuplexTx::new(outbound_tx, addr.clone(), status),
528 DuplexRx::new(inbound_rx, addr),
529 )
530}
531
532pub fn dial<Out: RemoteMessage, In: RemoteMessage>(
534 addr: ChannelAddr,
535) -> Result<(DuplexTx<Out>, DuplexRx<In>), ClientError> {
536 Ok(spawn(super::link(addr)?))
537}
538
539#[cfg(test)]
540mod tests {
541 use timed_test::async_timed_test;
542
543 use super::*;
544 use crate::channel::ChannelTransport;
545
546 #[async_timed_test(timeout_secs = 30)]
547 #[cfg_attr(not(fbcode_build), ignore)]
549 async fn test_duplex_basic() {
550 let mut server =
551 serve::<u64, String>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
552 let server_addr = server.addr().clone();
553
554 let (client_tx, mut client_rx) = dial::<u64, String>(server_addr).unwrap();
556
557 let (mut server_rx, server_tx) = server.accept().await.unwrap();
559
560 client_tx.post(42);
562 let received = server_rx.recv().await.unwrap();
563 assert_eq!(received, 42);
564
565 server_tx.post("hello".to_string());
567 let received = client_rx.recv().await.unwrap();
568 assert_eq!(received, "hello");
569
570 for i in 0..10u64 {
572 client_tx.post(i);
573 assert_eq!(server_rx.recv().await.unwrap(), i);
574
575 server_tx.post(format!("msg-{}", i));
576 assert_eq!(client_rx.recv().await.unwrap(), format!("msg-{}", i));
577 }
578 }
579
580 #[async_timed_test(timeout_secs = 30)]
581 #[cfg_attr(not(fbcode_build), ignore)]
582 async fn test_duplex_multiple_links() {
583 let mut server = serve::<u64, u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
584 let server_addr = server.addr().clone();
585
586 let (tx1, mut rx1) = dial::<u64, u64>(server_addr.clone()).unwrap();
588 let (mut srx1, stx1) = server.accept().await.unwrap();
589
590 let (tx2, mut rx2) = dial::<u64, u64>(server_addr).unwrap();
591 let (mut srx2, stx2) = server.accept().await.unwrap();
592
593 tx1.post(100);
595 assert_eq!(srx1.recv().await.unwrap(), 100);
596 stx1.post(200);
597 assert_eq!(rx1.recv().await.unwrap(), 200);
598
599 tx2.post(300);
601 assert_eq!(srx2.recv().await.unwrap(), 300);
602 stx2.post(400);
603 assert_eq!(rx2.recv().await.unwrap(), 400);
604 }
605
606 async fn duplex_ping_pong(
609 addr: ChannelAddr,
610 iterations: usize,
611 ) -> anyhow::Result<std::time::Duration> {
612 let mut server = serve::<u64, u64>(addr)?;
613 let server_addr = server.addr().clone();
614
615 let server_handle = tokio::spawn(async move {
616 let (mut rx, tx) = server.accept().await.unwrap();
617 while let Ok(msg) = rx.recv().await {
618 tx.post(msg);
619 }
620 });
621
622 let (client_tx, mut client_rx) = dial::<u64, u64>(server_addr).unwrap();
623
624 for i in 0..10u64 {
626 client_tx.post(i);
627 assert_eq!(client_rx.recv().await?, i);
628 }
629
630 let start = std::time::Instant::now();
631 for i in 0..iterations as u64 {
632 client_tx.post(i);
633 assert_eq!(client_rx.recv().await?, i);
634 }
635 let elapsed = start.elapsed();
636
637 server_handle.abort();
638 Ok(elapsed)
639 }
640
641 #[async_timed_test(timeout_secs = 30)]
642 #[cfg_attr(not(fbcode_build), ignore)]
643 async fn test_duplex_ping_pong_tcp() {
644 let elapsed = duplex_ping_pong(ChannelAddr::Tcp("[::1]:0".parse().unwrap()), 100)
645 .await
646 .unwrap();
647 println!("TCP duplex: 100 round-trips in {elapsed:?}");
648 }
649
650 #[async_timed_test(timeout_secs = 30)]
651 async fn test_duplex_ping_pong_unix() {
652 let elapsed = duplex_ping_pong(ChannelAddr::any(ChannelTransport::Unix), 100)
653 .await
654 .unwrap();
655 println!("Unix duplex: 100 round-trips in {elapsed:?}");
656 }
657}