hyperactor/channel/
net.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! A simple socket channel implementation using a single-stream
10//! framing protocol. Each frame is encoded as an 8-byte
11//! **big-endian** length prefix (u64), followed by exactly that many
12//! bytes of payload.
13//!
14//! Message frames carry a `serde_multipart::Message` (not raw
15//! bincode). In compat mode (current default), this is encoded as a
16//! sentinel `u64::MAX` followed by a single bincode payload. Response frames
17//! are a bincode-serialized NetRxResponse enum, containing either the acked
18//! sequence number, or the Reject value indicating that the server rejected
19//! the connection.
20//!
21//! Message frame (compat/unipart) example:
22//! ```text
23//! +------------------ len: u64 (BE) ------------------+----------------------- data -----------------------+
24//! | \x00\x00\x00\x00\x00\x00\x00\x10                  | \xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF | <bincode bytes> |
25//! |                       16                          |           u64::MAX             |                   |
26//! +---------------------------------------------------+-----------------------------------------------------+
27//! ```
28//!
29//! Response frame (wire format):
30//! ```text
31//! +------------------ len: u64 (BE) ------------------+---------------- data ------------------+
32//! | \x00\x00\x00\x00\x00\x00\x00\x??                  | <bincode acked sequence num or reject> |
33//! +---------------------------------------------------+----------------------------------------+
34//! ```
35//!
36//! I/O is handled by `FrameReader`/`FrameWrite`, which are
37//! cancellation-safe and avoid extra copies. Helper fns
38//! `serialize_response(NetRxResponse) -> Result<Bytes, bincode::Error>`
39//! and `deserialize_response(Bytes) -> Result<NetRxResponse, bincode::Error>`
40//! convert to/from the response payload.
41//!
42//! ### Limits & EOF semantics
43//! * **Max frame size:** frames larger than
44//!   `config::CODEC_MAX_FRAME_LENGTH` are rejected with
45//!   `io::ErrorKind::InvalidData`.
46//! * **EOF handling:** `FrameReader::next()` returns `Ok(None)` only
47//!   when EOF occurs exactly on a frame boundary. If EOF happens
48//!   mid-frame, it returns `Err(io::ErrorKind::UnexpectedEof)`.
49
50use std::fmt;
51use std::fmt::Debug;
52use std::net::ToSocketAddrs;
53use std::time::Duration;
54
55use backoff::ExponentialBackoffBuilder;
56use backoff::backoff::Backoff;
57use bytes::Bytes;
58use enum_as_inner::EnumAsInner;
59use serde::Deserialize;
60use serde::Serialize;
61use serde::de::Error;
62use tokio::io::AsyncRead;
63use tokio::io::AsyncReadExt;
64use tokio::io::AsyncWrite;
65use tokio::io::AsyncWriteExt;
66use tokio::sync::watch;
67use tokio::time::Instant;
68
69use super::*;
70use crate::RemoteMessage;
71
72pub mod duplex;
73mod framed;
74pub(super) mod server;
75pub(super) mod session;
76pub use server::ServerHandle;
77
78pub(crate) trait Stream:
79    AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug + 'static
80{
81}
82impl<S: AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug + 'static> Stream for S {}
83
84/// Opaque identifier for a session. Generated by the client,
85/// sent to the server on each connect so the server can correlate
86/// reconnections to the same logical session.
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
88pub(crate) struct SessionId(u64);
89
90impl SessionId {
91    /// Generate a new random session ID.
92    pub fn random() -> Self {
93        Self(rand::random())
94    }
95}
96
97impl fmt::Display for SessionId {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        write!(f, "{:016x}", self.0)
100    }
101}
102
103/// Logical channel tag for initiator→acceptor traffic.
104pub(crate) const INITIATOR_TO_ACCEPTOR: u8 = 0;
105
106/// Logical channel tag for acceptor→initiator traffic.
107pub(crate) const ACCEPTOR_TO_INITIATOR: u8 = 1;
108
109/// Fixed-size header sent at the start of every physical connection.
110/// This is written/read directly on the wire (not framed), before
111/// any session framing begins.
112///
113/// Wire format (12 bytes, big-endian):
114/// ```text
115/// [magic: 4B "LNK\0"] [session_id: 8B u64 BE]
116/// ```
117const LINK_INIT_MAGIC: [u8; 4] = *b"LNK\0";
118const LINK_INIT_SIZE: usize = 4 + 8;
119
120/// Write a LinkInit header to the stream.
121async fn write_link_init<S: AsyncWrite + Unpin>(
122    stream: &mut S,
123    session_id: SessionId,
124) -> Result<(), std::io::Error> {
125    let mut buf = [0u8; LINK_INIT_SIZE];
126    buf[0..4].copy_from_slice(&LINK_INIT_MAGIC);
127    buf[4..12].copy_from_slice(&session_id.0.to_be_bytes());
128    stream.write_all(&buf).await
129}
130
131/// Read a LinkInit header from the stream.
132async fn read_link_init<S: AsyncRead + Unpin>(stream: &mut S) -> Result<SessionId, std::io::Error> {
133    let mut buf = [0u8; LINK_INIT_SIZE];
134    stream.read_exact(&mut buf).await?;
135    if buf[0..4] != LINK_INIT_MAGIC {
136        return Err(std::io::Error::new(
137            std::io::ErrorKind::InvalidData,
138            format!(
139                "invalid LinkInit magic: expected {:?}, got {:?}",
140                LINK_INIT_MAGIC,
141                &buf[0..4]
142            ),
143        ));
144    }
145    let session_id = SessionId(u64::from_be_bytes(buf[4..12].try_into().unwrap()));
146    Ok(session_id)
147}
148
149/// Link represents a network link through which connections may be
150/// acquired. The session ID is baked in. Initiator links dial;
151/// acceptor links wait for dispatched streams.
152#[async_trait]
153pub(crate) trait Link: Send + Sync + Debug + 'static {
154    /// The underlying stream type.
155    type Stream: Stream;
156
157    /// The address of the link's destination.
158    fn dest(&self) -> ChannelAddr;
159
160    /// The session ID for this link.
161    fn link_id(&self) -> SessionId;
162
163    /// Acquire the next usable connection. For initiator links this
164    /// dials; for acceptor links this waits on a dispatch channel.
165    async fn next(&self) -> Result<Self::Stream, ClientError>;
166}
167
168use session::Session;
169
170use crate::config;
171use crate::metrics;
172
173/// Log a send-loop error and return `true` if the error is terminal
174/// (caller should exit), `false` if recoverable (caller should reconnect).
175fn log_send_error(
176    error: &session::SendLoopError,
177    dest: &ChannelAddr,
178    session_id: u64,
179    mode: &str,
180) -> bool {
181    match error {
182        session::SendLoopError::Io(err) => {
183            tracing::info!(dest = %dest, session_id, error = %err, mode, "send error");
184            metrics::CHANNEL_ERRORS.add(
185                1,
186                hyperactor_telemetry::kv_pairs!(
187                    "dest" => dest.to_string(),
188                    "session_id" => session_id.to_string(),
189                    "error_type" => metrics::ChannelErrorType::SendError.as_str(),
190                    "mode" => mode.to_string(),
191                ),
192            );
193            false
194        }
195        session::SendLoopError::AppClosed => true,
196        session::SendLoopError::Rejected(reason) => {
197            tracing::error!(dest = %dest, session_id, mode, "server rejected connection: {reason}");
198            true
199        }
200        session::SendLoopError::ServerClosed => {
201            tracing::info!(dest = %dest, session_id, mode, "server closed the channel");
202            true
203        }
204        session::SendLoopError::DeliveryTimeout => {
205            let timeout = hyperactor_config::global::get(config::MESSAGE_DELIVERY_TIMEOUT);
206            tracing::error!(
207                dest = %dest, session_id, mode,
208                "failed to receive ack within timeout {timeout:?}; link is currently connected"
209            );
210            true
211        }
212        session::SendLoopError::OversizedFrame(reason) => {
213            tracing::error!(dest = %dest, session_id, mode, "oversized frame: {reason}");
214            true
215        }
216    }
217}
218
219/// Establish a simplex (send-only) session over the given link. Returns a send handle.
220pub(crate) fn spawn<M: RemoteMessage>(link: impl Link) -> NetTx<M> {
221    let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
222    let dest = link.dest();
223    let session_id = link.link_id();
224    let (notify, status) = watch::channel(TxStatus::Active);
225    let tx = NetTx {
226        sender,
227        dest: dest.clone(),
228        status,
229    };
230    crate::init::get_runtime().spawn(async move {
231        let mut session = Session::new(link);
232        let log_id = format!("session {}.{:016x}", dest, session_id.0);
233        let mut deliveries = session::Deliveries {
234            outbox: session::Outbox::new(log_id.clone(), dest.clone(), session_id.0),
235            unacked: session::Unacked::new(None, log_id.clone()),
236        };
237        let mut receiver = receiver;
238
239        // Lazy connect: wait for first message.
240        match receiver.recv().await {
241            Some(msg) => {
242                if let Err(err) = deliveries.outbox.push_back(msg) {
243                    tracing::error!(
244                        dest = %dest,
245                        session_id = session_id.0,
246                        error = %err,
247                        "failed to push message to outbox"
248                    );
249                    let _ = notify.send(TxStatus::Closed);
250                    return;
251                }
252            }
253            None => {
254                let _ = notify.send(TxStatus::Closed);
255                return;
256            }
257        }
258
259        let reason: String = 'outer: loop {
260            let connected = match deliveries.expiry_time() {
261                Some(deadline) => match session.connect_by(deadline).await {
262                    Ok(s) => s,
263                    Err(_) => {
264                        let timeout =
265                            hyperactor_config::global::get(config::MESSAGE_DELIVERY_TIMEOUT);
266                        let error_msg = if deliveries.outbox.is_expired(timeout) {
267                            format!("failed to deliver message within timeout {timeout:?}",)
268                        } else {
269                            format!(
270                                "failed to receive ack within timeout {timeout:?}; \
271                                 link is currently broken",
272                            )
273                        };
274                        tracing::error!(
275                            dest = %dest, session_id = session_id.0, "{}", error_msg
276                        );
277                        break 'outer format!("{log_id}: {error_msg}");
278                    }
279                },
280                None => match session.connect().await {
281                    Ok(s) => s,
282                    Err(_) => break 'outer "session shut down".into(),
283                },
284            };
285
286            metrics::CHANNEL_CONNECTIONS.add(
287                1,
288                hyperactor_telemetry::kv_pairs!(
289                    "transport" => dest.transport().to_string(),
290                    "mode" => "simplex",
291                    "reason" => "link connected",
292                ),
293            );
294
295            if !deliveries.unacked.is_empty() {
296                metrics::CHANNEL_RECONNECTIONS.add(
297                    1,
298                    hyperactor_telemetry::kv_pairs!(
299                        "dest" => dest.to_string(),
300                        "transport" => dest.transport().to_string(),
301                        "mode" => "simplex",
302                        "reason" => "reconnect_with_unacked",
303                    ),
304                );
305            }
306            deliveries.requeue_unacked();
307
308            let result = {
309                let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
310                session::send_connected(&stream, &mut deliveries, &mut receiver).await
311            };
312            session = connected.release();
313
314            match result {
315                Ok(()) => {
316                    // EOF — connection closed normally, reconnect.
317                }
318                Err(ref e) => {
319                    if log_send_error(e, &dest, session_id.0, "simplex") {
320                        break 'outer format!("{log_id}: {e}");
321                    }
322                }
323            }
324        };
325
326        tracing::info!(
327            dest = %dest, session_id = session_id.0, "NetTx closing: {reason}"
328        );
329
330        receiver.close();
331        deliveries
332            .unacked
333            .deque
334            .drain(..)
335            .chain(deliveries.outbox.deque.drain(..))
336            .for_each(|queued| queued.try_return(Some(reason.clone())));
337        while let Ok((msg, return_channel, _)) = receiver.try_recv() {
338            let _ = return_channel.send(SendError {
339                error: ChannelError::Closed,
340                message: msg,
341                reason: Some(reason.clone()),
342            });
343        }
344
345        let _ = notify.send(TxStatus::Closed);
346    });
347    tx
348}
349
350/// Transport-agnostic link that dispatches to the appropriate
351/// transport based on the channel address.
352#[derive(Debug)]
353pub(crate) enum NetLink {
354    Tcp(tcp::TcpLink),
355    Unix(unix::UnixLink),
356    Tls(tls::TlsLink),
357}
358
359/// Create a link for the given channel address.
360pub(crate) fn link(addr: ChannelAddr) -> Result<NetLink, ClientError> {
361    match addr {
362        ChannelAddr::Tcp(socket_addr) => Ok(NetLink::Tcp(tcp::link(socket_addr))),
363        ChannelAddr::Unix(unix_addr) => Ok(NetLink::Unix(unix::link(unix_addr))),
364        ChannelAddr::Tls(tls_addr) => Ok(NetLink::Tls(tls::link(tls_addr)?)),
365        ChannelAddr::MetaTls(meta_addr) => Ok(NetLink::Tls(meta::link(meta_addr)?)),
366        other => Err(ClientError::Connect(
367            other,
368            std::io::Error::other("unsupported transport"),
369            "unsupported transport".into(),
370        )),
371    }
372}
373
374#[async_trait]
375impl Link for NetLink {
376    type Stream = Box<dyn Stream>;
377
378    fn dest(&self) -> ChannelAddr {
379        match self {
380            Self::Tcp(l) => l.dest(),
381            Self::Unix(l) => l.dest(),
382            Self::Tls(l) => l.dest(),
383        }
384    }
385
386    fn link_id(&self) -> SessionId {
387        match self {
388            Self::Tcp(l) => l.link_id(),
389            Self::Unix(l) => l.link_id(),
390            Self::Tls(l) => l.link_id(),
391        }
392    }
393
394    async fn next(&self) -> Result<Box<dyn Stream>, ClientError> {
395        match self {
396            Self::Tcp(l) => Ok(Box::new(l.next().await?)),
397            Self::Unix(l) => Ok(Box::new(l.next().await?)),
398            Self::Tls(l) => Ok(Box::new(l.next().await?)),
399        }
400    }
401}
402
403/// Listener represents the server side of a network link: it accepts inbound connections.
404///
405/// This is the counterpart to [`Link`]. Each transport module (tcp, unix, tls)
406/// provides both a `Link` impl (for dialing) and a `Listener` impl (for accepting).
407#[async_trait]
408pub(crate) trait Listener: Send + Unpin + 'static {
409    /// The underlying stream type produced by accepting a connection.
410    type Stream: Stream;
411
412    /// Accept the next inbound connection, returning the stream and the peer's address.
413    async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError>;
414}
415
416/// Transport-agnostic listener that dispatches to the appropriate
417/// transport based on the channel address. TLS has no variant — it
418/// uses `Tcp` under the hood (the TLS handshake happens in `prepare`,
419/// not the listener).
420#[derive(Debug)]
421pub(crate) enum NetListener {
422    Tcp(tcp::TcpSocketListener),
423    Unix(unix::UnixSocketListener),
424}
425
426#[async_trait]
427impl Listener for NetListener {
428    type Stream = Box<dyn Stream>;
429
430    async fn accept(&mut self) -> Result<(Box<dyn Stream>, ChannelAddr), ServerError> {
431        match self {
432            Self::Tcp(l) => {
433                let (stream, addr) = l.accept().await?;
434                Ok((Box::new(stream), addr))
435            }
436            Self::Unix(l) => {
437                let (stream, addr) = l.accept().await?;
438                Ok((Box::new(stream), addr))
439            }
440        }
441    }
442}
443
444/// Bind a listener for the given channel address. Returns the listener
445/// and the canonical address callers should advertise (which encodes
446/// the transport — e.g. `ChannelAddr::Tls` for TLS).
447pub(crate) fn listen(addr: ChannelAddr) -> Result<(NetListener, ChannelAddr), ServerError> {
448    match addr {
449        ChannelAddr::Tcp(socket_addr) => {
450            let std_listener = std::net::TcpListener::bind(socket_addr)
451                .map_err(|err| ServerError::Listen(ChannelAddr::Tcp(socket_addr), err))?;
452            std_listener
453                .set_nonblocking(true)
454                .map_err(|e| ServerError::Listen(ChannelAddr::Tcp(socket_addr), e))?;
455            let tokio_listener = tokio::net::TcpListener::from_std(std_listener)
456                .map_err(|e| ServerError::Listen(ChannelAddr::Tcp(socket_addr), e))?;
457            let local_addr = tokio_listener
458                .local_addr()
459                .map_err(|err| ServerError::Resolve(ChannelAddr::Tcp(socket_addr), err))?;
460            let listener = tcp::TcpSocketListener {
461                inner: tokio_listener,
462                addr: local_addr,
463            };
464            Ok((NetListener::Tcp(listener), ChannelAddr::Tcp(local_addr)))
465        }
466        ChannelAddr::Unix(ref unix_addr) => {
467            use std::os::unix::net::UnixDatagram as StdUnixDatagram;
468            use std::os::unix::net::UnixListener as StdUnixListener;
469
470            let caddr = addr.clone();
471            let maybe_listener = match unix_addr {
472                unix::SocketAddr::Bound(sock_addr) => StdUnixListener::bind_addr(sock_addr),
473                unix::SocketAddr::Unbound => StdUnixDatagram::unbound()
474                    .and_then(|u| u.local_addr())
475                    .and_then(|uaddr| StdUnixListener::bind_addr(&uaddr)),
476            };
477            let std_listener =
478                maybe_listener.map_err(|err| ServerError::Listen(caddr.clone(), err))?;
479            std_listener
480                .set_nonblocking(true)
481                .map_err(|err| ServerError::Listen(caddr.clone(), err))?;
482            let local_addr = std_listener
483                .local_addr()
484                .map_err(|err| ServerError::Resolve(caddr.clone(), err))?;
485            let tokio_listener = tokio::net::UnixListener::from_std(std_listener)
486                .map_err(|err| ServerError::Io(caddr, err))?;
487            let bound_addr = unix::SocketAddr::new(local_addr);
488            let listener = unix::UnixSocketListener {
489                inner: tokio_listener,
490                addr: bound_addr.clone(),
491            };
492            Ok((NetListener::Unix(listener), ChannelAddr::Unix(bound_addr)))
493        }
494        addr @ (ChannelAddr::Tls(_) | ChannelAddr::MetaTls(_)) => {
495            let is_meta = matches!(addr, ChannelAddr::MetaTls(_));
496            let tls_addr = match addr {
497                ChannelAddr::Tls(a) | ChannelAddr::MetaTls(a) => a,
498                _ => unreachable!(),
499            };
500            let TlsAddr { hostname, port } = tls_addr;
501            let make_channel_addr = |h: &str, p: Port| {
502                if is_meta {
503                    ChannelAddr::MetaTls(TlsAddr::new(h, p))
504                } else {
505                    ChannelAddr::Tls(TlsAddr::new(h, p))
506                }
507            };
508
509            let addrs: Vec<core::net::SocketAddr> = (hostname.as_ref(), port)
510                .to_socket_addrs()
511                .map_err(|err| ServerError::Resolve(make_channel_addr(&hostname, port), err))?
512                .collect();
513
514            if addrs.is_empty() {
515                return Err(ServerError::Resolve(
516                    make_channel_addr(&hostname, port),
517                    std::io::Error::other("no available socket addr"),
518                ));
519            }
520
521            let channel_addr = make_channel_addr(&hostname, port);
522            let std_listener = std::net::TcpListener::bind(&addrs[..])
523                .map_err(|err| ServerError::Listen(channel_addr.clone(), err))?;
524            std_listener
525                .set_nonblocking(true)
526                .map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
527            let tokio_listener = tokio::net::TcpListener::from_std(std_listener)
528                .map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
529            let local_addr = tokio_listener
530                .local_addr()
531                .map_err(|err| ServerError::Resolve(channel_addr, err))?;
532            let listener = tcp::TcpSocketListener {
533                inner: tokio_listener,
534                addr: local_addr,
535            };
536            Ok((
537                NetListener::Tcp(listener),
538                make_channel_addr(&hostname, local_addr.port()),
539            ))
540        }
541        other => Err(ServerError::Listen(
542            other.clone(),
543            std::io::Error::other(format!("unsupported transport: {}", other)),
544        )),
545    }
546}
547
548/// Frames are the messages sent between clients and servers over sessions.
549#[derive(Debug, Serialize, Deserialize, EnumAsInner, PartialEq)]
550pub(super) enum Frame<M> {
551    /// Send a message with the provided sequence number.
552    Message(u64, M),
553}
554
555#[derive(Debug, Serialize, Deserialize, EnumAsInner)]
556pub(super) enum NetRxResponse {
557    Ack(u64),
558    /// This session is rejected with the given reason. NetTx should stop reconnecting.
559    Reject(String),
560    /// This channel is closed.
561    Closed,
562}
563
564pub(super) fn serialize_response(response: NetRxResponse) -> Result<Bytes, bincode::Error> {
565    bincode::serialize(&response).map(|bytes| bytes.into())
566}
567
568pub(super) fn deserialize_response(data: Bytes) -> Result<NetRxResponse, bincode::Error> {
569    bincode::deserialize(&data)
570}
571
572/// A Tx implemented on top of a Link. The Tx manages the link state,
573/// reconnections, etc.
574pub(crate) struct NetTx<M: RemoteMessage> {
575    sender: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
576    dest: ChannelAddr,
577    status: watch::Receiver<TxStatus>,
578}
579
580#[async_trait]
581impl<M: RemoteMessage> Tx<M> for NetTx<M> {
582    fn addr(&self) -> ChannelAddr {
583        self.dest.clone()
584    }
585
586    fn status(&self) -> &watch::Receiver<TxStatus> {
587        &self.status
588    }
589
590    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
591        tracing::trace!(
592            name = "post",
593            dest = %self.dest,
594            "sending message"
595        );
596
597        let return_channel = return_channel.unwrap_or_else(|| oneshot::channel().0);
598        if let Err(mpsc::error::SendError((message, return_channel, _))) =
599            self.sender
600                .send((message, return_channel, tokio::time::Instant::now()))
601        {
602            let _ = return_channel.send(SendError {
603                error: ChannelError::Closed,
604                message,
605                reason: None,
606            });
607        }
608    }
609}
610
611pub struct NetRx<M: RemoteMessage>(mpsc::Receiver<M>, ChannelAddr, ServerHandle);
612
613#[async_trait]
614impl<M: RemoteMessage> Rx<M> for NetRx<M> {
615    async fn recv(&mut self) -> Result<M, ChannelError> {
616        tracing::trace!(
617            name = "recv",
618            dest = %self.1,
619            "receiving message"
620        );
621        self.0.recv().await.ok_or(ChannelError::Closed)
622    }
623
624    fn addr(&self) -> ChannelAddr {
625        self.1.clone()
626    }
627
628    /// Gracefully shut down the channel server, waiting for pending
629    /// acks to be flushed before returning.
630    async fn join(mut self) {
631        self.2
632            .stop(&format!("NetRx joined; channel address: {}", self.1));
633        let _ = (&mut self.2).await;
634        // Drop will call stop() again which is harmless (token already cancelled).
635    }
636}
637
638impl<M: RemoteMessage> Drop for NetRx<M> {
639    fn drop(&mut self) {
640        self.2
641            .stop(&format!("NetRx dropped; channel address: {}", self.1));
642    }
643}
644
645/// Error returned during server operations.
646#[derive(Debug, thiserror::Error)]
647pub enum ServerError {
648    #[error("io: {1}")]
649    Io(ChannelAddr, #[source] std::io::Error),
650    #[error("listen: {0} {1}")]
651    Listen(ChannelAddr, #[source] std::io::Error),
652    #[error("resolve: {0} {1}")]
653    Resolve(ChannelAddr, #[source] std::io::Error),
654    #[error("internal: {0} {1}")]
655    Internal(ChannelAddr, #[source] anyhow::Error),
656}
657
658#[derive(thiserror::Error, Debug)]
659pub enum ClientError {
660    #[error("connection to {0} failed: {1}: {2}")]
661    Connect(ChannelAddr, std::io::Error, String),
662    #[error("unable to resolve address: {0}")]
663    Resolve(ChannelAddr),
664    #[error("io: {0} {1}")]
665    Io(ChannelAddr, std::io::Error),
666    #[error("send {0}: serialize: {1}")]
667    Serialize(ChannelAddr, bincode::ErrorKind),
668    #[error("invalid address: {0}")]
669    InvalidAddress(String),
670}
671
672/// Tells whether the address is a 'net' address. These currently have different semantics
673/// from local transports.
674#[cfg(test)]
675pub(super) fn is_net_addr(addr: &ChannelAddr) -> bool {
676    match addr.transport() {
677        ChannelTransport::Tcp(_) => true,
678        ChannelTransport::MetaTls(_) => true,
679        ChannelTransport::Tls => true,
680        ChannelTransport::Unix => true,
681        _ => false,
682    }
683}
684
685pub(crate) mod unix {
686
687    use core::str;
688    use std::os::unix::net::SocketAddr as StdSocketAddr;
689    use std::os::unix::net::UnixStream as StdUnixStream;
690
691    use rand::Rng;
692    use rand::distributions::Alphanumeric;
693    use tokio::net::UnixListener;
694    use tokio::net::UnixStream;
695
696    use super::*;
697
698    #[derive(Debug)]
699    pub(crate) struct UnixLink {
700        pub(super) addr: SocketAddr,
701        pub(super) session_id: SessionId,
702    }
703
704    #[async_trait]
705    impl Link for UnixLink {
706        type Stream = UnixStream;
707
708        fn dest(&self) -> ChannelAddr {
709            ChannelAddr::Unix(self.addr.clone())
710        }
711
712        fn link_id(&self) -> SessionId {
713            self.session_id
714        }
715
716        async fn next(&self) -> Result<Self::Stream, ClientError> {
717            let session_id = self.session_id;
718            let sock_addr = match &self.addr {
719                SocketAddr::Bound(a) => a,
720                SocketAddr::Unbound => return Err(ClientError::Resolve(self.dest())),
721            };
722            let mut backoff = ExponentialBackoffBuilder::new()
723                .with_initial_interval(Duration::from_millis(1))
724                .with_multiplier(2.0)
725                .with_randomization_factor(0.1)
726                .with_max_interval(Duration::from_millis(1000))
727                .with_max_elapsed_time(None)
728                .build();
729            loop {
730                match StdUnixStream::connect_addr(sock_addr) {
731                    Ok(std_stream) => {
732                        std_stream
733                            .set_nonblocking(true)
734                            .map_err(|err| ClientError::Io(self.dest(), err))?;
735                        let mut stream = UnixStream::from_std(std_stream)
736                            .map_err(|err| ClientError::Io(self.dest(), err))?;
737                        write_link_init(&mut stream, session_id)
738                            .await
739                            .map_err(|err| ClientError::Io(self.dest(), err))?;
740                        return Ok(stream);
741                    }
742                    Err(err) => {
743                        tracing::debug!(error = %err, "unix connect failed, backing off");
744                        if let Some(delay) = backoff.next_backoff() {
745                            tokio::time::sleep(delay).await;
746                        }
747                    }
748                }
749            }
750        }
751    }
752
753    /// Server-side listener for Unix domain sockets.
754    #[derive(Debug)]
755    pub(crate) struct UnixSocketListener {
756        pub(super) inner: UnixListener,
757        pub(super) addr: SocketAddr,
758    }
759
760    #[async_trait]
761    impl super::Listener for UnixSocketListener {
762        type Stream = UnixStream;
763
764        async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
765            let (stream, peer_addr) = self
766                .inner
767                .accept()
768                .await
769                .map_err(|err| ServerError::Io(ChannelAddr::Unix(self.addr.clone()), err))?;
770            // tokio::net::unix::SocketAddr -> std::os::unix::net::SocketAddr
771            let std_addr: StdSocketAddr = peer_addr.into();
772            Ok((stream, ChannelAddr::Unix(SocketAddr::new(std_addr))))
773        }
774    }
775
776    /// Create a unix link to the given socket address.
777    pub(crate) fn link(addr: SocketAddr) -> UnixLink {
778        UnixLink {
779            addr,
780            session_id: SessionId::random(),
781        }
782    }
783
784    /// Wrapper around std-lib's unix::SocketAddr that lets us implement equality functions
785    #[derive(Clone, Debug)]
786    pub enum SocketAddr {
787        Bound(Box<StdSocketAddr>),
788        Unbound,
789    }
790
791    impl PartialOrd for SocketAddr {
792        fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
793            Some(self.cmp(other))
794        }
795    }
796
797    impl Ord for SocketAddr {
798        fn cmp(&self, other: &Self) -> std::cmp::Ordering {
799            self.to_string().cmp(&other.to_string())
800        }
801    }
802
803    impl<'de> Deserialize<'de> for SocketAddr {
804        fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
805        where
806            D: serde::Deserializer<'de>,
807        {
808            let s = String::deserialize(deserializer)?;
809            Self::from_str(&s).map_err(D::Error::custom)
810        }
811    }
812
813    impl Serialize for SocketAddr {
814        fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
815        where
816            S: serde::Serializer,
817        {
818            serializer.serialize_str(String::from(self).as_str())
819        }
820    }
821
822    impl From<&SocketAddr> for String {
823        fn from(value: &SocketAddr) -> Self {
824            match value {
825                SocketAddr::Bound(addr) => match addr.as_pathname() {
826                    Some(path) => path
827                        .to_str()
828                        .expect("unable to get str for path")
829                        .to_string(),
830                    #[cfg(target_os = "linux")]
831                    _ => match addr.as_abstract_name() {
832                        Some(name) => format!("@{}", String::from_utf8_lossy(name)),
833                        _ => String::from("(unnamed)"),
834                    },
835                    #[cfg(not(target_os = "linux"))]
836                    _ => String::from("(unnamed)"),
837                },
838                SocketAddr::Unbound => String::from("(unbound)"),
839            }
840        }
841    }
842
843    impl FromStr for SocketAddr {
844        type Err = anyhow::Error;
845
846        fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
847            match s {
848                "" => {
849                    // TODO: ensure this socket doesn't already exist. 24 bytes of randomness should be good for now but is not perfect.
850                    // We can't use annon sockets because those are not valid across processes that aren't in the same process hierarchy aka forked.
851                    let random_string = rand::thread_rng()
852                        .sample_iter(&Alphanumeric)
853                        .take(24)
854                        .map(char::from)
855                        .collect::<String>();
856                    SocketAddr::from_abstract_name(&random_string)
857                }
858                // by convention, named sockets are displayed with an '@' prefix
859                name if name.starts_with("@") => {
860                    SocketAddr::from_abstract_name(name.strip_prefix("@").unwrap())
861                }
862                path => SocketAddr::from_pathname(path),
863            }
864        }
865    }
866
867    impl Eq for SocketAddr {}
868    impl std::hash::Hash for SocketAddr {
869        fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
870            String::from(self).hash(state);
871        }
872    }
873    impl PartialEq for SocketAddr {
874        fn eq(&self, other: &Self) -> bool {
875            match (self, other) {
876                (Self::Bound(saddr), Self::Bound(oaddr)) => {
877                    if saddr.is_unnamed() || oaddr.is_unnamed() {
878                        return false;
879                    }
880
881                    #[cfg(target_os = "linux")]
882                    {
883                        saddr.as_pathname() == oaddr.as_pathname()
884                            && saddr.as_abstract_name() == oaddr.as_abstract_name()
885                    }
886                    #[cfg(not(target_os = "linux"))]
887                    {
888                        // On non-Linux platforms, only compare pathname since no abstract names
889                        saddr.as_pathname() == oaddr.as_pathname()
890                    }
891                }
892                (Self::Unbound, _) | (_, Self::Unbound) => false,
893            }
894        }
895    }
896
897    impl fmt::Display for SocketAddr {
898        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
899            match self {
900                Self::Bound(addr) => match addr.as_pathname() {
901                    Some(path) => {
902                        write!(f, "{}", path.to_string_lossy())
903                    }
904                    #[cfg(target_os = "linux")]
905                    _ => match addr.as_abstract_name() {
906                        Some(name) => {
907                            if name.starts_with(b"@") {
908                                return write!(f, "{}", String::from_utf8_lossy(name));
909                            }
910                            write!(f, "@{}", String::from_utf8_lossy(name))
911                        }
912                        _ => write!(f, "(unnamed)"),
913                    },
914                    #[cfg(not(target_os = "linux"))]
915                    _ => write!(f, "(unnamed)"),
916                },
917                Self::Unbound => write!(f, "(unbound)"),
918            }
919        }
920    }
921
922    impl SocketAddr {
923        /// Wraps the stdlib socket address for use with this module
924        pub fn new(addr: StdSocketAddr) -> Self {
925            Self::Bound(Box::new(addr))
926        }
927
928        /// Abstract socket names start with a "@" by convention when displayed. If there is an
929        /// "@" prefix, it will be stripped from the name before used.
930        #[cfg(target_os = "linux")]
931        pub fn from_abstract_name(name: &str) -> anyhow::Result<Self> {
932            Ok(Self::new(StdSocketAddr::from_abstract_name(
933                name.strip_prefix("@").unwrap_or(name),
934            )?))
935        }
936
937        #[cfg(not(target_os = "linux"))]
938        pub fn from_abstract_name(name: &str) -> anyhow::Result<Self> {
939            // On non-Linux platforms, convert abstract names to filesystem paths
940            let name = name.strip_prefix("@").unwrap_or(name);
941            let path = Self::abstract_to_filesystem_path(name);
942            Self::from_pathname(&path.to_string_lossy())
943        }
944
945        #[cfg(not(target_os = "linux"))]
946        fn abstract_to_filesystem_path(abstract_name: &str) -> std::path::PathBuf {
947            use std::collections::hash_map::DefaultHasher;
948            use std::hash::Hash;
949            use std::hash::Hasher;
950
951            // Generate a stable hash of the abstract name for deterministic paths
952            let mut hasher = DefaultHasher::new();
953            abstract_name.hash(&mut hasher);
954            let hash = hasher.finish();
955
956            // Include process ID to prevent inter-process conflicts
957            let process_id = std::process::id();
958
959            // TODO: we just leak these. Should we do something smarter?
960            std::path::PathBuf::from(format!("/tmp/hyperactor_{}_{:x}", process_id, hash))
961        }
962
963        /// Pathnames may be absolute or relative.
964        pub fn from_pathname(name: &str) -> anyhow::Result<Self> {
965            Ok(Self::new(StdSocketAddr::from_pathname(name)?))
966        }
967    }
968
969    impl TryFrom<SocketAddr> for StdSocketAddr {
970        type Error = anyhow::Error;
971
972        fn try_from(value: SocketAddr) -> Result<Self, Self::Error> {
973            match value {
974                SocketAddr::Bound(addr) => Ok(*addr),
975                SocketAddr::Unbound => Err(anyhow::anyhow!(
976                    "std::os::unix::SocketAddr must be a bound address"
977                )),
978            }
979        }
980    }
981}
982
983pub(crate) mod tcp {
984    use tokio::net::TcpListener;
985    use tokio::net::TcpStream;
986
987    use super::*;
988
989    #[derive(Debug)]
990    pub(crate) struct TcpLink {
991        pub(super) addr: SocketAddr,
992        pub(super) session_id: SessionId,
993    }
994
995    #[async_trait]
996    impl Link for TcpLink {
997        type Stream = TcpStream;
998
999        fn dest(&self) -> ChannelAddr {
1000            ChannelAddr::Tcp(self.addr)
1001        }
1002
1003        fn link_id(&self) -> SessionId {
1004            self.session_id
1005        }
1006
1007        async fn next(&self) -> Result<Self::Stream, ClientError> {
1008            let session_id = self.session_id;
1009            let mut backoff = ExponentialBackoffBuilder::new()
1010                .with_initial_interval(Duration::from_millis(1))
1011                .with_multiplier(2.0)
1012                .with_randomization_factor(0.1)
1013                .with_max_interval(Duration::from_millis(1000))
1014                .with_max_elapsed_time(None)
1015                .build();
1016            loop {
1017                match TcpStream::connect(&self.addr).await {
1018                    Ok(mut stream) => {
1019                        stream.set_nodelay(true).map_err(|err| {
1020                            ClientError::Connect(
1021                                self.dest(),
1022                                err,
1023                                "cannot disable Nagle algorithm".to_string(),
1024                            )
1025                        })?;
1026                        write_link_init(&mut stream, session_id)
1027                            .await
1028                            .map_err(|err| ClientError::Io(self.dest(), err))?;
1029                        return Ok(stream);
1030                    }
1031                    Err(err) => {
1032                        tracing::debug!(error = %err, "tcp connect failed, backing off");
1033                        if let Some(delay) = backoff.next_backoff() {
1034                            tokio::time::sleep(delay).await;
1035                        }
1036                    }
1037                }
1038            }
1039        }
1040    }
1041
1042    /// Server-side listener for TCP sockets.
1043    #[derive(Debug)]
1044    pub(crate) struct TcpSocketListener {
1045        pub(super) inner: TcpListener,
1046        pub(super) addr: SocketAddr,
1047    }
1048
1049    #[async_trait]
1050    impl super::Listener for TcpSocketListener {
1051        type Stream = TcpStream;
1052
1053        async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
1054            let (stream, peer_addr) = self
1055                .inner
1056                .accept()
1057                .await
1058                .map_err(|err| ServerError::Io(ChannelAddr::Tcp(self.addr), err))?;
1059            stream
1060                .set_nodelay(true)
1061                .map_err(|err| ServerError::Io(ChannelAddr::Tcp(self.addr), err))?;
1062            Ok((stream, ChannelAddr::Tcp(peer_addr)))
1063        }
1064    }
1065
1066    /// Create a TCP link to the given socket address.
1067    pub(crate) fn link(addr: SocketAddr) -> TcpLink {
1068        TcpLink {
1069            addr,
1070            session_id: SessionId::random(),
1071        }
1072    }
1073}
1074
1075// TODO: Try to simplify the TLS creation T208304433
1076pub(crate) mod meta {
1077    use std::io;
1078    use std::path::PathBuf;
1079    use std::sync::Arc;
1080
1081    use anyhow::Result;
1082    use tokio_rustls::TlsAcceptor;
1083    use tokio_rustls::TlsConnector;
1084
1085    use super::*;
1086    use crate::config::Pem;
1087    use crate::config::PemBundle;
1088
1089    const THRIFT_TLS_SRV_CA_PATH_ENV: &str = "THRIFT_TLS_SRV_CA_PATH";
1090    const DEFAULT_SRV_CA_PATH: &str = "/var/facebook/rootcanal/ca.pem";
1091    const THRIFT_TLS_CL_CERT_PATH_ENV: &str = "THRIFT_TLS_CL_CERT_PATH";
1092    const THRIFT_TLS_CL_KEY_PATH_ENV: &str = "THRIFT_TLS_CL_KEY_PATH";
1093    const DEFAULT_SERVER_PEM_PATH: &str = "/var/facebook/x509_identities/server.pem";
1094
1095    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
1096    pub(crate) fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
1097        // Use right split to allow for ipv6 addresses where ":" is expected.
1098        let parts = addr_string.rsplit_once(":");
1099        match parts {
1100            Some((hostname, port_str)) => {
1101                let Ok(port) = port_str.parse() else {
1102                    return Err(ChannelError::InvalidAddress(addr_string.to_string()));
1103                };
1104                Ok(ChannelAddr::MetaTls(TlsAddr::new(hostname, port)))
1105            }
1106            _ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
1107        }
1108    }
1109
1110    /// Construct a PemBundle for server operations from Meta-specific paths.
1111    /// Server cert and key come from the same file (server.pem).
1112    pub(super) fn get_server_pem_bundle() -> PemBundle {
1113        let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1114            .map(PathBuf::from)
1115            .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1116        let server_pem_path = PathBuf::from(DEFAULT_SERVER_PEM_PATH);
1117        PemBundle {
1118            ca: Pem::File(ca_path),
1119            cert: Pem::File(server_pem_path.clone()),
1120            key: Pem::File(server_pem_path),
1121        }
1122    }
1123
1124    /// Construct a PemBundle for client operations from Meta-specific env vars.
1125    /// Returns None if client cert/key env vars are not set.
1126    fn get_client_pem_bundle() -> Option<PemBundle> {
1127        let cert_path = std::env::var_os(THRIFT_TLS_CL_CERT_PATH_ENV)?;
1128        let key_path = std::env::var_os(THRIFT_TLS_CL_KEY_PATH_ENV)?;
1129        let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1130            .map(PathBuf::from)
1131            .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1132        Some(PemBundle {
1133            ca: Pem::File(ca_path),
1134            cert: Pem::File(PathBuf::from(cert_path)),
1135            key: Pem::File(PathBuf::from(key_path)),
1136        })
1137    }
1138
1139    /// Creates a TLS acceptor by looking for necessary certs and keys in a Meta server environment.
1140    pub(crate) fn tls_acceptor(enforce_client_tls: bool) -> Result<TlsAcceptor> {
1141        let bundle = get_server_pem_bundle();
1142        tls::tls_acceptor_from_bundle(&bundle, enforce_client_tls)
1143    }
1144
1145    /// Try to create a TLS connector for Meta environments.
1146    ///
1147    /// Returns `Ok` when the root CA is present (optional client certs
1148    /// are added when `THRIFT_TLS_CL_CERT_PATH` / `THRIFT_TLS_CL_KEY_PATH`
1149    /// are set).
1150    pub(super) fn try_tls_connector() -> Result<TlsConnector> {
1151        tls_connector()
1152    }
1153
1154    /// Creates a TLS connector by looking for necessary certs and keys in a Meta server environment.
1155    /// Supports optional client authentication (unlike the tls module which always requires it).
1156    fn tls_connector() -> Result<TlsConnector> {
1157        // Ensure ring is installed as the process-level crypto provider.
1158        // No-op when already installed (e.g. under Buck with native-tls).
1159        let _ = rustls::crypto::ring::default_provider().install_default();
1160
1161        let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1162            .map(PathBuf::from)
1163            .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1164        let ca_pem = Pem::File(ca_path);
1165        let root_store = tls::build_root_store(&ca_pem)?;
1166
1167        // If client certs are available, use mutual TLS; otherwise, no client auth
1168        let config = rustls::ClientConfig::builder().with_root_certificates(Arc::new(root_store));
1169
1170        let config = if let Some(bundle) = get_client_pem_bundle() {
1171            let certs = tls::load_certs(&bundle.cert)?;
1172            let key = tls::load_key(&bundle.key)?;
1173            config
1174                .with_client_auth_cert(certs, key)
1175                .map_err(|e| anyhow::anyhow!("load client certs: {}", e))?
1176        } else {
1177            config.with_no_client_auth()
1178        };
1179
1180        Ok(TlsConnector::from(Arc::new(config)))
1181    }
1182
1183    /// Create a MetaTLS link to the given address.
1184    pub fn link(addr: TlsAddr) -> Result<tls::TlsLink, ClientError> {
1185        let connector = tls_connector().map_err(|e| {
1186            ClientError::Connect(
1187                ChannelAddr::MetaTls(addr.clone()),
1188                io::Error::other(e.to_string()),
1189                "failed to create TLS connector".to_string(),
1190            )
1191        })?;
1192        let TlsAddr { hostname, port } = addr;
1193        Ok(tls::TlsLink {
1194            hostname,
1195            port,
1196            connector,
1197            addr_type: tls::TlsAddrType::MetaTls,
1198            session_id: SessionId::random(),
1199        })
1200    }
1201}
1202
1203/// TLS transport module using configurable certificates via hyperactor config attributes.
1204pub(crate) mod tls {
1205    use std::io;
1206    use std::io::BufReader;
1207    use std::sync::Arc;
1208
1209    use anyhow::Context;
1210    use anyhow::Result;
1211    use rustls::RootCertStore;
1212    use rustls::pki_types::CertificateDer;
1213    use rustls::pki_types::PrivateKeyDer;
1214    use rustls::pki_types::ServerName;
1215    use tokio::net::TcpStream;
1216    use tokio_rustls::TlsAcceptor;
1217    use tokio_rustls::TlsConnector;
1218    use tokio_rustls::client::TlsStream;
1219
1220    use super::*;
1221    use crate::channel::TlsAddr;
1222    use crate::config::Pem;
1223    use crate::config::PemBundle;
1224    use crate::config::TLS_CA;
1225    use crate::config::TLS_CERT;
1226    use crate::config::TLS_KEY;
1227
1228    /// Distinguishes between Tls and MetaTls for address construction.
1229    #[derive(Debug, Clone, Copy)]
1230    pub(crate) enum TlsAddrType {
1231        Tls,
1232        MetaTls,
1233    }
1234
1235    /// Parse an address string into a TlsAddr.
1236    #[allow(clippy::result_large_err)]
1237    pub(crate) fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
1238        // Use right split to allow for ipv6 addresses where ":" is expected.
1239        let parts = addr_string.rsplit_once(":");
1240        match parts {
1241            Some((hostname, port_str)) => {
1242                let Ok(port) = port_str.parse() else {
1243                    return Err(ChannelError::InvalidAddress(addr_string.to_string()));
1244                };
1245                Ok(ChannelAddr::Tls(TlsAddr::new(hostname, port)))
1246            }
1247            _ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
1248        }
1249    }
1250
1251    /// Load certificates from a Pem value.
1252    pub(super) fn load_certs(pem: &Pem) -> Result<Vec<CertificateDer<'static>>> {
1253        let mut reader = BufReader::new(pem.reader()?);
1254        let certs = rustls_pemfile::certs(&mut reader)
1255            .filter_map(Result::ok)
1256            .collect();
1257        Ok(certs)
1258    }
1259
1260    /// Load a private key from a Pem value.
1261    pub(super) fn load_key(pem: &Pem) -> Result<PrivateKeyDer<'static>> {
1262        let mut reader = BufReader::new(pem.reader()?);
1263        loop {
1264            break match rustls_pemfile::read_one(&mut reader)? {
1265                Some(rustls_pemfile::Item::Pkcs1Key(key)) => Ok(PrivateKeyDer::Pkcs1(key)),
1266                Some(rustls_pemfile::Item::Pkcs8Key(key)) => Ok(PrivateKeyDer::Pkcs8(key)),
1267                Some(rustls_pemfile::Item::Sec1Key(key)) => Ok(PrivateKeyDer::Sec1(key)),
1268                Some(_) => continue,
1269                None => anyhow::bail!("no private key found in TLS key file"),
1270            };
1271        }
1272    }
1273
1274    /// Build root certificate store from the CA pem.
1275    pub(super) fn build_root_store(ca_pem: &Pem) -> Result<RootCertStore> {
1276        let mut root_store = RootCertStore::empty();
1277        let certs = load_certs(ca_pem)?;
1278        root_store.add_parsable_certificates(certs);
1279        Ok(root_store)
1280    }
1281
1282    /// Get the PEM bundle from configuration.
1283    fn get_pem_bundle() -> PemBundle {
1284        PemBundle {
1285            ca: hyperactor_config::global::get_cloned(TLS_CA),
1286            cert: hyperactor_config::global::get_cloned(TLS_CERT),
1287            key: hyperactor_config::global::get_cloned(TLS_KEY),
1288        }
1289    }
1290
1291    /// Creates a TLS acceptor using certificates from the provided PEM bundle.
1292    /// If `enforce_client_tls` is true, requires client certificates for mutual TLS.
1293    pub(super) fn tls_acceptor_from_bundle(
1294        bundle: &PemBundle,
1295        enforce_client_tls: bool,
1296    ) -> Result<TlsAcceptor> {
1297        // Ensure ring is installed as the process-level crypto provider.
1298        // No-op when already installed (e.g. under Buck with native-tls).
1299        let _ = rustls::crypto::ring::default_provider().install_default();
1300
1301        let certs = load_certs(&bundle.cert).context("load TLS certificate")?;
1302        let key = load_key(&bundle.key).context("load TLS key")?;
1303        let root_store = build_root_store(&bundle.ca).context("build root cert store")?;
1304
1305        let config = rustls::ServerConfig::builder();
1306        let config = if enforce_client_tls {
1307            // Build server config with mutual TLS (require client certs)
1308            let client_verifier =
1309                rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
1310                    .build()
1311                    .map_err(|e| anyhow::anyhow!("failed to build client verifier: {}", e))?;
1312            config.with_client_cert_verifier(client_verifier)
1313        } else {
1314            config.with_no_client_auth()
1315        }
1316        .with_single_cert(certs, key)?;
1317
1318        Ok(TlsAcceptor::from(Arc::new(config)))
1319    }
1320
1321    /// Creates a TLS acceptor using certificates from config (always enforces mutual TLS).
1322    pub(crate) fn tls_acceptor() -> Result<TlsAcceptor> {
1323        tls_acceptor_from_bundle(&get_pem_bundle(), true)
1324    }
1325
1326    /// Creates a TLS connector using certificates from the provided PEM bundle.
1327    pub(super) fn tls_connector_from_bundle(bundle: &PemBundle) -> Result<TlsConnector> {
1328        // Ensure ring is installed as the process-level crypto provider.
1329        // No-op when already installed (e.g. under Buck with native-tls).
1330        let _ = rustls::crypto::ring::default_provider().install_default();
1331
1332        let certs = load_certs(&bundle.cert).context("load TLS certificate")?;
1333        let key = load_key(&bundle.key).context("load TLS key")?;
1334        let root_store = build_root_store(&bundle.ca).context("build root cert store")?;
1335
1336        let config = rustls::ClientConfig::builder()
1337            .with_root_certificates(Arc::new(root_store))
1338            .with_client_auth_cert(certs, key)
1339            .context("configure client auth")?;
1340
1341        Ok(TlsConnector::from(Arc::new(config)))
1342    }
1343
1344    /// Creates a TLS connector using certificates from config.
1345    fn tls_connector() -> Result<TlsConnector> {
1346        tls_connector_from_bundle(&get_pem_bundle())
1347    }
1348
1349    /// Shared TLS link implementation used by both tls and metatls transports.
1350    pub(crate) struct TlsLink {
1351        pub(crate) hostname: Hostname,
1352        pub(crate) port: Port,
1353        pub(crate) connector: TlsConnector,
1354        pub(crate) addr_type: TlsAddrType,
1355        pub(crate) session_id: SessionId,
1356    }
1357
1358    impl std::fmt::Debug for TlsLink {
1359        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1360            f.debug_struct("TlsLink")
1361                .field("hostname", &self.hostname)
1362                .field("port", &self.port)
1363                .field("addr_type", &self.addr_type)
1364                .finish()
1365        }
1366    }
1367
1368    #[async_trait]
1369    impl Link for TlsLink {
1370        type Stream = TlsStream<TcpStream>;
1371
1372        fn dest(&self) -> ChannelAddr {
1373            let addr = TlsAddr::new(self.hostname.clone(), self.port);
1374            match self.addr_type {
1375                TlsAddrType::Tls => ChannelAddr::Tls(addr),
1376                TlsAddrType::MetaTls => ChannelAddr::MetaTls(addr),
1377            }
1378        }
1379
1380        fn link_id(&self) -> SessionId {
1381            self.session_id
1382        }
1383
1384        async fn next(&self) -> Result<Self::Stream, ClientError> {
1385            let session_id = self.session_id;
1386            let server_name = ServerName::try_from(self.hostname.clone()).map_err(|e| {
1387                ClientError::Connect(
1388                    self.dest(),
1389                    io::Error::other(e.to_string()),
1390                    "invalid server name".to_string(),
1391                )
1392            })?;
1393            let mut backoff = ExponentialBackoffBuilder::new()
1394                .with_initial_interval(Duration::from_millis(1))
1395                .with_multiplier(2.0)
1396                .with_randomization_factor(0.1)
1397                .with_max_interval(Duration::from_millis(1000))
1398                .with_max_elapsed_time(None)
1399                .build();
1400            loop {
1401                let mut addrs = (self.hostname.as_ref(), self.port)
1402                    .to_socket_addrs()
1403                    .map_err(|_| ClientError::Resolve(self.dest()))?;
1404                let addr = addrs.next().ok_or(ClientError::Resolve(self.dest()))?;
1405                match TcpStream::connect(&addr).await {
1406                    Ok(stream) => {
1407                        stream.set_nodelay(true).map_err(|err| {
1408                            ClientError::Connect(
1409                                self.dest(),
1410                                err,
1411                                "cannot disable Nagle algorithm".to_string(),
1412                            )
1413                        })?;
1414                        let mut tls_stream = self
1415                            .connector
1416                            .connect(server_name.clone(), stream)
1417                            .await
1418                            .map_err(|err| {
1419                                ClientError::Connect(
1420                                    self.dest(),
1421                                    err,
1422                                    format!("cannot establish TLS connection to {:?}", server_name),
1423                                )
1424                            })?;
1425                        write_link_init(&mut tls_stream, session_id)
1426                            .await
1427                            .map_err(|err| ClientError::Io(self.dest(), err))?;
1428                        return Ok(tls_stream);
1429                    }
1430                    Err(err) => {
1431                        tracing::debug!(error = %err, "tls connect failed, backing off");
1432                        if let Some(delay) = backoff.next_backoff() {
1433                            tokio::time::sleep(delay).await;
1434                        }
1435                    }
1436                }
1437            }
1438        }
1439    }
1440
1441    /// Create a TLS link to the given address.
1442    pub fn link(addr: TlsAddr) -> Result<TlsLink, ClientError> {
1443        let connector = tls_connector().map_err(|e| {
1444            ClientError::Connect(
1445                ChannelAddr::Tls(addr.clone()),
1446                io::Error::other(e.to_string()),
1447                "failed to create TLS connector".to_string(),
1448            )
1449        })?;
1450        let TlsAddr { hostname, port } = addr;
1451        Ok(TlsLink {
1452            hostname,
1453            port,
1454            connector,
1455            addr_type: TlsAddrType::Tls,
1456            session_id: SessionId::random(),
1457        })
1458    }
1459
1460    #[cfg(test)]
1461    mod tests {
1462        use timed_test::async_timed_test;
1463
1464        use super::*;
1465        use crate::channel::Rx;
1466        use crate::channel::net::server;
1467        use crate::config::Pem;
1468        use crate::config::TLS_CA;
1469        use crate::config::TLS_CERT;
1470        use crate::config::TLS_KEY;
1471
1472        // Dummy test certificates generated with openssl for testing only.
1473        // These certificates include Subject Alternative Names (SAN) for localhost, 127.0.0.1, and ::1
1474        // CA certificate
1475        const TEST_CA_CERT: &str = r#"-----BEGIN CERTIFICATE-----
1476MIIDBTCCAe2gAwIBAgIUaGNmboiIosG+8Up0vgDr/+cg+2IwDQYJKoZIhvcNAQEL
1477BQAwEjEQMA4GA1UEAwwHVGVzdCBDQTAeFw0yNjAxMjgxNzA4MzlaFw0yNzAxMjgx
1478NzA4MzlaMBIxEDAOBgNVBAMMB1Rlc3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IB
1479DwAwggEKAoIBAQC9RBoMYXCajklswt8Vi1JI1lEYzic0WNOmz45vG/7H6jTWkgL3
1480K5Ri+Seg3MobDNc48YHWXYm4hP9wCzkx8ih3ntT5XiY1My/G3jLUuoIEE9pF/BoJ
1481YQwZVoPNFhA9WhXNRsINf1cXFf8NzRfXpxBfKWtQJxYXU4JiDBQ6rLnQQABo8JmQ
1482vYFhJbBaYip5jTSiVNn7mB1zNr5jsVxuoSF53Pb7xQ76bwBdOq4zd6PSxL5/lr4G
1483cHSoxwZQdZMG7PL6hbxDQ2S2YI2lYVET1zwc2WPKCfjbEXBC/jzx828CInQtuksk
148418gJt6xHkTFEA8CSA29GM3lejnwYWf51xyyBAgMBAAGjUzBRMB0GA1UdDgQWBBRX
1485cbxSZ70NsUkAS3Hhy6irugywJDAfBgNVHSMEGDAWgBRXcbxSZ70NsUkAS3Hhy6ir
1486ugywJDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQA7aAFfyW67
1487Z+uGSVYhpsT/uH/3Z3nr7X1smTz5CGEfq2czEcTC7gbYI2l8GZ47GPfnAvHTBZVm
1488V/XncBCsj7/thOh2jYEHFyCbPckoaSCRyCOnK7LPUlr4HN5uP9EFe45qBLCJDEoY
1489GTTw7MtzwdovfjchNfKQCTtkBJCXQ95WLCf6UOh02Sn28UTlgfXzF0X0FrcWqWa3
1490uJZd4XOo4O6hKKlHaBaQPiEr++1xc3SWPV7jZHbckI/vKBnDdEZ9JQX5fFZuypUI
1491sgomYHxvxrU2hWx+7k53CRdjfaIvT9Ie44z9sSdsU/+blw2S8f/ZTmuECoIAAXYO
14920qpzlxZMdr7T
1493-----END CERTIFICATE-----"#;
1494
1495        // Server certificate (signed by CA) with SAN for localhost, 127.0.0.1, ::1
1496        const TEST_SERVER_CERT: &str = r#"-----BEGIN CERTIFICATE-----
1497MIIDJDCCAgygAwIBAgIUaz66DsWaH5ZXM4hCFnbVbMsyN1cwDQYJKoZIhvcNAQEL
1498BQAwEjEQMA4GA1UEAwwHVGVzdCBDQTAeFw0yNjAxMjgxNzA4MzlaFw0yNzAxMjgx
1499NzA4MzlaMBQxEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQAD
1500ggEPADCCAQoCggEBAKCbp++qNyTn5LOsV0h9gLKJALBcjg2A14I3804N9UyDhPW2
1501QKQ2W424u2P1MfKrw/2C+CErGlrADlnco2RQVDAarAIuGdFvBOt5UezqOS7Mk4OS
15029MlS7NZnMbc37KuM9UIG5ScJjXR/Z5z9dxeR0I9y3n0Ix6khbV7tOSHobiweI0FI
15038LftBS+CQnXr6vbWPcHcW6Z0FHUv7IWhqMWmv9PlZRGe9Y6VzXrRp0PBnZMOnAYf
1504aMQUwYRswWdm9j9Z1sMdTJ14G+KVmO3Vj6XI6Sm9uIcYhlwG/kORwogJFWlVuP9o
1505rloFRCjyHJ1d7GZqqnRyHHDDCBms8ed+3YfEYQECAwEAAaNwMG4wLAYDVR0RBCUw
1506I4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMB0GA1UdDgQWBBQl
1507J4vxUoCzqqeTwQAiLqE8wYezKzAfBgNVHSMEGDAWgBRXcbxSZ70NsUkAS3Hhy6ir
1508ugywJDANBgkqhkiG9w0BAQsFAAOCAQEAnXHIBDQ4AHAMV71piTOuI41ShASQed6L
1509bi7XUMZgZDslLkfU1vnP3BlwpliraBsAytSYQC6kbytOuz1uQ4K7yLb2tAAmUgEO
1510EdIVt9SXr5tCcIPeLmInF0pysPqjZO8n7vtJyd9gryKqdhm1uzA7WQWq/Az8a9Sk
1511uW2J6Oc5p6P7Mf3/ixqXzvGRo8rzu0CUJOJ67UTE/HhbJuplQ5dep5CEEOAIsAtH
1512zn9O4rW92ueBkoBJM++YILS1vQ7jKc2N3RNrnHm7FeootBrtR9mBi0TH97K73ZPZ
15132Cdhnym0CsCJggrllFGH32cYo7+K2PO7/4oj5XbBCSWcssicvd8ovg==
1514-----END CERTIFICATE-----"#;
1515
1516        // Server private key
1517        // test-only embedded key -- though it might expire at some point:
1518        // @lint-ignore PRIVATEKEY insecure-private-key-storage
1519        const TEST_SERVER_KEY: &str = r#"-----BEGIN PRIVATE KEY-----
1520MIIEugIBADANBgkqhkiG9w0BAQEFAASCBKQwggSgAgEAAoIBAQCgm6fvqjck5+Sz
1521rFdIfYCyiQCwXI4NgNeCN/NODfVMg4T1tkCkNluNuLtj9THyq8P9gvghKxpawA5Z
15223KNkUFQwGqwCLhnRbwTreVHs6jkuzJODkvTJUuzWZzG3N+yrjPVCBuUnCY10f2ec
1523/XcXkdCPct59CMepIW1e7Tkh6G4sHiNBSPC37QUvgkJ16+r21j3B3FumdBR1L+yF
1524oajFpr/T5WURnvWOlc160adDwZ2TDpwGH2jEFMGEbMFnZvY/WdbDHUydeBvilZjt
15251Y+lyOkpvbiHGIZcBv5DkcKICRVpVbj/aK5aBUQo8hydXexmaqp0chxwwwgZrPHn
1526ft2HxGEBAgMBAAECgf8G5qlQov+7ljs9fSpC8yGUik59RXzVF7Qq5DyQHglsQDp2
1527VF5yr+M/M7DZmq+KvdauDfKbej6np5j2Q4TByrHTX1IExfZWCW8srwnWJDpQyHmO
1528LcJW5DlI/SYluUFyHZxsOd+ezcpGNzM8i6eSW7GaeFUXCkmJ+uW4LnlF+7bALnnd
1529D6sak/58EsII+IJyd4lFn+voszlPn3CZGR0jkp21rvpaKgrMIsKVWWQO/sLDU5pr
1530VbpBThcLU5gRcnQouQX12e2VTCIlFu75WTsJ8V/KnEaOZUVlU/B/Bs+WQF3U+/Jo
1531eX4N+D6OsEcNQjERAFyWujxsl1WpD4uSsbFMN0ECgYEA2b7AdL+oKPQHku2KcBhr
1532Zw8K4tMDlr2VPPNwZcBTLo+O71vv/xXjMcXrXmowzkgEQckUmt1VB46riyydhwdP
1533/n9ciWcz0Va/nwHR6Y9F9unBiyUBP7PRhRyjQyRZZRGDSJvP+Xmc5UJFpRr07VLU
1534nfgMXDj37vXzKDpfhdEB2nkCgYEAvNMfA8P8w3+6246x5YHflvTkPdw+2oyge+LD
1535mphB/w7SF8mlyNGloj3+KBZmd9SkvT57wCvO96Y9/n+mBAVisRggc0hK4ymOVYhb
1536+im/JvqGQMbVeg6iCOHnWdaZf9tL8uVsegQy3kVTN7vAa+CMFgX1dt65cGBX6XkB
153744pYmMkCgYALhbiRdQLlB+TOtZs5y1EDpxwgXKI3+9hF3Wv5NnAwapBZwje0++eF
15383r9Rw7TJda4j/QwGFehF+hrBxp6fYpetE/hFnRx0225Qb7w368j8A+ql/lNOl6li
1539rd1F1EqWupKD6RrcTL8sspEU55RGaretlE6zIqCcGI/BdTVQ03qRoQKBgHDC3zWf
1540d7XD9HGjQGdfbIe4jQjIGxzmd/wjik4q+NZ5IkukVwWa9P/zZ3DHF8Ad05dT1hEH
15412FwaAdGWpyyljq9VSiOuG1KXAXHgsZSuE4ISf9P1KYzvaiJFzaPfvOEWs79E9MfU
15429A+6dJzG2X1SpjWMr26iSTlrv3QkmFUqzAfJAoGASBkn4wls+oC5rv/Mch43pBv5
1543UmKru4ltnEHJZdbSi2DJ+AnDLD222JCasb1VT1tm2XgW6DBqrdVRPPP6GOlB0MHU
1544+3ULtZxAczt7I+ST2bo0/DV2Hse89Cm63w4wLOiVZs7+1wrAzJZLokWF7Q5gesra
1545u19txmtkiMEH+aNmekk=
1546-----END PRIVATE KEY-----"#;
1547
1548        #[async_timed_test(timeout_secs = 30)]
1549        async fn test_tls_basic() {
1550            // Ensure ring is installed as the default crypto provider
1551            // (no-op if already installed, e.g. under Buck with native-tls).
1552            let _ = rustls::crypto::ring::default_provider().install_default();
1553
1554            // Set up TLS config using the standard override pattern
1555            let config = hyperactor_config::global::lock();
1556            let _guard_cert =
1557                config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1558            let _guard_key =
1559                config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1560            let _guard_ca =
1561                config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1562
1563            // Create a TLS server bound to localhost with dynamic port
1564            let addr = TlsAddr::new("localhost", 0);
1565
1566            let (local_addr, mut rx) =
1567                server::serve::<u64>(ChannelAddr::Tls(addr)).expect("failed to serve");
1568
1569            // Dial the server
1570            let tx: super::NetTx<u64> = super::spawn(
1571                link(match &local_addr {
1572                    ChannelAddr::Tls(addr) => addr.clone(),
1573                    _ => panic!("unexpected address type"),
1574                })
1575                .expect("failed to create link"),
1576            );
1577
1578            // Send a message
1579            tx.post(42u64);
1580
1581            // Receive the message
1582            let received = rx.recv().await.expect("failed to receive");
1583            assert_eq!(received, 42u64);
1584        }
1585
1586        #[async_timed_test(timeout_secs = 30)]
1587        async fn test_tls_multiple_messages() {
1588            let _ = rustls::crypto::ring::default_provider().install_default();
1589
1590            // Set up TLS config using the standard override pattern
1591            let config = hyperactor_config::global::lock();
1592            let _guard_cert =
1593                config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1594            let _guard_key =
1595                config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1596            let _guard_ca =
1597                config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1598
1599            let addr = TlsAddr::new("localhost", 0);
1600
1601            let (local_addr, mut rx) =
1602                server::serve::<String>(ChannelAddr::Tls(addr)).expect("failed to serve");
1603            let tx: super::NetTx<String> = super::spawn(
1604                link(match &local_addr {
1605                    ChannelAddr::Tls(addr) => addr.clone(),
1606                    _ => panic!("unexpected address type"),
1607                })
1608                .expect("failed to create link"),
1609            );
1610
1611            // Send multiple messages
1612            for i in 0..10 {
1613                tx.post(format!("message {}", i));
1614            }
1615
1616            // Receive all messages
1617            for i in 0..10 {
1618                let received = rx.recv().await.expect("failed to receive");
1619                assert_eq!(received, format!("message {}", i));
1620            }
1621        }
1622
1623        #[test]
1624        fn test_tls_parse_hostname_port() {
1625            let addr = parse("localhost:8080").expect("failed to parse");
1626            assert!(matches!(
1627                addr,
1628                ChannelAddr::Tls(TlsAddr { hostname, port })
1629                    if hostname == "localhost" && port == 8080
1630            ));
1631        }
1632
1633        #[test]
1634        fn test_tls_parse_socket_addr() {
1635            let addr = parse("127.0.0.1:8080").expect("failed to parse");
1636            assert!(matches!(
1637                addr,
1638                ChannelAddr::Tls(TlsAddr { hostname, port })
1639                    if hostname == "127.0.0.1" && port == 8080
1640            ));
1641        }
1642
1643        #[test]
1644        fn test_tls_certs_parsing() {
1645            // Verify that the test certificates can be parsed correctly
1646            let cert_pem = Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec());
1647            let key_pem = Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec());
1648            let ca_pem = Pem::Value(TEST_CA_CERT.as_bytes().to_vec());
1649
1650            let certs = super::load_certs(&cert_pem).expect("failed to load certs");
1651            assert!(!certs.is_empty(), "expected at least one certificate");
1652
1653            let _key = super::load_key(&key_pem).expect("failed to load key");
1654
1655            let root_store = super::build_root_store(&ca_pem).expect("failed to build root store");
1656            assert!(!root_store.is_empty(), "expected at least one CA cert");
1657        }
1658
1659        #[test]
1660        fn test_tls_acceptor_creation() {
1661            // Ensure ring is installed as the default crypto provider
1662            // (no-op if already installed, e.g. under Buck with native-tls).
1663            let _ = rustls::crypto::ring::default_provider().install_default();
1664
1665            // Set up TLS config using the standard override pattern
1666            let config = hyperactor_config::global::lock();
1667            let _guard_cert =
1668                config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1669            let _guard_key =
1670                config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1671            let _guard_ca =
1672                config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1673
1674            // Verify that we can create a TLS acceptor
1675            let _acceptor = super::tls_acceptor().expect("failed to create TLS acceptor");
1676        }
1677
1678        #[test]
1679        fn test_tls_connector_creation() {
1680            // Ensure ring is installed as the default crypto provider
1681            // (no-op if already installed, e.g. under Buck with native-tls).
1682            let _ = rustls::crypto::ring::default_provider().install_default();
1683
1684            // Set up TLS config using the standard override pattern
1685            let config = hyperactor_config::global::lock();
1686            let _guard_cert =
1687                config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1688            let _guard_key =
1689                config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1690            let _guard_ca =
1691                config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1692
1693            // Verify that we can create a TLS connector
1694            let _connector = super::tls_connector().expect("failed to create TLS connector");
1695        }
1696    }
1697}
1698
1699/// Build the OSS PemBundle from hyperactor_config attributes.
1700fn oss_pem_bundle() -> crate::config::PemBundle {
1701    crate::config::PemBundle {
1702        ca: hyperactor_config::global::get_cloned(crate::config::TLS_CA),
1703        cert: hyperactor_config::global::get_cloned(crate::config::TLS_CERT),
1704        key: hyperactor_config::global::get_cloned(crate::config::TLS_KEY),
1705    }
1706}
1707
1708/// Try to find a usable TLS [`PemBundle`](crate::config::PemBundle)
1709/// by probing the same sources as [`try_tls_acceptor`] /
1710/// [`try_tls_connector`].
1711///
1712/// Returns the first bundle whose CA certificate is readable.
1713/// Only CA readability is checked — cert and key are returned as-is
1714/// and may not be valid. Callers that cannot use `tokio_rustls` types
1715/// directly (e.g. reqwest) can read the raw PEM bytes via
1716/// [`Pem::reader`](crate::config::Pem::reader).
1717pub fn try_tls_pem_bundle() -> Option<crate::config::PemBundle> {
1718    let oss_bundle = oss_pem_bundle();
1719    if oss_bundle.ca.reader().is_ok() {
1720        return Some(oss_bundle);
1721    }
1722    tracing::debug!("OSS TLS bundle: CA not readable, trying Meta paths");
1723
1724    let meta_bundle = meta::get_server_pem_bundle();
1725    if meta_bundle.ca.reader().is_ok() {
1726        return Some(meta_bundle);
1727    }
1728    tracing::debug!("Meta TLS bundle: CA not readable, no TLS available");
1729
1730    None
1731}
1732
1733/// Try to build a [`TlsAcceptor`](tokio_rustls::TlsAcceptor) for an
1734/// HTTP server by probing for available TLS certificates.
1735///
1736/// Detection order:
1737/// 1. **OSS / explicit config** — `HYPERACTOR_TLS_CERT`,
1738///    `HYPERACTOR_TLS_KEY`, and `HYPERACTOR_TLS_CA` (read via
1739///    [`hyperactor_config`]).
1740/// 2. **Meta default paths** —
1741///    `/var/facebook/x509_identities/server.pem` and
1742///    `/var/facebook/rootcanal/ca.pem`. These are present on
1743///    devservers and in MAST / Tupperware containers.
1744/// 3. **None** — no usable certificates found; caller should fall
1745///    back to plain HTTP.
1746///
1747/// When `enforce_client_tls` is `true`, the returned acceptor
1748/// requires clients to present a valid certificate signed by the
1749/// configured CA (mutual TLS via `WebPkiClientVerifier`). When
1750/// `false`, the acceptor authenticates itself but does not demand
1751/// client certificates.
1752pub fn try_tls_acceptor(enforce_client_tls: bool) -> Option<tokio_rustls::TlsAcceptor> {
1753    let oss_bundle = oss_pem_bundle();
1754    if let Ok(acceptor) = tls::tls_acceptor_from_bundle(&oss_bundle, enforce_client_tls) {
1755        return Some(acceptor);
1756    }
1757    tracing::debug!("OSS TLS acceptor failed, trying Meta paths");
1758
1759    let meta_bundle = meta::get_server_pem_bundle();
1760    if let Ok(acceptor) = tls::tls_acceptor_from_bundle(&meta_bundle, enforce_client_tls) {
1761        return Some(acceptor);
1762    }
1763    tracing::debug!("Meta TLS acceptor failed, no TLS available");
1764
1765    None
1766}
1767
1768/// Try to build a [`TlsConnector`](tokio_rustls::TlsConnector) for an
1769/// HTTP client that needs to connect to a TLS-enabled server.
1770///
1771/// Detection mirrors [`try_tls_acceptor`]:
1772/// 1. **OSS** — `HYPERACTOR_TLS_CA` (and optionally
1773///    `HYPERACTOR_TLS_CERT` + `HYPERACTOR_TLS_KEY` for mutual TLS).
1774/// 2. **Meta** — root CA at `/var/facebook/rootcanal/ca.pem`,
1775///    optional client certs from `THRIFT_TLS_CL_CERT_PATH` /
1776///    `THRIFT_TLS_CL_KEY_PATH`.
1777/// 3. **None** — no usable CA found; caller should fall back to plain
1778///    HTTP.
1779pub fn try_tls_connector() -> Option<tokio_rustls::TlsConnector> {
1780    let oss_bundle = oss_pem_bundle();
1781    if let Ok(connector) = tls::tls_connector_from_bundle(&oss_bundle) {
1782        return Some(connector);
1783    }
1784    tracing::debug!("OSS TLS connector failed, trying Meta paths");
1785
1786    if let Ok(connector) = meta::try_tls_connector() {
1787        return Some(connector);
1788    }
1789    tracing::debug!("Meta TLS connector failed, no TLS available");
1790
1791    None
1792}
1793
1794#[cfg(test)]
1795mod tests {
1796    use std::assert_matches::assert_matches;
1797    use std::collections::VecDeque;
1798    use std::marker::PhantomData;
1799    use std::sync::Arc;
1800    use std::sync::RwLock;
1801    use std::sync::atomic::AtomicBool;
1802    use std::sync::atomic::AtomicU64;
1803    use std::sync::atomic::Ordering;
1804    use std::time::Duration;
1805    #[cfg(target_os = "linux")] // uses abstract names
1806    use std::time::UNIX_EPOCH;
1807
1808    #[cfg(target_os = "linux")] // uses abstract names
1809    use anyhow::Result;
1810    use bytes::Bytes;
1811    use rand::Rng;
1812    use rand::SeedableRng;
1813    use rand::distributions::Alphanumeric;
1814    use timed_test::async_timed_test;
1815    use tokio::io::AsyncWrite;
1816    use tokio::io::DuplexStream;
1817    use tokio::io::ReadHalf;
1818    use tokio::io::WriteHalf;
1819    use tokio::task::JoinHandle;
1820    use tokio_util::sync::CancellationToken;
1821
1822    use super::server;
1823    use super::*;
1824    use crate::channel;
1825    use crate::channel::net::framed::FrameReader;
1826    use crate::channel::net::framed::FrameWrite;
1827    use crate::channel::net::server::AcceptorLink;
1828    use crate::config;
1829    use crate::metrics;
1830    use crate::sync::mvar::MVar;
1831
1832    /// Like the `logs_assert` injected by `#[traced_test]`, but without scope
1833    /// filtering. Use when asserting on events emitted outside the test's span
1834    /// (e.g. from spawned tasks or panic hooks).
1835    fn logs_assert_unscoped(f: impl Fn(&[&str]) -> Result<(), String>) {
1836        let buf = tracing_test::internal::global_buf().lock().unwrap();
1837        let logs_str = std::str::from_utf8(&buf).expect("Logs contain invalid UTF8");
1838        let lines: Vec<&str> = logs_str.lines().collect();
1839        match f(&lines) {
1840            Ok(()) => {}
1841            Err(msg) => panic!("{}", msg),
1842        }
1843    }
1844
1845    #[cfg(target_os = "linux")] // uses abstract names
1846    #[tracing_test::traced_test]
1847    #[tokio::test]
1848    async fn test_unix_basic() -> Result<()> {
1849        let timestamp = std::time::SystemTime::now()
1850            .duration_since(UNIX_EPOCH)
1851            .unwrap()
1852            .as_nanos();
1853        let unique_address = format!("test_unix_basic_{}", timestamp);
1854
1855        let (addr, mut rx) = server::serve::<u64>(ChannelAddr::Unix(
1856            unix::SocketAddr::from_abstract_name(&unique_address)?,
1857        ))
1858        .unwrap();
1859
1860        // It is important to keep Tx alive until all expected messages are
1861        // received. Otherwise, the channel would be closed when Tx is dropped.
1862        // Although the messages are sent to the server's buffer before the
1863        // channel was closed, NetRx could still error out before taking them
1864        // out of the buffer because NetRx could not ack through the closed
1865        // channel.
1866        {
1867            let tx: ChannelTx<u64> = channel::dial::<u64>(addr.clone()).unwrap();
1868            tx.post(123);
1869            assert_eq!(rx.recv().await.unwrap(), 123);
1870        }
1871
1872        {
1873            let tx = channel::dial::<u64>(addr.clone()).unwrap();
1874            tx.post(321);
1875            tx.post(111);
1876            tx.post(444);
1877
1878            assert_eq!(rx.recv().await.unwrap(), 321);
1879            assert_eq!(rx.recv().await.unwrap(), 111);
1880            assert_eq!(rx.recv().await.unwrap(), 444);
1881        }
1882
1883        {
1884            let tx = channel::dial::<u64>(addr).unwrap();
1885            drop(rx);
1886
1887            let (return_tx, return_rx) = oneshot::channel();
1888            tx.try_post(123, return_tx);
1889            assert_matches!(
1890                return_rx.await,
1891                Ok(SendError {
1892                    error: ChannelError::Closed,
1893                    message: 123,
1894                    ..
1895                })
1896            );
1897        }
1898
1899        Ok(())
1900    }
1901
1902    #[cfg(target_os = "linux")] // uses abstract names
1903    #[tracing_test::traced_test]
1904    #[tokio::test]
1905    async fn test_unix_basic_client_before_server() -> Result<()> {
1906        // We run this test on Unix because we can pick our own port names more easily.
1907        let timestamp = std::time::SystemTime::now()
1908            .duration_since(UNIX_EPOCH)
1909            .unwrap()
1910            .as_nanos();
1911        let socket_addr =
1912            unix::SocketAddr::from_abstract_name(&format!("test_unix_basic_{}", timestamp))
1913                .unwrap();
1914
1915        // Dial the channel before we actually serve it.
1916        let addr = ChannelAddr::Unix(socket_addr.clone());
1917        let tx = crate::channel::dial::<u64>(addr.clone()).unwrap();
1918        tx.post(123);
1919
1920        let (_, mut rx) = server::serve::<u64>(ChannelAddr::Unix(socket_addr)).unwrap();
1921        assert_eq!(rx.recv().await.unwrap(), 123);
1922
1923        tx.post(321);
1924        tx.post(111);
1925        tx.post(444);
1926
1927        assert_eq!(rx.recv().await.unwrap(), 321);
1928        assert_eq!(rx.recv().await.unwrap(), 111);
1929        assert_eq!(rx.recv().await.unwrap(), 444);
1930
1931        Ok(())
1932    }
1933
1934    #[tracing_test::traced_test]
1935    #[async_timed_test(timeout_secs = 60)]
1936    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
1937    #[cfg_attr(not(fbcode_build), ignore)]
1938    async fn test_tcp_basic() {
1939        let (addr, mut rx) =
1940            server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
1941        {
1942            let tx = channel::dial::<u64>(addr.clone()).unwrap();
1943            tx.post(123);
1944            assert_eq!(rx.recv().await.unwrap(), 123);
1945        }
1946
1947        {
1948            let tx = channel::dial::<u64>(addr.clone()).unwrap();
1949            tx.post(321);
1950            tx.post(111);
1951            tx.post(444);
1952
1953            assert_eq!(rx.recv().await.unwrap(), 321);
1954            assert_eq!(rx.recv().await.unwrap(), 111);
1955            assert_eq!(rx.recv().await.unwrap(), 444);
1956        }
1957
1958        {
1959            let tx = channel::dial::<u64>(addr).unwrap();
1960            drop(rx);
1961
1962            let (return_tx, return_rx) = oneshot::channel();
1963            tx.try_post(123, return_tx);
1964            assert_matches!(
1965                return_rx.await,
1966                Ok(SendError {
1967                    error: ChannelError::Closed,
1968                    message: 123,
1969                    ..
1970                })
1971            );
1972        }
1973    }
1974
1975    // The message size is limited by CODEC_MAX_FRAME_LENGTH.
1976    #[async_timed_test(timeout_secs = 5)]
1977    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
1978    #[cfg_attr(not(fbcode_build), ignore)]
1979    async fn test_tcp_message_size() {
1980        let default_size_in_bytes = 100 * 1024 * 1024;
1981        // Use temporary config for this test
1982        let config = hyperactor_config::global::lock();
1983        let _guard1 = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1));
1984        let _guard2 = config.override_key(config::CODEC_MAX_FRAME_LENGTH, default_size_in_bytes);
1985
1986        let (addr, mut rx) =
1987            server::serve::<String>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
1988
1989        let tx = channel::dial::<String>(addr.clone()).unwrap();
1990        // Default size is okay
1991        {
1992            // Leave some headroom because Tx will wrap the payload in Frame::Message.
1993            let message = "a".repeat(default_size_in_bytes - 1024);
1994            tx.post(message.clone());
1995            assert_eq!(rx.recv().await.unwrap(), message);
1996        }
1997        // Bigger than the default size will fail.
1998        {
1999            let (return_channel, return_receiver) = oneshot::channel();
2000            let message = "a".repeat(default_size_in_bytes + 1024);
2001            tx.try_post(message.clone(), return_channel);
2002            let returned = return_receiver.await.unwrap();
2003            assert_eq!(message, returned.message);
2004        }
2005    }
2006
2007    #[async_timed_test(timeout_secs = 30)]
2008    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
2009    #[cfg_attr(not(fbcode_build), ignore)]
2010    async fn test_ack_flush() {
2011        let config = hyperactor_config::global::lock();
2012        // Set a large value to effectively prevent acks from being sent except
2013        // during shutdown flush.
2014        let _guard_message_ack =
2015            config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 100000000);
2016        let _guard_delivery_timeout =
2017            config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(5));
2018
2019        let (addr, mut net_rx) =
2020            server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
2021        let net_tx = channel::dial::<u64>(addr.clone()).unwrap();
2022        let (tx, rx) = oneshot::channel();
2023        net_tx.try_post(1, tx);
2024        assert_eq!(net_rx.recv().await.unwrap(), 1);
2025        drop(net_rx);
2026        // Using `is_err` to confirm the message is delivered/acked is confusing,
2027        // but is correct. See how send is implemented: https://fburl.com/code/ywt8lip2
2028        assert!(rx.await.is_err());
2029    }
2030
2031    #[async_timed_test(timeout_secs = 60)]
2032    // TODO: OSS: failed to retrieve ipv6 address
2033    #[cfg_attr(not(fbcode_build), ignore)]
2034    async fn test_meta_tls_basic() {
2035        hyperactor_telemetry::initialize_logging_for_test();
2036
2037        let addr = ChannelAddr::any(ChannelTransport::MetaTls(TlsMode::IpV6));
2038        let meta_addr = match addr {
2039            ChannelAddr::MetaTls(meta_addr) => meta_addr,
2040            _ => panic!("expected MetaTls address"),
2041        };
2042        let (local_addr, mut rx) = server::serve::<u64>(ChannelAddr::MetaTls(meta_addr)).unwrap();
2043        {
2044            let tx = channel::dial::<u64>(local_addr.clone()).unwrap();
2045            tx.post(123);
2046        }
2047        assert_eq!(rx.recv().await.unwrap(), 123);
2048
2049        {
2050            let tx = channel::dial::<u64>(local_addr.clone()).unwrap();
2051            tx.post(321);
2052            tx.post(111);
2053            tx.post(444);
2054            assert_eq!(rx.recv().await.unwrap(), 321);
2055            assert_eq!(rx.recv().await.unwrap(), 111);
2056            assert_eq!(rx.recv().await.unwrap(), 444);
2057        }
2058
2059        {
2060            let tx = channel::dial::<u64>(local_addr).unwrap();
2061            drop(rx);
2062
2063            let (return_tx, return_rx) = oneshot::channel();
2064            tx.try_post(123, return_tx);
2065            assert_matches!(
2066                return_rx.await,
2067                Ok(SendError {
2068                    error: ChannelError::Closed,
2069                    message: 123,
2070                    ..
2071                })
2072            );
2073        }
2074    }
2075
2076    #[derive(Clone, Debug, Default)]
2077    struct NetworkFlakiness {
2078        // A tuple of:
2079        //   1. the probability of a network failure when sending a message.
2080        //   2. the max number of disconnections allowed.
2081        //   3. the minimum duration between disconnections.
2082        //
2083        //   2 and 3 are useful to prevent frequent disconnections leading to
2084        //   unacked messages being sent repeatedly.
2085        disconnect_params: Option<(f64, u64, Duration)>,
2086        // The max possible latency when sending a message. The actual latency
2087        // is randomly generated between 0 and max_latency.
2088        latency_range: Option<(Duration, Duration)>,
2089    }
2090
2091    impl NetworkFlakiness {
2092        // Calculate whether to disconnect
2093        async fn should_disconnect(
2094            &self,
2095            rng: &mut impl rand::Rng,
2096            disconnected_count: u64,
2097            prev_disconnected_at: &RwLock<Instant>,
2098        ) -> bool {
2099            let Some((prob, max_disconnects, duration)) = &self.disconnect_params else {
2100                return false;
2101            };
2102
2103            let disconnected_at = prev_disconnected_at.read().unwrap();
2104            if disconnected_at.elapsed() > *duration && disconnected_count < *max_disconnects {
2105                rng.gen_bool(*prob)
2106            } else {
2107                false
2108            }
2109        }
2110    }
2111
2112    struct MockLink<M> {
2113        buffer_size: usize,
2114        session_id: SessionId,
2115        receiver_storage: Arc<MVar<DuplexStream>>,
2116        // If true, `next()` on this link will always return an error.
2117        fail_connects: Arc<AtomicBool>,
2118        // Used to break the existing connection, if there is one. It still
2119        // allows reconnect.
2120        disconnect_signal: watch::Sender<()>,
2121        network_flakiness: NetworkFlakiness,
2122        disconnected_count: Arc<AtomicU64>,
2123        prev_disconnected_at: Arc<RwLock<Instant>>,
2124        // If set, print logs every `debug_log_sampling_rate` messages. This
2125        // is normally set only when debugging a test failure.
2126        debug_log_sampling_rate: Option<u64>,
2127        _message_type: PhantomData<M>,
2128    }
2129
2130    impl<M> fmt::Debug for MockLink<M> {
2131        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2132            f.debug_struct("MockLink")
2133                .field("buffer_size", &self.buffer_size)
2134                .field("receiver_storage", &"<MVar<DuplexStream>>")
2135                .field("fail_connects", &self.fail_connects)
2136                .field("disconnect_signal", &"<watch::Sender>")
2137                .field("network_flakiness", &self.network_flakiness)
2138                .field("disconnected_count", &self.disconnected_count)
2139                .field("prev_disconnected_at", &"<RwLock<Instant>>")
2140                .field("debug_log_sampling_rate", &self.debug_log_sampling_rate)
2141                .finish()
2142        }
2143    }
2144
2145    impl<M: RemoteMessage> MockLink<M> {
2146        fn new() -> Self {
2147            let (sender, _) = watch::channel(());
2148            Self {
2149                buffer_size: 64,
2150                session_id: SessionId::random(),
2151                receiver_storage: Arc::new(MVar::empty()),
2152                fail_connects: Arc::new(AtomicBool::new(false)),
2153                disconnect_signal: sender,
2154                network_flakiness: NetworkFlakiness::default(),
2155                disconnected_count: Arc::new(AtomicU64::new(0)),
2156                prev_disconnected_at: Arc::new(RwLock::new(tokio::time::Instant::now())),
2157                debug_log_sampling_rate: None,
2158                _message_type: PhantomData,
2159            }
2160        }
2161
2162        // If `fail_connects` is true, `next()` on this link will
2163        // always return an error.
2164        fn fail_connects() -> Self {
2165            Self {
2166                fail_connects: Arc::new(AtomicBool::new(true)),
2167                ..Self::new()
2168            }
2169        }
2170
2171        fn with_network_flakiness(network_flakiness: NetworkFlakiness) -> Self {
2172            if let Some((min, max)) = network_flakiness.latency_range {
2173                assert!(min < max);
2174            }
2175
2176            Self {
2177                network_flakiness,
2178                ..Self::new()
2179            }
2180        }
2181
2182        fn receiver_storage(&self) -> Arc<MVar<DuplexStream>> {
2183            self.receiver_storage.clone()
2184        }
2185
2186        fn disconnected_count(&self) -> Arc<AtomicU64> {
2187            self.disconnected_count.clone()
2188        }
2189
2190        fn disconnect_signal(&self) -> &watch::Sender<()> {
2191            &self.disconnect_signal
2192        }
2193
2194        fn fail_connects_switch(&self) -> Arc<AtomicBool> {
2195            self.fail_connects.clone()
2196        }
2197
2198        fn set_buffer_size(&mut self, size: usize) {
2199            self.buffer_size = size;
2200        }
2201
2202        fn set_sampling_rate(&mut self, sampling_rate: u64) {
2203            self.debug_log_sampling_rate = Some(sampling_rate);
2204        }
2205    }
2206
2207    #[async_trait]
2208    impl<M: RemoteMessage> Link for MockLink<M> {
2209        type Stream = DuplexStream;
2210
2211        fn dest(&self) -> ChannelAddr {
2212            ChannelAddr::Local(u64::MAX)
2213        }
2214
2215        fn link_id(&self) -> SessionId {
2216            self.session_id
2217        }
2218
2219        async fn next(&self) -> Result<Self::Stream, ClientError> {
2220            let session_id = self.session_id;
2221            tracing::debug!("MockLink starts to connect.");
2222            if self.fail_connects.load(Ordering::Acquire) {
2223                return Err(ClientError::Connect(
2224                    self.dest(),
2225                    std::io::Error::other("intentional error"),
2226                    "expected failure injected by the mock".to_string(),
2227                ));
2228            }
2229
2230            // Add relays between server and client streams. The
2231            // relays provides the place to inject network flakiness.
2232            // The message flow looks like:
2233            //
2234            // server <-> server relay <-> injection logic <-> client relay <-> client
2235            async fn relay_message<M: RemoteMessage>(
2236                mut disconnect_signal: watch::Receiver<()>,
2237                network_flakiness: NetworkFlakiness,
2238                disconnected_count: Arc<AtomicU64>,
2239                prev_disconnected_at: Arc<RwLock<Instant>>,
2240                mut reader: FrameReader<ReadHalf<DuplexStream>>,
2241                mut writer: WriteHalf<DuplexStream>,
2242                // Used by client and server tokio tasks to coordinate
2243                // stopping together.
2244                task_coordination_token: CancellationToken,
2245                debug_log_sampling_rate: Option<u64>,
2246                // Whether the relayed message is from client to
2247                // server.
2248                is_from_client: bool,
2249            ) {
2250                // Used to simulate latency. Briefly, messages are
2251                // buffered in the queue and wait for the expected
2252                // latency elapse.
2253                async fn wait_for_latency_elapse(
2254                    queue: &VecDeque<(Bytes, Instant)>,
2255                    network_flakiness: &NetworkFlakiness,
2256                    rng: &mut impl rand::Rng,
2257                ) {
2258                    if let Some((min, max)) = network_flakiness.latency_range {
2259                        let diff = max.abs_diff(min);
2260                        let factor = rng.gen_range(0.0..=1.0);
2261                        let latency = min + diff.mul_f64(factor);
2262                        tokio::time::sleep_until(queue.front().unwrap().1 + latency).await;
2263                    }
2264                }
2265
2266                let mut rng = rand::rngs::SmallRng::from_entropy();
2267                let mut queue: VecDeque<(Bytes, Instant)> = VecDeque::new();
2268                let mut send_count = 0u64;
2269
2270                loop {
2271                    tokio::select! {
2272                        read_res = reader.next() => {
2273                            match read_res {
2274                                Ok(Some((_, data))) => {
2275                                    queue.push_back((data, tokio::time::Instant::now()));
2276                                }
2277                                Ok(None) | Err(_) => {
2278                                        tracing::debug!("The upstream is closed or dropped. MockLink disconnects");
2279                                        break;
2280                                }
2281                            }
2282                        }
2283                        _ = wait_for_latency_elapse(&queue, &network_flakiness, &mut rng), if !queue.is_empty() => {
2284                            let count = disconnected_count.load(Ordering::Relaxed);
2285                            if network_flakiness.should_disconnect(&mut rng, count, &prev_disconnected_at).await {
2286                                tracing::debug!("MockLink disconnects");
2287                                disconnected_count.fetch_add(1, Ordering::Relaxed);
2288
2289                                metrics::CHANNEL_RECONNECTIONS.add(
2290                                    1,
2291                                    hyperactor_telemetry::kv_pairs!(
2292                                        "transport" => "mock",
2293                                        "reason" => "network_flakiness",
2294                                    ),
2295                                );
2296
2297                                let mut w = prev_disconnected_at.write().unwrap();
2298                                *w = tokio::time::Instant::now();
2299                                break;
2300                            }
2301                            let data = queue.pop_front().unwrap().0;
2302                            let is_sampled = debug_log_sampling_rate.is_some_and(|sample_rate| send_count % sample_rate == 1);
2303                            if is_sampled {
2304                                if is_from_client {
2305                                    if let Ok(Frame::Message(_seq, _msg)) = bincode::deserialize::<Frame<M>>(&data) {
2306                                        tracing::debug!("MockLink relays a msg from client. msg type: {}", std::any::type_name::<M>());
2307                                    }
2308                                } else {
2309                                    let result = deserialize_response(data.clone());
2310                                    if let Ok(NetRxResponse::Ack(seq)) = result {
2311                                        tracing::debug!("MockLink relays an ack from server. seq: {}", seq);
2312                                    }
2313                                }
2314                            }
2315                            let mut fw  = FrameWrite::new(writer, data, hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH), 0).unwrap();
2316                            if fw.send().await.is_err() {
2317                                break;
2318                            }
2319                            writer = fw.complete();
2320                            send_count += 1;
2321                        }
2322                        _ = task_coordination_token.cancelled() => break,
2323
2324                        changed = disconnect_signal.changed() => {
2325                            tracing::debug!("MockLink disconnects per disconnect_signal {:?}", changed);
2326                            break;
2327                        }
2328                    }
2329                }
2330
2331                task_coordination_token.cancel();
2332            }
2333
2334            let (server, mut server_relay) = tokio::io::duplex(self.buffer_size);
2335            let (client, client_relay) = tokio::io::duplex(self.buffer_size);
2336
2337            // Write LinkInit on server_relay so it's readable from `server`.
2338            // This simulates the client sending LinkInit over the wire before
2339            // the frame-level relay begins.
2340            write_link_init(&mut server_relay, session_id)
2341                .await
2342                .map_err(|err| ClientError::Io(self.dest(), err))?;
2343
2344            let (server_r, server_writer) = tokio::io::split(server_relay);
2345            let (client_r, client_writer) = tokio::io::split(client_relay);
2346
2347            let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
2348            let server_reader = FrameReader::new(server_r, max_len);
2349            let client_reader = FrameReader::new(client_r, max_len);
2350
2351            let task_coordination_token = CancellationToken::new();
2352            let _server_relay_task_handle = tokio::spawn(relay_message::<M>(
2353                self.disconnect_signal.subscribe(),
2354                self.network_flakiness.clone(),
2355                self.disconnected_count.clone(),
2356                self.prev_disconnected_at.clone(),
2357                server_reader,
2358                client_writer,
2359                task_coordination_token.clone(),
2360                self.debug_log_sampling_rate.clone(),
2361                /*is_from_client*/ false,
2362            ));
2363            let _client_relay_task_handle = tokio::spawn(relay_message::<M>(
2364                self.disconnect_signal.subscribe(),
2365                self.network_flakiness.clone(),
2366                self.disconnected_count.clone(),
2367                self.prev_disconnected_at.clone(),
2368                client_reader,
2369                server_writer,
2370                task_coordination_token,
2371                self.debug_log_sampling_rate.clone(),
2372                /*is_from_client*/ true,
2373            ));
2374
2375            self.receiver_storage.put(server).await;
2376            Ok(client)
2377        }
2378    }
2379
2380    struct MockLinkListener {
2381        receiver_storage: Arc<MVar<DuplexStream>>,
2382        channel_addr: ChannelAddr,
2383    }
2384
2385    impl MockLinkListener {
2386        fn new(receiver_storage: Arc<MVar<DuplexStream>>, channel_addr: ChannelAddr) -> Self {
2387            Self {
2388                receiver_storage,
2389                channel_addr,
2390            }
2391        }
2392    }
2393
2394    impl fmt::Debug for MockLinkListener {
2395        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2396            f.debug_struct("MockLinkListener")
2397                .field("channel_addr", &self.channel_addr)
2398                .finish()
2399        }
2400    }
2401
2402    #[async_trait]
2403    impl super::Listener for MockLinkListener {
2404        type Stream = DuplexStream;
2405
2406        async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
2407            let stream = self.receiver_storage.take().await;
2408            Ok((stream, self.channel_addr.clone()))
2409        }
2410    }
2411
2412    /// Create an AcceptorLink-based server test rig. Returns the
2413    /// session task handle, an MVar for dispatching streams,
2414    /// the message receiver, and a cancellation token.
2415    fn serve_acceptor_test<M: RemoteMessage>(
2416        session_id: SessionId,
2417    ) -> (
2418        JoinHandle<()>,
2419        crate::sync::mvar::MVar<DuplexStream>,
2420        mpsc::Receiver<M>,
2421        CancellationToken,
2422    ) {
2423        let mvar = crate::sync::mvar::MVar::empty();
2424        let cancel_token = CancellationToken::new();
2425        let link = AcceptorLink {
2426            dest: ChannelAddr::Local(u64::MAX),
2427            session_id,
2428            stream: mvar.clone(),
2429            cancel: cancel_token.clone(),
2430        };
2431        let (tx, rx) = mpsc::channel::<M>(1024);
2432        let ct = cancel_token.clone();
2433        let handle = tokio::spawn(async move {
2434            let mut session = Session::new(link);
2435            let mut next = session::Next { seq: 0, ack: 0 };
2436
2437            loop {
2438                let connected = match session.connect().await {
2439                    Ok(s) => s,
2440                    Err(_) => break,
2441                };
2442
2443                let result = {
2444                    let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2445                    tokio::select! {
2446                        r = session::recv_connected::<M, _, _>(&stream, &tx, &mut next) => r,
2447                        _ = ct.cancelled() => Err(session::RecvLoopError::Cancelled),
2448                    }
2449                };
2450
2451                // Flush remaining ack if behind.
2452                if next.ack < next.seq {
2453                    let ack = serialize_response(NetRxResponse::Ack(next.seq - 1)).unwrap();
2454                    let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2455                    let mut completion = stream.write(ack);
2456                    match completion.drive().await {
2457                        Ok(()) => {
2458                            next.ack = next.seq;
2459                        }
2460                        Err(e) => {
2461                            tracing::debug!(
2462                                error = %e,
2463                                "failed to flush acks during cleanup"
2464                            );
2465                        }
2466                    }
2467                }
2468
2469                // Send reject or closed response if appropriate.
2470                let terminal_response = match &result {
2471                    Err(session::RecvLoopError::SequenceError(reason)) => {
2472                        Some(NetRxResponse::Reject(reason.clone()))
2473                    }
2474                    Err(session::RecvLoopError::Cancelled) => Some(NetRxResponse::Closed),
2475                    _ => None,
2476                };
2477                if let Some(rsp) = terminal_response {
2478                    let data = serialize_response(rsp).unwrap();
2479                    let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2480                    let mut completion = stream.write(data);
2481                    let _ = completion.drive().await;
2482                }
2483
2484                let recoverable = matches!(&result, Ok(()) | Err(session::RecvLoopError::Io(_)));
2485                session = connected.release();
2486                if recoverable {
2487                    continue;
2488                }
2489                break;
2490            }
2491        });
2492        (handle, mvar, rx, cancel_token)
2493    }
2494
2495    async fn write_stream<M, W>(
2496        mut writer: W,
2497        _session_id: u64,
2498        messages: &[(u64, M)],
2499        _init: bool,
2500    ) -> W
2501    where
2502        M: RemoteMessage + PartialEq + Clone,
2503        W: AsyncWrite + Unpin,
2504    {
2505        for (seq, message) in messages {
2506            let message =
2507                serde_multipart::serialize_bincode(&Frame::<M>::Message(*seq, message.clone()))
2508                    .unwrap();
2509            let mut fw = FrameWrite::new(
2510                writer,
2511                message.framed(),
2512                hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2513                0,
2514            )
2515            .map_err(|(_w, e)| e)
2516            .unwrap();
2517            fw.send().await.unwrap();
2518            writer = fw.complete();
2519        }
2520
2521        writer
2522    }
2523
2524    #[async_timed_test(timeout_secs = 60)]
2525    async fn test_persistent_server_session() {
2526        let config = hyperactor_config::global::lock();
2527        let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
2528
2529        async fn verify_ack(reader: &mut FrameReader<ReadHalf<DuplexStream>>, expected_last: u64) {
2530            let mut last_acked: i128 = -1;
2531            loop {
2532                let (_, bytes) = reader.next().await.unwrap().unwrap();
2533                let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
2534                assert!(
2535                    acked as i128 > last_acked,
2536                    "acks should be delivered in ascending order"
2537                );
2538                last_acked = acked as i128;
2539                assert!(acked <= expected_last);
2540                if acked == expected_last {
2541                    break;
2542                }
2543            }
2544        }
2545
2546        let session_id = SessionId(123);
2547        let (_handle, mvar, mut rx, cancel_token) = serve_acceptor_test::<u64>(session_id);
2548
2549        // First connection: send messages, verify delivery and ack.
2550        {
2551            let (sender, receiver) = tokio::io::duplex(5000);
2552            mvar.put(receiver).await;
2553
2554            let (r, writer) = tokio::io::split(sender);
2555            let mut reader = FrameReader::new(
2556                r,
2557                hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2558            );
2559
2560            let _writer = write_stream(
2561                writer,
2562                123,
2563                &[
2564                    (0u64, 100u64),
2565                    (1u64, 101u64),
2566                    (2u64, 102u64),
2567                    (3u64, 103u64),
2568                ],
2569                true,
2570            )
2571            .await;
2572
2573            assert_eq!(rx.recv().await, Some(100));
2574            assert_eq!(rx.recv().await, Some(101));
2575            assert_eq!(rx.recv().await, Some(102));
2576            assert_eq!(rx.recv().await, Some(103));
2577
2578            verify_ack(&mut reader, 3).await;
2579            // Drop reader and writer to close the connection.
2580        }
2581
2582        // Second connection (reconnection): retransmitted messages are deduped.
2583        {
2584            let (sender2, receiver2) = tokio::io::duplex(5000);
2585            mvar.put(receiver2).await;
2586
2587            let (r2, writer2) = tokio::io::split(sender2);
2588            let mut reader2 = FrameReader::new(
2589                r2,
2590                hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2591            );
2592
2593            let _ = write_stream(
2594                writer2,
2595                123,
2596                &[
2597                    (2u64, 102u64),
2598                    (3u64, 103u64),
2599                    (4u64, 104u64),
2600                    (5u64, 105u64),
2601                ],
2602                true,
2603            )
2604            .await;
2605
2606            // 102 and 103 are retransmits; only 104 and 105 are new.
2607            assert_eq!(rx.recv().await, Some(104));
2608            assert_eq!(rx.recv().await, Some(105));
2609
2610            verify_ack(&mut reader2, 5).await;
2611
2612            cancel_token.cancel();
2613        }
2614    }
2615
2616    #[async_timed_test(timeout_secs = 60)]
2617    async fn test_ack_from_server_session() {
2618        let config = hyperactor_config::global::lock();
2619        let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
2620        let session_id = SessionId(123);
2621        let (_handle, mvar, mut rx, cancel_token) = serve_acceptor_test::<u64>(session_id);
2622
2623        let (sender, receiver) = tokio::io::duplex(5000);
2624        mvar.put(receiver).await;
2625        let (r, mut writer) = tokio::io::split(sender);
2626        let mut reader = FrameReader::new(
2627            r,
2628            hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2629        );
2630
2631        for i in 0u64..100u64 {
2632            writer = write_stream(writer, 123, &[(i, 100u64 + i)], /*init*/ i == 0u64).await;
2633            assert_eq!(rx.recv().await, Some(100u64 + i));
2634            let (_, bytes) = reader.next().await.unwrap().unwrap();
2635            let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
2636            assert_eq!(acked, i);
2637        }
2638
2639        // Wait long enough to ensure server processed everything.
2640        tokio::time::sleep(Duration::from_secs(5)).await;
2641
2642        cancel_token.cancel();
2643
2644        // Should send NetRxResponse::Closed before stopping.
2645        let (_, bytes) = reader.next().await.unwrap().unwrap();
2646        assert!(deserialize_response(bytes).unwrap().is_closed());
2647    }
2648
2649    #[tracing_test::traced_test]
2650    async fn verify_tx_closed(tx_status: &mut watch::Receiver<TxStatus>, expected_log: &str) {
2651        match tokio::time::timeout(Duration::from_secs(5), tx_status.changed()).await {
2652            Ok(Ok(())) => {
2653                let current_status = *tx_status.borrow();
2654                assert_eq!(current_status, TxStatus::Closed);
2655                logs_assert_unscoped(|logs| {
2656                    if logs.iter().any(|log| log.contains(expected_log)) {
2657                        Ok(())
2658                    } else {
2659                        Err("expected log not found".to_string())
2660                    }
2661                });
2662            }
2663            Ok(Err(_)) => panic!("watch::Receiver::changed() failed because sender is dropped."),
2664            Err(_) => panic!("timeout before tx_status changed"),
2665        }
2666    }
2667
2668    #[tracing_test::traced_test]
2669    #[tokio::test]
2670    // TODO: OSS: The logs_assert function returned an error: expected log not found
2671    #[cfg_attr(not(fbcode_build), ignore)]
2672    async fn test_tcp_tx_delivery_timeout() {
2673        // This link always fails to connect.
2674        let link = MockLink::<u64>::fail_connects();
2675        let tx = spawn::<u64>(link);
2676        // Override the default (1m) for the purposes of this test.
2677        let config = hyperactor_config::global::lock();
2678        let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1));
2679        let mut tx_receiver = tx.status().clone();
2680        let (return_channel, _return_receiver) = oneshot::channel();
2681        tx.try_post(123, return_channel);
2682        verify_tx_closed(&mut tx_receiver, "failed to deliver message within timeout").await;
2683    }
2684
2685    async fn take_receiver(
2686        receiver_storage: &MVar<DuplexStream>,
2687    ) -> (FrameReader<ReadHalf<DuplexStream>>, WriteHalf<DuplexStream>) {
2688        let mut receiver = receiver_storage.take().await;
2689        // Read and discard the LinkInit header that MockLink::connect() writes.
2690        let _session_id = read_link_init(&mut receiver).await.expect("read LinkInit");
2691        let (r, writer) = tokio::io::split(receiver);
2692        let reader = FrameReader::new(
2693            r,
2694            hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2695        );
2696        (reader, writer)
2697    }
2698
2699    async fn verify_message<M: RemoteMessage + PartialEq + std::fmt::Debug>(
2700        reader: &mut FrameReader<ReadHalf<DuplexStream>>,
2701        expect: (u64, M),
2702        loc: u32,
2703    ) {
2704        let expected = Frame::Message(expect.0, expect.1);
2705        let (_, bytes) = reader.next().await.unwrap().expect("unexpected EOF");
2706        let message = serde_multipart::Message::from_framed(bytes).unwrap();
2707        let frame: Frame<M> = serde_multipart::deserialize_bincode(message).unwrap();
2708
2709        assert_eq!(frame, expected, "from ln={loc}");
2710    }
2711
2712    async fn verify_stream<M: RemoteMessage + PartialEq + std::fmt::Debug + Clone>(
2713        reader: &mut FrameReader<ReadHalf<DuplexStream>>,
2714        expects: &[(u64, M)],
2715        _expect_session_id: Option<u64>,
2716        loc: u32,
2717    ) {
2718        for expect in expects {
2719            verify_message(reader, expect.clone(), loc).await;
2720        }
2721    }
2722
2723    async fn net_tx_send(tx: &NetTx<u64>, msgs: &[u64]) {
2724        for msg in msgs {
2725            tx.post(*msg);
2726        }
2727    }
2728
2729    // Happy path: all messages are acked.
2730    #[async_timed_test(timeout_secs = 30)]
2731    async fn test_ack_in_net_tx_basic() {
2732        let link = MockLink::<u64>::new();
2733        let receiver_storage = link.receiver_storage();
2734        let tx = spawn::<u64>(link);
2735
2736        // Send some messages, but not acking any of them.
2737        net_tx_send(&tx, &[100, 101, 102, 103, 104]).await;
2738        {
2739            let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2740            verify_stream(
2741                &mut reader,
2742                &[
2743                    (0u64, 100u64),
2744                    (1u64, 101u64),
2745                    (2u64, 102u64),
2746                    (3u64, 103u64),
2747                    (4u64, 104u64),
2748                ],
2749                None,
2750                line!(),
2751            )
2752            .await;
2753
2754            for i in 0u64..5u64 {
2755                writer = FrameWrite::write_frame(
2756                    writer,
2757                    serialize_response(NetRxResponse::Ack(i)).unwrap(),
2758                    1024,
2759                    0,
2760                )
2761                .await
2762                .map_err(|(_, e)| e)
2763                .unwrap();
2764            }
2765            // Wait for the acks to be processed by NetTx.
2766            tokio::time::sleep(Duration::from_secs(3)).await;
2767            // Drop both halves to break the in-memory connection (parity with old drop of DuplexStream).
2768            drop(reader);
2769            drop(writer);
2770        };
2771
2772        // Sent a new message to verify all sent messages will not be resent.
2773        net_tx_send(&tx, &[105u64]).await;
2774        {
2775            let (mut reader, _writer) = take_receiver(&receiver_storage).await;
2776            verify_stream(&mut reader, &[(5u64, 105u64)], None, line!()).await;
2777            // Reader/writer dropped here. This breaks the connection.
2778        };
2779    }
2780
2781    // Verify unacked message will be resent after reconnection.
2782    #[async_timed_test(timeout_secs = 60)]
2783    async fn test_persistent_net_tx() {
2784        let link = MockLink::<u64>::new();
2785        let receiver_storage = link.receiver_storage();
2786
2787        let tx = spawn::<u64>(link);
2788
2789        // Send some messages, but not acking any of them.
2790        net_tx_send(&tx, &[100, 101, 102, 103, 104]).await;
2791
2792        // How many times to reconnect.
2793        let n = 10;
2794
2795        // Reconnect multiple times. The messages should be resent every time
2796        // because none of them is acked.
2797        for i in 0..n {
2798            {
2799                let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2800                verify_stream(
2801                    &mut reader,
2802                    &[
2803                        (0u64, 100u64),
2804                        (1u64, 101u64),
2805                        (2u64, 102u64),
2806                        (3u64, 103u64),
2807                        (4u64, 104u64),
2808                    ],
2809                    None,
2810                    line!(),
2811                )
2812                .await;
2813
2814                // In the last iteration, ack part of the messages. This should
2815                // prune them from future resent.
2816                if i == n - 1 {
2817                    writer = FrameWrite::write_frame(
2818                        writer,
2819                        serialize_response(NetRxResponse::Ack(1)).unwrap(),
2820                        1024,
2821                        0,
2822                    )
2823                    .await
2824                    .map_err(|(_, e)| e)
2825                    .unwrap();
2826                    // Wait for the acks to be processed by NetTx.
2827                    tokio::time::sleep(Duration::from_secs(3)).await;
2828                }
2829                // client DuplexStream is dropped here. This breaks the connection.
2830                drop(reader);
2831                drop(writer);
2832            };
2833        }
2834
2835        // Verify only unacked are resent.
2836        for _ in 0..n {
2837            {
2838                let (mut reader, mut _writer) = take_receiver(&receiver_storage).await;
2839                verify_stream(
2840                    &mut reader,
2841                    &[(2u64, 102u64), (3u64, 103u64), (4u64, 104u64)],
2842                    None,
2843                    line!(),
2844                )
2845                .await;
2846                // drop(reader/_writer) at scope end
2847            };
2848        }
2849
2850        // Now send more messages.
2851        net_tx_send(&tx, &[105u64, 106u64, 107u64, 108u64, 109u64]).await;
2852        // Verify the unacked messages from the 1st send will be grouped with
2853        // the 2nd send.
2854        for i in 0..n {
2855            {
2856                let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2857                verify_stream(
2858                    &mut reader,
2859                    &[
2860                        // From the 1st send.
2861                        (2u64, 102u64),
2862                        (3u64, 103u64),
2863                        (4u64, 104u64),
2864                        // From the 2nd send.
2865                        (5u64, 105u64),
2866                        (6u64, 106u64),
2867                        (7u64, 107u64),
2868                        (8u64, 108u64),
2869                        (9u64, 109u64),
2870                    ],
2871                    None,
2872                    line!(),
2873                )
2874                .await;
2875
2876                // In the last iteration, ack part of the messages from the 1st
2877                // sent.
2878                if i == n - 1 {
2879                    // Intentionally ack 1 again to verify it is okay to ack
2880                    // messages that was already acked.
2881                    writer = FrameWrite::write_frame(
2882                        writer,
2883                        serialize_response(NetRxResponse::Ack(1)).unwrap(),
2884                        1024,
2885                        0,
2886                    )
2887                    .await
2888                    .map_err(|(_, e)| e)
2889                    .unwrap();
2890                    writer = FrameWrite::write_frame(
2891                        writer,
2892                        serialize_response(NetRxResponse::Ack(2)).unwrap(),
2893                        1024,
2894                        0,
2895                    )
2896                    .await
2897                    .map_err(|(_, e)| e)
2898                    .unwrap();
2899                    writer = FrameWrite::write_frame(
2900                        writer,
2901                        serialize_response(NetRxResponse::Ack(3)).unwrap(),
2902                        1024,
2903                        0,
2904                    )
2905                    .await
2906                    .map_err(|(_, e)| e)
2907                    .unwrap();
2908                    // Wait for the acks to be processed by NetTx.
2909                    tokio::time::sleep(Duration::from_secs(3)).await;
2910                }
2911                // client DuplexStream is dropped here. This breaks the connection.
2912                drop(reader);
2913                drop(writer);
2914            };
2915        }
2916
2917        for i in 0..n {
2918            {
2919                let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2920                verify_stream(
2921                    &mut reader,
2922                    &[
2923                        // From the 1st send.
2924                        (4u64, 104),
2925                        // From the 2nd send.
2926                        (5u64, 105u64),
2927                        (6u64, 106u64),
2928                        (7u64, 107u64),
2929                        (8u64, 108u64),
2930                        (9u64, 109u64),
2931                    ],
2932                    None,
2933                    line!(),
2934                )
2935                .await;
2936
2937                // In the last iteration, ack part of the messages from the 2nd send.
2938                if i == n - 1 {
2939                    writer = FrameWrite::write_frame(
2940                        writer,
2941                        serialize_response(NetRxResponse::Ack(7)).unwrap(),
2942                        1024,
2943                        0,
2944                    )
2945                    .await
2946                    .map_err(|(_, e)| e)
2947                    .unwrap();
2948                    // Wait for the acks to be processed by NetTx.
2949                    tokio::time::sleep(Duration::from_secs(3)).await;
2950                }
2951                // client DuplexStream is dropped here. This breaks the connection.
2952                drop(reader);
2953                drop(writer);
2954            };
2955        }
2956
2957        for _ in 0..n {
2958            {
2959                let (mut reader, writer) = take_receiver(&receiver_storage).await;
2960                verify_stream(
2961                    &mut reader,
2962                    &[
2963                        // From the 2nd send.
2964                        (8u64, 108u64),
2965                        (9u64, 109u64),
2966                    ],
2967                    None,
2968                    line!(),
2969                )
2970                .await;
2971                // client DuplexStream is dropped here. This breaks the connection.
2972                drop(reader);
2973                drop(writer);
2974            };
2975        }
2976    }
2977
2978    #[async_timed_test(timeout_secs = 15)]
2979    async fn test_ack_before_redelivery_in_net_tx() {
2980        let link = MockLink::<u64>::new();
2981        let receiver_storage = link.receiver_storage();
2982        let net_tx = spawn::<u64>(link);
2983
2984        // Verify sent-and-ack a message. This is necessary for the test to
2985        // trigger a connection.
2986        let (return_channel_tx, return_channel_rx) = oneshot::channel();
2987        net_tx.try_post(100, return_channel_tx);
2988        let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2989        verify_stream(&mut reader, &[(0u64, 100u64)], None, line!()).await;
2990        // ack it
2991        writer = FrameWrite::write_frame(
2992            writer,
2993            serialize_response(NetRxResponse::Ack(0)).unwrap(),
2994            1024,
2995            0,
2996        )
2997        .await
2998        .map_err(|(_, e)| e)
2999        .unwrap();
3000        // confirm Tx received ack
3001        //
3002        // Using `is_err` to confirm the message is delivered/acked is confusing,
3003        // but is correct. See how send is implemented: https://fburl.com/code/ywt8lip2
3004        assert!(return_channel_rx.await.is_err());
3005
3006        // Now fake an unknown delivery for Tx:
3007        // Although Tx did not actually send seq=1, we still ack it from Rx to
3008        // pretend Tx already sent it, just it did not know it was sent
3009        // successfully.
3010        let _ = FrameWrite::write_frame(
3011            writer,
3012            serialize_response(NetRxResponse::Ack(1)).unwrap(),
3013            1024,
3014            0,
3015        )
3016        .await
3017        .map_err(|(_, e)| e)
3018        .unwrap();
3019
3020        let (return_channel_tx, return_channel_rx) = oneshot::channel();
3021        net_tx.try_post(101, return_channel_tx);
3022        // Verify the message is sent to Rx.
3023        verify_message(&mut reader, (1u64, 101u64), line!()).await;
3024        // although we did not ack the message after it is sent, since we already
3025        // acked it previously, Tx will treat it as acked, and considered the
3026        // message delivered successfully.
3027        //
3028        // Using `is_err` to confirm the message is delivered/acked is confusing,
3029        // but is correct. See how send is implemented: https://fburl.com/code/ywt8lip2
3030        assert!(return_channel_rx.await.is_err());
3031    }
3032
3033    async fn verify_ack_exceeded_limit(disconnect_before_ack: bool) {
3034        // Use temporary config for this test
3035        let config = hyperactor_config::global::lock();
3036        let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(2));
3037
3038        let link: MockLink<u64> = MockLink::<u64>::new();
3039        let disconnect_signal = link.disconnect_signal().clone();
3040        let fail_connect_switch = link.fail_connects_switch();
3041        let receiver_storage = link.receiver_storage();
3042        let tx = spawn::<u64>(link);
3043        let mut tx_status = tx.status().clone();
3044        // send a message
3045        tx.post(100);
3046        let (mut reader, writer) = take_receiver(&receiver_storage).await;
3047        // Confirm message is sent to rx.
3048        verify_stream(&mut reader, &[(0u64, 100u64)], None, line!()).await;
3049        // ack it
3050        let _ = FrameWrite::write_frame(
3051            writer,
3052            serialize_response(NetRxResponse::Ack(0)).unwrap(),
3053            hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3054            0,
3055        )
3056        .await
3057        .map_err(|(_, e)| e)
3058        .unwrap();
3059        tokio::time::sleep(Duration::from_secs(3)).await;
3060        // Channel should be still alive because ack was sent.
3061        assert!(!tx_status.has_changed().unwrap());
3062        assert_eq!(*tx_status.borrow(), TxStatus::Active);
3063
3064        tx.post(101);
3065        // Confirm message is sent to rx.
3066        verify_message(&mut reader, (1u64, 101u64), line!()).await;
3067
3068        if disconnect_before_ack {
3069            // Prevent link from reconnect
3070            fail_connect_switch.store(true, Ordering::Release);
3071            // Break the existing connection
3072            disconnect_signal.send(()).unwrap();
3073        }
3074
3075        // Verify the channel is closed due to ack timeout based on the log.
3076        let expected_log: &str = if disconnect_before_ack {
3077            "failed to receive ack within timeout 2s; link is currently broken"
3078        } else {
3079            "failed to receive ack within timeout 2s; link is currently connected"
3080        };
3081
3082        verify_tx_closed(&mut tx_status, expected_log).await;
3083    }
3084
3085    #[tracing_test::traced_test]
3086    #[async_timed_test(timeout_secs = 30)]
3087    // TODO: OSS: The logs_assert function returned an error: expected log not found
3088    #[cfg_attr(not(fbcode_build), ignore)]
3089    async fn test_ack_exceeded_limit_with_connected_link() {
3090        verify_ack_exceeded_limit(false).await;
3091    }
3092
3093    #[tracing_test::traced_test]
3094    #[async_timed_test(timeout_secs = 30)]
3095    // TODO: OSS: The logs_assert function returned an error: expected log not found
3096    #[cfg_attr(not(fbcode_build), ignore)]
3097    async fn test_ack_exceeded_limit_with_broken_link() {
3098        verify_ack_exceeded_limit(true).await;
3099    }
3100
3101    // Verify a large number of messages can be delivered and acked with the
3102    // presence of flakiness in the network, i.e. random delay and disconnection.
3103    #[async_timed_test(timeout_secs = 60)]
3104    async fn test_network_flakiness_in_channel() {
3105        hyperactor_telemetry::initialize_logging_for_test();
3106
3107        let sampling_rate = 100;
3108        let mut link = MockLink::<u64>::with_network_flakiness(NetworkFlakiness {
3109            disconnect_params: Some((0.001, 15, Duration::from_millis(400))),
3110            latency_range: Some((Duration::from_millis(100), Duration::from_millis(200))),
3111        });
3112        link.set_sampling_rate(sampling_rate);
3113        // Set a large buffer size to improve throughput.
3114        link.set_buffer_size(1024000);
3115        let disconnected_count = link.disconnected_count();
3116        let receiver_storage = link.receiver_storage();
3117        let listener = MockLinkListener::new(receiver_storage.clone(), link.dest());
3118        let local_addr = listener.channel_addr.clone();
3119        let (_, mut nx): (ChannelAddr, NetRx<u64>) =
3120            super::server::serve_with_listener(listener, local_addr).unwrap();
3121        let tx = spawn::<u64>(link);
3122        let messages: Vec<_> = (0..10001).collect();
3123        let messages_clone = messages.clone();
3124        // Put the sender side in a separate task so we can start the receiver
3125        // side concurrently.
3126        let send_task_handle = tokio::spawn(async move {
3127            for message in messages_clone {
3128                // Add a small delay between messages to give NetRx time to ack.
3129                // Technically, this test still can pass without this delay. But
3130                // the test will need a might larger timeout. The reason is
3131                // fairly convoluted:
3132                //
3133                // MockLink uses the number of delivery to calculate the disconnection
3134                // probability. If NetRx sends messages much faster than NetTx
3135                // can ack them, there is a higher chance that the messages are
3136                // not acked before reconnect. Then those message would be redelivered.
3137                // The repeated redelivery increases the total time of sending
3138                // these messages.
3139                tokio::time::sleep(Duration::from_micros(rand::random::<u64>() % 100)).await;
3140                tx.post(message);
3141            }
3142            tracing::debug!("NetTx sent all messages");
3143            // It is important to return tx instead of dropping it here, because
3144            // Rx might not receive all messages yet.
3145            tx
3146        });
3147
3148        for message in &messages {
3149            if message % sampling_rate == 0 {
3150                tracing::debug!("NetRx received a message: {message}");
3151            }
3152            assert_eq!(nx.recv().await.unwrap(), *message);
3153        }
3154        tracing::debug!("NetRx received all messages");
3155
3156        let send_result = send_task_handle.await;
3157        assert!(send_result.is_ok());
3158
3159        tracing::debug!(
3160            "MockLink disconnected {} times.",
3161            disconnected_count.load(Ordering::SeqCst)
3162        );
3163        // TODO(pzhang) after the return_handle work in NetTx is done, add a
3164        // check here to verify the messages are acked correctly.
3165    }
3166
3167    #[async_timed_test(timeout_secs = 60)]
3168    async fn test_ack_every_n_messages() {
3169        let config = hyperactor_config::global::lock();
3170        let _guard_message_ack = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 600);
3171        let _guard_time_interval =
3172            config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(1000));
3173        sparse_ack().await;
3174    }
3175
3176    #[async_timed_test(timeout_secs = 60)]
3177    async fn test_ack_every_time_interval() {
3178        let config = hyperactor_config::global::lock();
3179        let _guard_message_ack =
3180            config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 100000000);
3181        let _guard_time_interval = config.override_key(
3182            config::MESSAGE_ACK_TIME_INTERVAL,
3183            Duration::from_millis(500),
3184        );
3185        sparse_ack().await;
3186    }
3187
3188    async fn sparse_ack() {
3189        let mut link = MockLink::<u64>::new();
3190        // Set a large buffer size to improve throughput.
3191        link.set_buffer_size(1024000);
3192        let disconnected_count = link.disconnected_count();
3193        let receiver_storage = link.receiver_storage();
3194        let listener = MockLinkListener::new(receiver_storage.clone(), link.dest());
3195        let local_addr = listener.channel_addr.clone();
3196        let (_, mut nx): (ChannelAddr, NetRx<u64>) =
3197            super::server::serve_with_listener(listener, local_addr).unwrap();
3198        let tx = spawn::<u64>(link);
3199        let messages: Vec<_> = (0..20001).collect();
3200        let messages_clone = messages.clone();
3201        // Put the sender side in a separate task so we can start the receiver
3202        // side concurrently.
3203        let send_task_handle = tokio::spawn(async move {
3204            for message in messages_clone {
3205                tokio::time::sleep(Duration::from_micros(rand::random::<u64>() % 100)).await;
3206                tx.post(message);
3207            }
3208            tokio::time::sleep(Duration::from_secs(5)).await;
3209            tracing::debug!("NetTx sent all messages");
3210            tx
3211        });
3212
3213        for message in &messages {
3214            assert_eq!(nx.recv().await.unwrap(), *message);
3215        }
3216        tracing::debug!("NetRx received all messages");
3217
3218        let send_result = send_task_handle.await;
3219        assert!(send_result.is_ok());
3220
3221        tracing::debug!(
3222            "MockLink disconnected {} times.",
3223            disconnected_count.load(Ordering::SeqCst)
3224        );
3225    }
3226
3227    #[test]
3228    fn test_metatls_parsing() {
3229        // host:port
3230        let channel: ChannelAddr = "metatls!localhost:1234".parse().unwrap();
3231        assert_eq!(
3232            channel,
3233            ChannelAddr::MetaTls(TlsAddr::new("localhost", 1234))
3234        );
3235        // ipv4:port - parsed as hostname with ip normalization
3236        let channel: ChannelAddr = "metatls!1.2.3.4:1234".parse().unwrap();
3237        assert_eq!(channel, ChannelAddr::MetaTls(TlsAddr::new("1.2.3.4", 1234)));
3238        // ipv6:port
3239        let channel: ChannelAddr = "metatls!2401:db00:33c:6902:face:0:2a2:0:1234"
3240            .parse()
3241            .unwrap();
3242        assert_eq!(
3243            channel,
3244            ChannelAddr::MetaTls(TlsAddr::new("2401:db00:33c:6902:face:0:2a2:0", 1234))
3245        );
3246
3247        let channel: ChannelAddr = "metatls![::]:1234".parse().unwrap();
3248        assert_eq!(channel, ChannelAddr::MetaTls(TlsAddr::new("::", 1234)));
3249    }
3250
3251    #[async_timed_test(timeout_secs = 300)]
3252    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
3253    #[cfg_attr(not(fbcode_build), ignore)]
3254    async fn test_tcp_throughput() {
3255        let config = hyperactor_config::global::lock();
3256        let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_mins(5));
3257
3258        let socket_addr: SocketAddr = "[::1]:0".parse().unwrap();
3259        let (local_addr, mut rx) = server::serve::<String>(ChannelAddr::Tcp(socket_addr)).unwrap();
3260
3261        // Test with 10 connections (senders), each sends 500K messages, 5M messages in total.
3262        let total_num_msgs = 500000;
3263
3264        let receive_handle = tokio::spawn(async move {
3265            let mut num = 0;
3266            for _ in 0..10 * total_num_msgs {
3267                rx.recv().await.unwrap();
3268                num += 1;
3269
3270                if num % 100000 == 0 {
3271                    tracing::info!("total number of received messages: {}", num);
3272                }
3273            }
3274        });
3275
3276        let mut tx_handles = vec![];
3277        let mut txs = vec![];
3278        for _ in 0..10 {
3279            let server_addr = local_addr.clone();
3280            let tx = Arc::new(channel::dial::<String>(server_addr).unwrap());
3281            let tx2 = Arc::clone(&tx);
3282            txs.push(tx);
3283            tx_handles.push(tokio::spawn(async move {
3284                let random_string = rand::thread_rng()
3285                    .sample_iter(&Alphanumeric)
3286                    .take(2048)
3287                    .map(char::from)
3288                    .collect::<String>();
3289                for _ in 0..total_num_msgs {
3290                    tx2.post(random_string.clone());
3291                }
3292            }));
3293        }
3294
3295        receive_handle.await.unwrap();
3296        for handle in tx_handles {
3297            handle.await.unwrap();
3298        }
3299    }
3300
3301    #[tracing_test::traced_test]
3302    #[async_timed_test(timeout_secs = 60)]
3303    // TODO: OSS: The logs_assert function returned an error: expected log not found
3304    #[cfg_attr(not(fbcode_build), ignore)]
3305    async fn test_net_tx_closed_on_server_reject() {
3306        let link = MockLink::<u64>::new();
3307        let receiver_storage = link.receiver_storage();
3308        let mut tx = spawn::<u64>(link);
3309        net_tx_send(&tx, &[100]).await;
3310
3311        {
3312            let (_reader, writer) = take_receiver(&receiver_storage).await;
3313            let _ = FrameWrite::write_frame(
3314                writer,
3315                serialize_response(NetRxResponse::Reject("testing".to_string())).unwrap(),
3316                1024,
3317                0,
3318            )
3319            .await
3320            .map_err(|(_, e)| e);
3321
3322            // Wait for response to be processed by NetTx before dropping reader/writer. Otherwise
3323            // the channel will be closed and we will get the wrong error.
3324            tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
3325        }
3326
3327        verify_tx_closed(&mut tx.status, "server rejected connection").await;
3328    }
3329
3330    #[async_timed_test(timeout_secs = 60)]
3331    async fn test_server_rejects_conn_on_out_of_sequence_message() {
3332        let config = hyperactor_config::global::lock();
3333        let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
3334        let session_id = SessionId(123);
3335        let (_handle, mvar, mut rx, _cancel_token) = serve_acceptor_test::<u64>(session_id);
3336
3337        let (sender, receiver) = tokio::io::duplex(5000);
3338        mvar.put(receiver).await;
3339        let (r, writer) = tokio::io::split(sender);
3340        let mut reader = FrameReader::new(
3341            r,
3342            hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3343        );
3344
3345        let _ = write_stream(writer, 123, &[(0, 100u64), (1, 101u64), (3, 103u64)], true).await;
3346        assert_eq!(rx.recv().await, Some(100u64));
3347        assert_eq!(rx.recv().await, Some(101u64));
3348        let (_, bytes) = reader.next().await.unwrap().unwrap();
3349        let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
3350        assert_eq!(acked, 0);
3351        let (_, bytes) = reader.next().await.unwrap().unwrap();
3352        let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
3353        assert_eq!(acked, 1);
3354        let (_, bytes) = reader.next().await.unwrap().unwrap();
3355        assert!(deserialize_response(bytes).unwrap().is_reject());
3356    }
3357
3358    #[async_timed_test(timeout_secs = 60)]
3359    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
3360    #[cfg_attr(not(fbcode_build), ignore)]
3361    async fn test_stop_net_tx_after_stopping_net_rx() {
3362        hyperactor_telemetry::initialize_logging_for_test();
3363
3364        let config = hyperactor_config::global::lock();
3365        let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_mins(5));
3366        let (addr, mut rx) =
3367            server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
3368        let socket_addr = match addr {
3369            ChannelAddr::Tcp(a) => a,
3370            _ => panic!("unexpected channel type"),
3371        };
3372        let tx: NetTx<u64> = spawn(tcp::link(socket_addr));
3373        // NetTx will not establish a connection until it sends the 1st message.
3374        // Without a live connection, NetTx cannot received the Closed message
3375        // from NetRx. Therefore, we need to send a message to establish the
3376        //connection.
3377        tx.send(100).await.unwrap();
3378        assert_eq!(rx.recv().await.unwrap(), 100);
3379        // Drop rx will close the NetRx server.
3380        rx.2.stop("testing");
3381        assert!(rx.recv().await.is_err());
3382
3383        // NetTx will only read from the stream when it needs to send a message
3384        // or wait for an ack. Therefore we need to send a message to trigger that.
3385        tx.post(101);
3386        let mut watcher = tx.status().clone();
3387        // When NetRx exits, it should notify NetTx to exit as well.
3388        let _ = watcher.wait_for(|val| *val == TxStatus::Closed).await;
3389        // wait_for could return Err due to race between when watch's sender was
3390        // dropped and when wait_for was called. So we still need to do an
3391        // equality check.
3392        assert_eq!(*watcher.borrow(), TxStatus::Closed);
3393    }
3394}