hyperactor/channel/net/
duplex.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Duplex-mode channels over the net link layer.
10//!
11//! A single physical connection carries messages in both directions,
12//! each with independent sequence/ack state.
13//!
14//! ## Wire protocol
15//!
16//! Each connection starts with a unified `LinkInit` header (13 bytes,
17//! unframed) containing only the `session_id`:
18//!
19//! ```text
20//! [magic: 4B "LNK\0"] [session_id: 8B u64 BE]
21//! ```
22//!
23//! After the init, the standard tagged frame format is used. The tag
24//! byte in the 8-byte header distinguishes logical channels:
25//!
26//! - `INITIATOR_TO_ACCEPTOR = 0x00`
27//! - `ACCEPTOR_TO_INITIATOR = 0x01`
28
29use std::sync::Arc;
30
31use async_trait::async_trait;
32use backoff::ExponentialBackoffBuilder;
33use backoff::backoff::Backoff;
34use dashmap::DashMap;
35use tokio::sync::mpsc;
36use tokio::sync::oneshot;
37use tokio::sync::watch;
38use tokio::time::Instant;
39use tokio_util::sync::CancellationToken;
40
41use super::ClientError;
42use super::Link;
43use super::LinkStatus;
44use super::ServerError;
45use super::SessionId;
46use super::log_send_error;
47use super::read_link_init;
48use super::server::AcceptorLink;
49use super::server::ServerHandle;
50use super::session;
51use super::session::Next;
52use super::session::Session;
53use crate::RemoteMessage;
54use crate::channel::ChannelAddr;
55use crate::channel::ChannelError;
56use crate::channel::ChannelTransport;
57use crate::channel::Rx;
58use crate::channel::SendError;
59use crate::channel::Tx;
60use crate::channel::TxStatus;
61use crate::channel::net::Stream;
62use crate::channel::net::meta;
63use crate::channel::net::tls;
64use crate::metrics;
65use crate::sync::mvar::MVar;
66
67/// Public duplex server that yields `(DuplexRx<In>, DuplexTx<Out>)` pairs.
68pub struct DuplexServer<In: RemoteMessage, Out: RemoteMessage> {
69    accept_rx: mpsc::Receiver<(DuplexRx<In>, DuplexTx<Out>)>,
70    _handle: ServerHandle,
71    addr: ChannelAddr,
72}
73
74impl<In: RemoteMessage, Out: RemoteMessage> DuplexServer<In, Out> {
75    /// Accept a new duplex link, returning `(rx, tx)` handles.
76    pub async fn accept(&mut self) -> Result<(DuplexRx<In>, DuplexTx<Out>), ChannelError> {
77        self.accept_rx.recv().await.ok_or(ChannelError::Closed)
78    }
79
80    /// The address this server is listening on.
81    pub fn addr(&self) -> &ChannelAddr {
82        &self.addr
83    }
84}
85
86/// Receiver half of a duplex channel.
87pub struct DuplexRx<M: RemoteMessage>(mpsc::Receiver<M>, ChannelAddr);
88
89impl<M: RemoteMessage> DuplexRx<M> {
90    pub(super) fn new(rx: mpsc::Receiver<M>, addr: ChannelAddr) -> Self {
91        Self(rx, addr)
92    }
93}
94
95#[async_trait]
96impl<M: RemoteMessage> Rx<M> for DuplexRx<M> {
97    async fn recv(&mut self) -> Result<M, ChannelError> {
98        self.0.recv().await.ok_or(ChannelError::Closed)
99    }
100
101    fn addr(&self) -> ChannelAddr {
102        self.1.clone()
103    }
104
105    async fn join(self) {}
106}
107
108/// Sender half of a duplex channel.
109pub struct DuplexTx<M: RemoteMessage> {
110    tx: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
111    addr: ChannelAddr,
112    status: watch::Receiver<TxStatus>,
113}
114
115impl<M: RemoteMessage> DuplexTx<M> {
116    pub(super) fn new(
117        tx: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
118        addr: ChannelAddr,
119        status: watch::Receiver<TxStatus>,
120    ) -> Self {
121        Self { tx, addr, status }
122    }
123}
124
125#[async_trait]
126impl<M: RemoteMessage> Tx<M> for DuplexTx<M> {
127    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
128        let return_channel = return_channel.unwrap_or_else(|| oneshot::channel().0);
129        if let Err(mpsc::error::SendError((message, return_channel, _))) =
130            self.tx
131                .send((message, return_channel, tokio::time::Instant::now()))
132        {
133            let reason = self.status.borrow().as_closed().map(|r| r.to_string());
134            let _ = return_channel.send(SendError {
135                error: ChannelError::Closed,
136                message,
137                reason,
138            });
139        }
140    }
141
142    fn addr(&self) -> ChannelAddr {
143        self.addr.clone()
144    }
145
146    fn status(&self) -> &watch::Receiver<TxStatus> {
147        &self.status
148    }
149}
150
151impl<M: RemoteMessage> Clone for DuplexTx<M> {
152    fn clone(&self) -> Self {
153        Self {
154            tx: self.tx.clone(),
155            addr: self.addr.clone(),
156            status: self.status.clone(),
157        }
158    }
159}
160
161/// Start a duplex server on the given address.
162pub fn serve<In: RemoteMessage, Out: RemoteMessage>(
163    addr: ChannelAddr,
164) -> Result<DuplexServer<In, Out>, ServerError> {
165    let (mut listener, channel_addr) = super::listen(addr)?;
166
167    let (accept_tx, accept_rx) = mpsc::channel(16);
168    let cancel_token = CancellationToken::new();
169    let child_token = cancel_token.child_token();
170
171    let is_tls = matches!(
172        channel_addr.transport(),
173        ChannelTransport::Tls | ChannelTransport::MetaTls(_)
174    );
175    let dest = channel_addr.clone();
176    let prepare = move |stream: Box<dyn Stream>, source: ChannelAddr| {
177        let dest = dest.clone();
178        async move {
179            if is_tls {
180                let tls_acceptor = match dest.transport() {
181                    ChannelTransport::Tls => tls::tls_acceptor()?,
182                    _ => meta::tls_acceptor(true)?,
183                };
184                let mut tls_stream = tls_acceptor.accept(stream).await?;
185                let session_id = read_link_init(&mut tls_stream)
186                    .await
187                    .map_err(|e| anyhow::anyhow!("LinkInit read failed from {}: {}", source, e))?;
188                Ok((session_id, Box::new(tls_stream) as Box<dyn Stream>))
189            } else {
190                let mut stream = stream;
191                let session_id = read_link_init(&mut stream)
192                    .await
193                    .map_err(|e| anyhow::anyhow!("LinkInit read failed from {}: {}", source, e))?;
194                Ok((session_id, stream))
195            }
196        }
197    };
198
199    let sessions: Arc<DashMap<SessionId, MVar<Box<dyn Stream>>>> = Arc::new(DashMap::new());
200    let child_cancel = CancellationToken::new();
201    let dispatch_dest = channel_addr.clone();
202    let dispatch = {
203        let sessions = Arc::clone(&sessions);
204        let accept_tx = accept_tx.clone();
205        let child_cancel = child_cancel.clone();
206        let dest = dispatch_dest;
207        move |session_id: SessionId, stream: Box<dyn Stream>| {
208            let sessions = Arc::clone(&sessions);
209            let accept_tx = accept_tx.clone();
210            let cancel = child_cancel.child_token();
211            let dest = dest.clone();
212            async move {
213                dispatch_duplex_stream::<In, Out>(
214                    session_id, stream, &sessions, dest, &accept_tx, cancel,
215                )
216                .await;
217            }
218        }
219    };
220
221    let ca = channel_addr.clone();
222    let join_handle = tokio::spawn(async move {
223        let result =
224            super::server::accept_loop(&mut listener, &ca, &child_token, prepare, dispatch).await;
225        child_cancel.cancel();
226        result
227    });
228
229    let server_handle = ServerHandle::new(join_handle, cancel_token, channel_addr.clone());
230
231    Ok(DuplexServer {
232        accept_rx,
233        _handle: server_handle,
234        addr: channel_addr,
235    })
236}
237
238/// Helper to distinguish send errors from recv errors in duplex select.
239enum Either {
240    Send(session::SendLoopError),
241    Recv(session::RecvLoopError),
242}
243
244/// Dispatch a stream to the appropriate duplex session, creating one
245/// if this is the first connection for the given session ID.
246async fn dispatch_duplex_stream<In: RemoteMessage, Out: RemoteMessage>(
247    session_id: SessionId,
248    stream: Box<dyn Stream>,
249    sessions: &DashMap<SessionId, MVar<Box<dyn Stream>>>,
250    addr: ChannelAddr,
251    accept_tx: &mpsc::Sender<(DuplexRx<In>, DuplexTx<Out>)>,
252    cancel: CancellationToken,
253) {
254    let mvar = {
255        let entry = sessions.entry(session_id);
256        match entry {
257            dashmap::mapref::entry::Entry::Occupied(e) => e.get().clone(),
258            dashmap::mapref::entry::Entry::Vacant(e) => {
259                let mvar: MVar<Box<dyn Stream>> = MVar::empty();
260                let link = AcceptorLink {
261                    dest: addr.clone(),
262                    session_id,
263                    stream: mvar.clone(),
264                    cancel: cancel.clone(),
265                };
266
267                let (inbound_tx, inbound_rx) = mpsc::channel::<In>(1024);
268                let (outbound_tx, outbound_rx) =
269                    mpsc::unbounded_channel::<(Out, oneshot::Sender<SendError<Out>>, Instant)>();
270                let (notify, status) = watch::channel(TxStatus::Active);
271                let net_rx = DuplexRx(inbound_rx, addr.clone());
272                let net_tx = DuplexTx {
273                    tx: outbound_tx,
274                    addr: addr.clone(),
275                    status,
276                };
277                let _ = accept_tx.send((net_rx, net_tx)).await;
278
279                let session_ct = cancel.clone();
280                let dest = addr.clone();
281                tokio::spawn(async move {
282                    let mut session = Session::new(link);
283                    let mut recv_next = Next { seq: 0, ack: 0 };
284                    let log_id = format!("duplex server {:016x}", session_id.0);
285                    let mut deliveries = session::Deliveries {
286                        outbox: session::Outbox::new(log_id.clone(), dest, session_id.0),
287                        unacked: session::Unacked::new(None, log_id),
288                    };
289                    let mut outbound_rx = outbound_rx;
290
291                    loop {
292                        let connected = match session.connect().await {
293                            Ok(s) => s,
294                            Err(_) => break,
295                        };
296                        deliveries.requeue_unacked();
297                        let result = {
298                            let recv_stream = connected.stream(super::INITIATOR_TO_ACCEPTOR);
299                            let send_stream = connected.stream(super::ACCEPTOR_TO_INITIATOR);
300                            tokio::select! {
301                                r = session::recv_connected::<In, _, _>(
302                                    &recv_stream,
303                                    &inbound_tx,
304                                    &mut recv_next,
305                                ) => r.map_err(Either::Recv),
306                                r = session::send_connected(
307                                    &send_stream,
308                                    &mut deliveries,
309                                    &mut outbound_rx,
310                                ) => r.map_err(Either::Send),
311                                _ = session_ct.cancelled() => Err(Either::Recv(session::RecvLoopError::Cancelled)),
312                            }
313                        };
314
315                        let terminal = match &result {
316                            Ok(()) => {
317                                tracing::info!(
318                                    session_id = session_id.0,
319                                    "duplex recv_connected returned EOF, awaiting reconnect"
320                                );
321                                false
322                            }
323                            Err(Either::Send(session::SendLoopError::Io(err))) => {
324                                tracing::info!(
325                                    session_id = session_id.0,
326                                    error = %err,
327                                    "duplex send error (recoverable)",
328                                );
329                                false
330                            }
331                            Err(Either::Recv(session::RecvLoopError::Io(err))) => {
332                                tracing::info!(
333                                    session_id = session_id.0,
334                                    error = %err,
335                                    "duplex recv error (recoverable)",
336                                );
337                                false
338                            }
339                            Err(Either::Send(e)) => {
340                                tracing::info!(
341                                    session_id = session_id.0,
342                                    error = %e,
343                                    "duplex send terminal error"
344                                );
345                                true
346                            }
347                            Err(Either::Recv(e)) => {
348                                tracing::info!(
349                                    session_id = session_id.0,
350                                    error = %e,
351                                    "duplex recv terminal error"
352                                );
353                                true
354                            }
355                        };
356                        session = connected.release();
357                        if terminal {
358                            break;
359                        }
360                    }
361
362                    let _ = notify.send(TxStatus::Closed("duplex session ended".into()));
363                });
364
365                e.insert(mvar.clone());
366                mvar
367            }
368        }
369    };
370
371    mvar.put(stream).await;
372}
373
374/// Establish a duplex (bidirectional) session over the given link.
375/// Returns send and receive handles.
376pub(crate) fn spawn<Out: RemoteMessage, In: RemoteMessage>(
377    link: impl Link,
378) -> (DuplexTx<Out>, DuplexRx<In>) {
379    let addr = link.dest();
380    let session_id = link.link_id();
381    let (outbound_tx, outbound_rx) = tokio::sync::mpsc::unbounded_channel();
382    let (inbound_tx, inbound_rx) = tokio::sync::mpsc::channel::<In>(1024);
383    let (notify, status) = watch::channel(TxStatus::Active);
384    let dest = addr.clone();
385    crate::init::get_runtime().spawn(async move {
386        let mut session = Session::new(link);
387        let log_id = format!("session {}.{:016x}", dest, session_id.0);
388        let mut deliveries = session::Deliveries {
389            outbox: session::Outbox::new(log_id.clone(), dest.clone(), session_id.0),
390            unacked: session::Unacked::new(None, log_id),
391        };
392        let mut outbound_rx = outbound_rx;
393        let mut recv_next = Next { seq: 0, ack: 0 };
394        let mut reconnect_backoff = ExponentialBackoffBuilder::new()
395            .with_initial_interval(std::time::Duration::from_millis(10))
396            .with_multiplier(2.0)
397            .with_randomization_factor(0.1)
398            .with_max_interval(std::time::Duration::from_secs(5))
399            .with_max_elapsed_time(None)
400            .build();
401
402        let mut link_status = LinkStatus::NeverConnected;
403
404        loop {
405            let connected = match session.connect().await {
406                Ok(s) => s,
407                Err(_) => break,
408            };
409
410            metrics::CHANNEL_CONNECTIONS.add(
411                1,
412                hyperactor_telemetry::kv_pairs!(
413                    "transport" => dest.transport().to_string(),
414                    "mode" => "duplex",
415                    "reason" => "link connected",
416                ),
417            );
418
419            if !deliveries.unacked.is_empty() {
420                metrics::CHANNEL_RECONNECTIONS.add(
421                    1,
422                    hyperactor_telemetry::kv_pairs!(
423                        "dest" => dest.to_string(),
424                        "transport" => dest.transport().to_string(),
425                        "mode" => "duplex",
426                        "reason" => "reconnect_with_unacked",
427                    ),
428                );
429            }
430            deliveries.requeue_unacked();
431
432            link_status.connected();
433            let connected_at = tokio::time::Instant::now();
434
435            let result = {
436                let send_stream = connected.stream(super::INITIATOR_TO_ACCEPTOR);
437                let recv_stream = connected.stream(super::ACCEPTOR_TO_INITIATOR);
438                tokio::select! {
439                    r = session::send_connected(
440                        &send_stream, &mut deliveries, &mut outbound_rx,
441                    ) => r.map_err(Either::Send),
442                    r = session::recv_connected::<In, _, _>(
443                        &recv_stream, &inbound_tx, &mut recv_next,
444                    ) => r.map_err(Either::Recv),
445                }
446            };
447
448            link_status.disconnected();
449
450            if connected_at.elapsed() > tokio::time::Duration::from_secs(1) {
451                reconnect_backoff.reset();
452            }
453
454            let terminal = match &result {
455                Ok(()) => {
456                    if let Some(delay) = reconnect_backoff.next_backoff() {
457                        tracing::info!(
458                            dest = %dest,
459                            session_id = session_id.0,
460                            delay_ms = delay.as_millis() as u64,
461                            "duplex send_connected returned EOF, reconnecting after backoff; {link_status}"
462                        );
463                        tokio::time::sleep(delay).await;
464                    }
465                    false
466                }
467                Err(Either::Send(e)) => {
468                    let terminal = log_send_error(e, &dest, session_id.0, "duplex", &link_status);
469                    if !terminal {
470                        // Recoverable send error — reconnect after backoff.
471                        if let Some(delay) = reconnect_backoff.next_backoff() {
472                            tracing::info!(
473                                dest = %dest,
474                                session_id = session_id.0,
475                                error = %e,
476                                delay_ms = delay.as_millis() as u64,
477                                mode = "duplex",
478                                "send error (recoverable), reconnecting after backoff; {link_status}",
479                            );
480                            tokio::time::sleep(delay).await;
481                        }
482                    }
483                    terminal
484                }
485                Err(Either::Recv(session::RecvLoopError::Io(err))) => {
486                    if let Some(delay) = reconnect_backoff.next_backoff() {
487                        tracing::info!(
488                            dest = %dest,
489                            session_id = session_id.0,
490                            error = %err,
491                            delay_ms = delay.as_millis() as u64,
492                            mode = "duplex",
493                            "recv error (recoverable), reconnecting after backoff; {link_status}",
494                        );
495                        tokio::time::sleep(delay).await;
496                    }
497                    metrics::CHANNEL_ERRORS.add(
498                        1,
499                        hyperactor_telemetry::kv_pairs!(
500                            "dest" => dest.to_string(),
501                            "session_id" => session_id.0.to_string(),
502                            "error_type" => metrics::ChannelErrorType::SendError.as_str(),
503                            "mode" => "duplex",
504                        ),
505                    );
506                    false
507                }
508                Err(Either::Recv(e)) => {
509                    tracing::info!(
510                        dest = %dest,
511                        session_id = session_id.0,
512                        error = %e,
513                        "duplex recv terminal error; {link_status}"
514                    );
515                    true
516                }
517            };
518            session = connected.release();
519            if terminal {
520                break;
521            }
522        }
523
524        let _ = notify.send(TxStatus::Closed("duplex session ended".into()));
525    });
526    (
527        DuplexTx::new(outbound_tx, addr.clone(), status),
528        DuplexRx::new(inbound_rx, addr),
529    )
530}
531
532/// Connect to a duplex server, returning tx and rx handles.
533pub fn dial<Out: RemoteMessage, In: RemoteMessage>(
534    addr: ChannelAddr,
535) -> Result<(DuplexTx<Out>, DuplexRx<In>), ClientError> {
536    Ok(spawn(super::link(addr)?))
537}
538
539#[cfg(test)]
540mod tests {
541    use timed_test::async_timed_test;
542
543    use super::*;
544    use crate::channel::ChannelTransport;
545
546    #[async_timed_test(timeout_secs = 30)]
547    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
548    #[cfg_attr(not(fbcode_build), ignore)]
549    async fn test_duplex_basic() {
550        let mut server =
551            serve::<u64, String>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
552        let server_addr = server.addr().clone();
553
554        // Client: sends u64, receives String.
555        let (client_tx, mut client_rx) = dial::<u64, String>(server_addr).unwrap();
556
557        // Server: receives u64, sends String.
558        let (mut server_rx, server_tx) = server.accept().await.unwrap();
559
560        // Client sends to server.
561        client_tx.post(42);
562        let received = server_rx.recv().await.unwrap();
563        assert_eq!(received, 42);
564
565        // Server sends to client.
566        server_tx.post("hello".to_string());
567        let received = client_rx.recv().await.unwrap();
568        assert_eq!(received, "hello");
569
570        // Multiple messages both ways.
571        for i in 0..10u64 {
572            client_tx.post(i);
573            assert_eq!(server_rx.recv().await.unwrap(), i);
574
575            server_tx.post(format!("msg-{}", i));
576            assert_eq!(client_rx.recv().await.unwrap(), format!("msg-{}", i));
577        }
578    }
579
580    #[async_timed_test(timeout_secs = 30)]
581    #[cfg_attr(not(fbcode_build), ignore)]
582    async fn test_duplex_multiple_links() {
583        let mut server = serve::<u64, u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
584        let server_addr = server.addr().clone();
585
586        // Two independent clients.
587        let (tx1, mut rx1) = dial::<u64, u64>(server_addr.clone()).unwrap();
588        let (mut srx1, stx1) = server.accept().await.unwrap();
589
590        let (tx2, mut rx2) = dial::<u64, u64>(server_addr).unwrap();
591        let (mut srx2, stx2) = server.accept().await.unwrap();
592
593        // Send on link 1.
594        tx1.post(100);
595        assert_eq!(srx1.recv().await.unwrap(), 100);
596        stx1.post(200);
597        assert_eq!(rx1.recv().await.unwrap(), 200);
598
599        // Send on link 2.
600        tx2.post(300);
601        assert_eq!(srx2.recv().await.unwrap(), 300);
602        stx2.post(400);
603        assert_eq!(rx2.recv().await.unwrap(), 400);
604    }
605
606    /// Ping-pong helper: server echoes back each message it receives.
607    /// Returns elapsed time for `iterations` round-trips.
608    async fn duplex_ping_pong(
609        addr: ChannelAddr,
610        iterations: usize,
611    ) -> anyhow::Result<std::time::Duration> {
612        let mut server = serve::<u64, u64>(addr)?;
613        let server_addr = server.addr().clone();
614
615        let server_handle = tokio::spawn(async move {
616            let (mut rx, tx) = server.accept().await.unwrap();
617            while let Ok(msg) = rx.recv().await {
618                tx.post(msg);
619            }
620        });
621
622        let (client_tx, mut client_rx) = dial::<u64, u64>(server_addr).unwrap();
623
624        // Warmup.
625        for i in 0..10u64 {
626            client_tx.post(i);
627            assert_eq!(client_rx.recv().await?, i);
628        }
629
630        let start = std::time::Instant::now();
631        for i in 0..iterations as u64 {
632            client_tx.post(i);
633            assert_eq!(client_rx.recv().await?, i);
634        }
635        let elapsed = start.elapsed();
636
637        server_handle.abort();
638        Ok(elapsed)
639    }
640
641    #[async_timed_test(timeout_secs = 30)]
642    #[cfg_attr(not(fbcode_build), ignore)]
643    async fn test_duplex_ping_pong_tcp() {
644        let elapsed = duplex_ping_pong(ChannelAddr::Tcp("[::1]:0".parse().unwrap()), 100)
645            .await
646            .unwrap();
647        println!("TCP duplex: 100 round-trips in {elapsed:?}");
648    }
649
650    #[async_timed_test(timeout_secs = 30)]
651    async fn test_duplex_ping_pong_unix() {
652        let elapsed = duplex_ping_pong(ChannelAddr::any(ChannelTransport::Unix), 100)
653            .await
654            .unwrap();
655        println!("Unix duplex: 100 round-trips in {elapsed:?}");
656    }
657}