Skip to main content

hyperactor/channel/net/
duplex.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Duplex-mode channels over the net link layer.
10//!
11//! A single physical connection carries messages in both directions,
12//! each with independent sequence/ack state.
13//!
14//! ## Wire protocol
15//!
16//! Each connection starts with a unified `LinkInit` header (13 bytes,
17//! unframed) containing only the `session_id`:
18//!
19//! ```text
20//! [magic: 4B "LNK\0"] [session_id: 8B u64 BE]
21//! ```
22//!
23//! After the init, the standard tagged frame format is used. The tag
24//! byte in the 8-byte header distinguishes logical channels:
25//!
26//! - `INITIATOR_TO_ACCEPTOR = 0x00`
27//! - `ACCEPTOR_TO_INITIATOR = 0x01`
28
29use std::sync::Arc;
30
31use async_trait::async_trait;
32use backoff::ExponentialBackoffBuilder;
33use backoff::backoff::Backoff;
34use dashmap::DashMap;
35use tokio::sync::mpsc;
36use tokio::sync::oneshot;
37use tokio::sync::watch;
38use tokio::time::Instant;
39use tokio_util::sync::CancellationToken;
40
41use super::ClientError;
42use super::Link;
43use super::LinkStatus;
44use super::ServerError;
45use super::SessionId;
46use super::log_send_error;
47use super::read_link_init;
48use super::server::AcceptorLink;
49use super::server::ServerHandle;
50use super::session;
51use super::session::Next;
52use super::session::Session;
53use crate::RemoteMessage;
54use crate::channel::ChannelAddr;
55use crate::channel::ChannelError;
56use crate::channel::ChannelTransport;
57use crate::channel::Rx;
58use crate::channel::SendError;
59use crate::channel::Tx;
60use crate::channel::TxStatus;
61use crate::channel::net::Stream;
62use crate::channel::net::meta;
63use crate::channel::net::tls;
64use crate::metrics;
65
66/// Public duplex server that yields `(DuplexRx<In>, DuplexTx<Out>)` pairs.
67pub struct DuplexServer<In: RemoteMessage, Out: RemoteMessage> {
68    accept_rx: mpsc::Receiver<(DuplexRx<In>, DuplexTx<Out>)>,
69    handle: ServerHandle,
70    addr: ChannelAddr,
71}
72
73impl<In: RemoteMessage, Out: RemoteMessage> DuplexServer<In, Out> {
74    /// Accept a new duplex link, returning `(rx, tx)` handles.
75    pub async fn accept(&mut self) -> Result<(DuplexRx<In>, DuplexTx<Out>), ChannelError> {
76        self.accept_rx.recv().await.ok_or(ChannelError::Closed)
77    }
78
79    /// The address this server is listening on.
80    pub fn addr(&self) -> &ChannelAddr {
81        &self.addr
82    }
83
84    /// Gracefully shut down the duplex server. Cancels the listener
85    /// and awaits its task; structured concurrency in
86    /// [`dispatch_duplex_stream`] guarantees every in-flight session
87    /// has finished its terminal cleanup (final ack flush + `Closed`
88    /// emit) before this returns.
89    pub async fn join(mut self) {
90        self.handle.stop(&format!(
91            "DuplexServer joined; channel address: {}",
92            self.addr
93        ));
94        let _ = (&mut self.handle).await;
95    }
96}
97
98/// Receiver half of a duplex channel.
99pub struct DuplexRx<M: RemoteMessage>(mpsc::Receiver<M>, ChannelAddr);
100
101impl<M: RemoteMessage> DuplexRx<M> {
102    pub(super) fn new(rx: mpsc::Receiver<M>, addr: ChannelAddr) -> Self {
103        Self(rx, addr)
104    }
105}
106
107#[async_trait]
108impl<M: RemoteMessage> Rx<M> for DuplexRx<M> {
109    async fn recv(&mut self) -> Result<M, ChannelError> {
110        self.0.recv().await.ok_or(ChannelError::Closed)
111    }
112
113    fn addr(&self) -> ChannelAddr {
114        self.1.clone()
115    }
116
117    async fn join(self) {}
118}
119
120/// A handle to a duplex client session: wraps the send/recv halves
121/// and the spawned task driving the connection. Owns a cancellation
122/// token so callers can deterministically stop the recv/send loop
123/// via [`DuplexClient::join`].
124///
125/// Dropping a `DuplexClient` does *not* cancel — that would tear
126/// down sessions whose tx/rx halves the application has handed off
127/// elsewhere (e.g., into a mailbox). Call [`join`](Self::join) for
128/// orderly shutdown.
129pub struct DuplexClient<Out: RemoteMessage, In: RemoteMessage> {
130    tx: DuplexTx<Out>,
131    rx: Option<DuplexRx<In>>,
132    join_handle: tokio::task::JoinHandle<()>,
133    cancel_token: CancellationToken,
134    addr: ChannelAddr,
135}
136
137impl<Out: RemoteMessage, In: RemoteMessage> DuplexClient<Out, In> {
138    /// Get a new clone of the [`DuplexTx`] for sending messages to
139    /// the peer.
140    pub fn tx(&self) -> DuplexTx<Out> {
141        self.tx.clone()
142    }
143
144    /// Take the [`DuplexRx`] out of the client. Returns `None` on
145    /// subsequent calls — the receiver is single-consumer.
146    pub fn take_rx(&mut self) -> Option<DuplexRx<In>> {
147        self.rx.take()
148    }
149
150    /// The peer address this client dialed.
151    pub fn addr(&self) -> &ChannelAddr {
152        &self.addr
153    }
154
155    /// Gracefully shut down the duplex client session. Cancels the
156    /// recv/send loop's cancellation token (which the spawned task
157    /// observes in its `select!`s) and awaits the spawned task. On
158    /// return, the task has finished its terminal cleanup (final
159    /// ack flush on `ACCEPTOR_TO_INITIATOR`) and dropped its
160    /// [`inbound_tx`](super::session) / outbound receiver halves —
161    /// so any in-progress [`DuplexRx::recv`](super::Rx::recv) on
162    /// the receiver half resolves with [`ChannelError::Closed`].
163    pub async fn join(self) {
164        self.cancel_token.cancel();
165        let _ = self.join_handle.await;
166    }
167}
168
169impl<Out: RemoteMessage, In: RemoteMessage> std::fmt::Debug for DuplexClient<Out, In> {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        f.debug_struct("DuplexClient")
172            .field("addr", &self.addr)
173            .field("rx_taken", &self.rx.is_none())
174            .finish()
175    }
176}
177
178/// Sender half of a duplex channel.
179pub struct DuplexTx<M: RemoteMessage> {
180    tx: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
181    addr: ChannelAddr,
182    status: watch::Receiver<TxStatus>,
183}
184
185impl<M: RemoteMessage> DuplexTx<M> {
186    pub(super) fn new(
187        tx: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
188        addr: ChannelAddr,
189        status: watch::Receiver<TxStatus>,
190    ) -> Self {
191        Self { tx, addr, status }
192    }
193}
194
195#[async_trait]
196impl<M: RemoteMessage> Tx<M> for DuplexTx<M> {
197    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
198        let return_channel = return_channel.unwrap_or_else(|| oneshot::channel().0);
199        if let Err(mpsc::error::SendError((message, return_channel, _))) =
200            self.tx
201                .send((message, return_channel, tokio::time::Instant::now()))
202        {
203            let reason = self.status.borrow().as_closed().map(|r| r.to_string());
204            let _ = return_channel.send(SendError {
205                error: ChannelError::Closed,
206                message,
207                reason,
208            });
209        }
210    }
211
212    fn addr(&self) -> ChannelAddr {
213        self.addr.clone()
214    }
215
216    fn status(&self) -> &watch::Receiver<TxStatus> {
217        &self.status
218    }
219}
220
221impl<M: RemoteMessage> Clone for DuplexTx<M> {
222    fn clone(&self) -> Self {
223        Self {
224            tx: self.tx.clone(),
225            addr: self.addr.clone(),
226            status: self.status.clone(),
227        }
228    }
229}
230
231impl<M: RemoteMessage> std::fmt::Debug for DuplexTx<M> {
232    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233        f.debug_struct("DuplexTx")
234            .field("addr", &self.addr)
235            .finish()
236    }
237}
238
239/// Start a duplex server on the given address.
240pub fn serve<In: RemoteMessage, Out: RemoteMessage>(
241    addr: ChannelAddr,
242    listener: Option<std::net::TcpListener>,
243) -> Result<DuplexServer<In, Out>, ServerError> {
244    let (mut listener, channel_addr) = super::listen_with_prebound(addr, listener)?;
245
246    let (accept_tx, accept_rx) = mpsc::channel(16);
247    let cancel_token = CancellationToken::new();
248    let child_token = cancel_token.child_token();
249
250    let is_tls = matches!(
251        channel_addr.transport(),
252        ChannelTransport::Tls | ChannelTransport::MetaTls(_)
253    );
254    let dest = channel_addr.clone();
255    let prepare = move |stream: Box<dyn Stream>, source: ChannelAddr| {
256        let dest = dest.clone();
257        async move {
258            if is_tls {
259                let tls_acceptor = match dest.transport() {
260                    ChannelTransport::Tls => tls::tls_acceptor()?,
261                    _ => meta::tls_acceptor(true)?,
262                };
263                let mut tls_stream = tls_acceptor.accept(stream).await?;
264                let link_init = read_link_init(&mut tls_stream)
265                    .await
266                    .map_err(|e| anyhow::anyhow!("LinkInit read failed from {}: {}", source, e))?;
267                Ok((link_init, Box::new(tls_stream) as Box<dyn Stream>))
268            } else {
269                let mut stream = stream;
270                let link_init = read_link_init(&mut stream)
271                    .await
272                    .map_err(|e| anyhow::anyhow!("LinkInit read failed from {}: {}", source, e))?;
273                Ok((link_init, stream))
274            }
275        }
276    };
277
278    let sessions: Arc<DashMap<SessionId, mpsc::UnboundedSender<Box<dyn Stream>>>> =
279        Arc::new(DashMap::new());
280    let dispatch_dest = channel_addr.clone();
281    let dispatch_cancel = child_token.clone();
282    let dispatch = {
283        let sessions = Arc::clone(&sessions);
284        let accept_tx = accept_tx.clone();
285        let dest = dispatch_dest;
286        move |link_init: super::LinkInit, stream: Box<dyn Stream>| {
287            let sessions = Arc::clone(&sessions);
288            let accept_tx = accept_tx.clone();
289            let cancel = dispatch_cancel.clone();
290            let dest = dest.clone();
291            async move {
292                dispatch_duplex_stream::<In, Out>(
293                    link_init.session_id,
294                    stream,
295                    sessions,
296                    dest,
297                    &accept_tx,
298                    cancel,
299                )
300                .await;
301            }
302        }
303    };
304
305    let ca = channel_addr.clone();
306    let join_handle = tokio::spawn(async move {
307        super::server::accept_loop(&mut listener, &ca, &child_token, prepare, dispatch).await
308    });
309
310    let server_handle = ServerHandle::new(join_handle, cancel_token, channel_addr.clone());
311
312    Ok(DuplexServer {
313        accept_rx,
314        handle: server_handle,
315        addr: channel_addr,
316    })
317}
318
319/// Test-only variant that accepts an arbitrary [`super::Listener`].
320/// Mirrors [`super::server::serve_with_listener`] but for duplex
321/// servers; lets wire-level tests stage `DuplexStream`s via a
322/// custom listener so they can inspect terminal-flush behavior on
323/// the read side.
324#[cfg(test)]
325pub(super) fn serve_with_listener<In, Out, L>(
326    mut listener: L,
327    channel_addr: ChannelAddr,
328) -> Result<DuplexServer<In, Out>, ServerError>
329where
330    In: RemoteMessage,
331    Out: RemoteMessage,
332    L: super::Listener + 'static,
333    L::Stream: Unpin + std::fmt::Debug + 'static,
334{
335    let (accept_tx, accept_rx) = mpsc::channel(16);
336    let cancel_token = CancellationToken::new();
337    let child_token = cancel_token.child_token();
338
339    let prepare = |stream: L::Stream, source: ChannelAddr| async move {
340        let mut boxed: Box<dyn Stream> = Box::new(stream);
341        let link_init = read_link_init(&mut boxed)
342            .await
343            .map_err(|e| anyhow::anyhow!("LinkInit read failed from {}: {}", source, e))?;
344        Ok((link_init, boxed))
345    };
346
347    let sessions: Arc<DashMap<SessionId, mpsc::UnboundedSender<Box<dyn Stream>>>> =
348        Arc::new(DashMap::new());
349    let dispatch_cancel = child_token.clone();
350    let dispatch = {
351        let sessions = Arc::clone(&sessions);
352        let accept_tx = accept_tx.clone();
353        let dest = channel_addr.clone();
354        move |link_init: super::LinkInit, stream: Box<dyn Stream>| {
355            let sessions = Arc::clone(&sessions);
356            let accept_tx = accept_tx.clone();
357            let cancel = dispatch_cancel.clone();
358            let dest = dest.clone();
359            async move {
360                dispatch_duplex_stream::<In, Out>(
361                    link_init.session_id,
362                    stream,
363                    sessions,
364                    dest,
365                    &accept_tx,
366                    cancel,
367                )
368                .await;
369            }
370        }
371    };
372
373    let ca = channel_addr.clone();
374    let join_handle = tokio::spawn(async move {
375        super::server::accept_loop(&mut listener, &ca, &child_token, prepare, dispatch).await
376    });
377
378    let server_handle = ServerHandle::new(join_handle, cancel_token, channel_addr.clone());
379
380    Ok(DuplexServer {
381        accept_rx,
382        handle: server_handle,
383        addr: channel_addr,
384    })
385}
386
387/// Helper to distinguish send errors from recv errors in duplex select.
388enum Either {
389    Send(session::SendLoopError),
390    Recv(session::RecvLoopError),
391}
392
393/// Dispatch a stream to the appropriate duplex session, creating one
394/// if this is the first connection for the given session ID.
395///
396/// Structured concurrency: the first dispatch for a session runs the
397/// recv/send loop inline and only returns after its terminal cleanup
398/// (flush any pending recv ack, emit `Closed` on cancellation).
399/// Reconnects hand the connection off via the per-session channel
400/// and return immediately. [`accept_loop`](super::server::accept_loop)
401/// joins every dispatch in its `connections` `JoinSet`, so it
402/// finishes only after every recv/send loop has finished — same
403/// contract as the simplex [`dispatch_stream`](super::server::dispatch_stream).
404async fn dispatch_duplex_stream<In: RemoteMessage, Out: RemoteMessage>(
405    session_id: SessionId,
406    stream: Box<dyn Stream>,
407    sessions: Arc<DashMap<SessionId, mpsc::UnboundedSender<Box<dyn Stream>>>>,
408    addr: ChannelAddr,
409    accept_tx: &mpsc::Sender<(DuplexRx<In>, DuplexTx<Out>)>,
410    cancel: CancellationToken,
411) {
412    // Insert into the session map and drop the DashMap shard guard
413    // before any await. Vacant inserts the sender side; occupied
414    // returns the existing sender so this dispatch acts as a feeder.
415    // Unbounded so this task can publish the sender before draining
416    // the first conn without deadlocking a concurrent feeder for
417    // the same session_id.
418    let entry_result = {
419        let entry = sessions.entry(session_id);
420        match entry {
421            dashmap::mapref::entry::Entry::Occupied(e) => Err(e.get().clone()),
422            dashmap::mapref::entry::Entry::Vacant(e) => {
423                let (sender, receiver) = mpsc::unbounded_channel::<Box<dyn Stream>>();
424                e.insert(sender.clone());
425                Ok((sender, receiver))
426            }
427        }
428    };
429
430    let (sender, receiver) = match entry_result {
431        Err(sender) => {
432            // Feeder: forward the conn through the existing channel.
433            // Send returns Err if the processor task exited and
434            // dropped the receiver — drop the conn in that case.
435            let _ = sender.send(stream);
436            return;
437        }
438        Ok(pair) => pair,
439    };
440
441    // First dispatch for this session_id: set up the duplex
442    // application-facing handles and yield them to the server.
443    let (inbound_tx, inbound_rx) = mpsc::channel::<In>(1024);
444    let (outbound_tx, mut outbound_rx) =
445        mpsc::unbounded_channel::<(Out, oneshot::Sender<SendError<Out>>, Instant)>();
446    let (notify, status) = watch::channel(TxStatus::Active);
447    let net_rx = DuplexRx(inbound_rx, addr.clone());
448    let net_tx = DuplexTx {
449        tx: outbound_tx,
450        addr: addr.clone(),
451        status,
452    };
453    let _ = accept_tx.send((net_rx, net_tx)).await;
454
455    // Hand the first connection off through the channel; the loop
456    // below picks it up via `session.connect().await`.
457    let _ = sender.send(stream);
458    drop(sender);
459
460    let link = AcceptorLink {
461        dest: addr.clone(),
462        session_id,
463        stream: receiver,
464        cancel: cancel.clone(),
465    };
466    let session_ct = cancel;
467    let dest = addr;
468    let log_id = format!("duplex server {:016x}", session_id.0);
469    let mut deliveries = session::Deliveries {
470        outbox: session::Outbox::new(log_id.clone(), dest.clone(), session_id.0),
471        unacked: session::Unacked::new(None, log_id),
472    };
473    let mut session = Session::new(link);
474    let mut recv_next = Next { seq: 0, ack: 0 };
475
476    loop {
477        let connected = match session.connect().await {
478            Ok(s) => s,
479            Err(_) => break,
480        };
481        deliveries.requeue_unacked();
482        let result = {
483            let recv_stream = connected.stream(super::INITIATOR_TO_ACCEPTOR);
484            let send_stream = connected.stream(super::ACCEPTOR_TO_INITIATOR);
485            tokio::select! {
486                r = session::recv_connected::<In, _, _>(
487                    &recv_stream,
488                    &inbound_tx,
489                    &mut recv_next,
490                ) => r.map_err(Either::Recv),
491                r = session::send_connected(
492                    &send_stream,
493                    &mut deliveries,
494                    &mut outbound_rx,
495                ) => r.map_err(Either::Send),
496                _ = session_ct.cancelled() => Err(Either::Recv(session::RecvLoopError::Cancelled)),
497            }
498        };
499
500        let terminal = match &result {
501            Ok(()) => {
502                tracing::info!(
503                    session_id = session_id.0,
504                    "duplex recv_connected returned EOF, awaiting reconnect"
505                );
506                false
507            }
508            Err(Either::Send(session::SendLoopError::Io(err))) => {
509                tracing::info!(
510                    session_id = session_id.0,
511                    error = %err,
512                    "duplex send error (recoverable)",
513                );
514                false
515            }
516            Err(Either::Recv(session::RecvLoopError::Io(err))) => {
517                tracing::info!(
518                    session_id = session_id.0,
519                    error = %err,
520                    "duplex recv error (recoverable)",
521                );
522                false
523            }
524            Err(Either::Send(e)) => {
525                tracing::info!(
526                    session_id = session_id.0,
527                    error = %e,
528                    "duplex send terminal error"
529                );
530                true
531            }
532            Err(Either::Recv(e)) => {
533                tracing::info!(
534                    session_id = session_id.0,
535                    error = %e,
536                    "duplex recv terminal error"
537                );
538                true
539            }
540        };
541
542        // Flush any pending recv ack so the peer's
543        // unacked queue clears cleanly before this
544        // connection goes away. Mirrors the simplex
545        // server's drain logic (see
546        // `dispatch_stream`); without it, peers
547        // retry-loop until `MESSAGE_DELIVERY_TIMEOUT`.
548        if recv_next.ack < recv_next.seq {
549            let recv_stream = connected.stream(super::INITIATOR_TO_ACCEPTOR);
550            let ack = super::serialize_response(super::NetRxResponse::Ack(recv_next.seq - 1))
551                .expect("serialize ack");
552            let mut completion = recv_stream.write(ack);
553            match completion.drive().await {
554                Ok(()) => {
555                    recv_next.ack = recv_next.seq;
556                }
557                Err(e) => {
558                    tracing::debug!(
559                        session_id = session_id.0,
560                        error = %e,
561                        "duplex: failed to flush acks during cleanup"
562                    );
563                }
564            }
565        }
566
567        // On terminal exit, tell the peer we're
568        // closing so it stops trying to reconnect.
569        let terminal_response = match &result {
570            Err(Either::Recv(session::RecvLoopError::SequenceError(reason))) => {
571                Some(super::NetRxResponse::Reject(reason.clone()))
572            }
573            Err(Either::Recv(session::RecvLoopError::Cancelled))
574            | Err(Either::Send(session::SendLoopError::AppClosed)) => {
575                Some(super::NetRxResponse::Closed)
576            }
577            _ => None,
578        };
579        if let Some(rsp) = terminal_response {
580            let recv_stream = connected.stream(super::INITIATOR_TO_ACCEPTOR);
581            let data = super::serialize_response(rsp).expect("serialize terminal response");
582            let mut completion = recv_stream.write(data);
583            let _ = completion.drive().await;
584        }
585
586        session = connected.release();
587        if terminal {
588            break;
589        }
590    }
591
592    // Recv/send loop is finished — drop the session entry so a later
593    // reconnect for the same session_id starts a fresh dispatch task
594    // instead of feeding a dead channel; any in-flight feeder's send
595    // fails after the link's receiver above is dropped.
596    sessions.remove(&session_id);
597
598    let _ = notify.send(TxStatus::Closed("duplex session ended".into()));
599}
600
601/// Establish a duplex (bidirectional) session over the given link.
602/// Returns a [`DuplexClient`] wrapping the send/recv halves and the
603/// spawned recv/send task; the client owns a cancellation token so
604/// callers can deterministically tear the session down via
605/// [`DuplexClient::join`].
606pub(crate) fn spawn<Out: RemoteMessage, In: RemoteMessage>(
607    link: impl Link,
608) -> DuplexClient<Out, In> {
609    let addr = link.dest();
610    let session_id = link.link_id();
611    let (outbound_tx, outbound_rx) = tokio::sync::mpsc::unbounded_channel();
612    let (inbound_tx, inbound_rx) = tokio::sync::mpsc::channel::<In>(1024);
613    let (notify, status) = watch::channel(TxStatus::Active);
614    let cancel_token = CancellationToken::new();
615    let task_cancel = cancel_token.clone();
616    let dest = addr.clone();
617    let join_handle = crate::init::get_runtime().spawn(async move {
618        let mut session = Session::new(link);
619        let log_id = format!("session {}.{:016x}", dest, session_id.0);
620        let mut deliveries = session::Deliveries {
621            outbox: session::Outbox::new(log_id.clone(), dest.clone(), session_id.0),
622            unacked: session::Unacked::new(None, log_id),
623        };
624        let mut outbound_rx = outbound_rx;
625        let mut recv_next = Next { seq: 0, ack: 0 };
626        let mut reconnect_backoff = ExponentialBackoffBuilder::new()
627            .with_initial_interval(std::time::Duration::from_millis(10))
628            .with_multiplier(2.0)
629            .with_randomization_factor(0.1)
630            .with_max_interval(std::time::Duration::from_secs(5))
631            .with_max_elapsed_time(None)
632            .build();
633
634        let mut link_status = LinkStatus::NeverConnected;
635
636        loop {
637            // Race connect against cancel so a `DuplexClient::join`
638            // call mid-dial doesn't have to wait for the dial
639            // backoff to elapse.
640            let connected = tokio::select! {
641                result = session.connect() => match result {
642                    Ok(s) => s,
643                    Err(_) => break,
644                },
645                _ = task_cancel.cancelled() => break,
646            };
647
648            metrics::CHANNEL_CONNECTIONS.add(
649                1,
650                hyperactor_telemetry::kv_pairs!(
651                    "transport" => dest.transport().to_string(),
652                    "mode" => "duplex",
653                    "reason" => "link connected",
654                ),
655            );
656
657            if !deliveries.unacked.is_empty() {
658                metrics::CHANNEL_RECONNECTIONS.add(
659                    1,
660                    hyperactor_telemetry::kv_pairs!(
661                        "dest" => dest.to_string(),
662                        "transport" => dest.transport().to_string(),
663                        "mode" => "duplex",
664                        "reason" => "reconnect_with_unacked",
665                    ),
666                );
667            }
668            deliveries.requeue_unacked();
669
670            link_status.connected();
671            let connected_at = tokio::time::Instant::now();
672
673            let result = {
674                let send_stream = connected.stream(super::INITIATOR_TO_ACCEPTOR);
675                let recv_stream = connected.stream(super::ACCEPTOR_TO_INITIATOR);
676                tokio::select! {
677                    r = session::send_connected(
678                        &send_stream, &mut deliveries, &mut outbound_rx,
679                    ) => r.map_err(Either::Send),
680                    r = session::recv_connected::<In, _, _>(
681                        &recv_stream, &inbound_tx, &mut recv_next,
682                    ) => r.map_err(Either::Recv),
683                    _ = task_cancel.cancelled() => Err(Either::Recv(session::RecvLoopError::Cancelled)),
684                }
685            };
686
687            link_status.disconnected();
688
689            if connected_at.elapsed() > tokio::time::Duration::from_secs(1) {
690                reconnect_backoff.reset();
691            }
692
693            let terminal = match &result {
694                Ok(()) => {
695                    if let Some(delay) = reconnect_backoff.next_backoff() {
696                        tracing::info!(
697                            dest = %dest,
698                            session_id = session_id.0,
699                            delay_ms = delay.as_millis() as u64,
700                            "duplex send_connected returned EOF, reconnecting after backoff; {link_status}"
701                        );
702                        tokio::time::sleep(delay).await;
703                    }
704                    false
705                }
706                Err(Either::Send(e)) => {
707                    let terminal = log_send_error(e, &dest, session_id.0, "duplex", &link_status);
708                    if !terminal {
709                        // Recoverable send error — reconnect after backoff.
710                        if let Some(delay) = reconnect_backoff.next_backoff() {
711                            tracing::info!(
712                                dest = %dest,
713                                session_id = session_id.0,
714                                error = %e,
715                                delay_ms = delay.as_millis() as u64,
716                                mode = "duplex",
717                                "send error (recoverable), reconnecting after backoff; {link_status}",
718                            );
719                            tokio::time::sleep(delay).await;
720                        }
721                    }
722                    terminal
723                }
724                Err(Either::Recv(session::RecvLoopError::Io(err))) => {
725                    if let Some(delay) = reconnect_backoff.next_backoff() {
726                        tracing::info!(
727                            dest = %dest,
728                            session_id = session_id.0,
729                            error = %err,
730                            delay_ms = delay.as_millis() as u64,
731                            mode = "duplex",
732                            "recv error (recoverable), reconnecting after backoff; {link_status}",
733                        );
734                        tokio::time::sleep(delay).await;
735                    }
736                    metrics::CHANNEL_ERRORS.add(
737                        1,
738                        hyperactor_telemetry::kv_pairs!(
739                            "dest" => dest.to_string(),
740                            "session_id" => session_id.0.to_string(),
741                            "error_type" => metrics::ChannelErrorType::SendError.as_str(),
742                            "mode" => "duplex",
743                        ),
744                    );
745                    false
746                }
747                Err(Either::Recv(e)) => {
748                    tracing::info!(
749                        dest = %dest,
750                        session_id = session_id.0,
751                        error = %e,
752                        "duplex recv terminal error; {link_status}"
753                    );
754                    true
755                }
756            };
757
758            // Flush any pending recv ack so the peer's send-side
759            // unacked queue clears cleanly before this connection
760            // goes away. Mirrors the cleanup in
761            // `dispatch_duplex_stream` but on the other tag — the
762            // initiator reads data on `ACCEPTOR_TO_INITIATOR`, so
763            // its acks travel back on that same tag.
764            if recv_next.ack < recv_next.seq {
765                let recv_stream = connected.stream(super::ACCEPTOR_TO_INITIATOR);
766                let ack = super::serialize_response(super::NetRxResponse::Ack(recv_next.seq - 1))
767                    .expect("serialize ack");
768                let mut completion = recv_stream.write(ack);
769                match completion.drive().await {
770                    Ok(()) => {
771                        recv_next.ack = recv_next.seq;
772                    }
773                    Err(e) => {
774                        tracing::debug!(
775                            dest = %dest,
776                            session_id = session_id.0,
777                            error = %e,
778                            "duplex client: failed to flush acks during cleanup"
779                        );
780                    }
781                }
782            }
783
784            // On terminal exit, tell the peer we're closing (or
785            // rejecting) on the same tag we read data on. Mirrors
786            // `dispatch_duplex_stream`; without it, the peer's
787            // server-side dispatch keeps awaiting a reconnect that
788            // will never come.
789            let terminal_response = match &result {
790                Err(Either::Recv(session::RecvLoopError::SequenceError(reason))) => {
791                    Some(super::NetRxResponse::Reject(reason.clone()))
792                }
793                Err(Either::Recv(session::RecvLoopError::Cancelled))
794                | Err(Either::Send(session::SendLoopError::AppClosed)) => {
795                    Some(super::NetRxResponse::Closed)
796                }
797                _ => None,
798            };
799            if let Some(rsp) = terminal_response {
800                let recv_stream = connected.stream(super::ACCEPTOR_TO_INITIATOR);
801                let data =
802                    super::serialize_response(rsp).expect("serialize terminal response");
803                let mut completion = recv_stream.write(data);
804                let _ = completion.drive().await;
805            }
806
807            session = connected.release();
808            if terminal {
809                break;
810            }
811        }
812
813        let _ = notify.send(TxStatus::Closed("duplex session ended".into()));
814    });
815    let tx = DuplexTx::new(outbound_tx, addr.clone(), status);
816    let rx = DuplexRx::new(inbound_rx, addr.clone());
817    DuplexClient {
818        tx,
819        rx: Some(rx),
820        join_handle,
821        cancel_token,
822        addr,
823    }
824}
825
826/// Connect to a duplex server. Returns a [`DuplexClient`] wrapping
827/// the send/recv halves and the spawned recv/send task; callers use
828/// [`DuplexClient::tx`] / [`DuplexClient::take_rx`] to extract the
829/// halves and [`DuplexClient::join`] to deterministically shut the
830/// session down.
831pub fn dial<Out: RemoteMessage, In: RemoteMessage>(
832    addr: ChannelAddr,
833) -> Result<DuplexClient<Out, In>, ClientError> {
834    Ok(spawn(super::link(addr, super::SessionId::random(), 0)?))
835}
836
837#[cfg(test)]
838mod tests {
839    use timed_test::async_timed_test;
840
841    use super::*;
842    use crate::channel::ChannelTransport;
843
844    #[async_timed_test(timeout_secs = 30)]
845    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
846    #[cfg_attr(not(fbcode_build), ignore)]
847    async fn test_duplex_basic() {
848        let mut server =
849            serve::<u64, String>(ChannelAddr::Tcp("[::1]:0".parse().unwrap()), None).unwrap();
850        let server_addr = server.addr().clone();
851
852        // Client: sends u64, receives String.
853        let mut client = dial::<u64, String>(server_addr).unwrap();
854        let client_tx = client.tx();
855        let mut client_rx = client.take_rx().unwrap();
856
857        // Server: receives u64, sends String.
858        let (mut server_rx, server_tx) = server.accept().await.unwrap();
859
860        // Client sends to server.
861        client_tx.post(42);
862        let received = server_rx.recv().await.unwrap();
863        assert_eq!(received, 42);
864
865        // Server sends to client.
866        server_tx.post("hello".to_string());
867        let received = client_rx.recv().await.unwrap();
868        assert_eq!(received, "hello");
869
870        // Multiple messages both ways.
871        for i in 0..10u64 {
872            client_tx.post(i);
873            assert_eq!(server_rx.recv().await.unwrap(), i);
874
875            server_tx.post(format!("msg-{}", i));
876            assert_eq!(client_rx.recv().await.unwrap(), format!("msg-{}", i));
877        }
878    }
879
880    #[async_timed_test(timeout_secs = 30)]
881    #[cfg_attr(not(fbcode_build), ignore)]
882    async fn test_duplex_multiple_links() {
883        let mut server =
884            serve::<u64, u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap()), None).unwrap();
885        let server_addr = server.addr().clone();
886
887        // Two independent clients.
888        let mut client1 = dial::<u64, u64>(server_addr.clone()).unwrap();
889        let tx1 = client1.tx();
890        let mut rx1 = client1.take_rx().unwrap();
891        let (mut srx1, stx1) = server.accept().await.unwrap();
892
893        let mut client2 = dial::<u64, u64>(server_addr).unwrap();
894        let tx2 = client2.tx();
895        let mut rx2 = client2.take_rx().unwrap();
896        let (mut srx2, stx2) = server.accept().await.unwrap();
897
898        // Send on link 1.
899        tx1.post(100);
900        assert_eq!(srx1.recv().await.unwrap(), 100);
901        stx1.post(200);
902        assert_eq!(rx1.recv().await.unwrap(), 200);
903
904        // Send on link 2.
905        tx2.post(300);
906        assert_eq!(srx2.recv().await.unwrap(), 300);
907        stx2.post(400);
908        assert_eq!(rx2.recv().await.unwrap(), 400);
909    }
910
911    /// Ping-pong helper: server echoes back each message it receives.
912    /// Returns elapsed time for `iterations` round-trips.
913    async fn duplex_ping_pong(
914        addr: ChannelAddr,
915        iterations: usize,
916    ) -> anyhow::Result<std::time::Duration> {
917        let mut server = serve::<u64, u64>(addr, None)?;
918        let server_addr = server.addr().clone();
919
920        let server_handle = tokio::spawn(async move {
921            let (mut rx, tx) = server.accept().await.unwrap();
922            while let Ok(msg) = rx.recv().await {
923                tx.post(msg);
924            }
925        });
926
927        let mut client = dial::<u64, u64>(server_addr).unwrap();
928        let client_tx = client.tx();
929        let mut client_rx = client.take_rx().unwrap();
930
931        // Warmup.
932        for i in 0..10u64 {
933            client_tx.post(i);
934            assert_eq!(client_rx.recv().await?, i);
935        }
936
937        let start = std::time::Instant::now();
938        for i in 0..iterations as u64 {
939            client_tx.post(i);
940            assert_eq!(client_rx.recv().await?, i);
941        }
942        let elapsed = start.elapsed();
943
944        server_handle.abort();
945        Ok(elapsed)
946    }
947
948    #[async_timed_test(timeout_secs = 30)]
949    #[cfg_attr(not(fbcode_build), ignore)]
950    async fn test_duplex_ping_pong_tcp() {
951        let elapsed = duplex_ping_pong(ChannelAddr::Tcp("[::1]:0".parse().unwrap()), 100)
952            .await
953            .unwrap();
954        println!("TCP duplex: 100 round-trips in {elapsed:?}");
955    }
956
957    #[async_timed_test(timeout_secs = 30)]
958    async fn test_duplex_ping_pong_unix() {
959        let elapsed = duplex_ping_pong(ChannelAddr::any(ChannelTransport::Unix), 100)
960            .await
961            .unwrap();
962        println!("Unix duplex: 100 round-trips in {elapsed:?}");
963    }
964}