hyperactor/channel/net/
duplex.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Duplex-mode channels over the net link layer.
10//!
11//! A single physical connection carries messages in both directions,
12//! each with independent sequence/ack state.
13//!
14//! ## Wire protocol
15//!
16//! Each connection starts with a unified `LinkInit` header (13 bytes,
17//! unframed) containing only the `session_id`:
18//!
19//! ```text
20//! [magic: 4B "LNK\0"] [session_id: 8B u64 BE]
21//! ```
22//!
23//! After the init, the standard tagged frame format is used. The tag
24//! byte in the 8-byte header distinguishes logical channels:
25//!
26//! - `INITIATOR_TO_ACCEPTOR = 0x00`
27//! - `ACCEPTOR_TO_INITIATOR = 0x01`
28
29use std::sync::Arc;
30
31use async_trait::async_trait;
32use dashmap::DashMap;
33use tokio::sync::mpsc;
34use tokio::sync::oneshot;
35use tokio::sync::watch;
36use tokio::time::Instant;
37use tokio_util::sync::CancellationToken;
38
39use super::ClientError;
40use super::Link;
41use super::ServerError;
42use super::SessionId;
43use super::log_send_error;
44use super::read_link_init;
45use super::server::AcceptorLink;
46use super::server::ServerHandle;
47use super::session;
48use super::session::Next;
49use super::session::Session;
50use crate::RemoteMessage;
51use crate::channel::ChannelAddr;
52use crate::channel::ChannelError;
53use crate::channel::ChannelTransport;
54use crate::channel::Rx;
55use crate::channel::SendError;
56use crate::channel::Tx;
57use crate::channel::TxStatus;
58use crate::channel::net::Stream;
59use crate::channel::net::meta;
60use crate::channel::net::tls;
61use crate::metrics;
62use crate::sync::mvar::MVar;
63
64/// Public duplex server that yields `(DuplexRx<In>, DuplexTx<Out>)` pairs.
65pub struct DuplexServer<In: RemoteMessage, Out: RemoteMessage> {
66    accept_rx: mpsc::Receiver<(DuplexRx<In>, DuplexTx<Out>)>,
67    _handle: ServerHandle,
68    addr: ChannelAddr,
69}
70
71impl<In: RemoteMessage, Out: RemoteMessage> DuplexServer<In, Out> {
72    /// Accept a new duplex link, returning `(rx, tx)` handles.
73    pub async fn accept(&mut self) -> Result<(DuplexRx<In>, DuplexTx<Out>), ChannelError> {
74        self.accept_rx.recv().await.ok_or(ChannelError::Closed)
75    }
76
77    /// The address this server is listening on.
78    pub fn addr(&self) -> &ChannelAddr {
79        &self.addr
80    }
81}
82
83/// Receiver half of a duplex channel.
84pub struct DuplexRx<M: RemoteMessage>(mpsc::Receiver<M>, ChannelAddr);
85
86impl<M: RemoteMessage> DuplexRx<M> {
87    pub(super) fn new(rx: mpsc::Receiver<M>, addr: ChannelAddr) -> Self {
88        Self(rx, addr)
89    }
90}
91
92#[async_trait]
93impl<M: RemoteMessage> Rx<M> for DuplexRx<M> {
94    async fn recv(&mut self) -> Result<M, ChannelError> {
95        self.0.recv().await.ok_or(ChannelError::Closed)
96    }
97
98    fn addr(&self) -> ChannelAddr {
99        self.1.clone()
100    }
101
102    async fn join(self) {}
103}
104
105/// Sender half of a duplex channel.
106pub struct DuplexTx<M: RemoteMessage> {
107    tx: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
108    addr: ChannelAddr,
109    status: watch::Receiver<TxStatus>,
110}
111
112impl<M: RemoteMessage> DuplexTx<M> {
113    pub(super) fn new(
114        tx: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
115        addr: ChannelAddr,
116        status: watch::Receiver<TxStatus>,
117    ) -> Self {
118        Self { tx, addr, status }
119    }
120}
121
122#[async_trait]
123impl<M: RemoteMessage> Tx<M> for DuplexTx<M> {
124    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
125        let return_channel = return_channel.unwrap_or_else(|| oneshot::channel().0);
126        if let Err(mpsc::error::SendError((message, return_channel, _))) =
127            self.tx
128                .send((message, return_channel, tokio::time::Instant::now()))
129        {
130            let _ = return_channel.send(SendError {
131                error: ChannelError::Closed,
132                message,
133                reason: None,
134            });
135        }
136    }
137
138    fn addr(&self) -> ChannelAddr {
139        self.addr.clone()
140    }
141
142    fn status(&self) -> &watch::Receiver<TxStatus> {
143        &self.status
144    }
145}
146
147impl<M: RemoteMessage> Clone for DuplexTx<M> {
148    fn clone(&self) -> Self {
149        Self {
150            tx: self.tx.clone(),
151            addr: self.addr.clone(),
152            status: self.status.clone(),
153        }
154    }
155}
156
157/// Start a duplex server on the given address.
158pub fn serve<In: RemoteMessage, Out: RemoteMessage>(
159    addr: ChannelAddr,
160) -> Result<DuplexServer<In, Out>, ServerError> {
161    let (mut listener, channel_addr) = super::listen(addr)?;
162
163    let (accept_tx, accept_rx) = mpsc::channel(16);
164    let cancel_token = CancellationToken::new();
165    let child_token = cancel_token.child_token();
166
167    let is_tls = matches!(
168        channel_addr.transport(),
169        ChannelTransport::Tls | ChannelTransport::MetaTls(_)
170    );
171    let dest = channel_addr.clone();
172    let prepare = move |stream: Box<dyn Stream>, source: ChannelAddr| {
173        let dest = dest.clone();
174        async move {
175            if is_tls {
176                let tls_acceptor = match dest.transport() {
177                    ChannelTransport::Tls => tls::tls_acceptor()?,
178                    _ => meta::tls_acceptor(true)?,
179                };
180                let mut tls_stream = tls_acceptor.accept(stream).await?;
181                let session_id = read_link_init(&mut tls_stream)
182                    .await
183                    .map_err(|e| anyhow::anyhow!("LinkInit read failed from {}: {}", source, e))?;
184                Ok((session_id, Box::new(tls_stream) as Box<dyn Stream>))
185            } else {
186                let mut stream = stream;
187                let session_id = read_link_init(&mut stream)
188                    .await
189                    .map_err(|e| anyhow::anyhow!("LinkInit read failed from {}: {}", source, e))?;
190                Ok((session_id, stream))
191            }
192        }
193    };
194
195    let sessions: Arc<DashMap<SessionId, MVar<Box<dyn Stream>>>> = Arc::new(DashMap::new());
196    let child_cancel = CancellationToken::new();
197    let dispatch_dest = channel_addr.clone();
198    let dispatch = {
199        let sessions = Arc::clone(&sessions);
200        let accept_tx = accept_tx.clone();
201        let child_cancel = child_cancel.clone();
202        let dest = dispatch_dest;
203        move |session_id: SessionId, stream: Box<dyn Stream>| {
204            let sessions = Arc::clone(&sessions);
205            let accept_tx = accept_tx.clone();
206            let cancel = child_cancel.child_token();
207            let dest = dest.clone();
208            async move {
209                dispatch_duplex_stream::<In, Out>(
210                    session_id, stream, &sessions, dest, &accept_tx, cancel,
211                )
212                .await;
213            }
214        }
215    };
216
217    let ca = channel_addr.clone();
218    let join_handle = tokio::spawn(async move {
219        let result =
220            super::server::accept_loop(&mut listener, &ca, &child_token, prepare, dispatch).await;
221        child_cancel.cancel();
222        result
223    });
224
225    let server_handle = ServerHandle::new(join_handle, cancel_token, channel_addr.clone());
226
227    Ok(DuplexServer {
228        accept_rx,
229        _handle: server_handle,
230        addr: channel_addr,
231    })
232}
233
234/// Helper to distinguish send errors from recv errors in duplex select.
235enum Either {
236    Send(session::SendLoopError),
237    Recv(session::RecvLoopError),
238}
239
240/// Dispatch a stream to the appropriate duplex session, creating one
241/// if this is the first connection for the given session ID.
242async fn dispatch_duplex_stream<In: RemoteMessage, Out: RemoteMessage>(
243    session_id: SessionId,
244    stream: Box<dyn Stream>,
245    sessions: &DashMap<SessionId, MVar<Box<dyn Stream>>>,
246    addr: ChannelAddr,
247    accept_tx: &mpsc::Sender<(DuplexRx<In>, DuplexTx<Out>)>,
248    cancel: CancellationToken,
249) {
250    let mvar = {
251        let entry = sessions.entry(session_id);
252        match entry {
253            dashmap::mapref::entry::Entry::Occupied(e) => e.get().clone(),
254            dashmap::mapref::entry::Entry::Vacant(e) => {
255                let mvar: MVar<Box<dyn Stream>> = MVar::empty();
256                let link = AcceptorLink {
257                    dest: addr.clone(),
258                    session_id,
259                    stream: mvar.clone(),
260                    cancel: cancel.clone(),
261                };
262
263                let (inbound_tx, inbound_rx) = mpsc::channel::<In>(1024);
264                let (outbound_tx, outbound_rx) =
265                    mpsc::unbounded_channel::<(Out, oneshot::Sender<SendError<Out>>, Instant)>();
266                let (notify, status) = watch::channel(TxStatus::Active);
267                let net_rx = DuplexRx(inbound_rx, addr.clone());
268                let net_tx = DuplexTx {
269                    tx: outbound_tx,
270                    addr: addr.clone(),
271                    status,
272                };
273                let _ = accept_tx.send((net_rx, net_tx)).await;
274
275                let session_ct = cancel.clone();
276                let dest = addr.clone();
277                tokio::spawn(async move {
278                    let mut session = Session::new(link);
279                    let mut recv_next = Next { seq: 0, ack: 0 };
280                    let log_id = format!("duplex server {:016x}", session_id.0);
281                    let mut deliveries = session::Deliveries {
282                        outbox: session::Outbox::new(log_id.clone(), dest, session_id.0),
283                        unacked: session::Unacked::new(None, log_id),
284                    };
285                    let mut outbound_rx = outbound_rx;
286
287                    loop {
288                        let connected = match session.connect().await {
289                            Ok(s) => s,
290                            Err(_) => break,
291                        };
292                        deliveries.requeue_unacked();
293                        let result = {
294                            let recv_stream = connected.stream(super::INITIATOR_TO_ACCEPTOR);
295                            let send_stream = connected.stream(super::ACCEPTOR_TO_INITIATOR);
296                            tokio::select! {
297                                r = session::recv_connected::<In, _, _>(
298                                    &recv_stream,
299                                    &inbound_tx,
300                                    &mut recv_next,
301                                ) => r.map_err(Either::Recv),
302                                r = session::send_connected(
303                                    &send_stream,
304                                    &mut deliveries,
305                                    &mut outbound_rx,
306                                ) => r.map_err(Either::Send),
307                                _ = session_ct.cancelled() => Err(Either::Recv(session::RecvLoopError::Cancelled)),
308                            }
309                        };
310
311                        let terminal = match &result {
312                            Ok(()) => false,
313                            Err(Either::Send(session::SendLoopError::Io(err))) => {
314                                tracing::info!(
315                                    session_id = session_id.0,
316                                    error = %err,
317                                    "duplex send/recv error",
318                                );
319                                false
320                            }
321                            Err(Either::Recv(session::RecvLoopError::Io(_))) => false,
322                            _ => true,
323                        };
324                        session = connected.release();
325                        if terminal {
326                            break;
327                        }
328                    }
329
330                    let _ = notify.send(TxStatus::Closed);
331                });
332
333                e.insert(mvar.clone());
334                mvar
335            }
336        }
337    };
338
339    mvar.put(stream).await;
340}
341
342/// Establish a duplex (bidirectional) session over the given link.
343/// Returns send and receive handles.
344pub(crate) fn spawn<Out: RemoteMessage, In: RemoteMessage>(
345    link: impl Link,
346) -> (DuplexTx<Out>, DuplexRx<In>) {
347    let addr = link.dest();
348    let session_id = link.link_id();
349    let (outbound_tx, outbound_rx) = tokio::sync::mpsc::unbounded_channel();
350    let (inbound_tx, inbound_rx) = tokio::sync::mpsc::channel::<In>(1024);
351    let (notify, status) = watch::channel(TxStatus::Active);
352    let dest = addr.clone();
353    crate::init::get_runtime().spawn(async move {
354        let mut session = Session::new(link);
355        let log_id = format!("session {}.{:016x}", dest, session_id.0);
356        let mut deliveries = session::Deliveries {
357            outbox: session::Outbox::new(log_id.clone(), dest.clone(), session_id.0),
358            unacked: session::Unacked::new(None, log_id),
359        };
360        let mut outbound_rx = outbound_rx;
361        let mut recv_next = Next { seq: 0, ack: 0 };
362
363        loop {
364            let connected = match session.connect().await {
365                Ok(s) => s,
366                Err(_) => break,
367            };
368
369            metrics::CHANNEL_CONNECTIONS.add(
370                1,
371                hyperactor_telemetry::kv_pairs!(
372                    "transport" => dest.transport().to_string(),
373                    "mode" => "duplex",
374                    "reason" => "link connected",
375                ),
376            );
377
378            if !deliveries.unacked.is_empty() {
379                metrics::CHANNEL_RECONNECTIONS.add(
380                    1,
381                    hyperactor_telemetry::kv_pairs!(
382                        "dest" => dest.to_string(),
383                        "transport" => dest.transport().to_string(),
384                        "mode" => "duplex",
385                        "reason" => "reconnect_with_unacked",
386                    ),
387                );
388            }
389            deliveries.requeue_unacked();
390            let result = {
391                let send_stream = connected.stream(super::INITIATOR_TO_ACCEPTOR);
392                let recv_stream = connected.stream(super::ACCEPTOR_TO_INITIATOR);
393                tokio::select! {
394                    r = session::send_connected(
395                        &send_stream, &mut deliveries, &mut outbound_rx,
396                    ) => r.map_err(Either::Send),
397                    r = session::recv_connected::<In, _, _>(
398                        &recv_stream, &inbound_tx, &mut recv_next,
399                    ) => r.map_err(Either::Recv),
400                }
401            };
402
403            let terminal = match &result {
404                Ok(()) => false, // EOF — reconnect
405                Err(Either::Send(e)) => log_send_error(e, &dest, session_id.0, "duplex"),
406                Err(Either::Recv(session::RecvLoopError::Io(err))) => {
407                    tracing::info!(
408                        dest = %dest,
409                        session_id = session_id.0,
410                        error = %err,
411                        mode = "duplex",
412                        "recv error",
413                    );
414                    metrics::CHANNEL_ERRORS.add(
415                        1,
416                        hyperactor_telemetry::kv_pairs!(
417                            "dest" => dest.to_string(),
418                            "session_id" => session_id.0.to_string(),
419                            "error_type" => metrics::ChannelErrorType::SendError.as_str(),
420                            "mode" => "duplex",
421                        ),
422                    );
423                    false
424                }
425                _ => true, // terminal
426            };
427            session = connected.release();
428            if terminal {
429                break;
430            }
431        }
432
433        let _ = notify.send(TxStatus::Closed);
434    });
435    (
436        DuplexTx::new(outbound_tx, addr.clone(), status),
437        DuplexRx::new(inbound_rx, addr),
438    )
439}
440
441/// Connect to a duplex server, returning tx and rx handles.
442pub fn dial<Out: RemoteMessage, In: RemoteMessage>(
443    addr: ChannelAddr,
444) -> Result<(DuplexTx<Out>, DuplexRx<In>), ClientError> {
445    Ok(spawn(super::link(addr)?))
446}
447
448#[cfg(test)]
449mod tests {
450    use timed_test::async_timed_test;
451
452    use super::*;
453    use crate::channel::ChannelTransport;
454
455    #[async_timed_test(timeout_secs = 30)]
456    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
457    #[cfg_attr(not(fbcode_build), ignore)]
458    async fn test_duplex_basic() {
459        let mut server =
460            serve::<u64, String>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
461        let server_addr = server.addr().clone();
462
463        // Client: sends u64, receives String.
464        let (client_tx, mut client_rx) = dial::<u64, String>(server_addr).unwrap();
465
466        // Server: receives u64, sends String.
467        let (mut server_rx, server_tx) = server.accept().await.unwrap();
468
469        // Client sends to server.
470        client_tx.post(42);
471        let received = server_rx.recv().await.unwrap();
472        assert_eq!(received, 42);
473
474        // Server sends to client.
475        server_tx.post("hello".to_string());
476        let received = client_rx.recv().await.unwrap();
477        assert_eq!(received, "hello");
478
479        // Multiple messages both ways.
480        for i in 0..10u64 {
481            client_tx.post(i);
482            assert_eq!(server_rx.recv().await.unwrap(), i);
483
484            server_tx.post(format!("msg-{}", i));
485            assert_eq!(client_rx.recv().await.unwrap(), format!("msg-{}", i));
486        }
487    }
488
489    #[async_timed_test(timeout_secs = 30)]
490    #[cfg_attr(not(fbcode_build), ignore)]
491    async fn test_duplex_multiple_links() {
492        let mut server = serve::<u64, u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
493        let server_addr = server.addr().clone();
494
495        // Two independent clients.
496        let (tx1, mut rx1) = dial::<u64, u64>(server_addr.clone()).unwrap();
497        let (mut srx1, stx1) = server.accept().await.unwrap();
498
499        let (tx2, mut rx2) = dial::<u64, u64>(server_addr).unwrap();
500        let (mut srx2, stx2) = server.accept().await.unwrap();
501
502        // Send on link 1.
503        tx1.post(100);
504        assert_eq!(srx1.recv().await.unwrap(), 100);
505        stx1.post(200);
506        assert_eq!(rx1.recv().await.unwrap(), 200);
507
508        // Send on link 2.
509        tx2.post(300);
510        assert_eq!(srx2.recv().await.unwrap(), 300);
511        stx2.post(400);
512        assert_eq!(rx2.recv().await.unwrap(), 400);
513    }
514
515    /// Ping-pong helper: server echoes back each message it receives.
516    /// Returns elapsed time for `iterations` round-trips.
517    async fn duplex_ping_pong(
518        addr: ChannelAddr,
519        iterations: usize,
520    ) -> anyhow::Result<std::time::Duration> {
521        let mut server = serve::<u64, u64>(addr)?;
522        let server_addr = server.addr().clone();
523
524        let server_handle = tokio::spawn(async move {
525            let (mut rx, tx) = server.accept().await.unwrap();
526            while let Ok(msg) = rx.recv().await {
527                tx.post(msg);
528            }
529        });
530
531        let (client_tx, mut client_rx) = dial::<u64, u64>(server_addr).unwrap();
532
533        // Warmup.
534        for i in 0..10u64 {
535            client_tx.post(i);
536            assert_eq!(client_rx.recv().await?, i);
537        }
538
539        let start = std::time::Instant::now();
540        for i in 0..iterations as u64 {
541            client_tx.post(i);
542            assert_eq!(client_rx.recv().await?, i);
543        }
544        let elapsed = start.elapsed();
545
546        server_handle.abort();
547        Ok(elapsed)
548    }
549
550    #[async_timed_test(timeout_secs = 30)]
551    #[cfg_attr(not(fbcode_build), ignore)]
552    async fn test_duplex_ping_pong_tcp() {
553        let elapsed = duplex_ping_pong(ChannelAddr::Tcp("[::1]:0".parse().unwrap()), 100)
554            .await
555            .unwrap();
556        println!("TCP duplex: 100 round-trips in {elapsed:?}");
557    }
558
559    #[async_timed_test(timeout_secs = 30)]
560    async fn test_duplex_ping_pong_unix() {
561        let elapsed = duplex_ping_pong(ChannelAddr::any(ChannelTransport::Unix), 100)
562            .await
563            .unwrap();
564        println!("Unix duplex: 100 round-trips in {elapsed:?}");
565    }
566}