1use std::sync::Arc;
30
31use async_trait::async_trait;
32use dashmap::DashMap;
33use tokio::sync::mpsc;
34use tokio::sync::oneshot;
35use tokio::sync::watch;
36use tokio::time::Instant;
37use tokio_util::sync::CancellationToken;
38
39use super::ClientError;
40use super::Link;
41use super::ServerError;
42use super::SessionId;
43use super::log_send_error;
44use super::read_link_init;
45use super::server::AcceptorLink;
46use super::server::ServerHandle;
47use super::session;
48use super::session::Next;
49use super::session::Session;
50use crate::RemoteMessage;
51use crate::channel::ChannelAddr;
52use crate::channel::ChannelError;
53use crate::channel::ChannelTransport;
54use crate::channel::Rx;
55use crate::channel::SendError;
56use crate::channel::Tx;
57use crate::channel::TxStatus;
58use crate::channel::net::Stream;
59use crate::channel::net::meta;
60use crate::channel::net::tls;
61use crate::metrics;
62use crate::sync::mvar::MVar;
63
64pub struct DuplexServer<In: RemoteMessage, Out: RemoteMessage> {
66 accept_rx: mpsc::Receiver<(DuplexRx<In>, DuplexTx<Out>)>,
67 _handle: ServerHandle,
68 addr: ChannelAddr,
69}
70
71impl<In: RemoteMessage, Out: RemoteMessage> DuplexServer<In, Out> {
72 pub async fn accept(&mut self) -> Result<(DuplexRx<In>, DuplexTx<Out>), ChannelError> {
74 self.accept_rx.recv().await.ok_or(ChannelError::Closed)
75 }
76
77 pub fn addr(&self) -> &ChannelAddr {
79 &self.addr
80 }
81}
82
83pub struct DuplexRx<M: RemoteMessage>(mpsc::Receiver<M>, ChannelAddr);
85
86impl<M: RemoteMessage> DuplexRx<M> {
87 pub(super) fn new(rx: mpsc::Receiver<M>, addr: ChannelAddr) -> Self {
88 Self(rx, addr)
89 }
90}
91
92#[async_trait]
93impl<M: RemoteMessage> Rx<M> for DuplexRx<M> {
94 async fn recv(&mut self) -> Result<M, ChannelError> {
95 self.0.recv().await.ok_or(ChannelError::Closed)
96 }
97
98 fn addr(&self) -> ChannelAddr {
99 self.1.clone()
100 }
101
102 async fn join(self) {}
103}
104
105pub struct DuplexTx<M: RemoteMessage> {
107 tx: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
108 addr: ChannelAddr,
109 status: watch::Receiver<TxStatus>,
110}
111
112impl<M: RemoteMessage> DuplexTx<M> {
113 pub(super) fn new(
114 tx: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
115 addr: ChannelAddr,
116 status: watch::Receiver<TxStatus>,
117 ) -> Self {
118 Self { tx, addr, status }
119 }
120}
121
122#[async_trait]
123impl<M: RemoteMessage> Tx<M> for DuplexTx<M> {
124 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
125 let return_channel = return_channel.unwrap_or_else(|| oneshot::channel().0);
126 if let Err(mpsc::error::SendError((message, return_channel, _))) =
127 self.tx
128 .send((message, return_channel, tokio::time::Instant::now()))
129 {
130 let _ = return_channel.send(SendError {
131 error: ChannelError::Closed,
132 message,
133 reason: None,
134 });
135 }
136 }
137
138 fn addr(&self) -> ChannelAddr {
139 self.addr.clone()
140 }
141
142 fn status(&self) -> &watch::Receiver<TxStatus> {
143 &self.status
144 }
145}
146
147impl<M: RemoteMessage> Clone for DuplexTx<M> {
148 fn clone(&self) -> Self {
149 Self {
150 tx: self.tx.clone(),
151 addr: self.addr.clone(),
152 status: self.status.clone(),
153 }
154 }
155}
156
157pub fn serve<In: RemoteMessage, Out: RemoteMessage>(
159 addr: ChannelAddr,
160) -> Result<DuplexServer<In, Out>, ServerError> {
161 let (mut listener, channel_addr) = super::listen(addr)?;
162
163 let (accept_tx, accept_rx) = mpsc::channel(16);
164 let cancel_token = CancellationToken::new();
165 let child_token = cancel_token.child_token();
166
167 let is_tls = matches!(
168 channel_addr.transport(),
169 ChannelTransport::Tls | ChannelTransport::MetaTls(_)
170 );
171 let dest = channel_addr.clone();
172 let prepare = move |stream: Box<dyn Stream>, source: ChannelAddr| {
173 let dest = dest.clone();
174 async move {
175 if is_tls {
176 let tls_acceptor = match dest.transport() {
177 ChannelTransport::Tls => tls::tls_acceptor()?,
178 _ => meta::tls_acceptor(true)?,
179 };
180 let mut tls_stream = tls_acceptor.accept(stream).await?;
181 let session_id = read_link_init(&mut tls_stream)
182 .await
183 .map_err(|e| anyhow::anyhow!("LinkInit read failed from {}: {}", source, e))?;
184 Ok((session_id, Box::new(tls_stream) as Box<dyn Stream>))
185 } else {
186 let mut stream = stream;
187 let session_id = read_link_init(&mut stream)
188 .await
189 .map_err(|e| anyhow::anyhow!("LinkInit read failed from {}: {}", source, e))?;
190 Ok((session_id, stream))
191 }
192 }
193 };
194
195 let sessions: Arc<DashMap<SessionId, MVar<Box<dyn Stream>>>> = Arc::new(DashMap::new());
196 let child_cancel = CancellationToken::new();
197 let dispatch_dest = channel_addr.clone();
198 let dispatch = {
199 let sessions = Arc::clone(&sessions);
200 let accept_tx = accept_tx.clone();
201 let child_cancel = child_cancel.clone();
202 let dest = dispatch_dest;
203 move |session_id: SessionId, stream: Box<dyn Stream>| {
204 let sessions = Arc::clone(&sessions);
205 let accept_tx = accept_tx.clone();
206 let cancel = child_cancel.child_token();
207 let dest = dest.clone();
208 async move {
209 dispatch_duplex_stream::<In, Out>(
210 session_id, stream, &sessions, dest, &accept_tx, cancel,
211 )
212 .await;
213 }
214 }
215 };
216
217 let ca = channel_addr.clone();
218 let join_handle = tokio::spawn(async move {
219 let result =
220 super::server::accept_loop(&mut listener, &ca, &child_token, prepare, dispatch).await;
221 child_cancel.cancel();
222 result
223 });
224
225 let server_handle = ServerHandle::new(join_handle, cancel_token, channel_addr.clone());
226
227 Ok(DuplexServer {
228 accept_rx,
229 _handle: server_handle,
230 addr: channel_addr,
231 })
232}
233
234enum Either {
236 Send(session::SendLoopError),
237 Recv(session::RecvLoopError),
238}
239
240async fn dispatch_duplex_stream<In: RemoteMessage, Out: RemoteMessage>(
243 session_id: SessionId,
244 stream: Box<dyn Stream>,
245 sessions: &DashMap<SessionId, MVar<Box<dyn Stream>>>,
246 addr: ChannelAddr,
247 accept_tx: &mpsc::Sender<(DuplexRx<In>, DuplexTx<Out>)>,
248 cancel: CancellationToken,
249) {
250 let mvar = {
251 let entry = sessions.entry(session_id);
252 match entry {
253 dashmap::mapref::entry::Entry::Occupied(e) => e.get().clone(),
254 dashmap::mapref::entry::Entry::Vacant(e) => {
255 let mvar: MVar<Box<dyn Stream>> = MVar::empty();
256 let link = AcceptorLink {
257 dest: addr.clone(),
258 session_id,
259 stream: mvar.clone(),
260 cancel: cancel.clone(),
261 };
262
263 let (inbound_tx, inbound_rx) = mpsc::channel::<In>(1024);
264 let (outbound_tx, outbound_rx) =
265 mpsc::unbounded_channel::<(Out, oneshot::Sender<SendError<Out>>, Instant)>();
266 let (notify, status) = watch::channel(TxStatus::Active);
267 let net_rx = DuplexRx(inbound_rx, addr.clone());
268 let net_tx = DuplexTx {
269 tx: outbound_tx,
270 addr: addr.clone(),
271 status,
272 };
273 let _ = accept_tx.send((net_rx, net_tx)).await;
274
275 let session_ct = cancel.clone();
276 let dest = addr.clone();
277 tokio::spawn(async move {
278 let mut session = Session::new(link);
279 let mut recv_next = Next { seq: 0, ack: 0 };
280 let log_id = format!("duplex server {:016x}", session_id.0);
281 let mut deliveries = session::Deliveries {
282 outbox: session::Outbox::new(log_id.clone(), dest, session_id.0),
283 unacked: session::Unacked::new(None, log_id),
284 };
285 let mut outbound_rx = outbound_rx;
286
287 loop {
288 let connected = match session.connect().await {
289 Ok(s) => s,
290 Err(_) => break,
291 };
292 deliveries.requeue_unacked();
293 let result = {
294 let recv_stream = connected.stream(super::INITIATOR_TO_ACCEPTOR);
295 let send_stream = connected.stream(super::ACCEPTOR_TO_INITIATOR);
296 tokio::select! {
297 r = session::recv_connected::<In, _, _>(
298 &recv_stream,
299 &inbound_tx,
300 &mut recv_next,
301 ) => r.map_err(Either::Recv),
302 r = session::send_connected(
303 &send_stream,
304 &mut deliveries,
305 &mut outbound_rx,
306 ) => r.map_err(Either::Send),
307 _ = session_ct.cancelled() => Err(Either::Recv(session::RecvLoopError::Cancelled)),
308 }
309 };
310
311 let terminal = match &result {
312 Ok(()) => false,
313 Err(Either::Send(session::SendLoopError::Io(err))) => {
314 tracing::info!(
315 session_id = session_id.0,
316 error = %err,
317 "duplex send/recv error",
318 );
319 false
320 }
321 Err(Either::Recv(session::RecvLoopError::Io(_))) => false,
322 _ => true,
323 };
324 session = connected.release();
325 if terminal {
326 break;
327 }
328 }
329
330 let _ = notify.send(TxStatus::Closed);
331 });
332
333 e.insert(mvar.clone());
334 mvar
335 }
336 }
337 };
338
339 mvar.put(stream).await;
340}
341
342pub(crate) fn spawn<Out: RemoteMessage, In: RemoteMessage>(
345 link: impl Link,
346) -> (DuplexTx<Out>, DuplexRx<In>) {
347 let addr = link.dest();
348 let session_id = link.link_id();
349 let (outbound_tx, outbound_rx) = tokio::sync::mpsc::unbounded_channel();
350 let (inbound_tx, inbound_rx) = tokio::sync::mpsc::channel::<In>(1024);
351 let (notify, status) = watch::channel(TxStatus::Active);
352 let dest = addr.clone();
353 crate::init::get_runtime().spawn(async move {
354 let mut session = Session::new(link);
355 let log_id = format!("session {}.{:016x}", dest, session_id.0);
356 let mut deliveries = session::Deliveries {
357 outbox: session::Outbox::new(log_id.clone(), dest.clone(), session_id.0),
358 unacked: session::Unacked::new(None, log_id),
359 };
360 let mut outbound_rx = outbound_rx;
361 let mut recv_next = Next { seq: 0, ack: 0 };
362
363 loop {
364 let connected = match session.connect().await {
365 Ok(s) => s,
366 Err(_) => break,
367 };
368
369 metrics::CHANNEL_CONNECTIONS.add(
370 1,
371 hyperactor_telemetry::kv_pairs!(
372 "transport" => dest.transport().to_string(),
373 "mode" => "duplex",
374 "reason" => "link connected",
375 ),
376 );
377
378 if !deliveries.unacked.is_empty() {
379 metrics::CHANNEL_RECONNECTIONS.add(
380 1,
381 hyperactor_telemetry::kv_pairs!(
382 "dest" => dest.to_string(),
383 "transport" => dest.transport().to_string(),
384 "mode" => "duplex",
385 "reason" => "reconnect_with_unacked",
386 ),
387 );
388 }
389 deliveries.requeue_unacked();
390 let result = {
391 let send_stream = connected.stream(super::INITIATOR_TO_ACCEPTOR);
392 let recv_stream = connected.stream(super::ACCEPTOR_TO_INITIATOR);
393 tokio::select! {
394 r = session::send_connected(
395 &send_stream, &mut deliveries, &mut outbound_rx,
396 ) => r.map_err(Either::Send),
397 r = session::recv_connected::<In, _, _>(
398 &recv_stream, &inbound_tx, &mut recv_next,
399 ) => r.map_err(Either::Recv),
400 }
401 };
402
403 let terminal = match &result {
404 Ok(()) => false, Err(Either::Send(e)) => log_send_error(e, &dest, session_id.0, "duplex"),
406 Err(Either::Recv(session::RecvLoopError::Io(err))) => {
407 tracing::info!(
408 dest = %dest,
409 session_id = session_id.0,
410 error = %err,
411 mode = "duplex",
412 "recv error",
413 );
414 metrics::CHANNEL_ERRORS.add(
415 1,
416 hyperactor_telemetry::kv_pairs!(
417 "dest" => dest.to_string(),
418 "session_id" => session_id.0.to_string(),
419 "error_type" => metrics::ChannelErrorType::SendError.as_str(),
420 "mode" => "duplex",
421 ),
422 );
423 false
424 }
425 _ => true, };
427 session = connected.release();
428 if terminal {
429 break;
430 }
431 }
432
433 let _ = notify.send(TxStatus::Closed);
434 });
435 (
436 DuplexTx::new(outbound_tx, addr.clone(), status),
437 DuplexRx::new(inbound_rx, addr),
438 )
439}
440
441pub fn dial<Out: RemoteMessage, In: RemoteMessage>(
443 addr: ChannelAddr,
444) -> Result<(DuplexTx<Out>, DuplexRx<In>), ClientError> {
445 Ok(spawn(super::link(addr)?))
446}
447
448#[cfg(test)]
449mod tests {
450 use timed_test::async_timed_test;
451
452 use super::*;
453 use crate::channel::ChannelTransport;
454
455 #[async_timed_test(timeout_secs = 30)]
456 #[cfg_attr(not(fbcode_build), ignore)]
458 async fn test_duplex_basic() {
459 let mut server =
460 serve::<u64, String>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
461 let server_addr = server.addr().clone();
462
463 let (client_tx, mut client_rx) = dial::<u64, String>(server_addr).unwrap();
465
466 let (mut server_rx, server_tx) = server.accept().await.unwrap();
468
469 client_tx.post(42);
471 let received = server_rx.recv().await.unwrap();
472 assert_eq!(received, 42);
473
474 server_tx.post("hello".to_string());
476 let received = client_rx.recv().await.unwrap();
477 assert_eq!(received, "hello");
478
479 for i in 0..10u64 {
481 client_tx.post(i);
482 assert_eq!(server_rx.recv().await.unwrap(), i);
483
484 server_tx.post(format!("msg-{}", i));
485 assert_eq!(client_rx.recv().await.unwrap(), format!("msg-{}", i));
486 }
487 }
488
489 #[async_timed_test(timeout_secs = 30)]
490 #[cfg_attr(not(fbcode_build), ignore)]
491 async fn test_duplex_multiple_links() {
492 let mut server = serve::<u64, u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
493 let server_addr = server.addr().clone();
494
495 let (tx1, mut rx1) = dial::<u64, u64>(server_addr.clone()).unwrap();
497 let (mut srx1, stx1) = server.accept().await.unwrap();
498
499 let (tx2, mut rx2) = dial::<u64, u64>(server_addr).unwrap();
500 let (mut srx2, stx2) = server.accept().await.unwrap();
501
502 tx1.post(100);
504 assert_eq!(srx1.recv().await.unwrap(), 100);
505 stx1.post(200);
506 assert_eq!(rx1.recv().await.unwrap(), 200);
507
508 tx2.post(300);
510 assert_eq!(srx2.recv().await.unwrap(), 300);
511 stx2.post(400);
512 assert_eq!(rx2.recv().await.unwrap(), 400);
513 }
514
515 async fn duplex_ping_pong(
518 addr: ChannelAddr,
519 iterations: usize,
520 ) -> anyhow::Result<std::time::Duration> {
521 let mut server = serve::<u64, u64>(addr)?;
522 let server_addr = server.addr().clone();
523
524 let server_handle = tokio::spawn(async move {
525 let (mut rx, tx) = server.accept().await.unwrap();
526 while let Ok(msg) = rx.recv().await {
527 tx.post(msg);
528 }
529 });
530
531 let (client_tx, mut client_rx) = dial::<u64, u64>(server_addr).unwrap();
532
533 for i in 0..10u64 {
535 client_tx.post(i);
536 assert_eq!(client_rx.recv().await?, i);
537 }
538
539 let start = std::time::Instant::now();
540 for i in 0..iterations as u64 {
541 client_tx.post(i);
542 assert_eq!(client_rx.recv().await?, i);
543 }
544 let elapsed = start.elapsed();
545
546 server_handle.abort();
547 Ok(elapsed)
548 }
549
550 #[async_timed_test(timeout_secs = 30)]
551 #[cfg_attr(not(fbcode_build), ignore)]
552 async fn test_duplex_ping_pong_tcp() {
553 let elapsed = duplex_ping_pong(ChannelAddr::Tcp("[::1]:0".parse().unwrap()), 100)
554 .await
555 .unwrap();
556 println!("TCP duplex: 100 round-trips in {elapsed:?}");
557 }
558
559 #[async_timed_test(timeout_secs = 30)]
560 async fn test_duplex_ping_pong_unix() {
561 let elapsed = duplex_ping_pong(ChannelAddr::any(ChannelTransport::Unix), 100)
562 .await
563 .unwrap();
564 println!("Unix duplex: 100 round-trips in {elapsed:?}");
565 }
566}