hyperactor/channel/
net.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! A simple socket channel implementation using a single-stream
10//! framing protocol. Each frame is encoded as an 8-byte
11//! **big-endian** length prefix (u64), followed by exactly that many
12//! bytes of payload.
13//!
14//! Message frames carry a `serde_multipart::Message` (not raw
15//! bincode). In compat mode (current default), this is encoded as a
16//! sentinel `u64::MAX` followed by a single bincode payload. Response frames
17//! are a bincode-serialized NetRxResponse enum, containing either the acked
18//! sequence number, or the Reject value indicating that the server rejected
19//! the connection.
20//!
21//! Message frame (compat/unipart) example:
22//! ```text
23//! +------------------ len: u64 (BE) ------------------+----------------------- data -----------------------+
24//! | \x00\x00\x00\x00\x00\x00\x00\x10                  | \xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF | <bincode bytes> |
25//! |                       16                          |           u64::MAX             |                   |
26//! +---------------------------------------------------+-----------------------------------------------------+
27//! ```
28//!
29//! Response frame (wire format):
30//! ```text
31//! +------------------ len: u64 (BE) ------------------+---------------- data ------------------+
32//! | \x00\x00\x00\x00\x00\x00\x00\x??                  | <bincode acked sequence num or reject> |
33//! +---------------------------------------------------+----------------------------------------+
34//! ```
35//!
36//! I/O is handled by `FrameReader`/`FrameWrite`, which are
37//! cancellation-safe and avoid extra copies. Helper fns
38//! `serialize_response(NetRxResponse) -> Result<Bytes, bincode::Error>`
39//! and `deserialize_response(Bytes) -> Result<NetRxResponse, bincode::Error>`
40//! convert to/from the response payload.
41//!
42//! ### Limits & EOF semantics
43//! * **Max frame size:** frames larger than
44//!   `config::CODEC_MAX_FRAME_LENGTH` are rejected with
45//!   `io::ErrorKind::InvalidData`.
46//! * **EOF handling:** `FrameReader::next()` returns `Ok(None)` only
47//!   when EOF occurs exactly on a frame boundary. If EOF happens
48//!   mid-frame, it returns `Err(io::ErrorKind::UnexpectedEof)`.
49
50use std::fmt;
51use std::fmt::Debug;
52use std::net::ToSocketAddrs;
53use std::time::Duration;
54
55use backoff::ExponentialBackoffBuilder;
56use backoff::backoff::Backoff;
57use bytes::Bytes;
58use enum_as_inner::EnumAsInner;
59use serde::Deserialize;
60use serde::Serialize;
61use serde::de::Error;
62use tokio::io::AsyncRead;
63use tokio::io::AsyncReadExt;
64use tokio::io::AsyncWrite;
65use tokio::io::AsyncWriteExt;
66use tokio::sync::watch;
67use tokio::time::Instant;
68
69use super::*;
70use crate::RemoteMessage;
71
72pub mod duplex;
73mod framed;
74pub(super) mod server;
75pub(super) mod session;
76pub use server::ServerHandle;
77
78pub(crate) trait Stream:
79    AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug + 'static
80{
81}
82impl<S: AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug + 'static> Stream for S {}
83
84/// Opaque identifier for a session. Generated by the client,
85/// sent to the server on each connect so the server can correlate
86/// reconnections to the same logical session.
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
88pub(crate) struct SessionId(u64);
89
90impl SessionId {
91    /// Generate a new random session ID.
92    pub fn random() -> Self {
93        Self(rand::random())
94    }
95}
96
97impl fmt::Display for SessionId {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        write!(f, "{:016x}", self.0)
100    }
101}
102
103/// Logical channel tag for initiator→acceptor traffic.
104pub(crate) const INITIATOR_TO_ACCEPTOR: u8 = 0;
105
106/// Logical channel tag for acceptor→initiator traffic.
107pub(crate) const ACCEPTOR_TO_INITIATOR: u8 = 1;
108
109/// Fixed-size header sent at the start of every physical connection.
110/// This is written/read directly on the wire (not framed), before
111/// any session framing begins.
112///
113/// Wire format (12 bytes, big-endian):
114/// ```text
115/// [magic: 4B "LNK\0"] [session_id: 8B u64 BE]
116/// ```
117const LINK_INIT_MAGIC: [u8; 4] = *b"LNK\0";
118const LINK_INIT_SIZE: usize = 4 + 8;
119
120/// Write a LinkInit header to the stream.
121async fn write_link_init<S: AsyncWrite + Unpin>(
122    stream: &mut S,
123    session_id: SessionId,
124) -> Result<(), std::io::Error> {
125    let mut buf = [0u8; LINK_INIT_SIZE];
126    buf[0..4].copy_from_slice(&LINK_INIT_MAGIC);
127    buf[4..12].copy_from_slice(&session_id.0.to_be_bytes());
128    stream.write_all(&buf).await
129}
130
131/// Read a LinkInit header from the stream.
132async fn read_link_init<S: AsyncRead + Unpin>(stream: &mut S) -> Result<SessionId, std::io::Error> {
133    let mut buf = [0u8; LINK_INIT_SIZE];
134    stream.read_exact(&mut buf).await?;
135    if buf[0..4] != LINK_INIT_MAGIC {
136        return Err(std::io::Error::new(
137            std::io::ErrorKind::InvalidData,
138            format!(
139                "invalid LinkInit magic: expected {:?}, got {:?}",
140                LINK_INIT_MAGIC,
141                &buf[0..4]
142            ),
143        ));
144    }
145    let session_id = SessionId(u64::from_be_bytes(buf[4..12].try_into().unwrap()));
146    Ok(session_id)
147}
148
149/// Link represents a network link through which connections may be
150/// acquired. The session ID is baked in. Initiator links dial;
151/// acceptor links wait for dispatched streams.
152#[async_trait]
153pub(crate) trait Link: Send + Sync + Debug + 'static {
154    /// The underlying stream type.
155    type Stream: Stream;
156
157    /// The address of the link's destination.
158    fn dest(&self) -> ChannelAddr;
159
160    /// The session ID for this link.
161    fn link_id(&self) -> SessionId;
162
163    /// Acquire the next usable connection. For initiator links this
164    /// dials; for acceptor links this waits on a dispatch channel.
165    async fn next(&self) -> Result<Self::Stream, ClientError>;
166}
167
168use session::Session;
169
170use crate::config;
171use crate::metrics;
172
173pub(crate) enum LinkStatus {
174    NeverConnected,
175    Connected(tokio::time::Instant),
176    Disconnected {
177        last_connected: tokio::time::Instant,
178        since: tokio::time::Instant,
179    },
180}
181
182impl LinkStatus {
183    fn connected(&mut self) {
184        *self = LinkStatus::Connected(tokio::time::Instant::now());
185    }
186
187    fn disconnected(&mut self) {
188        match *self {
189            LinkStatus::Connected(at) => {
190                *self = LinkStatus::Disconnected {
191                    last_connected: at,
192                    since: tokio::time::Instant::now(),
193                };
194            }
195            // Already disconnected or never connected — leave as is.
196            _ => {}
197        }
198    }
199}
200
201impl std::fmt::Display for LinkStatus {
202    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203        match self {
204            LinkStatus::NeverConnected => write!(f, "never connected"),
205            LinkStatus::Connected(at) => {
206                write!(f, "connected for {:.1}s", at.elapsed().as_secs_f64())
207            }
208            LinkStatus::Disconnected {
209                last_connected,
210                since,
211            } => {
212                write!(
213                    f,
214                    "last connected {:.1}s ago, disconnected for {:.1}s",
215                    last_connected.elapsed().as_secs_f64(),
216                    since.elapsed().as_secs_f64(),
217                )
218            }
219        }
220    }
221}
222
223/// Log a send-loop error and return `true` if the error is terminal
224/// (caller should exit), `false` if recoverable (caller should reconnect).
225fn log_send_error(
226    error: &session::SendLoopError,
227    dest: &ChannelAddr,
228    session_id: u64,
229    mode: &str,
230    link_status: &LinkStatus,
231) -> bool {
232    match error {
233        session::SendLoopError::Io(err) => {
234            tracing::info!(dest = %dest, session_id, error = %err, mode, "send error; {link_status}");
235            metrics::CHANNEL_ERRORS.add(
236                1,
237                hyperactor_telemetry::kv_pairs!(
238                    "dest" => dest.to_string(),
239                    "session_id" => session_id.to_string(),
240                    "error_type" => metrics::ChannelErrorType::SendError.as_str(),
241                    "mode" => mode.to_string(),
242                ),
243            );
244            false
245        }
246        session::SendLoopError::AppClosed => true,
247        session::SendLoopError::Rejected(reason) => {
248            tracing::error!(dest = %dest, session_id, mode, "server rejected connection: {reason}; {link_status}");
249            true
250        }
251        session::SendLoopError::ServerClosed => {
252            tracing::info!(dest = %dest, session_id, mode, "server closed the channel; {link_status}");
253            true
254        }
255        session::SendLoopError::DeliveryTimeout => {
256            let timeout = hyperactor_config::global::get(config::MESSAGE_DELIVERY_TIMEOUT);
257            tracing::error!(
258                dest = %dest, session_id, mode,
259                "failed to receive ack within timeout {timeout:?}; link is currently connected; {link_status}"
260            );
261            true
262        }
263        session::SendLoopError::OversizedFrame(reason) => {
264            tracing::error!(dest = %dest, session_id, mode, "oversized frame: {reason}; {link_status}");
265            true
266        }
267    }
268}
269
270/// Establish a simplex (send-only) session over the given link. Returns a send handle.
271pub(crate) fn spawn<M: RemoteMessage>(link: impl Link) -> NetTx<M> {
272    let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
273    let dest = link.dest();
274    let session_id = link.link_id();
275    let (notify, status) = watch::channel(TxStatus::Active);
276    let tx = NetTx {
277        sender,
278        dest: dest.clone(),
279        status,
280    };
281    crate::init::get_runtime().spawn(async move {
282        let mut session = Session::new(link);
283        let log_id = format!("session {}.{:016x}", dest, session_id.0);
284        let mut deliveries = session::Deliveries {
285            outbox: session::Outbox::new(log_id.clone(), dest.clone(), session_id.0),
286            unacked: session::Unacked::new(None, log_id.clone()),
287        };
288        let mut receiver = receiver;
289
290        // Lazy connect: wait for first message.
291        match receiver.recv().await {
292            Some(msg) => {
293                if let Err(err) = deliveries.outbox.push_back(msg) {
294                    tracing::error!(
295                        dest = %dest,
296                        session_id = session_id.0,
297                        error = %err,
298                        "failed to push message to outbox"
299                    );
300                    let _ = notify.send(TxStatus::Closed("failed to push to outbox".into()));
301                    return;
302                }
303            }
304            None => {
305                let _ = notify.send(TxStatus::Closed("sender dropped".into()));
306                return;
307            }
308        }
309
310        let mut reconnect_backoff = ExponentialBackoffBuilder::new()
311            .with_initial_interval(Duration::from_millis(10))
312            .with_multiplier(2.0)
313            .with_randomization_factor(0.1)
314            .with_max_interval(Duration::from_secs(5))
315            .with_max_elapsed_time(None)
316            .build();
317
318        let mut link_status = LinkStatus::NeverConnected;
319
320        let reason: String = 'outer: loop {
321            let connected = match deliveries.expiry_time() {
322                Some(deadline) => match session.connect_by(deadline).await {
323                    Ok(s) => s,
324                    Err(_) => {
325                        let timeout =
326                            hyperactor_config::global::get(config::MESSAGE_DELIVERY_TIMEOUT);
327                        let error_msg = if deliveries.outbox.is_expired(timeout) {
328                            format!("failed to deliver message within timeout {timeout:?}; {link_status}")
329                        } else {
330                            format!(
331                                "failed to receive ack within timeout {timeout:?}; \
332                                 link is currently broken; {link_status}",
333                            )
334                        };
335                        tracing::error!(
336                            dest = %dest, session_id = session_id.0, "{}", error_msg
337                        );
338                        break 'outer format!("{log_id}: {error_msg}");
339                    }
340                },
341                None => match session.connect().await {
342                    Ok(s) => s,
343                    Err(_) => break 'outer "session shut down".into(),
344                },
345            };
346
347            metrics::CHANNEL_CONNECTIONS.add(
348                1,
349                hyperactor_telemetry::kv_pairs!(
350                    "transport" => dest.transport().to_string(),
351                    "mode" => "simplex",
352                    "reason" => "link connected",
353                ),
354            );
355
356            if !deliveries.unacked.is_empty() {
357                metrics::CHANNEL_RECONNECTIONS.add(
358                    1,
359                    hyperactor_telemetry::kv_pairs!(
360                        "dest" => dest.to_string(),
361                        "transport" => dest.transport().to_string(),
362                        "mode" => "simplex",
363                        "reason" => "reconnect_with_unacked",
364                    ),
365                );
366            }
367            deliveries.requeue_unacked();
368
369            link_status.connected();
370            let connected_at = tokio::time::Instant::now();
371
372            let result = {
373                let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
374                session::send_connected(&stream, &mut deliveries, &mut receiver).await
375            };
376            session = connected.release();
377
378            link_status.disconnected();
379
380            // Reset backoff if the connection was alive long enough to have
381            // been useful (i.e. not an immediate EOF/error).
382            if connected_at.elapsed() > Duration::from_secs(1) {
383                reconnect_backoff.reset();
384            }
385
386            match result {
387                Ok(()) => {
388                    // EOF — connection closed normally, reconnect after backoff.
389                    if let Some(delay) = reconnect_backoff.next_backoff() {
390                        tracing::info!(
391                            dest = %dest,
392                            session_id = session_id.0,
393                            delay_ms = delay.as_millis() as u64,
394                            "send_connected returned EOF, reconnecting after backoff; {link_status}"
395                        );
396                        tokio::time::sleep(delay).await;
397                    }
398                }
399                Err(ref e) => {
400                    if log_send_error(e, &dest, session_id.0, "simplex", &link_status) {
401                        break 'outer format!("{log_id}: {e}");
402                    }
403                    // Recoverable error — reconnect after backoff.
404                    if let Some(delay) = reconnect_backoff.next_backoff() {
405                        tracing::info!(
406                            dest = %dest,
407                            session_id = session_id.0,
408                            delay_ms = delay.as_millis() as u64,
409                            error = %e,
410                            "send_connected returned recoverable error, reconnecting after backoff; {link_status}"
411                        );
412                        tokio::time::sleep(delay).await;
413                    }
414                }
415            }
416        };
417
418        tracing::info!(
419            dest = %dest, session_id = session_id.0, "NetTx closing: {reason}"
420        );
421
422        receiver.close();
423        deliveries
424            .unacked
425            .deque
426            .drain(..)
427            .chain(deliveries.outbox.deque.drain(..))
428            .for_each(|queued| queued.try_return(Some(reason.clone())));
429        while let Ok((msg, return_channel, _)) = receiver.try_recv() {
430            let _ = return_channel.send(SendError {
431                error: ChannelError::Closed,
432                message: msg,
433                reason: Some(reason.clone()),
434            });
435        }
436
437        let _ = notify.send(TxStatus::Closed(reason.into()));
438    });
439    tx
440}
441
442/// Transport-agnostic link that dispatches to the appropriate
443/// transport based on the channel address.
444#[derive(Debug)]
445pub(crate) enum NetLink {
446    Tcp(tcp::TcpLink),
447    Unix(unix::UnixLink),
448    Tls(tls::TlsLink),
449}
450
451/// Create a link for the given channel address.
452pub(crate) fn link(addr: ChannelAddr) -> Result<NetLink, ClientError> {
453    match addr {
454        ChannelAddr::Tcp(socket_addr) => Ok(NetLink::Tcp(tcp::link(socket_addr))),
455        ChannelAddr::Unix(unix_addr) => Ok(NetLink::Unix(unix::link(unix_addr))),
456        ChannelAddr::Tls(tls_addr) => Ok(NetLink::Tls(tls::link(tls_addr)?)),
457        ChannelAddr::MetaTls(meta_addr) => Ok(NetLink::Tls(meta::link(meta_addr)?)),
458        other => Err(ClientError::Connect(
459            other,
460            std::io::Error::other("unsupported transport"),
461            "unsupported transport".into(),
462        )),
463    }
464}
465
466#[async_trait]
467impl Link for NetLink {
468    type Stream = Box<dyn Stream>;
469
470    fn dest(&self) -> ChannelAddr {
471        match self {
472            Self::Tcp(l) => l.dest(),
473            Self::Unix(l) => l.dest(),
474            Self::Tls(l) => l.dest(),
475        }
476    }
477
478    fn link_id(&self) -> SessionId {
479        match self {
480            Self::Tcp(l) => l.link_id(),
481            Self::Unix(l) => l.link_id(),
482            Self::Tls(l) => l.link_id(),
483        }
484    }
485
486    async fn next(&self) -> Result<Box<dyn Stream>, ClientError> {
487        match self {
488            Self::Tcp(l) => Ok(Box::new(l.next().await?)),
489            Self::Unix(l) => Ok(Box::new(l.next().await?)),
490            Self::Tls(l) => Ok(Box::new(l.next().await?)),
491        }
492    }
493}
494
495/// Listener represents the server side of a network link: it accepts inbound connections.
496///
497/// This is the counterpart to [`Link`]. Each transport module (tcp, unix, tls)
498/// provides both a `Link` impl (for dialing) and a `Listener` impl (for accepting).
499#[async_trait]
500pub(crate) trait Listener: Send + Unpin + 'static {
501    /// The underlying stream type produced by accepting a connection.
502    type Stream: Stream;
503
504    /// Accept the next inbound connection, returning the stream and the peer's address.
505    async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError>;
506}
507
508/// Transport-agnostic listener that dispatches to the appropriate
509/// transport based on the channel address. TLS has no variant — it
510/// uses `Tcp` under the hood (the TLS handshake happens in `prepare`,
511/// not the listener).
512#[derive(Debug)]
513pub(crate) enum NetListener {
514    Tcp(tcp::TcpSocketListener),
515    Unix(unix::UnixSocketListener),
516}
517
518#[async_trait]
519impl Listener for NetListener {
520    type Stream = Box<dyn Stream>;
521
522    async fn accept(&mut self) -> Result<(Box<dyn Stream>, ChannelAddr), ServerError> {
523        match self {
524            Self::Tcp(l) => {
525                let (stream, addr) = l.accept().await?;
526                Ok((Box::new(stream), addr))
527            }
528            Self::Unix(l) => {
529                let (stream, addr) = l.accept().await?;
530                Ok((Box::new(stream), addr))
531            }
532        }
533    }
534}
535
536/// Bind a listener for the given channel address. Returns the listener
537/// and the canonical address callers should advertise (which encodes
538/// the transport — e.g. `ChannelAddr::Tls` for TLS).
539pub(crate) fn listen(addr: ChannelAddr) -> Result<(NetListener, ChannelAddr), ServerError> {
540    match addr {
541        ChannelAddr::Tcp(socket_addr) => {
542            let std_listener = std::net::TcpListener::bind(socket_addr)
543                .map_err(|err| ServerError::Listen(ChannelAddr::Tcp(socket_addr), err))?;
544            std_listener
545                .set_nonblocking(true)
546                .map_err(|e| ServerError::Listen(ChannelAddr::Tcp(socket_addr), e))?;
547            let tokio_listener = tokio::net::TcpListener::from_std(std_listener)
548                .map_err(|e| ServerError::Listen(ChannelAddr::Tcp(socket_addr), e))?;
549            let local_addr = tokio_listener
550                .local_addr()
551                .map_err(|err| ServerError::Resolve(ChannelAddr::Tcp(socket_addr), err))?;
552            let listener = tcp::TcpSocketListener {
553                inner: tokio_listener,
554                addr: local_addr,
555            };
556            Ok((NetListener::Tcp(listener), ChannelAddr::Tcp(local_addr)))
557        }
558        ChannelAddr::Unix(ref unix_addr) => {
559            use std::os::unix::net::UnixDatagram as StdUnixDatagram;
560            use std::os::unix::net::UnixListener as StdUnixListener;
561
562            let caddr = addr.clone();
563            let maybe_listener = match unix_addr {
564                unix::SocketAddr::Bound(sock_addr) => StdUnixListener::bind_addr(sock_addr),
565                unix::SocketAddr::Unbound => StdUnixDatagram::unbound()
566                    .and_then(|u| u.local_addr())
567                    .and_then(|uaddr| StdUnixListener::bind_addr(&uaddr)),
568            };
569            let std_listener =
570                maybe_listener.map_err(|err| ServerError::Listen(caddr.clone(), err))?;
571            std_listener
572                .set_nonblocking(true)
573                .map_err(|err| ServerError::Listen(caddr.clone(), err))?;
574            let local_addr = std_listener
575                .local_addr()
576                .map_err(|err| ServerError::Resolve(caddr.clone(), err))?;
577            let tokio_listener = tokio::net::UnixListener::from_std(std_listener)
578                .map_err(|err| ServerError::Io(caddr, err))?;
579            let bound_addr = unix::SocketAddr::new(local_addr);
580            let listener = unix::UnixSocketListener {
581                inner: tokio_listener,
582                addr: bound_addr.clone(),
583            };
584            Ok((NetListener::Unix(listener), ChannelAddr::Unix(bound_addr)))
585        }
586        addr @ (ChannelAddr::Tls(_) | ChannelAddr::MetaTls(_)) => {
587            let is_meta = matches!(addr, ChannelAddr::MetaTls(_));
588            let tls_addr = match addr {
589                ChannelAddr::Tls(a) | ChannelAddr::MetaTls(a) => a,
590                _ => unreachable!(),
591            };
592            let TlsAddr { hostname, port } = tls_addr;
593            let make_channel_addr = |h: &str, p: Port| {
594                if is_meta {
595                    ChannelAddr::MetaTls(TlsAddr::new(h, p))
596                } else {
597                    ChannelAddr::Tls(TlsAddr::new(h, p))
598                }
599            };
600
601            let addrs: Vec<core::net::SocketAddr> = (hostname.as_ref(), port)
602                .to_socket_addrs()
603                .map_err(|err| ServerError::Resolve(make_channel_addr(&hostname, port), err))?
604                .collect();
605
606            if addrs.is_empty() {
607                return Err(ServerError::Resolve(
608                    make_channel_addr(&hostname, port),
609                    std::io::Error::other("no available socket addr"),
610                ));
611            }
612
613            let channel_addr = make_channel_addr(&hostname, port);
614            let std_listener = std::net::TcpListener::bind(&addrs[..])
615                .map_err(|err| ServerError::Listen(channel_addr.clone(), err))?;
616            std_listener
617                .set_nonblocking(true)
618                .map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
619            let tokio_listener = tokio::net::TcpListener::from_std(std_listener)
620                .map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
621            let local_addr = tokio_listener
622                .local_addr()
623                .map_err(|err| ServerError::Resolve(channel_addr, err))?;
624            let listener = tcp::TcpSocketListener {
625                inner: tokio_listener,
626                addr: local_addr,
627            };
628            Ok((
629                NetListener::Tcp(listener),
630                make_channel_addr(&hostname, local_addr.port()),
631            ))
632        }
633        other => Err(ServerError::Listen(
634            other.clone(),
635            std::io::Error::other(format!("unsupported transport: {}", other)),
636        )),
637    }
638}
639
640/// Frames are the messages sent between clients and servers over sessions.
641#[derive(Debug, Serialize, Deserialize, EnumAsInner, PartialEq)]
642pub(super) enum Frame<M> {
643    /// Send a message with the provided sequence number.
644    Message(u64, M),
645}
646
647#[derive(Debug, Serialize, Deserialize, EnumAsInner)]
648pub(super) enum NetRxResponse {
649    Ack(u64),
650    /// This session is rejected with the given reason. NetTx should stop reconnecting.
651    Reject(String),
652    /// This channel is closed.
653    Closed,
654}
655
656pub(super) fn serialize_response(response: NetRxResponse) -> Result<Bytes, bincode::Error> {
657    bincode::serialize(&response).map(|bytes| bytes.into())
658}
659
660pub(super) fn deserialize_response(data: Bytes) -> Result<NetRxResponse, bincode::Error> {
661    bincode::deserialize(&data)
662}
663
664/// A Tx implemented on top of a Link. The Tx manages the link state,
665/// reconnections, etc.
666pub(crate) struct NetTx<M: RemoteMessage> {
667    sender: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
668    dest: ChannelAddr,
669    status: watch::Receiver<TxStatus>,
670}
671
672#[async_trait]
673impl<M: RemoteMessage> Tx<M> for NetTx<M> {
674    fn addr(&self) -> ChannelAddr {
675        self.dest.clone()
676    }
677
678    fn status(&self) -> &watch::Receiver<TxStatus> {
679        &self.status
680    }
681
682    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
683        tracing::trace!(
684            name = "post",
685            dest = %self.dest,
686            "sending message"
687        );
688
689        let return_channel = return_channel.unwrap_or_else(|| oneshot::channel().0);
690        if let Err(mpsc::error::SendError((message, return_channel, _))) =
691            self.sender
692                .send((message, return_channel, tokio::time::Instant::now()))
693        {
694            let reason = self.status.borrow().as_closed().map(|r| r.to_string());
695            let _ = return_channel.send(SendError {
696                error: ChannelError::Closed,
697                message,
698                reason,
699            });
700        }
701    }
702}
703
704pub struct NetRx<M: RemoteMessage>(mpsc::Receiver<M>, ChannelAddr, ServerHandle);
705
706#[async_trait]
707impl<M: RemoteMessage> Rx<M> for NetRx<M> {
708    async fn recv(&mut self) -> Result<M, ChannelError> {
709        tracing::trace!(
710            name = "recv",
711            dest = %self.1,
712            "receiving message"
713        );
714        self.0.recv().await.ok_or(ChannelError::Closed)
715    }
716
717    fn addr(&self) -> ChannelAddr {
718        self.1.clone()
719    }
720
721    /// Gracefully shut down the channel server, waiting for pending
722    /// acks to be flushed before returning.
723    async fn join(mut self) {
724        self.2
725            .stop(&format!("NetRx joined; channel address: {}", self.1));
726        let _ = (&mut self.2).await;
727        // Drop will call stop() again which is harmless (token already cancelled).
728    }
729}
730
731impl<M: RemoteMessage> Drop for NetRx<M> {
732    fn drop(&mut self) {
733        self.2
734            .stop(&format!("NetRx dropped; channel address: {}", self.1));
735    }
736}
737
738/// Error returned during server operations.
739#[derive(Debug, thiserror::Error)]
740pub enum ServerError {
741    #[error("io: {1}")]
742    Io(ChannelAddr, #[source] std::io::Error),
743    #[error("listen: {0} {1}")]
744    Listen(ChannelAddr, #[source] std::io::Error),
745    #[error("resolve: {0} {1}")]
746    Resolve(ChannelAddr, #[source] std::io::Error),
747    #[error("internal: {0} {1}")]
748    Internal(ChannelAddr, #[source] anyhow::Error),
749}
750
751#[derive(thiserror::Error, Debug)]
752pub enum ClientError {
753    #[error("connection to {0} failed: {1}: {2}")]
754    Connect(ChannelAddr, std::io::Error, String),
755    #[error("unable to resolve address: {0}")]
756    Resolve(ChannelAddr),
757    #[error("io: {0} {1}")]
758    Io(ChannelAddr, std::io::Error),
759    #[error("send {0}: serialize: {1}")]
760    Serialize(ChannelAddr, bincode::ErrorKind),
761    #[error("invalid address: {0}")]
762    InvalidAddress(String),
763}
764
765/// Tells whether the address is a 'net' address. These currently have different semantics
766/// from local transports.
767#[cfg(test)]
768pub(super) fn is_net_addr(addr: &ChannelAddr) -> bool {
769    match addr.transport() {
770        ChannelTransport::Tcp(_) => true,
771        ChannelTransport::MetaTls(_) => true,
772        ChannelTransport::Tls => true,
773        ChannelTransport::Unix => true,
774        _ => false,
775    }
776}
777
778pub(crate) mod unix {
779
780    use core::str;
781    use std::os::unix::net::SocketAddr as StdSocketAddr;
782    use std::os::unix::net::UnixStream as StdUnixStream;
783
784    use rand::Rng;
785    use rand::distr::Alphanumeric;
786    use tokio::net::UnixListener;
787    use tokio::net::UnixStream;
788
789    use super::*;
790
791    #[derive(Debug)]
792    pub(crate) struct UnixLink {
793        pub(super) addr: SocketAddr,
794        pub(super) session_id: SessionId,
795    }
796
797    #[async_trait]
798    impl Link for UnixLink {
799        type Stream = UnixStream;
800
801        fn dest(&self) -> ChannelAddr {
802            ChannelAddr::Unix(self.addr.clone())
803        }
804
805        fn link_id(&self) -> SessionId {
806            self.session_id
807        }
808
809        async fn next(&self) -> Result<Self::Stream, ClientError> {
810            let session_id = self.session_id;
811            let sock_addr = match &self.addr {
812                SocketAddr::Bound(a) => a,
813                SocketAddr::Unbound => return Err(ClientError::Resolve(self.dest())),
814            };
815            let mut backoff = ExponentialBackoffBuilder::new()
816                .with_initial_interval(Duration::from_millis(1))
817                .with_multiplier(2.0)
818                .with_randomization_factor(0.1)
819                .with_max_interval(Duration::from_millis(1000))
820                .with_max_elapsed_time(None)
821                .build();
822            loop {
823                match StdUnixStream::connect_addr(sock_addr) {
824                    Ok(std_stream) => {
825                        std_stream
826                            .set_nonblocking(true)
827                            .map_err(|err| ClientError::Io(self.dest(), err))?;
828                        let mut stream = UnixStream::from_std(std_stream)
829                            .map_err(|err| ClientError::Io(self.dest(), err))?;
830                        write_link_init(&mut stream, session_id)
831                            .await
832                            .map_err(|err| ClientError::Io(self.dest(), err))?;
833                        return Ok(stream);
834                    }
835                    Err(err) => {
836                        tracing::debug!(error = %err, "unix connect failed, backing off");
837                        if let Some(delay) = backoff.next_backoff() {
838                            tokio::time::sleep(delay).await;
839                        }
840                    }
841                }
842            }
843        }
844    }
845
846    /// Server-side listener for Unix domain sockets.
847    #[derive(Debug)]
848    pub(crate) struct UnixSocketListener {
849        pub(super) inner: UnixListener,
850        pub(super) addr: SocketAddr,
851    }
852
853    #[async_trait]
854    impl super::Listener for UnixSocketListener {
855        type Stream = UnixStream;
856
857        async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
858            let (stream, peer_addr) = self
859                .inner
860                .accept()
861                .await
862                .map_err(|err| ServerError::Io(ChannelAddr::Unix(self.addr.clone()), err))?;
863            // tokio::net::unix::SocketAddr -> std::os::unix::net::SocketAddr
864            let std_addr: StdSocketAddr = peer_addr.into();
865            Ok((stream, ChannelAddr::Unix(SocketAddr::new(std_addr))))
866        }
867    }
868
869    /// Create a unix link to the given socket address.
870    pub(crate) fn link(addr: SocketAddr) -> UnixLink {
871        UnixLink {
872            addr,
873            session_id: SessionId::random(),
874        }
875    }
876
877    /// Wrapper around std-lib's unix::SocketAddr that lets us implement equality functions
878    #[derive(Clone, Debug)]
879    pub enum SocketAddr {
880        Bound(Box<StdSocketAddr>),
881        Unbound,
882    }
883
884    impl PartialOrd for SocketAddr {
885        fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
886            Some(self.cmp(other))
887        }
888    }
889
890    impl Ord for SocketAddr {
891        fn cmp(&self, other: &Self) -> std::cmp::Ordering {
892            self.to_string().cmp(&other.to_string())
893        }
894    }
895
896    impl<'de> Deserialize<'de> for SocketAddr {
897        fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
898        where
899            D: serde::Deserializer<'de>,
900        {
901            let s = String::deserialize(deserializer)?;
902            Self::from_str(&s).map_err(D::Error::custom)
903        }
904    }
905
906    impl Serialize for SocketAddr {
907        fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
908        where
909            S: serde::Serializer,
910        {
911            serializer.serialize_str(String::from(self).as_str())
912        }
913    }
914
915    impl From<&SocketAddr> for String {
916        fn from(value: &SocketAddr) -> Self {
917            match value {
918                SocketAddr::Bound(addr) => match addr.as_pathname() {
919                    Some(path) => path
920                        .to_str()
921                        .expect("unable to get str for path")
922                        .to_string(),
923                    #[cfg(target_os = "linux")]
924                    _ => match addr.as_abstract_name() {
925                        Some(name) => format!("@{}", String::from_utf8_lossy(name)),
926                        _ => String::from("(unnamed)"),
927                    },
928                    #[cfg(not(target_os = "linux"))]
929                    _ => String::from("(unnamed)"),
930                },
931                SocketAddr::Unbound => String::from("(unbound)"),
932            }
933        }
934    }
935
936    impl FromStr for SocketAddr {
937        type Err = anyhow::Error;
938
939        fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
940            match s {
941                "" => {
942                    // TODO: ensure this socket doesn't already exist. 24 bytes of randomness should be good for now but is not perfect.
943                    // We can't use annon sockets because those are not valid across processes that aren't in the same process hierarchy aka forked.
944                    let random_string = rand::rng()
945                        .sample_iter(&Alphanumeric)
946                        .take(24)
947                        .map(char::from)
948                        .collect::<String>();
949                    SocketAddr::from_abstract_name(&random_string)
950                }
951                // by convention, named sockets are displayed with an '@' prefix
952                name if name.starts_with("@") => {
953                    SocketAddr::from_abstract_name(name.strip_prefix("@").unwrap())
954                }
955                path => SocketAddr::from_pathname(path),
956            }
957        }
958    }
959
960    impl Eq for SocketAddr {}
961    impl std::hash::Hash for SocketAddr {
962        fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
963            String::from(self).hash(state);
964        }
965    }
966    impl PartialEq for SocketAddr {
967        fn eq(&self, other: &Self) -> bool {
968            match (self, other) {
969                (Self::Bound(saddr), Self::Bound(oaddr)) => {
970                    if saddr.is_unnamed() || oaddr.is_unnamed() {
971                        return false;
972                    }
973
974                    #[cfg(target_os = "linux")]
975                    {
976                        saddr.as_pathname() == oaddr.as_pathname()
977                            && saddr.as_abstract_name() == oaddr.as_abstract_name()
978                    }
979                    #[cfg(not(target_os = "linux"))]
980                    {
981                        // On non-Linux platforms, only compare pathname since no abstract names
982                        saddr.as_pathname() == oaddr.as_pathname()
983                    }
984                }
985                (Self::Unbound, _) | (_, Self::Unbound) => false,
986            }
987        }
988    }
989
990    impl fmt::Display for SocketAddr {
991        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
992            match self {
993                Self::Bound(addr) => match addr.as_pathname() {
994                    Some(path) => {
995                        write!(f, "{}", path.to_string_lossy())
996                    }
997                    #[cfg(target_os = "linux")]
998                    _ => match addr.as_abstract_name() {
999                        Some(name) => {
1000                            if name.starts_with(b"@") {
1001                                return write!(f, "{}", String::from_utf8_lossy(name));
1002                            }
1003                            write!(f, "@{}", String::from_utf8_lossy(name))
1004                        }
1005                        _ => write!(f, "(unnamed)"),
1006                    },
1007                    #[cfg(not(target_os = "linux"))]
1008                    _ => write!(f, "(unnamed)"),
1009                },
1010                Self::Unbound => write!(f, "(unbound)"),
1011            }
1012        }
1013    }
1014
1015    impl SocketAddr {
1016        /// Wraps the stdlib socket address for use with this module
1017        pub fn new(addr: StdSocketAddr) -> Self {
1018            Self::Bound(Box::new(addr))
1019        }
1020
1021        /// Abstract socket names start with a "@" by convention when displayed. If there is an
1022        /// "@" prefix, it will be stripped from the name before used.
1023        #[cfg(target_os = "linux")]
1024        pub fn from_abstract_name(name: &str) -> anyhow::Result<Self> {
1025            Ok(Self::new(StdSocketAddr::from_abstract_name(
1026                name.strip_prefix("@").unwrap_or(name),
1027            )?))
1028        }
1029
1030        #[cfg(not(target_os = "linux"))]
1031        pub fn from_abstract_name(name: &str) -> anyhow::Result<Self> {
1032            // On non-Linux platforms, convert abstract names to filesystem paths
1033            let name = name.strip_prefix("@").unwrap_or(name);
1034            let path = Self::abstract_to_filesystem_path(name);
1035            Self::from_pathname(&path.to_string_lossy())
1036        }
1037
1038        #[cfg(not(target_os = "linux"))]
1039        fn abstract_to_filesystem_path(abstract_name: &str) -> std::path::PathBuf {
1040            use std::collections::hash_map::DefaultHasher;
1041            use std::hash::Hash;
1042            use std::hash::Hasher;
1043
1044            // Generate a stable hash of the abstract name for deterministic paths
1045            let mut hasher = DefaultHasher::new();
1046            abstract_name.hash(&mut hasher);
1047            let hash = hasher.finish();
1048
1049            // Include process ID to prevent inter-process conflicts
1050            let process_id = std::process::id();
1051
1052            // TODO: we just leak these. Should we do something smarter?
1053            std::path::PathBuf::from(format!("/tmp/hyperactor_{}_{:x}", process_id, hash))
1054        }
1055
1056        /// Pathnames may be absolute or relative.
1057        pub fn from_pathname(name: &str) -> anyhow::Result<Self> {
1058            Ok(Self::new(StdSocketAddr::from_pathname(name)?))
1059        }
1060    }
1061
1062    impl TryFrom<SocketAddr> for StdSocketAddr {
1063        type Error = anyhow::Error;
1064
1065        fn try_from(value: SocketAddr) -> Result<Self, Self::Error> {
1066            match value {
1067                SocketAddr::Bound(addr) => Ok(*addr),
1068                SocketAddr::Unbound => Err(anyhow::anyhow!(
1069                    "std::os::unix::SocketAddr must be a bound address"
1070                )),
1071            }
1072        }
1073    }
1074}
1075
1076pub(crate) mod tcp {
1077    use tokio::net::TcpListener;
1078    use tokio::net::TcpStream;
1079
1080    use super::*;
1081
1082    #[derive(Debug)]
1083    pub(crate) struct TcpLink {
1084        pub(super) addr: SocketAddr,
1085        pub(super) session_id: SessionId,
1086    }
1087
1088    #[async_trait]
1089    impl Link for TcpLink {
1090        type Stream = TcpStream;
1091
1092        fn dest(&self) -> ChannelAddr {
1093            ChannelAddr::Tcp(self.addr)
1094        }
1095
1096        fn link_id(&self) -> SessionId {
1097            self.session_id
1098        }
1099
1100        async fn next(&self) -> Result<Self::Stream, ClientError> {
1101            let session_id = self.session_id;
1102            let mut backoff = ExponentialBackoffBuilder::new()
1103                .with_initial_interval(Duration::from_millis(1))
1104                .with_multiplier(2.0)
1105                .with_randomization_factor(0.1)
1106                .with_max_interval(Duration::from_millis(1000))
1107                .with_max_elapsed_time(None)
1108                .build();
1109            loop {
1110                match TcpStream::connect(&self.addr).await {
1111                    Ok(mut stream) => {
1112                        stream.set_nodelay(true).map_err(|err| {
1113                            ClientError::Connect(
1114                                self.dest(),
1115                                err,
1116                                "cannot disable Nagle algorithm".to_string(),
1117                            )
1118                        })?;
1119                        write_link_init(&mut stream, session_id)
1120                            .await
1121                            .map_err(|err| ClientError::Io(self.dest(), err))?;
1122                        return Ok(stream);
1123                    }
1124                    Err(err) => {
1125                        tracing::debug!(error = %err, "tcp connect failed, backing off");
1126                        if let Some(delay) = backoff.next_backoff() {
1127                            tokio::time::sleep(delay).await;
1128                        }
1129                    }
1130                }
1131            }
1132        }
1133    }
1134
1135    /// Server-side listener for TCP sockets.
1136    #[derive(Debug)]
1137    pub(crate) struct TcpSocketListener {
1138        pub(super) inner: TcpListener,
1139        pub(super) addr: SocketAddr,
1140    }
1141
1142    #[async_trait]
1143    impl super::Listener for TcpSocketListener {
1144        type Stream = TcpStream;
1145
1146        async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
1147            let (stream, peer_addr) = self
1148                .inner
1149                .accept()
1150                .await
1151                .map_err(|err| ServerError::Io(ChannelAddr::Tcp(self.addr), err))?;
1152            stream
1153                .set_nodelay(true)
1154                .map_err(|err| ServerError::Io(ChannelAddr::Tcp(self.addr), err))?;
1155            Ok((stream, ChannelAddr::Tcp(peer_addr)))
1156        }
1157    }
1158
1159    /// Create a TCP link to the given socket address.
1160    pub(crate) fn link(addr: SocketAddr) -> TcpLink {
1161        TcpLink {
1162            addr,
1163            session_id: SessionId::random(),
1164        }
1165    }
1166}
1167
1168// TODO: Try to simplify the TLS creation T208304433
1169pub(crate) mod meta {
1170    use std::io;
1171    use std::path::PathBuf;
1172    use std::sync::Arc;
1173
1174    use anyhow::Result;
1175    use tokio_rustls::TlsAcceptor;
1176    use tokio_rustls::TlsConnector;
1177
1178    use super::*;
1179    use crate::config::Pem;
1180    use crate::config::PemBundle;
1181
1182    const THRIFT_TLS_SRV_CA_PATH_ENV: &str = "THRIFT_TLS_SRV_CA_PATH";
1183    const DEFAULT_SRV_CA_PATH: &str = "/var/facebook/rootcanal/ca.pem";
1184    const THRIFT_TLS_CL_CERT_PATH_ENV: &str = "THRIFT_TLS_CL_CERT_PATH";
1185    const THRIFT_TLS_CL_KEY_PATH_ENV: &str = "THRIFT_TLS_CL_KEY_PATH";
1186    const DEFAULT_SERVER_PEM_PATH: &str = "/var/facebook/x509_identities/server.pem";
1187
1188    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
1189    pub(crate) fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
1190        // Use right split to allow for ipv6 addresses where ":" is expected.
1191        let parts = addr_string.rsplit_once(":");
1192        match parts {
1193            Some((hostname, port_str)) => {
1194                let Ok(port) = port_str.parse() else {
1195                    return Err(ChannelError::InvalidAddress(addr_string.to_string()));
1196                };
1197                Ok(ChannelAddr::MetaTls(TlsAddr::new(hostname, port)))
1198            }
1199            _ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
1200        }
1201    }
1202
1203    /// Construct a PemBundle for server operations from Meta-specific paths.
1204    /// Server cert and key come from the same file (server.pem).
1205    pub(super) fn get_server_pem_bundle() -> PemBundle {
1206        let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1207            .map(PathBuf::from)
1208            .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1209        let server_pem_path = PathBuf::from(DEFAULT_SERVER_PEM_PATH);
1210        PemBundle {
1211            ca: Pem::File(ca_path),
1212            cert: Pem::File(server_pem_path.clone()),
1213            key: Pem::File(server_pem_path),
1214        }
1215    }
1216
1217    /// Construct a PemBundle for client operations from Meta-specific env vars.
1218    /// Returns None if client cert/key env vars are not set.
1219    fn get_client_pem_bundle() -> Option<PemBundle> {
1220        let cert_path = std::env::var_os(THRIFT_TLS_CL_CERT_PATH_ENV)?;
1221        let key_path = std::env::var_os(THRIFT_TLS_CL_KEY_PATH_ENV)?;
1222        let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1223            .map(PathBuf::from)
1224            .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1225        Some(PemBundle {
1226            ca: Pem::File(ca_path),
1227            cert: Pem::File(PathBuf::from(cert_path)),
1228            key: Pem::File(PathBuf::from(key_path)),
1229        })
1230    }
1231
1232    /// Creates a TLS acceptor by looking for necessary certs and keys in a Meta server environment.
1233    pub(crate) fn tls_acceptor(enforce_client_tls: bool) -> Result<TlsAcceptor> {
1234        let bundle = get_server_pem_bundle();
1235        tls::tls_acceptor_from_bundle(&bundle, enforce_client_tls)
1236    }
1237
1238    /// Try to create a TLS connector for Meta environments.
1239    ///
1240    /// Returns `Ok` when the root CA is present (optional client certs
1241    /// are added when `THRIFT_TLS_CL_CERT_PATH` / `THRIFT_TLS_CL_KEY_PATH`
1242    /// are set).
1243    pub(super) fn try_tls_connector() -> Result<TlsConnector> {
1244        tls_connector()
1245    }
1246
1247    /// Creates a TLS connector by looking for necessary certs and keys in a Meta server environment.
1248    /// Supports optional client authentication (unlike the tls module which always requires it).
1249    fn tls_connector() -> Result<TlsConnector> {
1250        // Ensure ring is installed as the process-level crypto provider.
1251        // No-op when already installed (e.g. under Buck with native-tls).
1252        let _ = rustls::crypto::ring::default_provider().install_default();
1253
1254        let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1255            .map(PathBuf::from)
1256            .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1257        let ca_pem = Pem::File(ca_path);
1258        let root_store = tls::build_root_store(&ca_pem)?;
1259
1260        // If client certs are available, use mutual TLS; otherwise, no client auth
1261        let config = rustls::ClientConfig::builder().with_root_certificates(Arc::new(root_store));
1262
1263        let config = if let Some(bundle) = get_client_pem_bundle() {
1264            let certs = tls::load_certs(&bundle.cert)?;
1265            let key = tls::load_key(&bundle.key)?;
1266            config
1267                .with_client_auth_cert(certs, key)
1268                .map_err(|e| anyhow::anyhow!("load client certs: {}", e))?
1269        } else {
1270            config.with_no_client_auth()
1271        };
1272
1273        Ok(TlsConnector::from(Arc::new(config)))
1274    }
1275
1276    /// Create a MetaTLS link to the given address.
1277    pub fn link(addr: TlsAddr) -> Result<tls::TlsLink, ClientError> {
1278        let connector = tls_connector().map_err(|e| {
1279            ClientError::Connect(
1280                ChannelAddr::MetaTls(addr.clone()),
1281                io::Error::other(e.to_string()),
1282                "failed to create TLS connector".to_string(),
1283            )
1284        })?;
1285        let TlsAddr { hostname, port } = addr;
1286        Ok(tls::TlsLink {
1287            hostname,
1288            port,
1289            connector,
1290            addr_type: tls::TlsAddrType::MetaTls,
1291            session_id: SessionId::random(),
1292        })
1293    }
1294}
1295
1296/// TLS transport module using configurable certificates via hyperactor config attributes.
1297pub(crate) mod tls {
1298    use std::io;
1299    use std::io::BufReader;
1300    use std::sync::Arc;
1301
1302    use anyhow::Context;
1303    use anyhow::Result;
1304    use rustls::RootCertStore;
1305    use rustls::pki_types::CertificateDer;
1306    use rustls::pki_types::PrivateKeyDer;
1307    use rustls::pki_types::ServerName;
1308    use tokio::net::TcpStream;
1309    use tokio_rustls::TlsAcceptor;
1310    use tokio_rustls::TlsConnector;
1311    use tokio_rustls::client::TlsStream;
1312
1313    use super::*;
1314    use crate::channel::TlsAddr;
1315    use crate::config::Pem;
1316    use crate::config::PemBundle;
1317    use crate::config::TLS_CA;
1318    use crate::config::TLS_CERT;
1319    use crate::config::TLS_KEY;
1320
1321    /// Distinguishes between Tls and MetaTls for address construction.
1322    #[derive(Debug, Clone, Copy)]
1323    pub(crate) enum TlsAddrType {
1324        Tls,
1325        MetaTls,
1326    }
1327
1328    /// Parse an address string into a TlsAddr.
1329    #[allow(clippy::result_large_err)]
1330    pub(crate) fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
1331        // Use right split to allow for ipv6 addresses where ":" is expected.
1332        let parts = addr_string.rsplit_once(":");
1333        match parts {
1334            Some((hostname, port_str)) => {
1335                let Ok(port) = port_str.parse() else {
1336                    return Err(ChannelError::InvalidAddress(addr_string.to_string()));
1337                };
1338                Ok(ChannelAddr::Tls(TlsAddr::new(hostname, port)))
1339            }
1340            _ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
1341        }
1342    }
1343
1344    /// Load certificates from a Pem value.
1345    pub(super) fn load_certs(pem: &Pem) -> Result<Vec<CertificateDer<'static>>> {
1346        let mut reader = BufReader::new(pem.reader()?);
1347        let certs = rustls_pemfile::certs(&mut reader)
1348            .filter_map(Result::ok)
1349            .collect();
1350        Ok(certs)
1351    }
1352
1353    /// Load a private key from a Pem value.
1354    pub(super) fn load_key(pem: &Pem) -> Result<PrivateKeyDer<'static>> {
1355        let mut reader = BufReader::new(pem.reader()?);
1356        loop {
1357            break match rustls_pemfile::read_one(&mut reader)? {
1358                Some(rustls_pemfile::Item::Pkcs1Key(key)) => Ok(PrivateKeyDer::Pkcs1(key)),
1359                Some(rustls_pemfile::Item::Pkcs8Key(key)) => Ok(PrivateKeyDer::Pkcs8(key)),
1360                Some(rustls_pemfile::Item::Sec1Key(key)) => Ok(PrivateKeyDer::Sec1(key)),
1361                Some(_) => continue,
1362                None => anyhow::bail!("no private key found in TLS key file"),
1363            };
1364        }
1365    }
1366
1367    /// Build root certificate store from the CA pem.
1368    pub(super) fn build_root_store(ca_pem: &Pem) -> Result<RootCertStore> {
1369        let mut root_store = RootCertStore::empty();
1370        let certs = load_certs(ca_pem)?;
1371        root_store.add_parsable_certificates(certs);
1372        Ok(root_store)
1373    }
1374
1375    /// Get the PEM bundle from configuration.
1376    fn get_pem_bundle() -> PemBundle {
1377        PemBundle {
1378            ca: hyperactor_config::global::get_cloned(TLS_CA),
1379            cert: hyperactor_config::global::get_cloned(TLS_CERT),
1380            key: hyperactor_config::global::get_cloned(TLS_KEY),
1381        }
1382    }
1383
1384    /// Creates a TLS acceptor using certificates from the provided PEM bundle.
1385    /// If `enforce_client_tls` is true, requires client certificates for mutual TLS.
1386    pub(super) fn tls_acceptor_from_bundle(
1387        bundle: &PemBundle,
1388        enforce_client_tls: bool,
1389    ) -> Result<TlsAcceptor> {
1390        // Ensure ring is installed as the process-level crypto provider.
1391        // No-op when already installed (e.g. under Buck with native-tls).
1392        let _ = rustls::crypto::ring::default_provider().install_default();
1393
1394        let certs = load_certs(&bundle.cert).context("load TLS certificate")?;
1395        let key = load_key(&bundle.key).context("load TLS key")?;
1396        let root_store = build_root_store(&bundle.ca).context("build root cert store")?;
1397
1398        let config = rustls::ServerConfig::builder();
1399        let config = if enforce_client_tls {
1400            // Build server config with mutual TLS (require client certs)
1401            let client_verifier =
1402                rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
1403                    .build()
1404                    .map_err(|e| anyhow::anyhow!("failed to build client verifier: {}", e))?;
1405            config.with_client_cert_verifier(client_verifier)
1406        } else {
1407            config.with_no_client_auth()
1408        }
1409        .with_single_cert(certs, key)?;
1410
1411        Ok(TlsAcceptor::from(Arc::new(config)))
1412    }
1413
1414    /// Creates a TLS acceptor using certificates from config (always enforces mutual TLS).
1415    pub(crate) fn tls_acceptor() -> Result<TlsAcceptor> {
1416        tls_acceptor_from_bundle(&get_pem_bundle(), true)
1417    }
1418
1419    /// Creates a TLS connector using certificates from the provided PEM bundle.
1420    pub(super) fn tls_connector_from_bundle(bundle: &PemBundle) -> Result<TlsConnector> {
1421        // Ensure ring is installed as the process-level crypto provider.
1422        // No-op when already installed (e.g. under Buck with native-tls).
1423        let _ = rustls::crypto::ring::default_provider().install_default();
1424
1425        let certs = load_certs(&bundle.cert).context("load TLS certificate")?;
1426        let key = load_key(&bundle.key).context("load TLS key")?;
1427        let root_store = build_root_store(&bundle.ca).context("build root cert store")?;
1428
1429        let config = rustls::ClientConfig::builder()
1430            .with_root_certificates(Arc::new(root_store))
1431            .with_client_auth_cert(certs, key)
1432            .context("configure client auth")?;
1433
1434        Ok(TlsConnector::from(Arc::new(config)))
1435    }
1436
1437    /// Creates a TLS connector using certificates from config.
1438    fn tls_connector() -> Result<TlsConnector> {
1439        tls_connector_from_bundle(&get_pem_bundle())
1440    }
1441
1442    /// Shared TLS link implementation used by both tls and metatls transports.
1443    pub(crate) struct TlsLink {
1444        pub(crate) hostname: Hostname,
1445        pub(crate) port: Port,
1446        pub(crate) connector: TlsConnector,
1447        pub(crate) addr_type: TlsAddrType,
1448        pub(crate) session_id: SessionId,
1449    }
1450
1451    impl std::fmt::Debug for TlsLink {
1452        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1453            f.debug_struct("TlsLink")
1454                .field("hostname", &self.hostname)
1455                .field("port", &self.port)
1456                .field("addr_type", &self.addr_type)
1457                .finish()
1458        }
1459    }
1460
1461    #[async_trait]
1462    impl Link for TlsLink {
1463        type Stream = TlsStream<TcpStream>;
1464
1465        fn dest(&self) -> ChannelAddr {
1466            let addr = TlsAddr::new(self.hostname.clone(), self.port);
1467            match self.addr_type {
1468                TlsAddrType::Tls => ChannelAddr::Tls(addr),
1469                TlsAddrType::MetaTls => ChannelAddr::MetaTls(addr),
1470            }
1471        }
1472
1473        fn link_id(&self) -> SessionId {
1474            self.session_id
1475        }
1476
1477        async fn next(&self) -> Result<Self::Stream, ClientError> {
1478            let session_id = self.session_id;
1479            let server_name = ServerName::try_from(self.hostname.clone()).map_err(|e| {
1480                ClientError::Connect(
1481                    self.dest(),
1482                    io::Error::other(e.to_string()),
1483                    "invalid server name".to_string(),
1484                )
1485            })?;
1486            let mut backoff = ExponentialBackoffBuilder::new()
1487                .with_initial_interval(Duration::from_millis(1))
1488                .with_multiplier(2.0)
1489                .with_randomization_factor(0.1)
1490                .with_max_interval(Duration::from_millis(1000))
1491                .with_max_elapsed_time(None)
1492                .build();
1493            loop {
1494                let mut addrs = (self.hostname.as_ref(), self.port)
1495                    .to_socket_addrs()
1496                    .map_err(|_| ClientError::Resolve(self.dest()))?;
1497                let addr = addrs.next().ok_or(ClientError::Resolve(self.dest()))?;
1498                match TcpStream::connect(&addr).await {
1499                    Ok(stream) => {
1500                        stream.set_nodelay(true).map_err(|err| {
1501                            ClientError::Connect(
1502                                self.dest(),
1503                                err,
1504                                "cannot disable Nagle algorithm".to_string(),
1505                            )
1506                        })?;
1507                        let mut tls_stream = self
1508                            .connector
1509                            .connect(server_name.clone(), stream)
1510                            .await
1511                            .map_err(|err| {
1512                                tracing::info!(
1513                                    dest = %self.dest(),
1514                                    error = %err,
1515                                    "TLS handshake failed"
1516                                );
1517                                ClientError::Connect(
1518                                    self.dest(),
1519                                    err,
1520                                    format!("cannot establish TLS connection to {:?}", server_name),
1521                                )
1522                            })?;
1523                        write_link_init(&mut tls_stream, session_id)
1524                            .await
1525                            .map_err(|err| ClientError::Io(self.dest(), err))?;
1526                        return Ok(tls_stream);
1527                    }
1528                    Err(err) => {
1529                        tracing::debug!(error = %err, "tls connect failed, backing off");
1530                        if let Some(delay) = backoff.next_backoff() {
1531                            tokio::time::sleep(delay).await;
1532                        }
1533                    }
1534                }
1535            }
1536        }
1537    }
1538
1539    /// Create a TLS link to the given address.
1540    pub fn link(addr: TlsAddr) -> Result<TlsLink, ClientError> {
1541        let connector = tls_connector().map_err(|e| {
1542            ClientError::Connect(
1543                ChannelAddr::Tls(addr.clone()),
1544                io::Error::other(e.to_string()),
1545                "failed to create TLS connector".to_string(),
1546            )
1547        })?;
1548        let TlsAddr { hostname, port } = addr;
1549        Ok(TlsLink {
1550            hostname,
1551            port,
1552            connector,
1553            addr_type: TlsAddrType::Tls,
1554            session_id: SessionId::random(),
1555        })
1556    }
1557
1558    #[cfg(test)]
1559    mod tests {
1560        use timed_test::async_timed_test;
1561
1562        use super::*;
1563        use crate::channel::Rx;
1564        use crate::channel::net::server;
1565        use crate::config::Pem;
1566        use crate::config::TLS_CA;
1567        use crate::config::TLS_CERT;
1568        use crate::config::TLS_KEY;
1569
1570        // Dummy test certificates generated with openssl for testing only.
1571        // These certificates include Subject Alternative Names (SAN) for localhost, 127.0.0.1, and ::1
1572        // CA certificate
1573        const TEST_CA_CERT: &str = r#"-----BEGIN CERTIFICATE-----
1574MIIDBTCCAe2gAwIBAgIUaGNmboiIosG+8Up0vgDr/+cg+2IwDQYJKoZIhvcNAQEL
1575BQAwEjEQMA4GA1UEAwwHVGVzdCBDQTAeFw0yNjAxMjgxNzA4MzlaFw0yNzAxMjgx
1576NzA4MzlaMBIxEDAOBgNVBAMMB1Rlc3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IB
1577DwAwggEKAoIBAQC9RBoMYXCajklswt8Vi1JI1lEYzic0WNOmz45vG/7H6jTWkgL3
1578K5Ri+Seg3MobDNc48YHWXYm4hP9wCzkx8ih3ntT5XiY1My/G3jLUuoIEE9pF/BoJ
1579YQwZVoPNFhA9WhXNRsINf1cXFf8NzRfXpxBfKWtQJxYXU4JiDBQ6rLnQQABo8JmQ
1580vYFhJbBaYip5jTSiVNn7mB1zNr5jsVxuoSF53Pb7xQ76bwBdOq4zd6PSxL5/lr4G
1581cHSoxwZQdZMG7PL6hbxDQ2S2YI2lYVET1zwc2WPKCfjbEXBC/jzx828CInQtuksk
158218gJt6xHkTFEA8CSA29GM3lejnwYWf51xyyBAgMBAAGjUzBRMB0GA1UdDgQWBBRX
1583cbxSZ70NsUkAS3Hhy6irugywJDAfBgNVHSMEGDAWgBRXcbxSZ70NsUkAS3Hhy6ir
1584ugywJDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQA7aAFfyW67
1585Z+uGSVYhpsT/uH/3Z3nr7X1smTz5CGEfq2czEcTC7gbYI2l8GZ47GPfnAvHTBZVm
1586V/XncBCsj7/thOh2jYEHFyCbPckoaSCRyCOnK7LPUlr4HN5uP9EFe45qBLCJDEoY
1587GTTw7MtzwdovfjchNfKQCTtkBJCXQ95WLCf6UOh02Sn28UTlgfXzF0X0FrcWqWa3
1588uJZd4XOo4O6hKKlHaBaQPiEr++1xc3SWPV7jZHbckI/vKBnDdEZ9JQX5fFZuypUI
1589sgomYHxvxrU2hWx+7k53CRdjfaIvT9Ie44z9sSdsU/+blw2S8f/ZTmuECoIAAXYO
15900qpzlxZMdr7T
1591-----END CERTIFICATE-----"#;
1592
1593        // Server certificate (signed by CA) with SAN for localhost, 127.0.0.1, ::1
1594        const TEST_SERVER_CERT: &str = r#"-----BEGIN CERTIFICATE-----
1595MIIDJDCCAgygAwIBAgIUaz66DsWaH5ZXM4hCFnbVbMsyN1cwDQYJKoZIhvcNAQEL
1596BQAwEjEQMA4GA1UEAwwHVGVzdCBDQTAeFw0yNjAxMjgxNzA4MzlaFw0yNzAxMjgx
1597NzA4MzlaMBQxEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQAD
1598ggEPADCCAQoCggEBAKCbp++qNyTn5LOsV0h9gLKJALBcjg2A14I3804N9UyDhPW2
1599QKQ2W424u2P1MfKrw/2C+CErGlrADlnco2RQVDAarAIuGdFvBOt5UezqOS7Mk4OS
16009MlS7NZnMbc37KuM9UIG5ScJjXR/Z5z9dxeR0I9y3n0Ix6khbV7tOSHobiweI0FI
16018LftBS+CQnXr6vbWPcHcW6Z0FHUv7IWhqMWmv9PlZRGe9Y6VzXrRp0PBnZMOnAYf
1602aMQUwYRswWdm9j9Z1sMdTJ14G+KVmO3Vj6XI6Sm9uIcYhlwG/kORwogJFWlVuP9o
1603rloFRCjyHJ1d7GZqqnRyHHDDCBms8ed+3YfEYQECAwEAAaNwMG4wLAYDVR0RBCUw
1604I4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMB0GA1UdDgQWBBQl
1605J4vxUoCzqqeTwQAiLqE8wYezKzAfBgNVHSMEGDAWgBRXcbxSZ70NsUkAS3Hhy6ir
1606ugywJDANBgkqhkiG9w0BAQsFAAOCAQEAnXHIBDQ4AHAMV71piTOuI41ShASQed6L
1607bi7XUMZgZDslLkfU1vnP3BlwpliraBsAytSYQC6kbytOuz1uQ4K7yLb2tAAmUgEO
1608EdIVt9SXr5tCcIPeLmInF0pysPqjZO8n7vtJyd9gryKqdhm1uzA7WQWq/Az8a9Sk
1609uW2J6Oc5p6P7Mf3/ixqXzvGRo8rzu0CUJOJ67UTE/HhbJuplQ5dep5CEEOAIsAtH
1610zn9O4rW92ueBkoBJM++YILS1vQ7jKc2N3RNrnHm7FeootBrtR9mBi0TH97K73ZPZ
16112Cdhnym0CsCJggrllFGH32cYo7+K2PO7/4oj5XbBCSWcssicvd8ovg==
1612-----END CERTIFICATE-----"#;
1613
1614        // Server private key
1615        // test-only embedded key -- though it might expire at some point:
1616        // @lint-ignore PRIVATEKEY insecure-private-key-storage
1617        const TEST_SERVER_KEY: &str = r#"-----BEGIN PRIVATE KEY-----
1618MIIEugIBADANBgkqhkiG9w0BAQEFAASCBKQwggSgAgEAAoIBAQCgm6fvqjck5+Sz
1619rFdIfYCyiQCwXI4NgNeCN/NODfVMg4T1tkCkNluNuLtj9THyq8P9gvghKxpawA5Z
16203KNkUFQwGqwCLhnRbwTreVHs6jkuzJODkvTJUuzWZzG3N+yrjPVCBuUnCY10f2ec
1621/XcXkdCPct59CMepIW1e7Tkh6G4sHiNBSPC37QUvgkJ16+r21j3B3FumdBR1L+yF
1622oajFpr/T5WURnvWOlc160adDwZ2TDpwGH2jEFMGEbMFnZvY/WdbDHUydeBvilZjt
16231Y+lyOkpvbiHGIZcBv5DkcKICRVpVbj/aK5aBUQo8hydXexmaqp0chxwwwgZrPHn
1624ft2HxGEBAgMBAAECgf8G5qlQov+7ljs9fSpC8yGUik59RXzVF7Qq5DyQHglsQDp2
1625VF5yr+M/M7DZmq+KvdauDfKbej6np5j2Q4TByrHTX1IExfZWCW8srwnWJDpQyHmO
1626LcJW5DlI/SYluUFyHZxsOd+ezcpGNzM8i6eSW7GaeFUXCkmJ+uW4LnlF+7bALnnd
1627D6sak/58EsII+IJyd4lFn+voszlPn3CZGR0jkp21rvpaKgrMIsKVWWQO/sLDU5pr
1628VbpBThcLU5gRcnQouQX12e2VTCIlFu75WTsJ8V/KnEaOZUVlU/B/Bs+WQF3U+/Jo
1629eX4N+D6OsEcNQjERAFyWujxsl1WpD4uSsbFMN0ECgYEA2b7AdL+oKPQHku2KcBhr
1630Zw8K4tMDlr2VPPNwZcBTLo+O71vv/xXjMcXrXmowzkgEQckUmt1VB46riyydhwdP
1631/n9ciWcz0Va/nwHR6Y9F9unBiyUBP7PRhRyjQyRZZRGDSJvP+Xmc5UJFpRr07VLU
1632nfgMXDj37vXzKDpfhdEB2nkCgYEAvNMfA8P8w3+6246x5YHflvTkPdw+2oyge+LD
1633mphB/w7SF8mlyNGloj3+KBZmd9SkvT57wCvO96Y9/n+mBAVisRggc0hK4ymOVYhb
1634+im/JvqGQMbVeg6iCOHnWdaZf9tL8uVsegQy3kVTN7vAa+CMFgX1dt65cGBX6XkB
163544pYmMkCgYALhbiRdQLlB+TOtZs5y1EDpxwgXKI3+9hF3Wv5NnAwapBZwje0++eF
16363r9Rw7TJda4j/QwGFehF+hrBxp6fYpetE/hFnRx0225Qb7w368j8A+ql/lNOl6li
1637rd1F1EqWupKD6RrcTL8sspEU55RGaretlE6zIqCcGI/BdTVQ03qRoQKBgHDC3zWf
1638d7XD9HGjQGdfbIe4jQjIGxzmd/wjik4q+NZ5IkukVwWa9P/zZ3DHF8Ad05dT1hEH
16392FwaAdGWpyyljq9VSiOuG1KXAXHgsZSuE4ISf9P1KYzvaiJFzaPfvOEWs79E9MfU
16409A+6dJzG2X1SpjWMr26iSTlrv3QkmFUqzAfJAoGASBkn4wls+oC5rv/Mch43pBv5
1641UmKru4ltnEHJZdbSi2DJ+AnDLD222JCasb1VT1tm2XgW6DBqrdVRPPP6GOlB0MHU
1642+3ULtZxAczt7I+ST2bo0/DV2Hse89Cm63w4wLOiVZs7+1wrAzJZLokWF7Q5gesra
1643u19txmtkiMEH+aNmekk=
1644-----END PRIVATE KEY-----"#;
1645
1646        #[async_timed_test(timeout_secs = 30)]
1647        async fn test_tls_basic() {
1648            // Ensure ring is installed as the default crypto provider
1649            // (no-op if already installed, e.g. under Buck with native-tls).
1650            let _ = rustls::crypto::ring::default_provider().install_default();
1651
1652            // Set up TLS config using the standard override pattern
1653            let config = hyperactor_config::global::lock();
1654            let _guard_cert =
1655                config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1656            let _guard_key =
1657                config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1658            let _guard_ca =
1659                config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1660
1661            // Create a TLS server bound to localhost with dynamic port
1662            let addr = TlsAddr::new("localhost", 0);
1663
1664            let (local_addr, mut rx) =
1665                server::serve::<u64>(ChannelAddr::Tls(addr)).expect("failed to serve");
1666
1667            // Dial the server
1668            let tx: super::NetTx<u64> = super::spawn(
1669                link(match &local_addr {
1670                    ChannelAddr::Tls(addr) => addr.clone(),
1671                    _ => panic!("unexpected address type"),
1672                })
1673                .expect("failed to create link"),
1674            );
1675
1676            // Send a message
1677            tx.post(42u64);
1678
1679            // Receive the message
1680            let received = rx.recv().await.expect("failed to receive");
1681            assert_eq!(received, 42u64);
1682        }
1683
1684        #[async_timed_test(timeout_secs = 30)]
1685        async fn test_tls_multiple_messages() {
1686            let _ = rustls::crypto::ring::default_provider().install_default();
1687
1688            // Set up TLS config using the standard override pattern
1689            let config = hyperactor_config::global::lock();
1690            let _guard_cert =
1691                config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1692            let _guard_key =
1693                config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1694            let _guard_ca =
1695                config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1696
1697            let addr = TlsAddr::new("localhost", 0);
1698
1699            let (local_addr, mut rx) =
1700                server::serve::<String>(ChannelAddr::Tls(addr)).expect("failed to serve");
1701            let tx: super::NetTx<String> = super::spawn(
1702                link(match &local_addr {
1703                    ChannelAddr::Tls(addr) => addr.clone(),
1704                    _ => panic!("unexpected address type"),
1705                })
1706                .expect("failed to create link"),
1707            );
1708
1709            // Send multiple messages
1710            for i in 0..10 {
1711                tx.post(format!("message {}", i));
1712            }
1713
1714            // Receive all messages
1715            for i in 0..10 {
1716                let received = rx.recv().await.expect("failed to receive");
1717                assert_eq!(received, format!("message {}", i));
1718            }
1719        }
1720
1721        #[test]
1722        fn test_tls_parse_hostname_port() {
1723            let addr = parse("localhost:8080").expect("failed to parse");
1724            assert!(matches!(
1725                addr,
1726                ChannelAddr::Tls(TlsAddr { hostname, port })
1727                    if hostname == "localhost" && port == 8080
1728            ));
1729        }
1730
1731        #[test]
1732        fn test_tls_parse_socket_addr() {
1733            let addr = parse("127.0.0.1:8080").expect("failed to parse");
1734            assert!(matches!(
1735                addr,
1736                ChannelAddr::Tls(TlsAddr { hostname, port })
1737                    if hostname == "127.0.0.1" && port == 8080
1738            ));
1739        }
1740
1741        #[test]
1742        fn test_tls_certs_parsing() {
1743            // Verify that the test certificates can be parsed correctly
1744            let cert_pem = Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec());
1745            let key_pem = Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec());
1746            let ca_pem = Pem::Value(TEST_CA_CERT.as_bytes().to_vec());
1747
1748            let certs = super::load_certs(&cert_pem).expect("failed to load certs");
1749            assert!(!certs.is_empty(), "expected at least one certificate");
1750
1751            let _key = super::load_key(&key_pem).expect("failed to load key");
1752
1753            let root_store = super::build_root_store(&ca_pem).expect("failed to build root store");
1754            assert!(!root_store.is_empty(), "expected at least one CA cert");
1755        }
1756
1757        #[test]
1758        fn test_tls_acceptor_creation() {
1759            // Ensure ring is installed as the default crypto provider
1760            // (no-op if already installed, e.g. under Buck with native-tls).
1761            let _ = rustls::crypto::ring::default_provider().install_default();
1762
1763            // Set up TLS config using the standard override pattern
1764            let config = hyperactor_config::global::lock();
1765            let _guard_cert =
1766                config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1767            let _guard_key =
1768                config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1769            let _guard_ca =
1770                config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1771
1772            // Verify that we can create a TLS acceptor
1773            let _acceptor = super::tls_acceptor().expect("failed to create TLS acceptor");
1774        }
1775
1776        #[test]
1777        fn test_tls_connector_creation() {
1778            // Ensure ring is installed as the default crypto provider
1779            // (no-op if already installed, e.g. under Buck with native-tls).
1780            let _ = rustls::crypto::ring::default_provider().install_default();
1781
1782            // Set up TLS config using the standard override pattern
1783            let config = hyperactor_config::global::lock();
1784            let _guard_cert =
1785                config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1786            let _guard_key =
1787                config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1788            let _guard_ca =
1789                config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1790
1791            // Verify that we can create a TLS connector
1792            let _connector = super::tls_connector().expect("failed to create TLS connector");
1793        }
1794    }
1795}
1796
1797/// Build the OSS PemBundle from hyperactor_config attributes.
1798fn oss_pem_bundle() -> crate::config::PemBundle {
1799    crate::config::PemBundle {
1800        ca: hyperactor_config::global::get_cloned(crate::config::TLS_CA),
1801        cert: hyperactor_config::global::get_cloned(crate::config::TLS_CERT),
1802        key: hyperactor_config::global::get_cloned(crate::config::TLS_KEY),
1803    }
1804}
1805
1806/// Try to find a usable TLS [`PemBundle`](crate::config::PemBundle)
1807/// by probing the same sources as [`try_tls_acceptor`] /
1808/// [`try_tls_connector`].
1809///
1810/// Returns the first bundle whose CA certificate is readable.
1811/// Only CA readability is checked — cert and key are returned as-is
1812/// and may not be valid. Callers that cannot use `tokio_rustls` types
1813/// directly (e.g. reqwest) can read the raw PEM bytes via
1814/// [`Pem::reader`](crate::config::Pem::reader).
1815pub fn try_tls_pem_bundle() -> Option<crate::config::PemBundle> {
1816    let oss_bundle = oss_pem_bundle();
1817    if oss_bundle.ca.reader().is_ok() {
1818        return Some(oss_bundle);
1819    }
1820    tracing::debug!("OSS TLS bundle: CA not readable, trying Meta paths");
1821
1822    let meta_bundle = meta::get_server_pem_bundle();
1823    if meta_bundle.ca.reader().is_ok() {
1824        return Some(meta_bundle);
1825    }
1826    tracing::debug!("Meta TLS bundle: CA not readable, no TLS available");
1827
1828    None
1829}
1830
1831/// Try to build a [`TlsAcceptor`](tokio_rustls::TlsAcceptor) for an
1832/// HTTP server by probing for available TLS certificates.
1833///
1834/// Detection order:
1835/// 1. **OSS / explicit config** — `HYPERACTOR_TLS_CERT`,
1836///    `HYPERACTOR_TLS_KEY`, and `HYPERACTOR_TLS_CA` (read via
1837///    [`hyperactor_config`]).
1838/// 2. **Meta default paths** —
1839///    `/var/facebook/x509_identities/server.pem` and
1840///    `/var/facebook/rootcanal/ca.pem`. These are present on
1841///    devservers and in MAST / Tupperware containers.
1842/// 3. **None** — no usable certificates found; caller should fall
1843///    back to plain HTTP.
1844///
1845/// When `enforce_client_tls` is `true`, the returned acceptor
1846/// requires clients to present a valid certificate signed by the
1847/// configured CA (mutual TLS via `WebPkiClientVerifier`). When
1848/// `false`, the acceptor authenticates itself but does not demand
1849/// client certificates.
1850pub fn try_tls_acceptor(enforce_client_tls: bool) -> Option<tokio_rustls::TlsAcceptor> {
1851    let oss_bundle = oss_pem_bundle();
1852    if let Ok(acceptor) = tls::tls_acceptor_from_bundle(&oss_bundle, enforce_client_tls) {
1853        return Some(acceptor);
1854    }
1855    tracing::debug!("OSS TLS acceptor failed, trying Meta paths");
1856
1857    let meta_bundle = meta::get_server_pem_bundle();
1858    if let Ok(acceptor) = tls::tls_acceptor_from_bundle(&meta_bundle, enforce_client_tls) {
1859        return Some(acceptor);
1860    }
1861    tracing::debug!("Meta TLS acceptor failed, no TLS available");
1862
1863    None
1864}
1865
1866/// Try to build a [`TlsConnector`](tokio_rustls::TlsConnector) for an
1867/// HTTP client that needs to connect to a TLS-enabled server.
1868///
1869/// Detection mirrors [`try_tls_acceptor`]:
1870/// 1. **OSS** — `HYPERACTOR_TLS_CA` (and optionally
1871///    `HYPERACTOR_TLS_CERT` + `HYPERACTOR_TLS_KEY` for mutual TLS).
1872/// 2. **Meta** — root CA at `/var/facebook/rootcanal/ca.pem`,
1873///    optional client certs from `THRIFT_TLS_CL_CERT_PATH` /
1874///    `THRIFT_TLS_CL_KEY_PATH`.
1875/// 3. **None** — no usable CA found; caller should fall back to plain
1876///    HTTP.
1877pub fn try_tls_connector() -> Option<tokio_rustls::TlsConnector> {
1878    let oss_bundle = oss_pem_bundle();
1879    if let Ok(connector) = tls::tls_connector_from_bundle(&oss_bundle) {
1880        return Some(connector);
1881    }
1882    tracing::debug!("OSS TLS connector failed, trying Meta paths");
1883
1884    if let Ok(connector) = meta::try_tls_connector() {
1885        return Some(connector);
1886    }
1887    tracing::debug!("Meta TLS connector failed, no TLS available");
1888
1889    None
1890}
1891
1892#[cfg(test)]
1893mod tests {
1894    use std::assert_matches::assert_matches;
1895    use std::collections::VecDeque;
1896    use std::marker::PhantomData;
1897    use std::sync::Arc;
1898    use std::sync::RwLock;
1899    use std::sync::atomic::AtomicBool;
1900    use std::sync::atomic::AtomicU64;
1901    use std::sync::atomic::Ordering;
1902    use std::time::Duration;
1903    #[cfg(target_os = "linux")] // uses abstract names
1904    use std::time::UNIX_EPOCH;
1905
1906    #[cfg(target_os = "linux")] // uses abstract names
1907    use anyhow::Result;
1908    use bytes::Bytes;
1909    use rand::Rng;
1910    use rand::SeedableRng;
1911    use rand::distr::Alphanumeric;
1912    use timed_test::async_timed_test;
1913    use tokio::io::AsyncWrite;
1914    use tokio::io::DuplexStream;
1915    use tokio::io::ReadHalf;
1916    use tokio::io::WriteHalf;
1917    use tokio::task::JoinHandle;
1918    use tokio_util::sync::CancellationToken;
1919
1920    use super::server;
1921    use super::*;
1922    use crate::channel;
1923    use crate::channel::net::framed::FrameReader;
1924    use crate::channel::net::framed::FrameWrite;
1925    use crate::channel::net::server::AcceptorLink;
1926    use crate::config;
1927    use crate::metrics;
1928    use crate::sync::mvar::MVar;
1929
1930    /// Like the `logs_assert` injected by `#[traced_test]`, but without scope
1931    /// filtering. Use when asserting on events emitted outside the test's span
1932    /// (e.g. from spawned tasks or panic hooks).
1933    fn logs_assert_unscoped(f: impl Fn(&[&str]) -> Result<(), String>) {
1934        let buf = tracing_test::internal::global_buf().lock().unwrap();
1935        let logs_str = std::str::from_utf8(&buf).expect("Logs contain invalid UTF8");
1936        let lines: Vec<&str> = logs_str.lines().collect();
1937        match f(&lines) {
1938            Ok(()) => {}
1939            Err(msg) => panic!("{}", msg),
1940        }
1941    }
1942
1943    #[cfg(target_os = "linux")] // uses abstract names
1944    #[tracing_test::traced_test]
1945    #[tokio::test]
1946    async fn test_unix_basic() -> Result<()> {
1947        let timestamp = std::time::SystemTime::now()
1948            .duration_since(UNIX_EPOCH)
1949            .unwrap()
1950            .as_nanos();
1951        let unique_address = format!("test_unix_basic_{}", timestamp);
1952
1953        let (addr, mut rx) = server::serve::<u64>(ChannelAddr::Unix(
1954            unix::SocketAddr::from_abstract_name(&unique_address)?,
1955        ))
1956        .unwrap();
1957
1958        // It is important to keep Tx alive until all expected messages are
1959        // received. Otherwise, the channel would be closed when Tx is dropped.
1960        // Although the messages are sent to the server's buffer before the
1961        // channel was closed, NetRx could still error out before taking them
1962        // out of the buffer because NetRx could not ack through the closed
1963        // channel.
1964        {
1965            let tx: ChannelTx<u64> = channel::dial::<u64>(addr.clone()).unwrap();
1966            tx.post(123);
1967            assert_eq!(rx.recv().await.unwrap(), 123);
1968        }
1969
1970        {
1971            let tx = channel::dial::<u64>(addr.clone()).unwrap();
1972            tx.post(321);
1973            tx.post(111);
1974            tx.post(444);
1975
1976            assert_eq!(rx.recv().await.unwrap(), 321);
1977            assert_eq!(rx.recv().await.unwrap(), 111);
1978            assert_eq!(rx.recv().await.unwrap(), 444);
1979        }
1980
1981        {
1982            let tx = channel::dial::<u64>(addr).unwrap();
1983            drop(rx);
1984
1985            let (return_tx, return_rx) = oneshot::channel();
1986            tx.try_post(123, return_tx);
1987            assert_matches!(
1988                return_rx.await,
1989                Ok(SendError {
1990                    error: ChannelError::Closed,
1991                    message: 123,
1992                    ..
1993                })
1994            );
1995        }
1996
1997        Ok(())
1998    }
1999
2000    #[cfg(target_os = "linux")] // uses abstract names
2001    #[tracing_test::traced_test]
2002    #[tokio::test]
2003    async fn test_unix_basic_client_before_server() -> Result<()> {
2004        // We run this test on Unix because we can pick our own port names more easily.
2005        let timestamp = std::time::SystemTime::now()
2006            .duration_since(UNIX_EPOCH)
2007            .unwrap()
2008            .as_nanos();
2009        let socket_addr =
2010            unix::SocketAddr::from_abstract_name(&format!("test_unix_basic_{}", timestamp))
2011                .unwrap();
2012
2013        // Dial the channel before we actually serve it.
2014        let addr = ChannelAddr::Unix(socket_addr.clone());
2015        let tx = crate::channel::dial::<u64>(addr.clone()).unwrap();
2016        tx.post(123);
2017
2018        let (_, mut rx) = server::serve::<u64>(ChannelAddr::Unix(socket_addr)).unwrap();
2019        assert_eq!(rx.recv().await.unwrap(), 123);
2020
2021        tx.post(321);
2022        tx.post(111);
2023        tx.post(444);
2024
2025        assert_eq!(rx.recv().await.unwrap(), 321);
2026        assert_eq!(rx.recv().await.unwrap(), 111);
2027        assert_eq!(rx.recv().await.unwrap(), 444);
2028
2029        Ok(())
2030    }
2031
2032    #[tracing_test::traced_test]
2033    #[async_timed_test(timeout_secs = 60)]
2034    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
2035    #[cfg_attr(not(fbcode_build), ignore)]
2036    async fn test_tcp_basic() {
2037        let (addr, mut rx) =
2038            server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
2039        {
2040            let tx = channel::dial::<u64>(addr.clone()).unwrap();
2041            tx.post(123);
2042            assert_eq!(rx.recv().await.unwrap(), 123);
2043        }
2044
2045        {
2046            let tx = channel::dial::<u64>(addr.clone()).unwrap();
2047            tx.post(321);
2048            tx.post(111);
2049            tx.post(444);
2050
2051            assert_eq!(rx.recv().await.unwrap(), 321);
2052            assert_eq!(rx.recv().await.unwrap(), 111);
2053            assert_eq!(rx.recv().await.unwrap(), 444);
2054        }
2055
2056        {
2057            let tx = channel::dial::<u64>(addr).unwrap();
2058            drop(rx);
2059
2060            let (return_tx, return_rx) = oneshot::channel();
2061            tx.try_post(123, return_tx);
2062            assert_matches!(
2063                return_rx.await,
2064                Ok(SendError {
2065                    error: ChannelError::Closed,
2066                    message: 123,
2067                    ..
2068                })
2069            );
2070        }
2071    }
2072
2073    // The message size is limited by CODEC_MAX_FRAME_LENGTH.
2074    #[async_timed_test(timeout_secs = 5)]
2075    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
2076    #[cfg_attr(not(fbcode_build), ignore)]
2077    async fn test_tcp_message_size() {
2078        let default_size_in_bytes = 100 * 1024 * 1024;
2079        // Use temporary config for this test
2080        let config = hyperactor_config::global::lock();
2081        let _guard1 = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1));
2082        let _guard2 = config.override_key(config::CODEC_MAX_FRAME_LENGTH, default_size_in_bytes);
2083
2084        let (addr, mut rx) =
2085            server::serve::<String>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
2086
2087        let tx = channel::dial::<String>(addr.clone()).unwrap();
2088        // Default size is okay
2089        {
2090            // Leave some headroom because Tx will wrap the payload in Frame::Message.
2091            let message = "a".repeat(default_size_in_bytes - 1024);
2092            tx.post(message.clone());
2093            assert_eq!(rx.recv().await.unwrap(), message);
2094        }
2095        // Bigger than the default size will fail.
2096        {
2097            let (return_channel, return_receiver) = oneshot::channel();
2098            let message = "a".repeat(default_size_in_bytes + 1024);
2099            tx.try_post(message.clone(), return_channel);
2100            let returned = return_receiver.await.unwrap();
2101            assert_eq!(message, returned.message);
2102        }
2103    }
2104
2105    #[async_timed_test(timeout_secs = 30)]
2106    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
2107    #[cfg_attr(not(fbcode_build), ignore)]
2108    async fn test_ack_flush() {
2109        let config = hyperactor_config::global::lock();
2110        // Set a large value to effectively prevent acks from being sent except
2111        // during shutdown flush.
2112        let _guard_message_ack =
2113            config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 100000000);
2114        let _guard_delivery_timeout =
2115            config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(5));
2116
2117        let (addr, mut net_rx) =
2118            server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
2119        let net_tx = channel::dial::<u64>(addr.clone()).unwrap();
2120        let (tx, rx) = oneshot::channel();
2121        net_tx.try_post(1, tx);
2122        assert_eq!(net_rx.recv().await.unwrap(), 1);
2123        drop(net_rx);
2124        // Using `is_err` to confirm the message is delivered/acked is confusing,
2125        // but is correct. See how send is implemented: https://fburl.com/code/ywt8lip2
2126        assert!(rx.await.is_err());
2127    }
2128
2129    #[async_timed_test(timeout_secs = 60)]
2130    // TODO: OSS: failed to retrieve ipv6 address
2131    #[cfg_attr(not(fbcode_build), ignore)]
2132    async fn test_meta_tls_basic() {
2133        hyperactor_telemetry::initialize_logging_for_test();
2134
2135        let addr = ChannelAddr::any(ChannelTransport::MetaTls(TlsMode::IpV6));
2136        let meta_addr = match addr {
2137            ChannelAddr::MetaTls(meta_addr) => meta_addr,
2138            _ => panic!("expected MetaTls address"),
2139        };
2140        let (local_addr, mut rx) = server::serve::<u64>(ChannelAddr::MetaTls(meta_addr)).unwrap();
2141        {
2142            let tx = channel::dial::<u64>(local_addr.clone()).unwrap();
2143            tx.post(123);
2144        }
2145        assert_eq!(rx.recv().await.unwrap(), 123);
2146
2147        {
2148            let tx = channel::dial::<u64>(local_addr.clone()).unwrap();
2149            tx.post(321);
2150            tx.post(111);
2151            tx.post(444);
2152            assert_eq!(rx.recv().await.unwrap(), 321);
2153            assert_eq!(rx.recv().await.unwrap(), 111);
2154            assert_eq!(rx.recv().await.unwrap(), 444);
2155        }
2156
2157        {
2158            let tx = channel::dial::<u64>(local_addr).unwrap();
2159            drop(rx);
2160
2161            let (return_tx, return_rx) = oneshot::channel();
2162            tx.try_post(123, return_tx);
2163            assert_matches!(
2164                return_rx.await,
2165                Ok(SendError {
2166                    error: ChannelError::Closed,
2167                    message: 123,
2168                    ..
2169                })
2170            );
2171        }
2172    }
2173
2174    #[derive(Clone, Debug, Default)]
2175    struct NetworkFlakiness {
2176        // A tuple of:
2177        //   1. the probability of a network failure when sending a message.
2178        //   2. the max number of disconnections allowed.
2179        //   3. the minimum duration between disconnections.
2180        //
2181        //   2 and 3 are useful to prevent frequent disconnections leading to
2182        //   unacked messages being sent repeatedly.
2183        disconnect_params: Option<(f64, u64, Duration)>,
2184        // The max possible latency when sending a message. The actual latency
2185        // is randomly generated between 0 and max_latency.
2186        latency_range: Option<(Duration, Duration)>,
2187    }
2188
2189    impl NetworkFlakiness {
2190        // Calculate whether to disconnect
2191        async fn should_disconnect(
2192            &self,
2193            rng: &mut impl rand::Rng,
2194            disconnected_count: u64,
2195            prev_disconnected_at: &RwLock<Instant>,
2196        ) -> bool {
2197            let Some((prob, max_disconnects, duration)) = &self.disconnect_params else {
2198                return false;
2199            };
2200
2201            let disconnected_at = prev_disconnected_at.read().unwrap();
2202            if disconnected_at.elapsed() > *duration && disconnected_count < *max_disconnects {
2203                rng.random_bool(*prob)
2204            } else {
2205                false
2206            }
2207        }
2208    }
2209
2210    struct MockLink<M> {
2211        buffer_size: usize,
2212        session_id: SessionId,
2213        receiver_storage: Arc<MVar<DuplexStream>>,
2214        // If true, `next()` on this link will always return an error.
2215        fail_connects: Arc<AtomicBool>,
2216        // Used to break the existing connection, if there is one. It still
2217        // allows reconnect.
2218        disconnect_signal: watch::Sender<()>,
2219        network_flakiness: NetworkFlakiness,
2220        disconnected_count: Arc<AtomicU64>,
2221        prev_disconnected_at: Arc<RwLock<Instant>>,
2222        // If set, print logs every `debug_log_sampling_rate` messages. This
2223        // is normally set only when debugging a test failure.
2224        debug_log_sampling_rate: Option<u64>,
2225        _message_type: PhantomData<M>,
2226    }
2227
2228    impl<M> fmt::Debug for MockLink<M> {
2229        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2230            f.debug_struct("MockLink")
2231                .field("buffer_size", &self.buffer_size)
2232                .field("receiver_storage", &"<MVar<DuplexStream>>")
2233                .field("fail_connects", &self.fail_connects)
2234                .field("disconnect_signal", &"<watch::Sender>")
2235                .field("network_flakiness", &self.network_flakiness)
2236                .field("disconnected_count", &self.disconnected_count)
2237                .field("prev_disconnected_at", &"<RwLock<Instant>>")
2238                .field("debug_log_sampling_rate", &self.debug_log_sampling_rate)
2239                .finish()
2240        }
2241    }
2242
2243    impl<M: RemoteMessage> MockLink<M> {
2244        fn new() -> Self {
2245            let (sender, _) = watch::channel(());
2246            Self {
2247                buffer_size: 64,
2248                session_id: SessionId::random(),
2249                receiver_storage: Arc::new(MVar::empty()),
2250                fail_connects: Arc::new(AtomicBool::new(false)),
2251                disconnect_signal: sender,
2252                network_flakiness: NetworkFlakiness::default(),
2253                disconnected_count: Arc::new(AtomicU64::new(0)),
2254                prev_disconnected_at: Arc::new(RwLock::new(tokio::time::Instant::now())),
2255                debug_log_sampling_rate: None,
2256                _message_type: PhantomData,
2257            }
2258        }
2259
2260        // If `fail_connects` is true, `next()` on this link will
2261        // always return an error.
2262        fn fail_connects() -> Self {
2263            Self {
2264                fail_connects: Arc::new(AtomicBool::new(true)),
2265                ..Self::new()
2266            }
2267        }
2268
2269        fn with_network_flakiness(network_flakiness: NetworkFlakiness) -> Self {
2270            if let Some((min, max)) = network_flakiness.latency_range {
2271                assert!(min < max);
2272            }
2273
2274            Self {
2275                network_flakiness,
2276                ..Self::new()
2277            }
2278        }
2279
2280        fn receiver_storage(&self) -> Arc<MVar<DuplexStream>> {
2281            self.receiver_storage.clone()
2282        }
2283
2284        fn disconnected_count(&self) -> Arc<AtomicU64> {
2285            self.disconnected_count.clone()
2286        }
2287
2288        fn disconnect_signal(&self) -> &watch::Sender<()> {
2289            &self.disconnect_signal
2290        }
2291
2292        fn fail_connects_switch(&self) -> Arc<AtomicBool> {
2293            self.fail_connects.clone()
2294        }
2295
2296        fn set_buffer_size(&mut self, size: usize) {
2297            self.buffer_size = size;
2298        }
2299
2300        fn set_sampling_rate(&mut self, sampling_rate: u64) {
2301            self.debug_log_sampling_rate = Some(sampling_rate);
2302        }
2303    }
2304
2305    #[async_trait]
2306    impl<M: RemoteMessage> Link for MockLink<M> {
2307        type Stream = DuplexStream;
2308
2309        fn dest(&self) -> ChannelAddr {
2310            ChannelAddr::Local(u64::MAX)
2311        }
2312
2313        fn link_id(&self) -> SessionId {
2314            self.session_id
2315        }
2316
2317        async fn next(&self) -> Result<Self::Stream, ClientError> {
2318            let session_id = self.session_id;
2319            tracing::debug!("MockLink starts to connect.");
2320            if self.fail_connects.load(Ordering::Acquire) {
2321                return Err(ClientError::Connect(
2322                    self.dest(),
2323                    std::io::Error::other("intentional error"),
2324                    "expected failure injected by the mock".to_string(),
2325                ));
2326            }
2327
2328            // Add relays between server and client streams. The
2329            // relays provides the place to inject network flakiness.
2330            // The message flow looks like:
2331            //
2332            // server <-> server relay <-> injection logic <-> client relay <-> client
2333            async fn relay_message<M: RemoteMessage>(
2334                mut disconnect_signal: watch::Receiver<()>,
2335                network_flakiness: NetworkFlakiness,
2336                disconnected_count: Arc<AtomicU64>,
2337                prev_disconnected_at: Arc<RwLock<Instant>>,
2338                mut reader: FrameReader<ReadHalf<DuplexStream>>,
2339                mut writer: WriteHalf<DuplexStream>,
2340                // Used by client and server tokio tasks to coordinate
2341                // stopping together.
2342                task_coordination_token: CancellationToken,
2343                debug_log_sampling_rate: Option<u64>,
2344                // Whether the relayed message is from client to
2345                // server.
2346                is_from_client: bool,
2347            ) {
2348                // Used to simulate latency. Briefly, messages are
2349                // buffered in the queue and wait for the expected
2350                // latency elapse.
2351                async fn wait_for_latency_elapse(
2352                    queue: &VecDeque<(Bytes, Instant)>,
2353                    network_flakiness: &NetworkFlakiness,
2354                    rng: &mut impl rand::Rng,
2355                ) {
2356                    if let Some((min, max)) = network_flakiness.latency_range {
2357                        let diff = max.abs_diff(min);
2358                        let factor = rng.random_range(0.0..=1.0);
2359                        let latency = min + diff.mul_f64(factor);
2360                        tokio::time::sleep_until(queue.front().unwrap().1 + latency).await;
2361                    }
2362                }
2363
2364                let mut rng = rand::rngs::SmallRng::from_os_rng();
2365                let mut queue: VecDeque<(Bytes, Instant)> = VecDeque::new();
2366                let mut send_count = 0u64;
2367
2368                loop {
2369                    tokio::select! {
2370                        read_res = reader.next() => {
2371                            match read_res {
2372                                Ok(Some((_, data))) => {
2373                                    queue.push_back((data, tokio::time::Instant::now()));
2374                                }
2375                                Ok(None) | Err(_) => {
2376                                        tracing::debug!("The upstream is closed or dropped. MockLink disconnects");
2377                                        break;
2378                                }
2379                            }
2380                        }
2381                        _ = wait_for_latency_elapse(&queue, &network_flakiness, &mut rng), if !queue.is_empty() => {
2382                            let count = disconnected_count.load(Ordering::Relaxed);
2383                            if network_flakiness.should_disconnect(&mut rng, count, &prev_disconnected_at).await {
2384                                tracing::debug!("MockLink disconnects");
2385                                disconnected_count.fetch_add(1, Ordering::Relaxed);
2386
2387                                metrics::CHANNEL_RECONNECTIONS.add(
2388                                    1,
2389                                    hyperactor_telemetry::kv_pairs!(
2390                                        "transport" => "mock",
2391                                        "reason" => "network_flakiness",
2392                                    ),
2393                                );
2394
2395                                let mut w = prev_disconnected_at.write().unwrap();
2396                                *w = tokio::time::Instant::now();
2397                                break;
2398                            }
2399                            let data = queue.pop_front().unwrap().0;
2400                            let is_sampled = debug_log_sampling_rate.is_some_and(|sample_rate| send_count % sample_rate == 1);
2401                            if is_sampled {
2402                                if is_from_client {
2403                                    if let Ok(Frame::Message(_seq, _msg)) = bincode::deserialize::<Frame<M>>(&data) {
2404                                        tracing::debug!("MockLink relays a msg from client. msg type: {}", std::any::type_name::<M>());
2405                                    }
2406                                } else {
2407                                    let result = deserialize_response(data.clone());
2408                                    if let Ok(NetRxResponse::Ack(seq)) = result {
2409                                        tracing::debug!("MockLink relays an ack from server. seq: {}", seq);
2410                                    }
2411                                }
2412                            }
2413                            let mut fw  = FrameWrite::new(writer, data, hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH), 0).unwrap();
2414                            if fw.send().await.is_err() {
2415                                break;
2416                            }
2417                            writer = fw.complete();
2418                            send_count += 1;
2419                        }
2420                        _ = task_coordination_token.cancelled() => break,
2421
2422                        changed = disconnect_signal.changed() => {
2423                            tracing::debug!("MockLink disconnects per disconnect_signal {:?}", changed);
2424                            break;
2425                        }
2426                    }
2427                }
2428
2429                task_coordination_token.cancel();
2430            }
2431
2432            let (server, mut server_relay) = tokio::io::duplex(self.buffer_size);
2433            let (client, client_relay) = tokio::io::duplex(self.buffer_size);
2434
2435            // Write LinkInit on server_relay so it's readable from `server`.
2436            // This simulates the client sending LinkInit over the wire before
2437            // the frame-level relay begins.
2438            write_link_init(&mut server_relay, session_id)
2439                .await
2440                .map_err(|err| ClientError::Io(self.dest(), err))?;
2441
2442            let (server_r, server_writer) = tokio::io::split(server_relay);
2443            let (client_r, client_writer) = tokio::io::split(client_relay);
2444
2445            let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
2446            let server_reader = FrameReader::new(server_r, max_len);
2447            let client_reader = FrameReader::new(client_r, max_len);
2448
2449            let task_coordination_token = CancellationToken::new();
2450            let _server_relay_task_handle = tokio::spawn(relay_message::<M>(
2451                self.disconnect_signal.subscribe(),
2452                self.network_flakiness.clone(),
2453                self.disconnected_count.clone(),
2454                self.prev_disconnected_at.clone(),
2455                server_reader,
2456                client_writer,
2457                task_coordination_token.clone(),
2458                self.debug_log_sampling_rate.clone(),
2459                /*is_from_client*/ false,
2460            ));
2461            let _client_relay_task_handle = tokio::spawn(relay_message::<M>(
2462                self.disconnect_signal.subscribe(),
2463                self.network_flakiness.clone(),
2464                self.disconnected_count.clone(),
2465                self.prev_disconnected_at.clone(),
2466                client_reader,
2467                server_writer,
2468                task_coordination_token,
2469                self.debug_log_sampling_rate.clone(),
2470                /*is_from_client*/ true,
2471            ));
2472
2473            self.receiver_storage.put(server).await;
2474            Ok(client)
2475        }
2476    }
2477
2478    struct MockLinkListener {
2479        receiver_storage: Arc<MVar<DuplexStream>>,
2480        channel_addr: ChannelAddr,
2481    }
2482
2483    impl MockLinkListener {
2484        fn new(receiver_storage: Arc<MVar<DuplexStream>>, channel_addr: ChannelAddr) -> Self {
2485            Self {
2486                receiver_storage,
2487                channel_addr,
2488            }
2489        }
2490    }
2491
2492    impl fmt::Debug for MockLinkListener {
2493        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2494            f.debug_struct("MockLinkListener")
2495                .field("channel_addr", &self.channel_addr)
2496                .finish()
2497        }
2498    }
2499
2500    #[async_trait]
2501    impl super::Listener for MockLinkListener {
2502        type Stream = DuplexStream;
2503
2504        async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
2505            let stream = self.receiver_storage.take().await;
2506            Ok((stream, self.channel_addr.clone()))
2507        }
2508    }
2509
2510    /// Create an AcceptorLink-based server test rig. Returns the
2511    /// session task handle, an MVar for dispatching streams,
2512    /// the message receiver, and a cancellation token.
2513    fn serve_acceptor_test<M: RemoteMessage>(
2514        session_id: SessionId,
2515    ) -> (
2516        JoinHandle<()>,
2517        crate::sync::mvar::MVar<DuplexStream>,
2518        mpsc::Receiver<M>,
2519        CancellationToken,
2520    ) {
2521        let mvar = crate::sync::mvar::MVar::empty();
2522        let cancel_token = CancellationToken::new();
2523        let link = AcceptorLink {
2524            dest: ChannelAddr::Local(u64::MAX),
2525            session_id,
2526            stream: mvar.clone(),
2527            cancel: cancel_token.clone(),
2528        };
2529        let (tx, rx) = mpsc::channel::<M>(1024);
2530        let ct = cancel_token.clone();
2531        let handle = tokio::spawn(async move {
2532            let mut session = Session::new(link);
2533            let mut next = session::Next { seq: 0, ack: 0 };
2534
2535            loop {
2536                let connected = match session.connect().await {
2537                    Ok(s) => s,
2538                    Err(_) => break,
2539                };
2540
2541                let result = {
2542                    let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2543                    tokio::select! {
2544                        r = session::recv_connected::<M, _, _>(&stream, &tx, &mut next) => r,
2545                        _ = ct.cancelled() => Err(session::RecvLoopError::Cancelled),
2546                    }
2547                };
2548
2549                // Flush remaining ack if behind.
2550                if next.ack < next.seq {
2551                    let ack = serialize_response(NetRxResponse::Ack(next.seq - 1)).unwrap();
2552                    let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2553                    let mut completion = stream.write(ack);
2554                    match completion.drive().await {
2555                        Ok(()) => {
2556                            next.ack = next.seq;
2557                        }
2558                        Err(e) => {
2559                            tracing::debug!(
2560                                error = %e,
2561                                "failed to flush acks during cleanup"
2562                            );
2563                        }
2564                    }
2565                }
2566
2567                // Send reject or closed response if appropriate.
2568                let terminal_response = match &result {
2569                    Err(session::RecvLoopError::SequenceError(reason)) => {
2570                        Some(NetRxResponse::Reject(reason.clone()))
2571                    }
2572                    Err(session::RecvLoopError::Cancelled) => Some(NetRxResponse::Closed),
2573                    _ => None,
2574                };
2575                if let Some(rsp) = terminal_response {
2576                    let data = serialize_response(rsp).unwrap();
2577                    let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2578                    let mut completion = stream.write(data);
2579                    let _ = completion.drive().await;
2580                }
2581
2582                let recoverable = matches!(&result, Ok(()) | Err(session::RecvLoopError::Io(_)));
2583                session = connected.release();
2584                if recoverable {
2585                    continue;
2586                }
2587                break;
2588            }
2589        });
2590        (handle, mvar, rx, cancel_token)
2591    }
2592
2593    async fn write_stream<M, W>(
2594        mut writer: W,
2595        _session_id: u64,
2596        messages: &[(u64, M)],
2597        _init: bool,
2598    ) -> W
2599    where
2600        M: RemoteMessage + PartialEq + Clone,
2601        W: AsyncWrite + Unpin,
2602    {
2603        for (seq, message) in messages {
2604            let message =
2605                serde_multipart::serialize_bincode(&Frame::<M>::Message(*seq, message.clone()))
2606                    .unwrap();
2607            let mut fw = FrameWrite::new(
2608                writer,
2609                message.framed(),
2610                hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2611                0,
2612            )
2613            .map_err(|(_w, e)| e)
2614            .unwrap();
2615            fw.send().await.unwrap();
2616            writer = fw.complete();
2617        }
2618
2619        writer
2620    }
2621
2622    #[async_timed_test(timeout_secs = 60)]
2623    async fn test_persistent_server_session() {
2624        let config = hyperactor_config::global::lock();
2625        let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
2626
2627        async fn verify_ack(reader: &mut FrameReader<ReadHalf<DuplexStream>>, expected_last: u64) {
2628            let mut last_acked: i128 = -1;
2629            loop {
2630                let (_, bytes) = reader.next().await.unwrap().unwrap();
2631                let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
2632                assert!(
2633                    acked as i128 > last_acked,
2634                    "acks should be delivered in ascending order"
2635                );
2636                last_acked = acked as i128;
2637                assert!(acked <= expected_last);
2638                if acked == expected_last {
2639                    break;
2640                }
2641            }
2642        }
2643
2644        let session_id = SessionId(123);
2645        let (_handle, mvar, mut rx, cancel_token) = serve_acceptor_test::<u64>(session_id);
2646
2647        // First connection: send messages, verify delivery and ack.
2648        {
2649            let (sender, receiver) = tokio::io::duplex(5000);
2650            mvar.put(receiver).await;
2651
2652            let (r, writer) = tokio::io::split(sender);
2653            let mut reader = FrameReader::new(
2654                r,
2655                hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2656            );
2657
2658            let _writer = write_stream(
2659                writer,
2660                123,
2661                &[
2662                    (0u64, 100u64),
2663                    (1u64, 101u64),
2664                    (2u64, 102u64),
2665                    (3u64, 103u64),
2666                ],
2667                true,
2668            )
2669            .await;
2670
2671            assert_eq!(rx.recv().await, Some(100));
2672            assert_eq!(rx.recv().await, Some(101));
2673            assert_eq!(rx.recv().await, Some(102));
2674            assert_eq!(rx.recv().await, Some(103));
2675
2676            verify_ack(&mut reader, 3).await;
2677            // Drop reader and writer to close the connection.
2678        }
2679
2680        // Second connection (reconnection): retransmitted messages are deduped.
2681        {
2682            let (sender2, receiver2) = tokio::io::duplex(5000);
2683            mvar.put(receiver2).await;
2684
2685            let (r2, writer2) = tokio::io::split(sender2);
2686            let mut reader2 = FrameReader::new(
2687                r2,
2688                hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2689            );
2690
2691            let _ = write_stream(
2692                writer2,
2693                123,
2694                &[
2695                    (2u64, 102u64),
2696                    (3u64, 103u64),
2697                    (4u64, 104u64),
2698                    (5u64, 105u64),
2699                ],
2700                true,
2701            )
2702            .await;
2703
2704            // 102 and 103 are retransmits; only 104 and 105 are new.
2705            assert_eq!(rx.recv().await, Some(104));
2706            assert_eq!(rx.recv().await, Some(105));
2707
2708            verify_ack(&mut reader2, 5).await;
2709
2710            cancel_token.cancel();
2711        }
2712    }
2713
2714    #[async_timed_test(timeout_secs = 60)]
2715    async fn test_ack_from_server_session() {
2716        let config = hyperactor_config::global::lock();
2717        let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
2718        let session_id = SessionId(123);
2719        let (_handle, mvar, mut rx, cancel_token) = serve_acceptor_test::<u64>(session_id);
2720
2721        let (sender, receiver) = tokio::io::duplex(5000);
2722        mvar.put(receiver).await;
2723        let (r, mut writer) = tokio::io::split(sender);
2724        let mut reader = FrameReader::new(
2725            r,
2726            hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2727        );
2728
2729        for i in 0u64..100u64 {
2730            writer = write_stream(writer, 123, &[(i, 100u64 + i)], /*init*/ i == 0u64).await;
2731            assert_eq!(rx.recv().await, Some(100u64 + i));
2732            let (_, bytes) = reader.next().await.unwrap().unwrap();
2733            let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
2734            assert_eq!(acked, i);
2735        }
2736
2737        // Wait long enough to ensure server processed everything.
2738        tokio::time::sleep(Duration::from_secs(5)).await;
2739
2740        cancel_token.cancel();
2741
2742        // Should send NetRxResponse::Closed before stopping.
2743        let (_, bytes) = reader.next().await.unwrap().unwrap();
2744        assert!(deserialize_response(bytes).unwrap().is_closed());
2745    }
2746
2747    #[tracing_test::traced_test]
2748    async fn verify_tx_closed(tx_status: &mut watch::Receiver<TxStatus>, expected_log: &str) {
2749        match tokio::time::timeout(Duration::from_secs(5), tx_status.changed()).await {
2750            Ok(Ok(())) => {
2751                let current_status = tx_status.borrow().clone();
2752                assert!(current_status.is_closed());
2753                logs_assert_unscoped(|logs| {
2754                    if logs.iter().any(|log| log.contains(expected_log)) {
2755                        Ok(())
2756                    } else {
2757                        Err("expected log not found".to_string())
2758                    }
2759                });
2760            }
2761            Ok(Err(_)) => panic!("watch::Receiver::changed() failed because sender is dropped."),
2762            Err(_) => panic!("timeout before tx_status changed"),
2763        }
2764    }
2765
2766    #[tracing_test::traced_test]
2767    #[tokio::test]
2768    // TODO: OSS: The logs_assert function returned an error: expected log not found
2769    #[cfg_attr(not(fbcode_build), ignore)]
2770    async fn test_tcp_tx_delivery_timeout() {
2771        // This link always fails to connect.
2772        let link = MockLink::<u64>::fail_connects();
2773        let tx = spawn::<u64>(link);
2774        // Override the default (1m) for the purposes of this test.
2775        let config = hyperactor_config::global::lock();
2776        let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1));
2777        let mut tx_receiver = tx.status().clone();
2778        let (return_channel, _return_receiver) = oneshot::channel();
2779        tx.try_post(123, return_channel);
2780        verify_tx_closed(&mut tx_receiver, "failed to deliver message within timeout").await;
2781    }
2782
2783    async fn take_receiver(
2784        receiver_storage: &MVar<DuplexStream>,
2785    ) -> (FrameReader<ReadHalf<DuplexStream>>, WriteHalf<DuplexStream>) {
2786        let mut receiver = receiver_storage.take().await;
2787        // Read and discard the LinkInit header that MockLink::connect() writes.
2788        let _session_id = read_link_init(&mut receiver).await.expect("read LinkInit");
2789        let (r, writer) = tokio::io::split(receiver);
2790        let reader = FrameReader::new(
2791            r,
2792            hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2793        );
2794        (reader, writer)
2795    }
2796
2797    async fn verify_message<M: RemoteMessage + PartialEq + std::fmt::Debug>(
2798        reader: &mut FrameReader<ReadHalf<DuplexStream>>,
2799        expect: (u64, M),
2800        loc: u32,
2801    ) {
2802        let expected = Frame::Message(expect.0, expect.1);
2803        let (_, bytes) = reader.next().await.unwrap().expect("unexpected EOF");
2804        let message = serde_multipart::Message::from_framed(bytes).unwrap();
2805        let frame: Frame<M> = serde_multipart::deserialize_bincode(message).unwrap();
2806
2807        assert_eq!(frame, expected, "from ln={loc}");
2808    }
2809
2810    async fn verify_stream<M: RemoteMessage + PartialEq + std::fmt::Debug + Clone>(
2811        reader: &mut FrameReader<ReadHalf<DuplexStream>>,
2812        expects: &[(u64, M)],
2813        _expect_session_id: Option<u64>,
2814        loc: u32,
2815    ) {
2816        for expect in expects {
2817            verify_message(reader, expect.clone(), loc).await;
2818        }
2819    }
2820
2821    async fn net_tx_send(tx: &NetTx<u64>, msgs: &[u64]) {
2822        for msg in msgs {
2823            tx.post(*msg);
2824        }
2825    }
2826
2827    // Happy path: all messages are acked.
2828    #[async_timed_test(timeout_secs = 30)]
2829    async fn test_ack_in_net_tx_basic() {
2830        let link = MockLink::<u64>::new();
2831        let receiver_storage = link.receiver_storage();
2832        let tx = spawn::<u64>(link);
2833
2834        // Send some messages, but not acking any of them.
2835        net_tx_send(&tx, &[100, 101, 102, 103, 104]).await;
2836        {
2837            let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2838            verify_stream(
2839                &mut reader,
2840                &[
2841                    (0u64, 100u64),
2842                    (1u64, 101u64),
2843                    (2u64, 102u64),
2844                    (3u64, 103u64),
2845                    (4u64, 104u64),
2846                ],
2847                None,
2848                line!(),
2849            )
2850            .await;
2851
2852            for i in 0u64..5u64 {
2853                writer = FrameWrite::write_frame(
2854                    writer,
2855                    serialize_response(NetRxResponse::Ack(i)).unwrap(),
2856                    1024,
2857                    0,
2858                )
2859                .await
2860                .map_err(|(_, e)| e)
2861                .unwrap();
2862            }
2863            // Wait for the acks to be processed by NetTx.
2864            tokio::time::sleep(Duration::from_secs(3)).await;
2865            // Drop both halves to break the in-memory connection (parity with old drop of DuplexStream).
2866            drop(reader);
2867            drop(writer);
2868        };
2869
2870        // Sent a new message to verify all sent messages will not be resent.
2871        net_tx_send(&tx, &[105u64]).await;
2872        {
2873            let (mut reader, _writer) = take_receiver(&receiver_storage).await;
2874            verify_stream(&mut reader, &[(5u64, 105u64)], None, line!()).await;
2875            // Reader/writer dropped here. This breaks the connection.
2876        };
2877    }
2878
2879    // Verify unacked message will be resent after reconnection.
2880    #[async_timed_test(timeout_secs = 60)]
2881    async fn test_persistent_net_tx() {
2882        let link = MockLink::<u64>::new();
2883        let receiver_storage = link.receiver_storage();
2884
2885        let tx = spawn::<u64>(link);
2886
2887        // Send some messages, but not acking any of them.
2888        net_tx_send(&tx, &[100, 101, 102, 103, 104]).await;
2889
2890        // How many times to reconnect. Keep this small because the send loop
2891        // applies exponential backoff between reconnections, and mock connections
2892        // are too short-lived to trigger the backoff reset.
2893        let n = 3;
2894
2895        // Reconnect multiple times. The messages should be resent every time
2896        // because none of them is acked.
2897        for i in 0..n {
2898            {
2899                let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2900                verify_stream(
2901                    &mut reader,
2902                    &[
2903                        (0u64, 100u64),
2904                        (1u64, 101u64),
2905                        (2u64, 102u64),
2906                        (3u64, 103u64),
2907                        (4u64, 104u64),
2908                    ],
2909                    None,
2910                    line!(),
2911                )
2912                .await;
2913
2914                // In the last iteration, ack part of the messages. This should
2915                // prune them from future resent.
2916                if i == n - 1 {
2917                    writer = FrameWrite::write_frame(
2918                        writer,
2919                        serialize_response(NetRxResponse::Ack(1)).unwrap(),
2920                        1024,
2921                        0,
2922                    )
2923                    .await
2924                    .map_err(|(_, e)| e)
2925                    .unwrap();
2926                    // Wait for the acks to be processed by NetTx.
2927                    tokio::time::sleep(Duration::from_secs(3)).await;
2928                }
2929                // client DuplexStream is dropped here. This breaks the connection.
2930                drop(reader);
2931                drop(writer);
2932            };
2933        }
2934
2935        // Verify only unacked are resent.
2936        for _ in 0..n {
2937            {
2938                let (mut reader, mut _writer) = take_receiver(&receiver_storage).await;
2939                verify_stream(
2940                    &mut reader,
2941                    &[(2u64, 102u64), (3u64, 103u64), (4u64, 104u64)],
2942                    None,
2943                    line!(),
2944                )
2945                .await;
2946                // drop(reader/_writer) at scope end
2947            };
2948        }
2949
2950        // Now send more messages.
2951        net_tx_send(&tx, &[105u64, 106u64, 107u64, 108u64, 109u64]).await;
2952        // Verify the unacked messages from the 1st send will be grouped with
2953        // the 2nd send.
2954        for i in 0..n {
2955            {
2956                let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2957                verify_stream(
2958                    &mut reader,
2959                    &[
2960                        // From the 1st send.
2961                        (2u64, 102u64),
2962                        (3u64, 103u64),
2963                        (4u64, 104u64),
2964                        // From the 2nd send.
2965                        (5u64, 105u64),
2966                        (6u64, 106u64),
2967                        (7u64, 107u64),
2968                        (8u64, 108u64),
2969                        (9u64, 109u64),
2970                    ],
2971                    None,
2972                    line!(),
2973                )
2974                .await;
2975
2976                // In the last iteration, ack part of the messages from the 1st
2977                // sent.
2978                if i == n - 1 {
2979                    // Intentionally ack 1 again to verify it is okay to ack
2980                    // messages that was already acked.
2981                    writer = FrameWrite::write_frame(
2982                        writer,
2983                        serialize_response(NetRxResponse::Ack(1)).unwrap(),
2984                        1024,
2985                        0,
2986                    )
2987                    .await
2988                    .map_err(|(_, e)| e)
2989                    .unwrap();
2990                    writer = FrameWrite::write_frame(
2991                        writer,
2992                        serialize_response(NetRxResponse::Ack(2)).unwrap(),
2993                        1024,
2994                        0,
2995                    )
2996                    .await
2997                    .map_err(|(_, e)| e)
2998                    .unwrap();
2999                    writer = FrameWrite::write_frame(
3000                        writer,
3001                        serialize_response(NetRxResponse::Ack(3)).unwrap(),
3002                        1024,
3003                        0,
3004                    )
3005                    .await
3006                    .map_err(|(_, e)| e)
3007                    .unwrap();
3008                    // Wait for the acks to be processed by NetTx.
3009                    tokio::time::sleep(Duration::from_secs(3)).await;
3010                }
3011                // client DuplexStream is dropped here. This breaks the connection.
3012                drop(reader);
3013                drop(writer);
3014            };
3015        }
3016
3017        for i in 0..n {
3018            {
3019                let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
3020                verify_stream(
3021                    &mut reader,
3022                    &[
3023                        // From the 1st send.
3024                        (4u64, 104),
3025                        // From the 2nd send.
3026                        (5u64, 105u64),
3027                        (6u64, 106u64),
3028                        (7u64, 107u64),
3029                        (8u64, 108u64),
3030                        (9u64, 109u64),
3031                    ],
3032                    None,
3033                    line!(),
3034                )
3035                .await;
3036
3037                // In the last iteration, ack part of the messages from the 2nd send.
3038                if i == n - 1 {
3039                    writer = FrameWrite::write_frame(
3040                        writer,
3041                        serialize_response(NetRxResponse::Ack(7)).unwrap(),
3042                        1024,
3043                        0,
3044                    )
3045                    .await
3046                    .map_err(|(_, e)| e)
3047                    .unwrap();
3048                    // Wait for the acks to be processed by NetTx.
3049                    tokio::time::sleep(Duration::from_secs(3)).await;
3050                }
3051                // client DuplexStream is dropped here. This breaks the connection.
3052                drop(reader);
3053                drop(writer);
3054            };
3055        }
3056
3057        for _ in 0..n {
3058            {
3059                let (mut reader, writer) = take_receiver(&receiver_storage).await;
3060                verify_stream(
3061                    &mut reader,
3062                    &[
3063                        // From the 2nd send.
3064                        (8u64, 108u64),
3065                        (9u64, 109u64),
3066                    ],
3067                    None,
3068                    line!(),
3069                )
3070                .await;
3071                // client DuplexStream is dropped here. This breaks the connection.
3072                drop(reader);
3073                drop(writer);
3074            };
3075        }
3076    }
3077
3078    #[async_timed_test(timeout_secs = 15)]
3079    async fn test_ack_before_redelivery_in_net_tx() {
3080        let link = MockLink::<u64>::new();
3081        let receiver_storage = link.receiver_storage();
3082        let net_tx = spawn::<u64>(link);
3083
3084        // Verify sent-and-ack a message. This is necessary for the test to
3085        // trigger a connection.
3086        let (return_channel_tx, return_channel_rx) = oneshot::channel();
3087        net_tx.try_post(100, return_channel_tx);
3088        let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
3089        verify_stream(&mut reader, &[(0u64, 100u64)], None, line!()).await;
3090        // ack it
3091        writer = FrameWrite::write_frame(
3092            writer,
3093            serialize_response(NetRxResponse::Ack(0)).unwrap(),
3094            1024,
3095            0,
3096        )
3097        .await
3098        .map_err(|(_, e)| e)
3099        .unwrap();
3100        // confirm Tx received ack
3101        //
3102        // Using `is_err` to confirm the message is delivered/acked is confusing,
3103        // but is correct. See how send is implemented: https://fburl.com/code/ywt8lip2
3104        assert!(return_channel_rx.await.is_err());
3105
3106        // Now fake an unknown delivery for Tx:
3107        // Although Tx did not actually send seq=1, we still ack it from Rx to
3108        // pretend Tx already sent it, just it did not know it was sent
3109        // successfully.
3110        let _ = FrameWrite::write_frame(
3111            writer,
3112            serialize_response(NetRxResponse::Ack(1)).unwrap(),
3113            1024,
3114            0,
3115        )
3116        .await
3117        .map_err(|(_, e)| e)
3118        .unwrap();
3119
3120        let (return_channel_tx, return_channel_rx) = oneshot::channel();
3121        net_tx.try_post(101, return_channel_tx);
3122        // Verify the message is sent to Rx.
3123        verify_message(&mut reader, (1u64, 101u64), line!()).await;
3124        // although we did not ack the message after it is sent, since we already
3125        // acked it previously, Tx will treat it as acked, and considered the
3126        // message delivered successfully.
3127        //
3128        // Using `is_err` to confirm the message is delivered/acked is confusing,
3129        // but is correct. See how send is implemented: https://fburl.com/code/ywt8lip2
3130        assert!(return_channel_rx.await.is_err());
3131    }
3132
3133    async fn verify_ack_exceeded_limit(disconnect_before_ack: bool) {
3134        // Use temporary config for this test
3135        let config = hyperactor_config::global::lock();
3136        let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(2));
3137
3138        let link: MockLink<u64> = MockLink::<u64>::new();
3139        let disconnect_signal = link.disconnect_signal().clone();
3140        let fail_connect_switch = link.fail_connects_switch();
3141        let receiver_storage = link.receiver_storage();
3142        let tx = spawn::<u64>(link);
3143        let mut tx_status = tx.status().clone();
3144        // send a message
3145        tx.post(100);
3146        let (mut reader, writer) = take_receiver(&receiver_storage).await;
3147        // Confirm message is sent to rx.
3148        verify_stream(&mut reader, &[(0u64, 100u64)], None, line!()).await;
3149        // ack it
3150        let _ = FrameWrite::write_frame(
3151            writer,
3152            serialize_response(NetRxResponse::Ack(0)).unwrap(),
3153            hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3154            0,
3155        )
3156        .await
3157        .map_err(|(_, e)| e)
3158        .unwrap();
3159        tokio::time::sleep(Duration::from_secs(3)).await;
3160        // Channel should be still alive because ack was sent.
3161        assert!(!tx_status.has_changed().unwrap());
3162        assert_eq!(*tx_status.borrow(), TxStatus::Active);
3163
3164        tx.post(101);
3165        // Confirm message is sent to rx.
3166        verify_message(&mut reader, (1u64, 101u64), line!()).await;
3167
3168        if disconnect_before_ack {
3169            // Prevent link from reconnect
3170            fail_connect_switch.store(true, Ordering::Release);
3171            // Break the existing connection
3172            disconnect_signal.send(()).unwrap();
3173        }
3174
3175        // Verify the channel is closed due to ack timeout based on the log.
3176        let expected_log: &str = if disconnect_before_ack {
3177            "failed to receive ack within timeout 2s; link is currently broken"
3178        } else {
3179            "failed to receive ack within timeout 2s; link is currently connected"
3180        };
3181
3182        verify_tx_closed(&mut tx_status, expected_log).await;
3183    }
3184
3185    #[tracing_test::traced_test]
3186    #[async_timed_test(timeout_secs = 30)]
3187    // TODO: OSS: The logs_assert function returned an error: expected log not found
3188    #[cfg_attr(not(fbcode_build), ignore)]
3189    async fn test_ack_exceeded_limit_with_connected_link() {
3190        verify_ack_exceeded_limit(false).await;
3191    }
3192
3193    #[tracing_test::traced_test]
3194    #[async_timed_test(timeout_secs = 30)]
3195    // TODO: OSS: The logs_assert function returned an error: expected log not found
3196    #[cfg_attr(not(fbcode_build), ignore)]
3197    async fn test_ack_exceeded_limit_with_broken_link() {
3198        verify_ack_exceeded_limit(true).await;
3199    }
3200
3201    // Verify a large number of messages can be delivered and acked with the
3202    // presence of flakiness in the network, i.e. random delay and disconnection.
3203    #[async_timed_test(timeout_secs = 60)]
3204    async fn test_network_flakiness_in_channel() {
3205        hyperactor_telemetry::initialize_logging_for_test();
3206
3207        let sampling_rate = 100;
3208        let mut link = MockLink::<u64>::with_network_flakiness(NetworkFlakiness {
3209            disconnect_params: Some((0.001, 15, Duration::from_millis(400))),
3210            latency_range: Some((Duration::from_millis(100), Duration::from_millis(200))),
3211        });
3212        link.set_sampling_rate(sampling_rate);
3213        // Set a large buffer size to improve throughput.
3214        link.set_buffer_size(1024000);
3215        let disconnected_count = link.disconnected_count();
3216        let receiver_storage = link.receiver_storage();
3217        let listener = MockLinkListener::new(receiver_storage.clone(), link.dest());
3218        let local_addr = listener.channel_addr.clone();
3219        let (_, mut nx): (ChannelAddr, NetRx<u64>) =
3220            super::server::serve_with_listener(listener, local_addr).unwrap();
3221        let tx = spawn::<u64>(link);
3222        let messages: Vec<_> = (0..10001).collect();
3223        let messages_clone = messages.clone();
3224        // Put the sender side in a separate task so we can start the receiver
3225        // side concurrently.
3226        let send_task_handle = tokio::spawn(async move {
3227            for message in messages_clone {
3228                // Add a small delay between messages to give NetRx time to ack.
3229                // Technically, this test still can pass without this delay. But
3230                // the test will need a might larger timeout. The reason is
3231                // fairly convoluted:
3232                //
3233                // MockLink uses the number of delivery to calculate the disconnection
3234                // probability. If NetRx sends messages much faster than NetTx
3235                // can ack them, there is a higher chance that the messages are
3236                // not acked before reconnect. Then those message would be redelivered.
3237                // The repeated redelivery increases the total time of sending
3238                // these messages.
3239                tokio::time::sleep(Duration::from_micros(rand::random::<u64>() % 100)).await;
3240                tx.post(message);
3241            }
3242            tracing::debug!("NetTx sent all messages");
3243            // It is important to return tx instead of dropping it here, because
3244            // Rx might not receive all messages yet.
3245            tx
3246        });
3247
3248        for message in &messages {
3249            if message % sampling_rate == 0 {
3250                tracing::debug!("NetRx received a message: {message}");
3251            }
3252            assert_eq!(nx.recv().await.unwrap(), *message);
3253        }
3254        tracing::debug!("NetRx received all messages");
3255
3256        let send_result = send_task_handle.await;
3257        assert!(send_result.is_ok());
3258
3259        tracing::debug!(
3260            "MockLink disconnected {} times.",
3261            disconnected_count.load(Ordering::SeqCst)
3262        );
3263        // TODO(pzhang) after the return_handle work in NetTx is done, add a
3264        // check here to verify the messages are acked correctly.
3265    }
3266
3267    #[async_timed_test(timeout_secs = 60)]
3268    async fn test_ack_every_n_messages() {
3269        let config = hyperactor_config::global::lock();
3270        let _guard_message_ack = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 600);
3271        let _guard_time_interval =
3272            config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(1000));
3273        sparse_ack().await;
3274    }
3275
3276    #[async_timed_test(timeout_secs = 60)]
3277    async fn test_ack_every_time_interval() {
3278        let config = hyperactor_config::global::lock();
3279        let _guard_message_ack =
3280            config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 100000000);
3281        let _guard_time_interval = config.override_key(
3282            config::MESSAGE_ACK_TIME_INTERVAL,
3283            Duration::from_millis(500),
3284        );
3285        sparse_ack().await;
3286    }
3287
3288    async fn sparse_ack() {
3289        let mut link = MockLink::<u64>::new();
3290        // Set a large buffer size to improve throughput.
3291        link.set_buffer_size(1024000);
3292        let disconnected_count = link.disconnected_count();
3293        let receiver_storage = link.receiver_storage();
3294        let listener = MockLinkListener::new(receiver_storage.clone(), link.dest());
3295        let local_addr = listener.channel_addr.clone();
3296        let (_, mut nx): (ChannelAddr, NetRx<u64>) =
3297            super::server::serve_with_listener(listener, local_addr).unwrap();
3298        let tx = spawn::<u64>(link);
3299        let messages: Vec<_> = (0..20001).collect();
3300        let messages_clone = messages.clone();
3301        // Put the sender side in a separate task so we can start the receiver
3302        // side concurrently.
3303        let send_task_handle = tokio::spawn(async move {
3304            for message in messages_clone {
3305                tokio::time::sleep(Duration::from_micros(rand::random::<u64>() % 100)).await;
3306                tx.post(message);
3307            }
3308            tokio::time::sleep(Duration::from_secs(5)).await;
3309            tracing::debug!("NetTx sent all messages");
3310            tx
3311        });
3312
3313        for message in &messages {
3314            assert_eq!(nx.recv().await.unwrap(), *message);
3315        }
3316        tracing::debug!("NetRx received all messages");
3317
3318        let send_result = send_task_handle.await;
3319        assert!(send_result.is_ok());
3320
3321        tracing::debug!(
3322            "MockLink disconnected {} times.",
3323            disconnected_count.load(Ordering::SeqCst)
3324        );
3325    }
3326
3327    #[test]
3328    fn test_metatls_parsing() {
3329        // host:port
3330        let channel: ChannelAddr = "metatls!localhost:1234".parse().unwrap();
3331        assert_eq!(
3332            channel,
3333            ChannelAddr::MetaTls(TlsAddr::new("localhost", 1234))
3334        );
3335        // ipv4:port - parsed as hostname with ip normalization
3336        let channel: ChannelAddr = "metatls!1.2.3.4:1234".parse().unwrap();
3337        assert_eq!(channel, ChannelAddr::MetaTls(TlsAddr::new("1.2.3.4", 1234)));
3338        // ipv6:port
3339        let channel: ChannelAddr = "metatls!2401:db00:33c:6902:face:0:2a2:0:1234"
3340            .parse()
3341            .unwrap();
3342        assert_eq!(
3343            channel,
3344            ChannelAddr::MetaTls(TlsAddr::new("2401:db00:33c:6902:face:0:2a2:0", 1234))
3345        );
3346
3347        let channel: ChannelAddr = "metatls![::]:1234".parse().unwrap();
3348        assert_eq!(channel, ChannelAddr::MetaTls(TlsAddr::new("::", 1234)));
3349    }
3350
3351    #[async_timed_test(timeout_secs = 300)]
3352    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
3353    #[cfg_attr(not(fbcode_build), ignore)]
3354    async fn test_tcp_throughput() {
3355        let config = hyperactor_config::global::lock();
3356        let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_mins(5));
3357
3358        let socket_addr: SocketAddr = "[::1]:0".parse().unwrap();
3359        let (local_addr, mut rx) = server::serve::<String>(ChannelAddr::Tcp(socket_addr)).unwrap();
3360
3361        // Test with 10 connections (senders), each sends 500K messages, 5M messages in total.
3362        let total_num_msgs = 500000;
3363
3364        let receive_handle = tokio::spawn(async move {
3365            let mut num = 0;
3366            for _ in 0..10 * total_num_msgs {
3367                rx.recv().await.unwrap();
3368                num += 1;
3369
3370                if num % 100000 == 0 {
3371                    tracing::info!("total number of received messages: {}", num);
3372                }
3373            }
3374        });
3375
3376        let mut tx_handles = vec![];
3377        let mut txs = vec![];
3378        for _ in 0..10 {
3379            let server_addr = local_addr.clone();
3380            let tx = Arc::new(channel::dial::<String>(server_addr).unwrap());
3381            let tx2 = Arc::clone(&tx);
3382            txs.push(tx);
3383            tx_handles.push(tokio::spawn(async move {
3384                let random_string = rand::rng()
3385                    .sample_iter(&Alphanumeric)
3386                    .take(2048)
3387                    .map(char::from)
3388                    .collect::<String>();
3389                for _ in 0..total_num_msgs {
3390                    tx2.post(random_string.clone());
3391                }
3392            }));
3393        }
3394
3395        receive_handle.await.unwrap();
3396        for handle in tx_handles {
3397            handle.await.unwrap();
3398        }
3399    }
3400
3401    #[tracing_test::traced_test]
3402    #[async_timed_test(timeout_secs = 60)]
3403    // TODO: OSS: The logs_assert function returned an error: expected log not found
3404    #[cfg_attr(not(fbcode_build), ignore)]
3405    async fn test_net_tx_closed_on_server_reject() {
3406        let link = MockLink::<u64>::new();
3407        let receiver_storage = link.receiver_storage();
3408        let mut tx = spawn::<u64>(link);
3409        net_tx_send(&tx, &[100]).await;
3410
3411        {
3412            let (_reader, writer) = take_receiver(&receiver_storage).await;
3413            let _ = FrameWrite::write_frame(
3414                writer,
3415                serialize_response(NetRxResponse::Reject("testing".to_string())).unwrap(),
3416                1024,
3417                0,
3418            )
3419            .await
3420            .map_err(|(_, e)| e);
3421
3422            // Wait for response to be processed by NetTx before dropping reader/writer. Otherwise
3423            // the channel will be closed and we will get the wrong error.
3424            tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
3425        }
3426
3427        verify_tx_closed(&mut tx.status, "server rejected connection").await;
3428    }
3429
3430    #[async_timed_test(timeout_secs = 60)]
3431    async fn test_server_rejects_conn_on_out_of_sequence_message() {
3432        let config = hyperactor_config::global::lock();
3433        let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
3434        let session_id = SessionId(123);
3435        let (_handle, mvar, mut rx, _cancel_token) = serve_acceptor_test::<u64>(session_id);
3436
3437        let (sender, receiver) = tokio::io::duplex(5000);
3438        mvar.put(receiver).await;
3439        let (r, writer) = tokio::io::split(sender);
3440        let mut reader = FrameReader::new(
3441            r,
3442            hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3443        );
3444
3445        let _ = write_stream(writer, 123, &[(0, 100u64), (1, 101u64), (3, 103u64)], true).await;
3446        assert_eq!(rx.recv().await, Some(100u64));
3447        assert_eq!(rx.recv().await, Some(101u64));
3448        let (_, bytes) = reader.next().await.unwrap().unwrap();
3449        let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
3450        assert_eq!(acked, 0);
3451        let (_, bytes) = reader.next().await.unwrap().unwrap();
3452        let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
3453        assert_eq!(acked, 1);
3454        let (_, bytes) = reader.next().await.unwrap().unwrap();
3455        assert!(deserialize_response(bytes).unwrap().is_reject());
3456    }
3457
3458    #[async_timed_test(timeout_secs = 60)]
3459    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
3460    #[cfg_attr(not(fbcode_build), ignore)]
3461    async fn test_stop_net_tx_after_stopping_net_rx() {
3462        hyperactor_telemetry::initialize_logging_for_test();
3463
3464        let config = hyperactor_config::global::lock();
3465        let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_mins(5));
3466        let (addr, mut rx) =
3467            server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
3468        let socket_addr = match addr {
3469            ChannelAddr::Tcp(a) => a,
3470            _ => panic!("unexpected channel type"),
3471        };
3472        let tx: NetTx<u64> = spawn(tcp::link(socket_addr));
3473        // NetTx will not establish a connection until it sends the 1st message.
3474        // Without a live connection, NetTx cannot received the Closed message
3475        // from NetRx. Therefore, we need to send a message to establish the
3476        //connection.
3477        tx.send(100).await.unwrap();
3478        assert_eq!(rx.recv().await.unwrap(), 100);
3479        // Drop rx will close the NetRx server.
3480        rx.2.stop("testing");
3481        assert!(rx.recv().await.is_err());
3482
3483        // NetTx will only read from the stream when it needs to send a message
3484        // or wait for an ack. Therefore we need to send a message to trigger that.
3485        tx.post(101);
3486        let mut watcher = tx.status().clone();
3487        // When NetRx exits, it should notify NetTx to exit as well.
3488        let _ = watcher.wait_for(|val| val.is_closed()).await;
3489        // wait_for could return Err due to race between when watch's sender was
3490        // dropped and when wait_for was called. So we still need to do an
3491        // equality check.
3492        assert!(watcher.borrow().is_closed());
3493    }
3494}