Skip to main content

hyperactor/channel/
net.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! A simple socket channel implementation using a single-stream
10//! framing protocol. Each frame is encoded as an 8-byte
11//! **big-endian** length prefix (u64), followed by exactly that many
12//! bytes of payload.
13//!
14//! Message frames carry a `serde_multipart::Message` (not raw
15//! bincode). In compat mode (current default), this is encoded as a
16//! sentinel `u64::MAX` followed by a single bincode payload. Response frames
17//! are a bincode-serialized NetRxResponse enum, containing either the acked
18//! sequence number, or the Reject value indicating that the server rejected
19//! the connection.
20//!
21//! Message frame (compat/unipart) example:
22//! ```text
23//! +------------------ len: u64 (BE) ------------------+----------------------- data -----------------------+
24//! | \x00\x00\x00\x00\x00\x00\x00\x10                  | \xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF | <bincode bytes> |
25//! |                       16                          |           u64::MAX             |                   |
26//! +---------------------------------------------------+-----------------------------------------------------+
27//! ```
28//!
29//! Response frame (wire format):
30//! ```text
31//! +------------------ len: u64 (BE) ------------------+---------------- data ------------------+
32//! | \x00\x00\x00\x00\x00\x00\x00\x??                  | <bincode acked sequence num or reject> |
33//! +---------------------------------------------------+----------------------------------------+
34//! ```
35//!
36//! I/O is handled by `FrameReader`/`FrameWrite`, which are
37//! cancellation-safe and avoid extra copies. Helper fns
38//! `serialize_response(NetRxResponse) -> Result<Bytes, bincode::Error>`
39//! and `deserialize_response(Bytes) -> Result<NetRxResponse, bincode::Error>`
40//! convert to/from the response payload.
41//!
42//! ### Limits & EOF semantics
43//! * **Max frame size:** frames larger than
44//!   `config::CODEC_MAX_FRAME_LENGTH` are rejected with
45//!   `io::ErrorKind::InvalidData`.
46//! * **EOF handling:** `FrameReader::next()` returns `Ok(None)` only
47//!   when EOF occurs exactly on a frame boundary. If EOF happens
48//!   mid-frame, it returns `Err(io::ErrorKind::UnexpectedEof)`.
49
50use std::fmt;
51use std::fmt::Debug;
52use std::net::ToSocketAddrs;
53use std::time::Duration;
54
55use backoff::ExponentialBackoffBuilder;
56use backoff::backoff::Backoff;
57use bytes::Bytes;
58use enum_as_inner::EnumAsInner;
59use serde::Deserialize;
60use serde::Serialize;
61use serde::de::Error;
62use tokio::io::AsyncRead;
63use tokio::io::AsyncReadExt;
64use tokio::io::AsyncWrite;
65use tokio::io::AsyncWriteExt;
66use tokio::sync::watch;
67use tokio::time::Instant;
68
69use super::*;
70use crate::RemoteMessage;
71
72pub mod duplex;
73mod framed;
74pub(super) mod server;
75pub(super) mod session;
76pub use server::ServerHandle;
77
78pub(crate) trait Stream:
79    AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug + 'static
80{
81}
82impl<S: AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug + 'static> Stream for S {}
83
84/// Opaque identifier for a session. Generated by the client,
85/// sent to the server on each connect so the server can correlate
86/// reconnections to the same logical session.
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
88pub(crate) struct SessionId(u64);
89
90impl SessionId {
91    /// Generate a new random session ID.
92    pub fn random() -> Self {
93        Self(rand::random())
94    }
95}
96
97impl fmt::Display for SessionId {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        write!(f, "{:016x}", self.0)
100    }
101}
102
103/// Logical channel tag for initiator→acceptor traffic.
104pub(crate) const INITIATOR_TO_ACCEPTOR: u8 = 0;
105
106/// Logical channel tag for acceptor→initiator traffic.
107pub(crate) const ACCEPTOR_TO_INITIATOR: u8 = 1;
108
109/// Fixed-size header sent at the start of every physical connection.
110/// This is written/read directly on the wire (not framed), before
111/// any session framing begins.
112///
113/// Wire format (13 bytes, big-endian):
114/// ```text
115/// [magic: 4B "LNK\0"] [session_id: 8B u64 BE] [stream_id: 1B u8]
116/// ```
117const LINK_INIT_MAGIC: [u8; 4] = *b"LNK\0";
118const LINK_INIT_SIZE: usize = 4 + 8 + 1;
119
120/// Parsed LinkInit header.
121#[derive(Debug, Clone, Copy)]
122pub(crate) struct LinkInit {
123    pub session_id: SessionId,
124    pub stream_id: u8,
125}
126
127/// Write a LinkInit header to the stream.
128async fn write_link_init<S: AsyncWrite + Unpin>(
129    stream: &mut S,
130    session_id: SessionId,
131    stream_id: u8,
132) -> Result<(), std::io::Error> {
133    let mut buf = [0u8; LINK_INIT_SIZE];
134    buf[0..4].copy_from_slice(&LINK_INIT_MAGIC);
135    buf[4..12].copy_from_slice(&session_id.0.to_be_bytes());
136    buf[12] = stream_id;
137    stream.write_all(&buf).await
138}
139
140/// Read a LinkInit header from the stream.
141async fn read_link_init<S: AsyncRead + Unpin>(stream: &mut S) -> Result<LinkInit, std::io::Error> {
142    let mut buf = [0u8; LINK_INIT_SIZE];
143    stream.read_exact(&mut buf).await?;
144    if buf[0..4] != LINK_INIT_MAGIC {
145        return Err(std::io::Error::new(
146            std::io::ErrorKind::InvalidData,
147            format!("invalid LinkInit magic: expected LNK, got {:?}", &buf[0..4]),
148        ));
149    }
150    let session_id = SessionId(u64::from_be_bytes(buf[4..12].try_into().unwrap()));
151    let stream_id = buf[12];
152    Ok(LinkInit {
153        session_id,
154        stream_id,
155    })
156}
157
158/// Link represents a network link through which connections may be
159/// acquired. The session ID is baked in. Initiator links dial;
160/// acceptor links wait for dispatched streams.
161#[async_trait]
162pub(crate) trait Link: Send + Sync + Debug + 'static {
163    /// The underlying stream type.
164    type Stream: Stream;
165
166    /// The address of the link's destination.
167    fn dest(&self) -> ChannelAddr;
168
169    /// The session ID for this link.
170    fn link_id(&self) -> SessionId;
171
172    /// Acquire the next usable connection. For initiator links this
173    /// dials; for acceptor links this waits on a dispatch channel.
174    async fn next(&mut self) -> Result<Self::Stream, ClientError>;
175}
176
177use session::Session;
178
179use crate::config;
180use crate::metrics;
181
182pub(crate) enum LinkStatus {
183    NeverConnected,
184    Connected(tokio::time::Instant),
185    Disconnected {
186        last_connected: tokio::time::Instant,
187        since: tokio::time::Instant,
188    },
189}
190
191impl LinkStatus {
192    fn connected(&mut self) {
193        *self = LinkStatus::Connected(tokio::time::Instant::now());
194    }
195
196    fn disconnected(&mut self) {
197        match *self {
198            LinkStatus::Connected(at) => {
199                *self = LinkStatus::Disconnected {
200                    last_connected: at,
201                    since: tokio::time::Instant::now(),
202                };
203            }
204            // Already disconnected or never connected — leave as is.
205            _ => {}
206        }
207    }
208}
209
210impl std::fmt::Display for LinkStatus {
211    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212        match self {
213            LinkStatus::NeverConnected => write!(f, "never connected"),
214            LinkStatus::Connected(at) => {
215                write!(f, "connected for {:.1}s", at.elapsed().as_secs_f64())
216            }
217            LinkStatus::Disconnected {
218                last_connected,
219                since,
220            } => {
221                write!(
222                    f,
223                    "last connected {:.1}s ago, disconnected for {:.1}s",
224                    last_connected.elapsed().as_secs_f64(),
225                    since.elapsed().as_secs_f64(),
226                )
227            }
228        }
229    }
230}
231
232/// Log a send-loop error and return `true` if the error is terminal
233/// (caller should exit), `false` if recoverable (caller should reconnect).
234fn log_send_error(
235    error: &session::SendLoopError,
236    dest: &ChannelAddr,
237    session_id: u64,
238    mode: &str,
239    link_status: &LinkStatus,
240) -> bool {
241    match error {
242        session::SendLoopError::Io(err) => {
243            tracing::info!(dest = %dest, session_id, error = %err, mode, "send error; {link_status}");
244            metrics::CHANNEL_ERRORS.add(
245                1,
246                hyperactor_telemetry::kv_pairs!(
247                    "dest" => dest.to_string(),
248                    "session_id" => session_id.to_string(),
249                    "error_type" => metrics::ChannelErrorType::SendError.as_str(),
250                    "mode" => mode.to_string(),
251                ),
252            );
253            false
254        }
255        session::SendLoopError::AppClosed => true,
256        session::SendLoopError::Rejected(reason) => {
257            tracing::error!(dest = %dest, session_id, mode, "server rejected connection: {reason}; {link_status}");
258            true
259        }
260        session::SendLoopError::ServerClosed => {
261            tracing::info!(dest = %dest, session_id, mode, "server closed the channel; {link_status}");
262            true
263        }
264        session::SendLoopError::DeliveryTimeout => {
265            let timeout = hyperactor_config::global::get(config::MESSAGE_DELIVERY_TIMEOUT);
266            tracing::error!(
267                dest = %dest, session_id, mode,
268                "failed to receive ack within timeout {timeout:?}; link is currently connected; {link_status}"
269            );
270            true
271        }
272        session::SendLoopError::OversizedFrame(reason) => {
273            tracing::error!(dest = %dest, session_id, mode, "oversized frame: {reason}; {link_status}");
274            true
275        }
276    }
277}
278
279/// Establish a simplex (send-only) session over the given link. Returns a send handle.
280pub(crate) fn spawn<M: RemoteMessage>(link: impl Link) -> NetTx<M> {
281    spawn_inner(link)
282}
283
284/// Establish a multi-stream (unordered) simplex session over N
285/// links sharing the same `SessionId`. Returns a single send handle
286/// that distributes frames across streams.
287pub(crate) fn spawn_unordered<M: RemoteMessage>(links: Vec<impl Link + 'static>) -> NetTx<M> {
288    assert!(!links.is_empty());
289    if links.len() == 1 {
290        return spawn(links.into_iter().next().unwrap());
291    }
292
293    let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
294    let dest = links[0].dest();
295    let session_id = links[0].link_id();
296    let (notify, status) = watch::channel(TxStatus::Active);
297    let tx = NetTx {
298        sender,
299        dest: dest.clone(),
300        status,
301    };
302
303    let num_streams = links.len();
304
305    crate::init::get_runtime().spawn(async move {
306        // Shared MPMC work queue. The dispatcher enqueues *unserialized*
307        // messages; each writer serializes its own pulls before writing.
308        // This spreads serialization cost across all N writer tasks
309        // instead of bottlenecking on the dispatcher.
310        let (queue_tx, queue_rx) =
311            async_channel::bounded::<session::PendingMessage<M>>(num_streams * 8);
312
313        // Shared unacked buffer: any writer's ack reader can prune it.
314        // BTreeMap keyed by seq — writers insert out of order
315        // (multiple streams, interleaved), so not a VecDeque.
316        let unacked: Arc<
317            tokio::sync::Mutex<std::collections::BTreeMap<u64, session::QueuedMessage<M>>>,
318        > = Arc::new(tokio::sync::Mutex::new(std::collections::BTreeMap::new()));
319
320        let mut writer_handles: Vec<tokio::task::JoinHandle<()>> = Vec::with_capacity(num_streams);
321        let log_id = format!("session {}.{:016x}", dest, session_id.0);
322
323        for (i, link) in links.into_iter().enumerate() {
324            let dest = dest.clone();
325            let unacked = unacked.clone();
326            let queue_rx = queue_rx.clone();
327            let log_id = log_id.clone();
328
329            writer_handles.push(tokio::spawn(async move {
330                let mut session = Session::new(link);
331                let mut reconnect_backoff = ExponentialBackoffBuilder::new()
332                    .with_initial_interval(Duration::from_millis(10))
333                    .with_multiplier(2.0)
334                    .with_randomization_factor(0.1)
335                    .with_max_interval(Duration::from_secs(5))
336                    .with_max_elapsed_time(None)
337                    .build();
338
339                loop {
340                    let connected = match session.connect().await {
341                        Ok(s) => s,
342                        Err(_) => {
343                            tracing::info!(
344                                dest = %dest, stream = i,
345                                "multi-stream writer {} connect failed", i
346                            );
347                            break;
348                        }
349                    };
350                    tracing::info!(
351                        dest = %dest, stream = i, "multi-stream writer {} connected", i
352                    );
353
354                    let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
355                    let connected_at = tokio::time::Instant::now();
356
357                    // Pull from the shared queue; serialize locally; write; interleave ack reads.
358                    let result: Result<(), session::SendLoopError> = async {
359                        loop {
360                            tokio::select! {
361                                biased;
362
363                                ack_result = stream.next() => {
364                                    match ack_result {
365                                        Ok(Some(buffer)) => {
366                                            let response = deserialize_response(buffer)
367                                                .map_err(|e| session::SendLoopError::Io(e.into()))?;
368                                            match response {
369                                                NetRxResponse::Ack(ack) => {
370                                                    let mut guard = unacked.lock().await;
371                                                    // Remove all entries with seq <= ack.
372                                                    let retain: std::collections::BTreeMap<u64, session::QueuedMessage<M>> = guard.split_off(&(ack + 1));
373                                                    drop(std::mem::replace(&mut *guard, retain));
374                                                }
375                                                NetRxResponse::Reject(reason) => {
376                                                    return Err(session::SendLoopError::Rejected(reason));
377                                                }
378                                                NetRxResponse::Closed => {
379                                                    return Err(session::SendLoopError::ServerClosed);
380                                                }
381                                            }
382                                        }
383                                        Ok(None) => return Ok(()),
384                                        Err(e) => return Err(session::SendLoopError::Io(e.into())),
385                                    }
386                                }
387
388                                msg = queue_rx.recv() => {
389                                    let pending = match msg {
390                                        Ok(m) => m,
391                                        // Dispatcher closed the queue: clean shutdown.
392                                        Err(_) => return Ok(()),
393                                    };
394                                    let session::PendingMessage {
395                                        seq,
396                                        message,
397                                        received_at,
398                                        return_channel,
399                                    } = pending;
400                                    let frame = Frame::Message(seq, message);
401                                    let serialized = match serde_multipart::serialize_bincode(&frame) {
402                                        Ok(m) => m,
403                                        Err(e) => {
404                                            tracing::error!(
405                                                "{log_id}: serialization error: {e}"
406                                            );
407                                            // Drops return_channel; sender perceives success
408                                            // (preserving prior behavior of the dispatcher-side
409                                            // serialize path).
410                                            continue;
411                                        }
412                                    };
413                                    let mut queued = session::QueuedMessage {
414                                        seq,
415                                        message: serialized,
416                                        received_at,
417                                        sent_at: None,
418                                        return_channel,
419                                    };
420                                    let framed = queued.message.clone().framed();
421                                    stream.write(framed).drive().await.map_err(|e| {
422                                        session::SendLoopError::Io(e.into())
423                                    })?;
424                                    queued.sent_at = Some(tokio::time::Instant::now());
425                                    unacked.lock().await.insert(queued.seq, queued);
426                                }
427                            }
428                        }
429                    }
430                    .await;
431
432                    session = connected.release();
433
434                    if connected_at.elapsed() > Duration::from_secs(1) {
435                        reconnect_backoff.reset();
436                    }
437
438                    match result {
439                        Ok(()) => {
440                            if queue_rx.is_closed() {
441                                // Dispatcher is gone and queue is drained.
442                                break;
443                            }
444                            if let Some(delay) = reconnect_backoff.next_backoff() {
445                                tokio::time::sleep(delay).await;
446                            }
447                        }
448                        Err(ref e) => {
449                            if log_send_error(e, &dest, session_id.0, "multi-stream", &LinkStatus::NeverConnected) {
450                                break;
451                            }
452                            if let Some(delay) = reconnect_backoff.next_backoff() {
453                                tokio::time::sleep(delay).await;
454                            }
455                        }
456                    }
457                }
458
459                tracing::info!(
460                    dest = %dest,
461                    stream = i,
462                    "multi-stream writer {} shutting down",
463                    i,
464                );
465            }));
466        }
467
468        // Drop our local receiver clone so the queue closes once the
469        // dispatcher's sender (queue_tx) is dropped at shutdown.
470        drop(queue_rx);
471
472        // Dispatcher: receive from app and enqueue for writers — no
473        // serialization here; the writer that pulls the item serializes
474        // it before writing.
475        let mut next_seq = 0u64;
476
477        tracing::info!(
478            %dest, session = %log_id, num_streams,
479            "multi-stream dispatcher started"
480        );
481
482        while let Some((message, return_channel, received_at)) = receiver.recv().await {
483            let pending = session::PendingMessage {
484                seq: next_seq,
485                message,
486                received_at,
487                return_channel,
488            };
489            next_seq += 1;
490
491            if queue_tx.send(pending).await.is_err() {
492                // All writers are gone.
493                break;
494            }
495        }
496
497        // Shutdown: close the shared queue and wait for writers to drain.
498        drop(queue_tx);
499        for handle in writer_handles {
500            let _ = handle.await;
501        }
502
503        let reason = format!("{log_id}: dispatcher closed");
504        let _ = notify.send(TxStatus::Closed(reason.into()));
505    });
506
507    tx
508}
509
510fn spawn_inner<M: RemoteMessage>(link: impl Link) -> NetTx<M> {
511    let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
512    let dest = link.dest();
513    let session_id = link.link_id();
514    let (notify, status) = watch::channel(TxStatus::Active);
515    let tx = NetTx {
516        sender,
517        dest: dest.clone(),
518        status,
519    };
520    crate::init::get_runtime().spawn(async move {
521        let mut session = Session::new(link);
522        let log_id = format!("session {}.{:016x}", dest, session_id.0);
523        let mut deliveries = session::Deliveries {
524            outbox: session::Outbox::new(log_id.clone(), dest.clone(), session_id.0),
525            unacked: session::Unacked::new(None, log_id.clone()),
526        };
527        let mut receiver = receiver;
528
529        // Lazy connect: wait for first message.
530        match receiver.recv().await {
531            Some(msg) => {
532                if let Err(err) = deliveries.outbox.push_back(msg) {
533                    tracing::error!(
534                        dest = %dest,
535                        session_id = session_id.0,
536                        error = %err,
537                        "failed to push message to outbox"
538                    );
539                    let _ = notify.send(TxStatus::Closed("failed to push to outbox".into()));
540                    return;
541                }
542            }
543            None => {
544                let _ = notify.send(TxStatus::Closed("sender dropped".into()));
545                return;
546            }
547        }
548
549        let mut reconnect_backoff = ExponentialBackoffBuilder::new()
550            .with_initial_interval(Duration::from_millis(10))
551            .with_multiplier(2.0)
552            .with_randomization_factor(0.1)
553            .with_max_interval(Duration::from_secs(5))
554            .with_max_elapsed_time(None)
555            .build();
556
557        let mut link_status = LinkStatus::NeverConnected;
558
559        let reason: String = 'outer: loop {
560            let connected = match deliveries.expiry_time() {
561                Some(deadline) => match session.connect_by(deadline).await {
562                    Ok(s) => s,
563                    Err(_) => {
564                        let timeout =
565                            hyperactor_config::global::get(config::MESSAGE_DELIVERY_TIMEOUT);
566                        let error_msg = if deliveries.outbox.is_expired(timeout) {
567                            format!("failed to deliver message within timeout {timeout:?}; {link_status}")
568                        } else {
569                            format!(
570                                "failed to receive ack within timeout {timeout:?}; \
571                                 link is currently broken; {link_status}",
572                            )
573                        };
574                        tracing::error!(
575                            dest = %dest, session_id = session_id.0, "{}", error_msg
576                        );
577                        break 'outer format!("{log_id}: {error_msg}");
578                    }
579                },
580                None => match session.connect().await {
581                    Ok(s) => s,
582                    Err(_) => break 'outer "session shut down".into(),
583                },
584            };
585
586            metrics::CHANNEL_CONNECTIONS.add(
587                1,
588                hyperactor_telemetry::kv_pairs!(
589                    "transport" => dest.transport().to_string(),
590                    "mode" => "simplex",
591                    "reason" => "link connected",
592                ),
593            );
594
595            if !deliveries.unacked.is_empty() {
596                metrics::CHANNEL_RECONNECTIONS.add(
597                    1,
598                    hyperactor_telemetry::kv_pairs!(
599                        "dest" => dest.to_string(),
600                        "transport" => dest.transport().to_string(),
601                        "mode" => "simplex",
602                        "reason" => "reconnect_with_unacked",
603                    ),
604                );
605            }
606            deliveries.requeue_unacked();
607
608            link_status.connected();
609            let connected_at = tokio::time::Instant::now();
610
611            let result = {
612                let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
613                session::send_connected(&stream, &mut deliveries, &mut receiver).await
614            };
615            session = connected.release();
616
617            link_status.disconnected();
618
619            // Reset backoff if the connection was alive long enough to have
620            // been useful (i.e. not an immediate EOF/error).
621            if connected_at.elapsed() > Duration::from_secs(1) {
622                reconnect_backoff.reset();
623            }
624
625            match result {
626                Ok(()) => {
627                    // EOF — connection closed normally, reconnect after backoff.
628                    if let Some(delay) = reconnect_backoff.next_backoff() {
629                        tracing::info!(
630                            dest = %dest,
631                            session_id = session_id.0,
632                            delay_ms = delay.as_millis() as u64,
633                            "send_connected returned EOF, reconnecting after backoff; {link_status}"
634                        );
635                        tokio::time::sleep(delay).await;
636                    }
637                }
638                Err(ref e) => {
639                    if log_send_error(e, &dest, session_id.0, "simplex", &link_status) {
640                        break 'outer format!("{log_id}: {e}");
641                    }
642                    // Recoverable error — reconnect after backoff.
643                    if let Some(delay) = reconnect_backoff.next_backoff() {
644                        tracing::info!(
645                            dest = %dest,
646                            session_id = session_id.0,
647                            delay_ms = delay.as_millis() as u64,
648                            error = %e,
649                            "send_connected returned recoverable error, reconnecting after backoff; {link_status}"
650                        );
651                        tokio::time::sleep(delay).await;
652                    }
653                }
654            }
655        };
656
657        tracing::info!(
658            dest = %dest, session_id = session_id.0, "NetTx closing: {reason}"
659        );
660
661        receiver.close();
662        deliveries
663            .unacked
664            .deque
665            .drain(..)
666            .chain(deliveries.outbox.deque.drain(..))
667            .for_each(|queued| queued.try_return(Some(reason.clone())));
668        while let Ok((msg, return_channel, _)) = receiver.try_recv() {
669            let _ = return_channel.send(SendError {
670                error: ChannelError::Closed,
671                message: msg,
672                reason: Some(reason.clone()),
673            });
674        }
675
676        let _ = notify.send(TxStatus::Closed(reason.into()));
677    });
678    tx
679}
680
681/// Transport-agnostic link that dispatches to the appropriate
682/// transport based on the channel address.
683#[derive(Debug)]
684pub(crate) enum NetLink {
685    Tcp(tcp::TcpLink),
686    Unix(unix::UnixLink),
687    Tls(tls::TlsLink),
688}
689
690/// Create a link for the given channel address with the given
691/// `session_id` and `stream_id`. Single-stream callers pass a fresh
692/// `SessionId::random()` and `stream_id = 0`.
693pub(crate) fn link(
694    addr: ChannelAddr,
695    session_id: SessionId,
696    stream_id: u8,
697) -> Result<NetLink, ClientError> {
698    match addr {
699        ChannelAddr::Tcp(socket_addr) => {
700            Ok(NetLink::Tcp(tcp::link(socket_addr, session_id, stream_id)))
701        }
702        ChannelAddr::Unix(unix_addr) => {
703            Ok(NetLink::Unix(unix::link(unix_addr, session_id, stream_id)))
704        }
705        ChannelAddr::Tls(tls_addr) => Ok(NetLink::Tls(tls::link(tls_addr, session_id, stream_id)?)),
706        ChannelAddr::MetaTls(meta_addr) => {
707            Ok(NetLink::Tls(meta::link(meta_addr, session_id, stream_id)?))
708        }
709        other => Err(ClientError::Connect(
710            other,
711            std::io::Error::other("unsupported transport"),
712            "unsupported transport".into(),
713        )),
714    }
715}
716
717#[async_trait]
718impl Link for NetLink {
719    type Stream = Box<dyn Stream>;
720
721    fn dest(&self) -> ChannelAddr {
722        match self {
723            Self::Tcp(l) => l.dest(),
724            Self::Unix(l) => l.dest(),
725            Self::Tls(l) => l.dest(),
726        }
727    }
728
729    fn link_id(&self) -> SessionId {
730        match self {
731            Self::Tcp(l) => l.link_id(),
732            Self::Unix(l) => l.link_id(),
733            Self::Tls(l) => l.link_id(),
734        }
735    }
736
737    async fn next(&mut self) -> Result<Box<dyn Stream>, ClientError> {
738        match self {
739            Self::Tcp(l) => Ok(Box::new(l.next().await?)),
740            Self::Unix(l) => Ok(Box::new(l.next().await?)),
741            Self::Tls(l) => Ok(Box::new(l.next().await?)),
742        }
743    }
744}
745
746/// Listener represents the server side of a network link: it accepts inbound connections.
747///
748/// This is the counterpart to [`Link`]. Each transport module (tcp, unix, tls)
749/// provides both a `Link` impl (for dialing) and a `Listener` impl (for accepting).
750#[async_trait]
751pub(crate) trait Listener: Send + Unpin + 'static {
752    /// The underlying stream type produced by accepting a connection.
753    type Stream: Stream;
754
755    /// Accept the next inbound connection, returning the stream and the peer's address.
756    async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError>;
757}
758
759/// Transport-agnostic listener that dispatches to the appropriate
760/// transport based on the channel address. TLS has no variant — it
761/// uses `Tcp` under the hood (the TLS handshake happens in `prepare`,
762/// not the listener).
763#[derive(Debug)]
764pub(crate) enum NetListener {
765    Tcp(tcp::TcpSocketListener),
766    Unix(unix::UnixSocketListener),
767}
768
769#[async_trait]
770impl Listener for NetListener {
771    type Stream = Box<dyn Stream>;
772
773    async fn accept(&mut self) -> Result<(Box<dyn Stream>, ChannelAddr), ServerError> {
774        match self {
775            Self::Tcp(l) => {
776                let (stream, addr) = l.accept().await?;
777                Ok((Box::new(stream), addr))
778            }
779            Self::Unix(l) => {
780                let (stream, addr) = l.accept().await?;
781                Ok((Box::new(stream), addr))
782            }
783        }
784    }
785}
786
787/// Bind a listener for the given channel address, optionally using a pre-opened TCP listener.
788/// Returns the listener and the canonical address callers should advertise.
789/// When `prebound` is `Some`, it is used for TCP/TLS transports instead of binding a new socket.
790pub(crate) fn listen_with_prebound(
791    addr: ChannelAddr,
792    prebound: Option<std::net::TcpListener>,
793) -> Result<(NetListener, ChannelAddr), ServerError> {
794    match addr {
795        ChannelAddr::Tcp(socket_addr) => {
796            let std_listener = match prebound {
797                Some(l) => l,
798                None => std::net::TcpListener::bind(socket_addr)
799                    .map_err(|err| ServerError::Listen(ChannelAddr::Tcp(socket_addr), err))?,
800            };
801            std_listener
802                .set_nonblocking(true)
803                .map_err(|e| ServerError::Listen(ChannelAddr::Tcp(socket_addr), e))?;
804            let tokio_listener = tokio::net::TcpListener::from_std(std_listener)
805                .map_err(|e| ServerError::Listen(ChannelAddr::Tcp(socket_addr), e))?;
806            let local_addr = tokio_listener
807                .local_addr()
808                .map_err(|err| ServerError::Resolve(ChannelAddr::Tcp(socket_addr), err))?;
809            let listener = tcp::TcpSocketListener {
810                inner: tokio_listener,
811                addr: local_addr,
812            };
813            Ok((NetListener::Tcp(listener), ChannelAddr::Tcp(local_addr)))
814        }
815        ChannelAddr::Unix(ref unix_addr) => {
816            use std::os::unix::net::UnixDatagram as StdUnixDatagram;
817            use std::os::unix::net::UnixListener as StdUnixListener;
818
819            let caddr = addr.clone();
820            let maybe_listener = match unix_addr {
821                unix::SocketAddr::Bound(sock_addr) => StdUnixListener::bind_addr(sock_addr),
822                unix::SocketAddr::Unbound => StdUnixDatagram::unbound()
823                    .and_then(|u| u.local_addr())
824                    .and_then(|uaddr| StdUnixListener::bind_addr(&uaddr)),
825            };
826            let std_listener =
827                maybe_listener.map_err(|err| ServerError::Listen(caddr.clone(), err))?;
828            std_listener
829                .set_nonblocking(true)
830                .map_err(|err| ServerError::Listen(caddr.clone(), err))?;
831            let local_addr = std_listener
832                .local_addr()
833                .map_err(|err| ServerError::Resolve(caddr.clone(), err))?;
834            let tokio_listener = tokio::net::UnixListener::from_std(std_listener)
835                .map_err(|err| ServerError::Io(caddr, err))?;
836            let bound_addr = unix::SocketAddr::new(local_addr);
837            let listener = unix::UnixSocketListener {
838                inner: tokio_listener,
839                addr: bound_addr.clone(),
840            };
841            Ok((NetListener::Unix(listener), ChannelAddr::Unix(bound_addr)))
842        }
843        addr @ (ChannelAddr::Tls(_) | ChannelAddr::MetaTls(_)) => {
844            let is_meta = matches!(addr, ChannelAddr::MetaTls(_));
845            let tls_addr = match addr {
846                ChannelAddr::Tls(a) | ChannelAddr::MetaTls(a) => a,
847                _ => unreachable!(),
848            };
849            let TlsAddr { hostname, port } = tls_addr;
850            let make_channel_addr = |h: &str, p: Port| {
851                if is_meta {
852                    ChannelAddr::MetaTls(TlsAddr::new(h, p))
853                } else {
854                    ChannelAddr::Tls(TlsAddr::new(h, p))
855                }
856            };
857
858            let addrs: Vec<core::net::SocketAddr> = (hostname.as_ref(), port)
859                .to_socket_addrs()
860                .map_err(|err| ServerError::Resolve(make_channel_addr(&hostname, port), err))?
861                .collect();
862
863            if addrs.is_empty() {
864                return Err(ServerError::Resolve(
865                    make_channel_addr(&hostname, port),
866                    std::io::Error::other("no available socket addr"),
867                ));
868            }
869
870            let channel_addr = make_channel_addr(&hostname, port);
871            let std_listener = match prebound {
872                Some(l) => l,
873                None => std::net::TcpListener::bind(&addrs[..])
874                    .map_err(|err| ServerError::Listen(channel_addr.clone(), err))?,
875            };
876            std_listener
877                .set_nonblocking(true)
878                .map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
879            let tokio_listener = tokio::net::TcpListener::from_std(std_listener)
880                .map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
881            let local_addr = tokio_listener
882                .local_addr()
883                .map_err(|err| ServerError::Resolve(channel_addr, err))?;
884            let listener = tcp::TcpSocketListener {
885                inner: tokio_listener,
886                addr: local_addr,
887            };
888            Ok((
889                NetListener::Tcp(listener),
890                make_channel_addr(&hostname, local_addr.port()),
891            ))
892        }
893        other => Err(ServerError::Listen(
894            other.clone(),
895            std::io::Error::other(format!("unsupported transport: {}", other)),
896        )),
897    }
898}
899
900/// Bind a listener for the given channel address. Returns the listener
901/// and the canonical address callers should advertise (which encodes
902/// the transport — e.g. `ChannelAddr::Tls` for TLS).
903#[expect(
904    dead_code,
905    reason = "canonical listen() entry point; callers currently route through listen_with_prebound"
906)]
907pub(crate) fn listen(addr: ChannelAddr) -> Result<(NetListener, ChannelAddr), ServerError> {
908    listen_with_prebound(addr, None)
909}
910
911/// Frames are the messages sent between clients and servers over sessions.
912#[derive(Debug, Serialize, Deserialize, EnumAsInner, PartialEq)]
913pub(super) enum Frame<M> {
914    /// Send a message with the provided sequence number.
915    Message(u64, M),
916}
917
918#[derive(Debug, Serialize, Deserialize, EnumAsInner)]
919pub(super) enum NetRxResponse {
920    Ack(u64),
921    /// This session is rejected with the given reason. NetTx should stop reconnecting.
922    Reject(String),
923    /// This channel is closed.
924    Closed,
925}
926
927pub(super) fn serialize_response(
928    response: NetRxResponse,
929) -> Result<Bytes, bincode::error::EncodeError> {
930    bincode::serde::encode_to_vec(&response, bincode::config::legacy()).map(|bytes| bytes.into())
931}
932
933pub(super) fn deserialize_response(
934    data: Bytes,
935) -> Result<NetRxResponse, bincode::error::DecodeError> {
936    bincode::serde::decode_from_slice(&data, bincode::config::legacy()).map(|(v, _)| v)
937}
938
939/// A Tx implemented on top of a Link. The Tx manages the link state,
940/// reconnections, etc.
941pub(crate) struct NetTx<M: RemoteMessage> {
942    sender: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
943    dest: ChannelAddr,
944    status: watch::Receiver<TxStatus>,
945}
946
947#[async_trait]
948impl<M: RemoteMessage> Tx<M> for NetTx<M> {
949    fn addr(&self) -> ChannelAddr {
950        self.dest.clone()
951    }
952
953    fn status(&self) -> &watch::Receiver<TxStatus> {
954        &self.status
955    }
956
957    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
958        tracing::trace!(
959            name = "post",
960            dest = %self.dest,
961            "sending message"
962        );
963
964        let return_channel = return_channel.unwrap_or_else(|| oneshot::channel().0);
965        if let Err(mpsc::error::SendError((message, return_channel, _))) =
966            self.sender
967                .send((message, return_channel, tokio::time::Instant::now()))
968        {
969            let reason = self.status.borrow().as_closed().map(|r| r.to_string());
970            let _ = return_channel.send(SendError {
971                error: ChannelError::Closed,
972                message,
973                reason,
974            });
975        }
976    }
977}
978
979pub struct NetRx<M: RemoteMessage>(mpsc::Receiver<M>, ChannelAddr, ServerHandle);
980
981#[async_trait]
982impl<M: RemoteMessage> Rx<M> for NetRx<M> {
983    async fn recv(&mut self) -> Result<M, ChannelError> {
984        tracing::trace!(
985            name = "recv",
986            dest = %self.1,
987            "receiving message"
988        );
989        self.0.recv().await.ok_or(ChannelError::Closed)
990    }
991
992    fn addr(&self) -> ChannelAddr {
993        self.1.clone()
994    }
995
996    /// Gracefully shut down the channel server, waiting for pending
997    /// acks to be flushed before returning.
998    async fn join(mut self) {
999        self.2
1000            .stop(&format!("NetRx joined; channel address: {}", self.1));
1001        let _ = (&mut self.2).await;
1002        // Drop will call stop() again which is harmless (token already cancelled).
1003    }
1004}
1005
1006impl<M: RemoteMessage> Drop for NetRx<M> {
1007    fn drop(&mut self) {
1008        self.2
1009            .stop(&format!("NetRx dropped; channel address: {}", self.1));
1010    }
1011}
1012
1013/// Error returned during server operations.
1014#[derive(Debug, thiserror::Error)]
1015pub enum ServerError {
1016    /// An I/O error occurred while operating on the server at the given address.
1017    #[error("io: {1}")]
1018    Io(ChannelAddr, #[source] std::io::Error),
1019    /// Listening on the given address failed.
1020    #[error("listen: {0} {1}")]
1021    Listen(ChannelAddr, #[source] std::io::Error),
1022    /// Resolving the given address failed.
1023    #[error("resolve: {0} {1}")]
1024    Resolve(ChannelAddr, #[source] std::io::Error),
1025    /// An internal server error occurred for the given address.
1026    #[error("internal: {0} {1}")]
1027    Internal(ChannelAddr, #[source] anyhow::Error),
1028}
1029
1030#[derive(thiserror::Error, Debug)]
1031pub enum ClientError {
1032    #[error("connection to {0} failed: {1}: {2}")]
1033    Connect(ChannelAddr, std::io::Error, String),
1034    #[error("unable to resolve address: {0}")]
1035    Resolve(ChannelAddr),
1036    #[error("io: {0} {1}")]
1037    Io(ChannelAddr, std::io::Error),
1038    #[error("send {0}: serialize: {1}")]
1039    Serialize(ChannelAddr, bincode::error::EncodeError),
1040    #[error("invalid address: {0}")]
1041    InvalidAddress(String),
1042}
1043
1044/// Tells whether the address is a 'net' address. These currently have different semantics
1045/// from local transports.
1046#[cfg(test)]
1047pub(super) fn is_net_addr(addr: &ChannelAddr) -> bool {
1048    match addr.transport() {
1049        ChannelTransport::Tcp(_) => true,
1050        ChannelTransport::MetaTls(_) => true,
1051        ChannelTransport::Tls => true,
1052        ChannelTransport::Unix => true,
1053        _ => false,
1054    }
1055}
1056
1057pub(crate) mod unix {
1058
1059    use core::str;
1060    use std::os::unix::net::SocketAddr as StdSocketAddr;
1061    use std::os::unix::net::UnixStream as StdUnixStream;
1062
1063    use rand::RngExt as _;
1064    use rand::distr::Alphanumeric;
1065    use tokio::net::UnixListener;
1066    use tokio::net::UnixStream;
1067
1068    use super::*;
1069
1070    #[derive(Debug)]
1071    pub(crate) struct UnixLink {
1072        pub(super) addr: SocketAddr,
1073        pub(super) session_id: SessionId,
1074        pub(super) stream_id: u8,
1075    }
1076
1077    #[async_trait]
1078    impl Link for UnixLink {
1079        type Stream = UnixStream;
1080
1081        fn dest(&self) -> ChannelAddr {
1082            ChannelAddr::Unix(self.addr.clone())
1083        }
1084
1085        fn link_id(&self) -> SessionId {
1086            self.session_id
1087        }
1088
1089        async fn next(&mut self) -> Result<Self::Stream, ClientError> {
1090            let session_id = self.session_id;
1091            let sock_addr = match &self.addr {
1092                SocketAddr::Bound(a) => a,
1093                SocketAddr::Unbound => return Err(ClientError::Resolve(self.dest())),
1094            };
1095            let mut backoff = ExponentialBackoffBuilder::new()
1096                .with_initial_interval(Duration::from_millis(1))
1097                .with_multiplier(2.0)
1098                .with_randomization_factor(0.1)
1099                .with_max_interval(Duration::from_millis(1000))
1100                .with_max_elapsed_time(None)
1101                .build();
1102            loop {
1103                match StdUnixStream::connect_addr(sock_addr) {
1104                    Ok(std_stream) => {
1105                        std_stream
1106                            .set_nonblocking(true)
1107                            .map_err(|err| ClientError::Io(self.dest(), err))?;
1108                        let mut stream = UnixStream::from_std(std_stream)
1109                            .map_err(|err| ClientError::Io(self.dest(), err))?;
1110                        write_link_init(&mut stream, session_id, self.stream_id)
1111                            .await
1112                            .map_err(|err| ClientError::Io(self.dest(), err))?;
1113                        return Ok(stream);
1114                    }
1115                    Err(err) => {
1116                        tracing::debug!(error = %err, "unix connect failed, backing off");
1117                        if let Some(delay) = backoff.next_backoff() {
1118                            tokio::time::sleep(delay).await;
1119                        }
1120                    }
1121                }
1122            }
1123        }
1124    }
1125
1126    /// Server-side listener for Unix domain sockets.
1127    #[derive(Debug)]
1128    pub(crate) struct UnixSocketListener {
1129        pub(super) inner: UnixListener,
1130        pub(super) addr: SocketAddr,
1131    }
1132
1133    #[async_trait]
1134    impl super::Listener for UnixSocketListener {
1135        type Stream = UnixStream;
1136
1137        async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
1138            let (stream, peer_addr) = self
1139                .inner
1140                .accept()
1141                .await
1142                .map_err(|err| ServerError::Io(ChannelAddr::Unix(self.addr.clone()), err))?;
1143            // tokio::net::unix::SocketAddr -> std::os::unix::net::SocketAddr
1144            let std_addr: StdSocketAddr = peer_addr.into();
1145            Ok((stream, ChannelAddr::Unix(SocketAddr::new(std_addr))))
1146        }
1147    }
1148
1149    /// Create a unix link to the given socket address.
1150    pub(crate) fn link(addr: SocketAddr, session_id: SessionId, stream_id: u8) -> UnixLink {
1151        UnixLink {
1152            addr,
1153            session_id,
1154            stream_id,
1155        }
1156    }
1157
1158    /// Wrapper around std-lib's unix::SocketAddr that lets us implement equality functions
1159    #[derive(Clone, Debug)]
1160    pub enum SocketAddr {
1161        Bound(Box<StdSocketAddr>),
1162        Unbound,
1163    }
1164
1165    impl PartialOrd for SocketAddr {
1166        fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1167            Some(self.cmp(other))
1168        }
1169    }
1170
1171    impl Ord for SocketAddr {
1172        fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1173            self.to_string().cmp(&other.to_string())
1174        }
1175    }
1176
1177    impl<'de> Deserialize<'de> for SocketAddr {
1178        fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
1179        where
1180            D: serde::Deserializer<'de>,
1181        {
1182            let s = String::deserialize(deserializer)?;
1183            Self::from_str(&s).map_err(D::Error::custom)
1184        }
1185    }
1186
1187    impl Serialize for SocketAddr {
1188        fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
1189        where
1190            S: serde::Serializer,
1191        {
1192            serializer.serialize_str(String::from(self).as_str())
1193        }
1194    }
1195
1196    impl From<&SocketAddr> for String {
1197        fn from(value: &SocketAddr) -> Self {
1198            match value {
1199                SocketAddr::Bound(addr) => match addr.as_pathname() {
1200                    Some(path) => path
1201                        .to_str()
1202                        .expect("unable to get str for path")
1203                        .to_string(),
1204                    #[cfg(target_os = "linux")]
1205                    _ => match addr.as_abstract_name() {
1206                        Some(name) => format!("@{}", String::from_utf8_lossy(name)),
1207                        _ => String::from("(unnamed)"),
1208                    },
1209                    #[cfg(not(target_os = "linux"))]
1210                    _ => String::from("(unnamed)"),
1211                },
1212                SocketAddr::Unbound => String::from("(unbound)"),
1213            }
1214        }
1215    }
1216
1217    impl FromStr for SocketAddr {
1218        type Err = anyhow::Error;
1219
1220        fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
1221            match s {
1222                "" => {
1223                    // TODO: ensure this socket doesn't already exist. 24 bytes of randomness should be good for now but is not perfect.
1224                    // We can't use annon sockets because those are not valid across processes that aren't in the same process hierarchy aka forked.
1225                    let random_string = rand::rng()
1226                        .sample_iter(&Alphanumeric)
1227                        .take(24)
1228                        .map(char::from)
1229                        .collect::<String>();
1230                    SocketAddr::from_abstract_name(&random_string)
1231                }
1232                // by convention, named sockets are displayed with an '@' prefix
1233                name if name.starts_with("@") => {
1234                    SocketAddr::from_abstract_name(name.strip_prefix("@").unwrap())
1235                }
1236                path => SocketAddr::from_pathname(path),
1237            }
1238        }
1239    }
1240
1241    impl Eq for SocketAddr {}
1242    impl std::hash::Hash for SocketAddr {
1243        fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
1244            String::from(self).hash(state);
1245        }
1246    }
1247    impl PartialEq for SocketAddr {
1248        fn eq(&self, other: &Self) -> bool {
1249            match (self, other) {
1250                (Self::Bound(saddr), Self::Bound(oaddr)) => {
1251                    if saddr.is_unnamed() || oaddr.is_unnamed() {
1252                        return false;
1253                    }
1254
1255                    #[cfg(target_os = "linux")]
1256                    {
1257                        saddr.as_pathname() == oaddr.as_pathname()
1258                            && saddr.as_abstract_name() == oaddr.as_abstract_name()
1259                    }
1260                    #[cfg(not(target_os = "linux"))]
1261                    {
1262                        // On non-Linux platforms, only compare pathname since no abstract names
1263                        saddr.as_pathname() == oaddr.as_pathname()
1264                    }
1265                }
1266                (Self::Unbound, _) | (_, Self::Unbound) => false,
1267            }
1268        }
1269    }
1270
1271    impl fmt::Display for SocketAddr {
1272        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1273            match self {
1274                Self::Bound(addr) => match addr.as_pathname() {
1275                    Some(path) => {
1276                        write!(f, "{}", path.to_string_lossy())
1277                    }
1278                    #[cfg(target_os = "linux")]
1279                    _ => match addr.as_abstract_name() {
1280                        Some(name) => {
1281                            if name.starts_with(b"@") {
1282                                return write!(f, "{}", String::from_utf8_lossy(name));
1283                            }
1284                            write!(f, "@{}", String::from_utf8_lossy(name))
1285                        }
1286                        _ => write!(f, "(unnamed)"),
1287                    },
1288                    #[cfg(not(target_os = "linux"))]
1289                    _ => write!(f, "(unnamed)"),
1290                },
1291                Self::Unbound => write!(f, "(unbound)"),
1292            }
1293        }
1294    }
1295
1296    impl SocketAddr {
1297        /// Wraps the stdlib socket address for use with this module
1298        pub fn new(addr: StdSocketAddr) -> Self {
1299            Self::Bound(Box::new(addr))
1300        }
1301
1302        /// Abstract socket names start with a "@" by convention when displayed. If there is an
1303        /// "@" prefix, it will be stripped from the name before used.
1304        #[cfg(target_os = "linux")]
1305        pub fn from_abstract_name(name: &str) -> anyhow::Result<Self> {
1306            Ok(Self::new(StdSocketAddr::from_abstract_name(
1307                name.strip_prefix("@").unwrap_or(name),
1308            )?))
1309        }
1310
1311        #[cfg(not(target_os = "linux"))]
1312        pub fn from_abstract_name(name: &str) -> anyhow::Result<Self> {
1313            // On non-Linux platforms, convert abstract names to filesystem paths
1314            let name = name.strip_prefix("@").unwrap_or(name);
1315            let path = Self::abstract_to_filesystem_path(name);
1316            Self::from_pathname(&path.to_string_lossy())
1317        }
1318
1319        #[cfg(not(target_os = "linux"))]
1320        fn abstract_to_filesystem_path(abstract_name: &str) -> std::path::PathBuf {
1321            use std::collections::hash_map::DefaultHasher;
1322            use std::hash::Hash;
1323            use std::hash::Hasher;
1324
1325            // Generate a stable hash of the abstract name for deterministic paths
1326            let mut hasher = DefaultHasher::new();
1327            abstract_name.hash(&mut hasher);
1328            let hash = hasher.finish();
1329
1330            // Include process ID to prevent inter-process conflicts
1331            let process_id = std::process::id();
1332
1333            // TODO: we just leak these. Should we do something smarter?
1334            std::path::PathBuf::from(format!("/tmp/hyperactor_{}_{:x}", process_id, hash))
1335        }
1336
1337        /// Pathnames may be absolute or relative.
1338        pub fn from_pathname(name: &str) -> anyhow::Result<Self> {
1339            Ok(Self::new(StdSocketAddr::from_pathname(name)?))
1340        }
1341    }
1342
1343    impl TryFrom<SocketAddr> for StdSocketAddr {
1344        type Error = anyhow::Error;
1345
1346        fn try_from(value: SocketAddr) -> Result<Self, Self::Error> {
1347            match value {
1348                SocketAddr::Bound(addr) => Ok(*addr),
1349                SocketAddr::Unbound => Err(anyhow::anyhow!(
1350                    "std::os::unix::SocketAddr must be a bound address"
1351                )),
1352            }
1353        }
1354    }
1355}
1356
1357pub(crate) mod tcp {
1358    use tokio::net::TcpListener;
1359    use tokio::net::TcpStream;
1360
1361    use super::*;
1362
1363    #[derive(Debug)]
1364    pub(crate) struct TcpLink {
1365        pub(super) addr: SocketAddr,
1366        pub(super) session_id: SessionId,
1367        pub(super) stream_id: u8,
1368    }
1369
1370    #[async_trait]
1371    impl Link for TcpLink {
1372        type Stream = TcpStream;
1373
1374        fn dest(&self) -> ChannelAddr {
1375            ChannelAddr::Tcp(self.addr)
1376        }
1377
1378        fn link_id(&self) -> SessionId {
1379            self.session_id
1380        }
1381
1382        async fn next(&mut self) -> Result<Self::Stream, ClientError> {
1383            let session_id = self.session_id;
1384            let mut backoff = ExponentialBackoffBuilder::new()
1385                .with_initial_interval(Duration::from_millis(1))
1386                .with_multiplier(2.0)
1387                .with_randomization_factor(0.1)
1388                .with_max_interval(Duration::from_millis(1000))
1389                .with_max_elapsed_time(None)
1390                .build();
1391            loop {
1392                match TcpStream::connect(&self.addr).await {
1393                    Ok(mut stream) => {
1394                        stream.set_nodelay(true).map_err(|err| {
1395                            ClientError::Connect(
1396                                self.dest(),
1397                                err,
1398                                "cannot disable Nagle algorithm".to_string(),
1399                            )
1400                        })?;
1401                        write_link_init(&mut stream, session_id, self.stream_id)
1402                            .await
1403                            .map_err(|err| ClientError::Io(self.dest(), err))?;
1404                        return Ok(stream);
1405                    }
1406                    Err(err) => {
1407                        tracing::debug!(error = %err, "tcp connect failed, backing off");
1408                        if let Some(delay) = backoff.next_backoff() {
1409                            tokio::time::sleep(delay).await;
1410                        }
1411                    }
1412                }
1413            }
1414        }
1415    }
1416
1417    /// Server-side listener for TCP sockets.
1418    #[derive(Debug)]
1419    pub(crate) struct TcpSocketListener {
1420        pub(super) inner: TcpListener,
1421        pub(super) addr: SocketAddr,
1422    }
1423
1424    #[async_trait]
1425    impl super::Listener for TcpSocketListener {
1426        type Stream = TcpStream;
1427
1428        async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
1429            let (stream, peer_addr) = self
1430                .inner
1431                .accept()
1432                .await
1433                .map_err(|err| ServerError::Io(ChannelAddr::Tcp(self.addr), err))?;
1434            stream
1435                .set_nodelay(true)
1436                .map_err(|err| ServerError::Io(ChannelAddr::Tcp(self.addr), err))?;
1437            Ok((stream, ChannelAddr::Tcp(peer_addr)))
1438        }
1439    }
1440
1441    /// Create a TCP link to the given socket address.
1442    pub(crate) fn link(addr: SocketAddr, session_id: SessionId, stream_id: u8) -> TcpLink {
1443        TcpLink {
1444            addr,
1445            session_id,
1446            stream_id,
1447        }
1448    }
1449}
1450
1451// TODO: Try to simplify the TLS creation T208304433
1452pub(crate) mod meta {
1453    use std::io;
1454    use std::path::PathBuf;
1455    use std::sync::Arc;
1456
1457    use anyhow::Result;
1458    use tokio_rustls::TlsAcceptor;
1459    use tokio_rustls::TlsConnector;
1460
1461    use super::*;
1462    use crate::config::Pem;
1463    use crate::config::PemBundle;
1464
1465    const THRIFT_TLS_SRV_CA_PATH_ENV: &str = "THRIFT_TLS_SRV_CA_PATH";
1466    const DEFAULT_SRV_CA_PATH: &str = "/var/facebook/rootcanal/ca.pem";
1467    const THRIFT_TLS_CL_CERT_PATH_ENV: &str = "THRIFT_TLS_CL_CERT_PATH";
1468    const THRIFT_TLS_CL_KEY_PATH_ENV: &str = "THRIFT_TLS_CL_KEY_PATH";
1469    const DEFAULT_SERVER_PEM_PATH: &str = "/var/facebook/x509_identities/server.pem";
1470
1471    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
1472    pub(crate) fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
1473        // Use right split to allow for ipv6 addresses where ":" is expected.
1474        let parts = addr_string.rsplit_once(":");
1475        match parts {
1476            Some((hostname, port_str)) => {
1477                let Ok(port) = port_str.parse() else {
1478                    return Err(ChannelError::InvalidAddress(addr_string.to_string()));
1479                };
1480                Ok(ChannelAddr::MetaTls(TlsAddr::new(hostname, port)))
1481            }
1482            _ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
1483        }
1484    }
1485
1486    /// Construct a PemBundle for server operations from Meta-specific paths.
1487    /// Server cert and key come from the same file (server.pem).
1488    pub(super) fn get_server_pem_bundle() -> PemBundle {
1489        let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1490            .map(PathBuf::from)
1491            .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1492        let server_pem_path = PathBuf::from(DEFAULT_SERVER_PEM_PATH);
1493        PemBundle {
1494            ca: Pem::File(ca_path),
1495            cert: Pem::File(server_pem_path.clone()),
1496            key: Pem::File(server_pem_path),
1497        }
1498    }
1499
1500    /// Construct a PemBundle for client operations from Meta-specific env vars.
1501    /// Returns None if client cert/key env vars are not set.
1502    fn get_client_pem_bundle() -> Option<PemBundle> {
1503        let cert_path = std::env::var_os(THRIFT_TLS_CL_CERT_PATH_ENV)?;
1504        let key_path = std::env::var_os(THRIFT_TLS_CL_KEY_PATH_ENV)?;
1505        let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1506            .map(PathBuf::from)
1507            .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1508        Some(PemBundle {
1509            ca: Pem::File(ca_path),
1510            cert: Pem::File(PathBuf::from(cert_path)),
1511            key: Pem::File(PathBuf::from(key_path)),
1512        })
1513    }
1514
1515    /// Creates a TLS acceptor by looking for necessary certs and keys in a Meta server environment.
1516    pub(crate) fn tls_acceptor(enforce_client_tls: bool) -> Result<TlsAcceptor> {
1517        let bundle = get_server_pem_bundle();
1518        tls::tls_acceptor_from_bundle(&bundle, enforce_client_tls)
1519    }
1520
1521    /// Try to create a TLS connector for Meta environments.
1522    ///
1523    /// Returns `Ok` when the root CA is present (optional client certs
1524    /// are added when `THRIFT_TLS_CL_CERT_PATH` / `THRIFT_TLS_CL_KEY_PATH`
1525    /// are set).
1526    pub(super) fn try_tls_connector() -> Result<TlsConnector> {
1527        tls_connector()
1528    }
1529
1530    /// Creates a TLS connector by looking for necessary certs and keys in a Meta server environment.
1531    /// Supports optional client authentication (unlike the tls module which always requires it).
1532    fn tls_connector() -> Result<TlsConnector> {
1533        // Ensure ring is installed as the process-level crypto provider.
1534        // No-op when already installed (e.g. under Buck with native-tls).
1535        let _ = rustls::crypto::ring::default_provider().install_default();
1536
1537        let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1538            .map(PathBuf::from)
1539            .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1540        let ca_pem = Pem::File(ca_path);
1541        let root_store = tls::build_root_store(&ca_pem)?;
1542
1543        // If client certs are available, use mutual TLS; otherwise, no client auth
1544        let config = rustls::ClientConfig::builder().with_root_certificates(Arc::new(root_store));
1545
1546        let config = if let Some(bundle) = get_client_pem_bundle() {
1547            let certs = tls::load_certs(&bundle.cert)?;
1548            let key = tls::load_key(&bundle.key)?;
1549            config
1550                .with_client_auth_cert(certs, key)
1551                .map_err(|e| anyhow::anyhow!("load client certs: {}", e))?
1552        } else {
1553            config.with_no_client_auth()
1554        };
1555
1556        Ok(TlsConnector::from(Arc::new(config)))
1557    }
1558
1559    /// Create a MetaTLS link to the given address.
1560    pub fn link(
1561        addr: TlsAddr,
1562        session_id: SessionId,
1563        stream_id: u8,
1564    ) -> Result<tls::TlsLink, ClientError> {
1565        let connector = tls_connector().map_err(|e| {
1566            ClientError::Connect(
1567                ChannelAddr::MetaTls(addr.clone()),
1568                io::Error::other(e.to_string()),
1569                "failed to create TLS connector".to_string(),
1570            )
1571        })?;
1572        let TlsAddr { hostname, port } = addr;
1573        Ok(tls::TlsLink {
1574            hostname,
1575            port,
1576            connector,
1577            addr_type: tls::TlsAddrType::MetaTls,
1578            session_id,
1579            stream_id,
1580        })
1581    }
1582}
1583
1584/// TLS transport module using configurable certificates via hyperactor config attributes.
1585pub(crate) mod tls {
1586    use std::io;
1587    use std::io::BufReader;
1588    use std::sync::Arc;
1589
1590    use anyhow::Context;
1591    use anyhow::Result;
1592    use rustls::RootCertStore;
1593    use rustls::pki_types::CertificateDer;
1594    use rustls::pki_types::PrivateKeyDer;
1595    use rustls::pki_types::ServerName;
1596    use tokio::net::TcpStream;
1597    use tokio_rustls::TlsAcceptor;
1598    use tokio_rustls::TlsConnector;
1599    use tokio_rustls::client::TlsStream;
1600
1601    use super::*;
1602    use crate::channel::TlsAddr;
1603    use crate::config::Pem;
1604    use crate::config::PemBundle;
1605    use crate::config::TLS_CA;
1606    use crate::config::TLS_CERT;
1607    use crate::config::TLS_KEY;
1608
1609    /// Distinguishes between Tls and MetaTls for address construction.
1610    #[derive(Debug, Clone, Copy)]
1611    pub(crate) enum TlsAddrType {
1612        Tls,
1613        MetaTls,
1614    }
1615
1616    /// Parse an address string into a TlsAddr.
1617    #[allow(clippy::result_large_err)]
1618    pub(crate) fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
1619        // Use right split to allow for ipv6 addresses where ":" is expected.
1620        let parts = addr_string.rsplit_once(":");
1621        match parts {
1622            Some((hostname, port_str)) => {
1623                let Ok(port) = port_str.parse() else {
1624                    return Err(ChannelError::InvalidAddress(addr_string.to_string()));
1625                };
1626                Ok(ChannelAddr::Tls(TlsAddr::new(hostname, port)))
1627            }
1628            _ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
1629        }
1630    }
1631
1632    /// Load certificates from a Pem value.
1633    pub(super) fn load_certs(pem: &Pem) -> Result<Vec<CertificateDer<'static>>> {
1634        let mut reader = BufReader::new(pem.reader()?);
1635        let certs = rustls_pemfile::certs(&mut reader)
1636            .filter_map(Result::ok)
1637            .collect();
1638        Ok(certs)
1639    }
1640
1641    /// Load a private key from a Pem value.
1642    pub(super) fn load_key(pem: &Pem) -> Result<PrivateKeyDer<'static>> {
1643        let mut reader = BufReader::new(pem.reader()?);
1644        loop {
1645            break match rustls_pemfile::read_one(&mut reader)? {
1646                Some(rustls_pemfile::Item::Pkcs1Key(key)) => Ok(PrivateKeyDer::Pkcs1(key)),
1647                Some(rustls_pemfile::Item::Pkcs8Key(key)) => Ok(PrivateKeyDer::Pkcs8(key)),
1648                Some(rustls_pemfile::Item::Sec1Key(key)) => Ok(PrivateKeyDer::Sec1(key)),
1649                Some(_) => continue,
1650                None => anyhow::bail!("no private key found in TLS key file"),
1651            };
1652        }
1653    }
1654
1655    /// Build root certificate store from the CA pem.
1656    pub(super) fn build_root_store(ca_pem: &Pem) -> Result<RootCertStore> {
1657        let mut root_store = RootCertStore::empty();
1658        let certs = load_certs(ca_pem)?;
1659        root_store.add_parsable_certificates(certs);
1660        Ok(root_store)
1661    }
1662
1663    /// Get the PEM bundle from configuration.
1664    fn get_pem_bundle() -> PemBundle {
1665        PemBundle {
1666            ca: hyperactor_config::global::get_cloned(TLS_CA),
1667            cert: hyperactor_config::global::get_cloned(TLS_CERT),
1668            key: hyperactor_config::global::get_cloned(TLS_KEY),
1669        }
1670    }
1671
1672    /// Creates a TLS acceptor using certificates from the provided PEM bundle.
1673    /// If `enforce_client_tls` is true, requires client certificates for mutual TLS.
1674    pub(super) fn tls_acceptor_from_bundle(
1675        bundle: &PemBundle,
1676        enforce_client_tls: bool,
1677    ) -> Result<TlsAcceptor> {
1678        // Ensure ring is installed as the process-level crypto provider.
1679        // No-op when already installed (e.g. under Buck with native-tls).
1680        let _ = rustls::crypto::ring::default_provider().install_default();
1681
1682        let certs = load_certs(&bundle.cert).context("load TLS certificate")?;
1683        let key = load_key(&bundle.key).context("load TLS key")?;
1684        let root_store = build_root_store(&bundle.ca).context("build root cert store")?;
1685
1686        let config = rustls::ServerConfig::builder();
1687        let config = if enforce_client_tls {
1688            // Build server config with mutual TLS (require client certs)
1689            let client_verifier =
1690                rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
1691                    .build()
1692                    .map_err(|e| anyhow::anyhow!("failed to build client verifier: {}", e))?;
1693            config.with_client_cert_verifier(client_verifier)
1694        } else {
1695            config.with_no_client_auth()
1696        }
1697        .with_single_cert(certs, key)?;
1698
1699        Ok(TlsAcceptor::from(Arc::new(config)))
1700    }
1701
1702    /// Creates a TLS acceptor using certificates from config (always enforces mutual TLS).
1703    pub(crate) fn tls_acceptor() -> Result<TlsAcceptor> {
1704        tls_acceptor_from_bundle(&get_pem_bundle(), true)
1705    }
1706
1707    /// Creates a TLS connector using certificates from the provided PEM bundle.
1708    pub(super) fn tls_connector_from_bundle(bundle: &PemBundle) -> Result<TlsConnector> {
1709        // Ensure ring is installed as the process-level crypto provider.
1710        // No-op when already installed (e.g. under Buck with native-tls).
1711        let _ = rustls::crypto::ring::default_provider().install_default();
1712
1713        let certs = load_certs(&bundle.cert).context("load TLS certificate")?;
1714        let key = load_key(&bundle.key).context("load TLS key")?;
1715        let root_store = build_root_store(&bundle.ca).context("build root cert store")?;
1716
1717        let config = rustls::ClientConfig::builder()
1718            .with_root_certificates(Arc::new(root_store))
1719            .with_client_auth_cert(certs, key)
1720            .context("configure client auth")?;
1721
1722        Ok(TlsConnector::from(Arc::new(config)))
1723    }
1724
1725    /// Creates a TLS connector using certificates from config.
1726    fn tls_connector() -> Result<TlsConnector> {
1727        tls_connector_from_bundle(&get_pem_bundle())
1728    }
1729
1730    /// Shared TLS link implementation used by both tls and metatls transports.
1731    pub(crate) struct TlsLink {
1732        pub(crate) hostname: Hostname,
1733        pub(crate) port: Port,
1734        pub(crate) connector: TlsConnector,
1735        pub(crate) addr_type: TlsAddrType,
1736        pub(crate) session_id: SessionId,
1737        pub(crate) stream_id: u8,
1738    }
1739
1740    impl std::fmt::Debug for TlsLink {
1741        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1742            f.debug_struct("TlsLink")
1743                .field("hostname", &self.hostname)
1744                .field("port", &self.port)
1745                .field("addr_type", &self.addr_type)
1746                .finish()
1747        }
1748    }
1749
1750    #[async_trait]
1751    impl Link for TlsLink {
1752        type Stream = TlsStream<TcpStream>;
1753
1754        fn dest(&self) -> ChannelAddr {
1755            let addr = TlsAddr::new(self.hostname.clone(), self.port);
1756            match self.addr_type {
1757                TlsAddrType::Tls => ChannelAddr::Tls(addr),
1758                TlsAddrType::MetaTls => ChannelAddr::MetaTls(addr),
1759            }
1760        }
1761
1762        fn link_id(&self) -> SessionId {
1763            self.session_id
1764        }
1765
1766        async fn next(&mut self) -> Result<Self::Stream, ClientError> {
1767            let session_id = self.session_id;
1768            let server_name = ServerName::try_from(self.hostname.clone()).map_err(|e| {
1769                ClientError::Connect(
1770                    self.dest(),
1771                    io::Error::other(e.to_string()),
1772                    "invalid server name".to_string(),
1773                )
1774            })?;
1775            let mut backoff = ExponentialBackoffBuilder::new()
1776                .with_initial_interval(Duration::from_millis(1))
1777                .with_multiplier(2.0)
1778                .with_randomization_factor(0.1)
1779                .with_max_interval(Duration::from_millis(1000))
1780                .with_max_elapsed_time(None)
1781                .build();
1782            loop {
1783                let mut addrs = (self.hostname.as_ref(), self.port)
1784                    .to_socket_addrs()
1785                    .map_err(|_| ClientError::Resolve(self.dest()))?;
1786                let addr = addrs.next().ok_or(ClientError::Resolve(self.dest()))?;
1787                match TcpStream::connect(&addr).await {
1788                    Ok(stream) => {
1789                        stream.set_nodelay(true).map_err(|err| {
1790                            ClientError::Connect(
1791                                self.dest(),
1792                                err,
1793                                "cannot disable Nagle algorithm".to_string(),
1794                            )
1795                        })?;
1796                        let mut tls_stream = self
1797                            .connector
1798                            .connect(server_name.clone(), stream)
1799                            .await
1800                            .map_err(|err| {
1801                                tracing::info!(
1802                                    dest = %self.dest(),
1803                                    error = %err,
1804                                    "TLS handshake failed"
1805                                );
1806                                ClientError::Connect(
1807                                    self.dest(),
1808                                    err,
1809                                    format!("cannot establish TLS connection to {:?}", server_name),
1810                                )
1811                            })?;
1812                        write_link_init(&mut tls_stream, session_id, self.stream_id)
1813                            .await
1814                            .map_err(|err| ClientError::Io(self.dest(), err))?;
1815                        return Ok(tls_stream);
1816                    }
1817                    Err(err) => {
1818                        tracing::debug!(error = %err, "tls connect failed, backing off");
1819                        if let Some(delay) = backoff.next_backoff() {
1820                            tokio::time::sleep(delay).await;
1821                        }
1822                    }
1823                }
1824            }
1825        }
1826    }
1827
1828    /// Create a TLS link to the given address.
1829    pub fn link(
1830        addr: TlsAddr,
1831        session_id: SessionId,
1832        stream_id: u8,
1833    ) -> Result<TlsLink, ClientError> {
1834        let connector = tls_connector().map_err(|e| {
1835            ClientError::Connect(
1836                ChannelAddr::Tls(addr.clone()),
1837                io::Error::other(e.to_string()),
1838                "failed to create TLS connector".to_string(),
1839            )
1840        })?;
1841        let TlsAddr { hostname, port } = addr;
1842        Ok(TlsLink {
1843            hostname,
1844            port,
1845            connector,
1846            addr_type: TlsAddrType::Tls,
1847            session_id,
1848            stream_id,
1849        })
1850    }
1851
1852    #[cfg(test)]
1853    mod tests {
1854        use timed_test::async_timed_test;
1855
1856        use super::*;
1857        use crate::channel::Rx;
1858        use crate::channel::net::server;
1859        use crate::config::Pem;
1860        use crate::config::TLS_CA;
1861        use crate::config::TLS_CERT;
1862        use crate::config::TLS_KEY;
1863
1864        // Dummy test certificates generated with openssl for testing only.
1865        // These certificates include Subject Alternative Names (SAN) for localhost, 127.0.0.1, and ::1
1866        // CA certificate
1867        const TEST_CA_CERT: &str = r#"-----BEGIN CERTIFICATE-----
1868MIIDBTCCAe2gAwIBAgIUaGNmboiIosG+8Up0vgDr/+cg+2IwDQYJKoZIhvcNAQEL
1869BQAwEjEQMA4GA1UEAwwHVGVzdCBDQTAeFw0yNjAxMjgxNzA4MzlaFw0yNzAxMjgx
1870NzA4MzlaMBIxEDAOBgNVBAMMB1Rlc3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IB
1871DwAwggEKAoIBAQC9RBoMYXCajklswt8Vi1JI1lEYzic0WNOmz45vG/7H6jTWkgL3
1872K5Ri+Seg3MobDNc48YHWXYm4hP9wCzkx8ih3ntT5XiY1My/G3jLUuoIEE9pF/BoJ
1873YQwZVoPNFhA9WhXNRsINf1cXFf8NzRfXpxBfKWtQJxYXU4JiDBQ6rLnQQABo8JmQ
1874vYFhJbBaYip5jTSiVNn7mB1zNr5jsVxuoSF53Pb7xQ76bwBdOq4zd6PSxL5/lr4G
1875cHSoxwZQdZMG7PL6hbxDQ2S2YI2lYVET1zwc2WPKCfjbEXBC/jzx828CInQtuksk
187618gJt6xHkTFEA8CSA29GM3lejnwYWf51xyyBAgMBAAGjUzBRMB0GA1UdDgQWBBRX
1877cbxSZ70NsUkAS3Hhy6irugywJDAfBgNVHSMEGDAWgBRXcbxSZ70NsUkAS3Hhy6ir
1878ugywJDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQA7aAFfyW67
1879Z+uGSVYhpsT/uH/3Z3nr7X1smTz5CGEfq2czEcTC7gbYI2l8GZ47GPfnAvHTBZVm
1880V/XncBCsj7/thOh2jYEHFyCbPckoaSCRyCOnK7LPUlr4HN5uP9EFe45qBLCJDEoY
1881GTTw7MtzwdovfjchNfKQCTtkBJCXQ95WLCf6UOh02Sn28UTlgfXzF0X0FrcWqWa3
1882uJZd4XOo4O6hKKlHaBaQPiEr++1xc3SWPV7jZHbckI/vKBnDdEZ9JQX5fFZuypUI
1883sgomYHxvxrU2hWx+7k53CRdjfaIvT9Ie44z9sSdsU/+blw2S8f/ZTmuECoIAAXYO
18840qpzlxZMdr7T
1885-----END CERTIFICATE-----"#;
1886
1887        // Server certificate (signed by CA) with SAN for localhost, 127.0.0.1, ::1
1888        const TEST_SERVER_CERT: &str = r#"-----BEGIN CERTIFICATE-----
1889MIIDJDCCAgygAwIBAgIUaz66DsWaH5ZXM4hCFnbVbMsyN1cwDQYJKoZIhvcNAQEL
1890BQAwEjEQMA4GA1UEAwwHVGVzdCBDQTAeFw0yNjAxMjgxNzA4MzlaFw0yNzAxMjgx
1891NzA4MzlaMBQxEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQAD
1892ggEPADCCAQoCggEBAKCbp++qNyTn5LOsV0h9gLKJALBcjg2A14I3804N9UyDhPW2
1893QKQ2W424u2P1MfKrw/2C+CErGlrADlnco2RQVDAarAIuGdFvBOt5UezqOS7Mk4OS
18949MlS7NZnMbc37KuM9UIG5ScJjXR/Z5z9dxeR0I9y3n0Ix6khbV7tOSHobiweI0FI
18958LftBS+CQnXr6vbWPcHcW6Z0FHUv7IWhqMWmv9PlZRGe9Y6VzXrRp0PBnZMOnAYf
1896aMQUwYRswWdm9j9Z1sMdTJ14G+KVmO3Vj6XI6Sm9uIcYhlwG/kORwogJFWlVuP9o
1897rloFRCjyHJ1d7GZqqnRyHHDDCBms8ed+3YfEYQECAwEAAaNwMG4wLAYDVR0RBCUw
1898I4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMB0GA1UdDgQWBBQl
1899J4vxUoCzqqeTwQAiLqE8wYezKzAfBgNVHSMEGDAWgBRXcbxSZ70NsUkAS3Hhy6ir
1900ugywJDANBgkqhkiG9w0BAQsFAAOCAQEAnXHIBDQ4AHAMV71piTOuI41ShASQed6L
1901bi7XUMZgZDslLkfU1vnP3BlwpliraBsAytSYQC6kbytOuz1uQ4K7yLb2tAAmUgEO
1902EdIVt9SXr5tCcIPeLmInF0pysPqjZO8n7vtJyd9gryKqdhm1uzA7WQWq/Az8a9Sk
1903uW2J6Oc5p6P7Mf3/ixqXzvGRo8rzu0CUJOJ67UTE/HhbJuplQ5dep5CEEOAIsAtH
1904zn9O4rW92ueBkoBJM++YILS1vQ7jKc2N3RNrnHm7FeootBrtR9mBi0TH97K73ZPZ
19052Cdhnym0CsCJggrllFGH32cYo7+K2PO7/4oj5XbBCSWcssicvd8ovg==
1906-----END CERTIFICATE-----"#;
1907
1908        // Server private key
1909        // test-only embedded key -- though it might expire at some point:
1910        // @lint-ignore PRIVATEKEY insecure-private-key-storage
1911        const TEST_SERVER_KEY: &str = r#"-----BEGIN PRIVATE KEY-----
1912MIIEugIBADANBgkqhkiG9w0BAQEFAASCBKQwggSgAgEAAoIBAQCgm6fvqjck5+Sz
1913rFdIfYCyiQCwXI4NgNeCN/NODfVMg4T1tkCkNluNuLtj9THyq8P9gvghKxpawA5Z
19143KNkUFQwGqwCLhnRbwTreVHs6jkuzJODkvTJUuzWZzG3N+yrjPVCBuUnCY10f2ec
1915/XcXkdCPct59CMepIW1e7Tkh6G4sHiNBSPC37QUvgkJ16+r21j3B3FumdBR1L+yF
1916oajFpr/T5WURnvWOlc160adDwZ2TDpwGH2jEFMGEbMFnZvY/WdbDHUydeBvilZjt
19171Y+lyOkpvbiHGIZcBv5DkcKICRVpVbj/aK5aBUQo8hydXexmaqp0chxwwwgZrPHn
1918ft2HxGEBAgMBAAECgf8G5qlQov+7ljs9fSpC8yGUik59RXzVF7Qq5DyQHglsQDp2
1919VF5yr+M/M7DZmq+KvdauDfKbej6np5j2Q4TByrHTX1IExfZWCW8srwnWJDpQyHmO
1920LcJW5DlI/SYluUFyHZxsOd+ezcpGNzM8i6eSW7GaeFUXCkmJ+uW4LnlF+7bALnnd
1921D6sak/58EsII+IJyd4lFn+voszlPn3CZGR0jkp21rvpaKgrMIsKVWWQO/sLDU5pr
1922VbpBThcLU5gRcnQouQX12e2VTCIlFu75WTsJ8V/KnEaOZUVlU/B/Bs+WQF3U+/Jo
1923eX4N+D6OsEcNQjERAFyWujxsl1WpD4uSsbFMN0ECgYEA2b7AdL+oKPQHku2KcBhr
1924Zw8K4tMDlr2VPPNwZcBTLo+O71vv/xXjMcXrXmowzkgEQckUmt1VB46riyydhwdP
1925/n9ciWcz0Va/nwHR6Y9F9unBiyUBP7PRhRyjQyRZZRGDSJvP+Xmc5UJFpRr07VLU
1926nfgMXDj37vXzKDpfhdEB2nkCgYEAvNMfA8P8w3+6246x5YHflvTkPdw+2oyge+LD
1927mphB/w7SF8mlyNGloj3+KBZmd9SkvT57wCvO96Y9/n+mBAVisRggc0hK4ymOVYhb
1928+im/JvqGQMbVeg6iCOHnWdaZf9tL8uVsegQy3kVTN7vAa+CMFgX1dt65cGBX6XkB
192944pYmMkCgYALhbiRdQLlB+TOtZs5y1EDpxwgXKI3+9hF3Wv5NnAwapBZwje0++eF
19303r9Rw7TJda4j/QwGFehF+hrBxp6fYpetE/hFnRx0225Qb7w368j8A+ql/lNOl6li
1931rd1F1EqWupKD6RrcTL8sspEU55RGaretlE6zIqCcGI/BdTVQ03qRoQKBgHDC3zWf
1932d7XD9HGjQGdfbIe4jQjIGxzmd/wjik4q+NZ5IkukVwWa9P/zZ3DHF8Ad05dT1hEH
19332FwaAdGWpyyljq9VSiOuG1KXAXHgsZSuE4ISf9P1KYzvaiJFzaPfvOEWs79E9MfU
19349A+6dJzG2X1SpjWMr26iSTlrv3QkmFUqzAfJAoGASBkn4wls+oC5rv/Mch43pBv5
1935UmKru4ltnEHJZdbSi2DJ+AnDLD222JCasb1VT1tm2XgW6DBqrdVRPPP6GOlB0MHU
1936+3ULtZxAczt7I+ST2bo0/DV2Hse89Cm63w4wLOiVZs7+1wrAzJZLokWF7Q5gesra
1937u19txmtkiMEH+aNmekk=
1938-----END PRIVATE KEY-----"#;
1939
1940        #[async_timed_test(timeout_secs = 30)]
1941        async fn test_tls_basic() {
1942            // Ensure ring is installed as the default crypto provider
1943            // (no-op if already installed, e.g. under Buck with native-tls).
1944            let _ = rustls::crypto::ring::default_provider().install_default();
1945
1946            // Set up TLS config using the standard override pattern
1947            let config = hyperactor_config::global::lock();
1948            let _guard_cert =
1949                config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1950            let _guard_key =
1951                config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1952            let _guard_ca =
1953                config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1954
1955            // Create a TLS server bound to localhost with dynamic port
1956            let addr = TlsAddr::new("localhost", 0);
1957
1958            let (local_addr, mut rx) =
1959                server::serve::<u64>(ChannelAddr::Tls(addr), None).expect("failed to serve");
1960
1961            // Dial the server
1962            let tx: super::NetTx<u64> = super::spawn(
1963                link(
1964                    match &local_addr {
1965                        ChannelAddr::Tls(addr) => addr.clone(),
1966                        _ => panic!("unexpected address type"),
1967                    },
1968                    SessionId::random(),
1969                    0,
1970                )
1971                .expect("failed to create link"),
1972            );
1973
1974            // Send a message
1975            tx.post(42u64);
1976
1977            // Receive the message
1978            let received = rx.recv().await.expect("failed to receive");
1979            assert_eq!(received, 42u64);
1980        }
1981
1982        #[async_timed_test(timeout_secs = 30)]
1983        async fn test_tls_multiple_messages() {
1984            let _ = rustls::crypto::ring::default_provider().install_default();
1985
1986            // Set up TLS config using the standard override pattern
1987            let config = hyperactor_config::global::lock();
1988            let _guard_cert =
1989                config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1990            let _guard_key =
1991                config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1992            let _guard_ca =
1993                config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1994
1995            let addr = TlsAddr::new("localhost", 0);
1996
1997            let (local_addr, mut rx) =
1998                server::serve::<String>(ChannelAddr::Tls(addr), None).expect("failed to serve");
1999            let tx: super::NetTx<String> = super::spawn(
2000                link(
2001                    match &local_addr {
2002                        ChannelAddr::Tls(addr) => addr.clone(),
2003                        _ => panic!("unexpected address type"),
2004                    },
2005                    SessionId::random(),
2006                    0,
2007                )
2008                .expect("failed to create link"),
2009            );
2010
2011            // Send multiple messages
2012            for i in 0..10 {
2013                tx.post(format!("message {}", i));
2014            }
2015
2016            // Receive all messages
2017            for i in 0..10 {
2018                let received = rx.recv().await.expect("failed to receive");
2019                assert_eq!(received, format!("message {}", i));
2020            }
2021        }
2022
2023        #[test]
2024        fn test_tls_parse_hostname_port() {
2025            let addr = parse("localhost:8080").expect("failed to parse");
2026            assert!(matches!(
2027                addr,
2028                ChannelAddr::Tls(TlsAddr { hostname, port })
2029                    if hostname == "localhost" && port == 8080
2030            ));
2031        }
2032
2033        #[test]
2034        fn test_tls_parse_socket_addr() {
2035            let addr = parse("127.0.0.1:8080").expect("failed to parse");
2036            assert!(matches!(
2037                addr,
2038                ChannelAddr::Tls(TlsAddr { hostname, port })
2039                    if hostname == "127.0.0.1" && port == 8080
2040            ));
2041        }
2042
2043        #[test]
2044        fn test_tls_certs_parsing() {
2045            // Verify that the test certificates can be parsed correctly
2046            let cert_pem = Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec());
2047            let key_pem = Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec());
2048            let ca_pem = Pem::Value(TEST_CA_CERT.as_bytes().to_vec());
2049
2050            let certs = super::load_certs(&cert_pem).expect("failed to load certs");
2051            assert!(!certs.is_empty(), "expected at least one certificate");
2052
2053            let _key = super::load_key(&key_pem).expect("failed to load key");
2054
2055            let root_store = super::build_root_store(&ca_pem).expect("failed to build root store");
2056            assert!(!root_store.is_empty(), "expected at least one CA cert");
2057        }
2058
2059        #[test]
2060        fn test_tls_acceptor_creation() {
2061            // Ensure ring is installed as the default crypto provider
2062            // (no-op if already installed, e.g. under Buck with native-tls).
2063            let _ = rustls::crypto::ring::default_provider().install_default();
2064
2065            // Set up TLS config using the standard override pattern
2066            let config = hyperactor_config::global::lock();
2067            let _guard_cert =
2068                config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
2069            let _guard_key =
2070                config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
2071            let _guard_ca =
2072                config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
2073
2074            // Verify that we can create a TLS acceptor
2075            let _acceptor = super::tls_acceptor().expect("failed to create TLS acceptor");
2076        }
2077
2078        #[test]
2079        fn test_tls_connector_creation() {
2080            // Ensure ring is installed as the default crypto provider
2081            // (no-op if already installed, e.g. under Buck with native-tls).
2082            let _ = rustls::crypto::ring::default_provider().install_default();
2083
2084            // Set up TLS config using the standard override pattern
2085            let config = hyperactor_config::global::lock();
2086            let _guard_cert =
2087                config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
2088            let _guard_key =
2089                config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
2090            let _guard_ca =
2091                config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
2092
2093            // Verify that we can create a TLS connector
2094            let _connector = super::tls_connector().expect("failed to create TLS connector");
2095        }
2096    }
2097}
2098
2099/// Build the OSS PemBundle from hyperactor_config attributes.
2100fn oss_pem_bundle() -> crate::config::PemBundle {
2101    crate::config::PemBundle {
2102        ca: hyperactor_config::global::get_cloned(crate::config::TLS_CA),
2103        cert: hyperactor_config::global::get_cloned(crate::config::TLS_CERT),
2104        key: hyperactor_config::global::get_cloned(crate::config::TLS_KEY),
2105    }
2106}
2107
2108/// Try to find a usable TLS [`PemBundle`](crate::config::PemBundle)
2109/// by probing the same sources as [`try_tls_acceptor`] /
2110/// [`try_tls_connector`].
2111///
2112/// Returns the first bundle whose CA certificate is readable.
2113/// Only CA readability is checked — cert and key are returned as-is
2114/// and may not be valid. Callers that cannot use `tokio_rustls` types
2115/// directly (e.g. reqwest) can read the raw PEM bytes via
2116/// [`Pem::reader`](crate::config::Pem::reader).
2117pub fn try_tls_pem_bundle() -> Option<crate::config::PemBundle> {
2118    let oss_bundle = oss_pem_bundle();
2119    if oss_bundle.ca.reader().is_ok() {
2120        return Some(oss_bundle);
2121    }
2122    tracing::debug!("OSS TLS bundle: CA not readable, trying Meta paths");
2123
2124    let meta_bundle = meta::get_server_pem_bundle();
2125    if meta_bundle.ca.reader().is_ok() {
2126        return Some(meta_bundle);
2127    }
2128    tracing::debug!("Meta TLS bundle: CA not readable, no TLS available");
2129
2130    None
2131}
2132
2133/// Try to build a [`TlsAcceptor`](tokio_rustls::TlsAcceptor) for an
2134/// HTTP server by probing for available TLS certificates.
2135///
2136/// Detection order:
2137/// 1. **OSS / explicit config** — `HYPERACTOR_TLS_CERT`,
2138///    `HYPERACTOR_TLS_KEY`, and `HYPERACTOR_TLS_CA` (read via
2139///    [`hyperactor_config`]).
2140/// 2. **Meta default paths** —
2141///    `/var/facebook/x509_identities/server.pem` and
2142///    `/var/facebook/rootcanal/ca.pem`. These are present on
2143///    devservers and in MAST / Tupperware containers.
2144/// 3. **None** — no usable certificates found; caller should fall
2145///    back to plain HTTP.
2146///
2147/// When `enforce_client_tls` is `true`, the returned acceptor
2148/// requires clients to present a valid certificate signed by the
2149/// configured CA (mutual TLS via `WebPkiClientVerifier`). When
2150/// `false`, the acceptor authenticates itself but does not demand
2151/// client certificates.
2152pub fn try_tls_acceptor(enforce_client_tls: bool) -> Option<tokio_rustls::TlsAcceptor> {
2153    let oss_bundle = oss_pem_bundle();
2154    if let Ok(acceptor) = tls::tls_acceptor_from_bundle(&oss_bundle, enforce_client_tls) {
2155        return Some(acceptor);
2156    }
2157    tracing::debug!("OSS TLS acceptor failed, trying Meta paths");
2158
2159    let meta_bundle = meta::get_server_pem_bundle();
2160    if let Ok(acceptor) = tls::tls_acceptor_from_bundle(&meta_bundle, enforce_client_tls) {
2161        return Some(acceptor);
2162    }
2163    tracing::debug!("Meta TLS acceptor failed, no TLS available");
2164
2165    None
2166}
2167
2168/// Try to build a [`TlsConnector`](tokio_rustls::TlsConnector) for an
2169/// HTTP client that needs to connect to a TLS-enabled server.
2170///
2171/// Detection mirrors [`try_tls_acceptor`]:
2172/// 1. **OSS** — `HYPERACTOR_TLS_CA` (and optionally
2173///    `HYPERACTOR_TLS_CERT` + `HYPERACTOR_TLS_KEY` for mutual TLS).
2174/// 2. **Meta** — root CA at `/var/facebook/rootcanal/ca.pem`,
2175///    optional client certs from `THRIFT_TLS_CL_CERT_PATH` /
2176///    `THRIFT_TLS_CL_KEY_PATH`.
2177/// 3. **None** — no usable CA found; caller should fall back to plain
2178///    HTTP.
2179pub fn try_tls_connector() -> Option<tokio_rustls::TlsConnector> {
2180    let oss_bundle = oss_pem_bundle();
2181    if let Ok(connector) = tls::tls_connector_from_bundle(&oss_bundle) {
2182        return Some(connector);
2183    }
2184    tracing::debug!("OSS TLS connector failed, trying Meta paths");
2185
2186    if let Ok(connector) = meta::try_tls_connector() {
2187        return Some(connector);
2188    }
2189    tracing::debug!("Meta TLS connector failed, no TLS available");
2190
2191    None
2192}
2193
2194#[cfg(test)]
2195mod tests {
2196
2197    #![expect(
2198        clippy::await_holding_invalid_type,
2199        reason = "tracing_test::traced_test macro expansion holds tracing::span::Entered across awaits; can't be fixed in our code"
2200    )]
2201
2202    use std::assert_matches;
2203    use std::collections::VecDeque;
2204    use std::marker::PhantomData;
2205    use std::sync::Arc;
2206    use std::sync::RwLock;
2207    use std::sync::atomic::AtomicBool;
2208    use std::sync::atomic::AtomicU64;
2209    use std::sync::atomic::Ordering;
2210    use std::time::Duration;
2211    #[cfg(target_os = "linux")] // uses abstract names
2212    use std::time::UNIX_EPOCH;
2213
2214    #[cfg(target_os = "linux")] // uses abstract names
2215    use anyhow::Result;
2216    use bytes::Bytes;
2217    use rand::RngExt as _;
2218    use rand::SeedableRng as _;
2219    use rand::distr::Alphanumeric;
2220    use rand::rngs::SysRng;
2221    use timed_test::async_timed_test;
2222    use tokio::io::AsyncWrite;
2223    use tokio::io::DuplexStream;
2224    use tokio::io::ReadHalf;
2225    use tokio::io::WriteHalf;
2226    use tokio::task::JoinHandle;
2227    use tokio_util::sync::CancellationToken;
2228
2229    use super::server;
2230    use super::*;
2231    use crate::channel;
2232    use crate::channel::net::framed::FrameReader;
2233    use crate::channel::net::framed::FrameWrite;
2234    use crate::channel::net::server::AcceptorLink;
2235    use crate::config;
2236    use crate::metrics;
2237    use crate::sync::mvar::MVar;
2238
2239    /// Like the `logs_assert` injected by `#[traced_test]`, but without scope
2240    /// filtering. Use when asserting on events emitted outside the test's span
2241    /// (e.g. from spawned tasks or panic hooks).
2242    fn logs_assert_unscoped(f: impl Fn(&[&str]) -> Result<(), String>) {
2243        let buf = tracing_test::internal::global_buf().lock().unwrap();
2244        let logs_str = std::str::from_utf8(&buf).expect("Logs contain invalid UTF8");
2245        let lines: Vec<&str> = logs_str.lines().collect();
2246        match f(&lines) {
2247            Ok(()) => {}
2248            Err(msg) => panic!("{}", msg),
2249        }
2250    }
2251
2252    #[cfg(target_os = "linux")] // uses abstract names
2253    #[tracing_test::traced_test]
2254    #[tokio::test]
2255    async fn test_unix_basic() -> Result<()> {
2256        let timestamp = std::time::SystemTime::now()
2257            .duration_since(UNIX_EPOCH)
2258            .unwrap()
2259            .as_nanos();
2260        let unique_address = format!("test_unix_basic_{}", timestamp);
2261
2262        let (addr, mut rx) = server::serve::<u64>(
2263            ChannelAddr::Unix(unix::SocketAddr::from_abstract_name(&unique_address)?),
2264            None,
2265        )
2266        .unwrap();
2267
2268        // It is important to keep Tx alive until all expected messages are
2269        // received. Otherwise, the channel would be closed when Tx is dropped.
2270        // Although the messages are sent to the server's buffer before the
2271        // channel was closed, NetRx could still error out before taking them
2272        // out of the buffer because NetRx could not ack through the closed
2273        // channel.
2274        {
2275            let tx: ChannelTx<u64> = channel::dial::<u64>(addr.clone()).unwrap();
2276            tx.post(123);
2277            assert_eq!(rx.recv().await.unwrap(), 123);
2278        }
2279
2280        {
2281            let tx = channel::dial::<u64>(addr.clone()).unwrap();
2282            tx.post(321);
2283            tx.post(111);
2284            tx.post(444);
2285
2286            assert_eq!(rx.recv().await.unwrap(), 321);
2287            assert_eq!(rx.recv().await.unwrap(), 111);
2288            assert_eq!(rx.recv().await.unwrap(), 444);
2289        }
2290
2291        {
2292            let tx = channel::dial::<u64>(addr).unwrap();
2293            drop(rx);
2294
2295            let (return_tx, return_rx) = oneshot::channel();
2296            tx.try_post(123, return_tx);
2297            assert_matches!(
2298                return_rx.await,
2299                Ok(SendError {
2300                    error: ChannelError::Closed,
2301                    message: 123,
2302                    ..
2303                })
2304            );
2305        }
2306
2307        Ok(())
2308    }
2309
2310    #[cfg(target_os = "linux")] // uses abstract names
2311    #[tracing_test::traced_test]
2312    #[tokio::test]
2313    async fn test_unix_basic_client_before_server() -> Result<()> {
2314        // We run this test on Unix because we can pick our own port names more easily.
2315        let timestamp = std::time::SystemTime::now()
2316            .duration_since(UNIX_EPOCH)
2317            .unwrap()
2318            .as_nanos();
2319        let socket_addr =
2320            unix::SocketAddr::from_abstract_name(&format!("test_unix_basic_{}", timestamp))
2321                .unwrap();
2322
2323        // Dial the channel before we actually serve it.
2324        let addr = ChannelAddr::Unix(socket_addr.clone());
2325        let tx = crate::channel::dial::<u64>(addr.clone()).unwrap();
2326        tx.post(123);
2327
2328        let (_, mut rx) = server::serve::<u64>(ChannelAddr::Unix(socket_addr), None).unwrap();
2329        assert_eq!(rx.recv().await.unwrap(), 123);
2330
2331        tx.post(321);
2332        tx.post(111);
2333        tx.post(444);
2334
2335        assert_eq!(rx.recv().await.unwrap(), 321);
2336        assert_eq!(rx.recv().await.unwrap(), 111);
2337        assert_eq!(rx.recv().await.unwrap(), 444);
2338
2339        Ok(())
2340    }
2341
2342    #[tracing_test::traced_test]
2343    #[async_timed_test(timeout_secs = 60)]
2344    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
2345    #[cfg_attr(not(fbcode_build), ignore)]
2346    async fn test_tcp_basic() {
2347        let (addr, mut rx) =
2348            server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap()), None).unwrap();
2349        {
2350            let tx = channel::dial::<u64>(addr.clone()).unwrap();
2351            tx.post(123);
2352            assert_eq!(rx.recv().await.unwrap(), 123);
2353        }
2354
2355        {
2356            let tx = channel::dial::<u64>(addr.clone()).unwrap();
2357            tx.post(321);
2358            tx.post(111);
2359            tx.post(444);
2360
2361            assert_eq!(rx.recv().await.unwrap(), 321);
2362            assert_eq!(rx.recv().await.unwrap(), 111);
2363            assert_eq!(rx.recv().await.unwrap(), 444);
2364        }
2365
2366        {
2367            let tx = channel::dial::<u64>(addr).unwrap();
2368            drop(rx);
2369
2370            let (return_tx, return_rx) = oneshot::channel();
2371            tx.try_post(123, return_tx);
2372            assert_matches!(
2373                return_rx.await,
2374                Ok(SendError {
2375                    error: ChannelError::Closed,
2376                    message: 123,
2377                    ..
2378                })
2379            );
2380        }
2381    }
2382
2383    // The message size is limited by CODEC_MAX_FRAME_LENGTH.
2384    #[async_timed_test(timeout_secs = 5)]
2385    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
2386    #[cfg_attr(not(fbcode_build), ignore)]
2387    async fn test_tcp_message_size() {
2388        let default_size_in_bytes = 100 * 1024 * 1024;
2389        // Use temporary config for this test
2390        let config = hyperactor_config::global::lock();
2391        let _guard1 = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1));
2392        let _guard2 = config.override_key(config::CODEC_MAX_FRAME_LENGTH, default_size_in_bytes);
2393
2394        let (addr, mut rx) =
2395            server::serve::<String>(ChannelAddr::Tcp("[::1]:0".parse().unwrap()), None).unwrap();
2396
2397        let tx = channel::dial::<String>(addr.clone()).unwrap();
2398        // Default size is okay
2399        {
2400            // Leave some headroom because Tx will wrap the payload in Frame::Message.
2401            let message = "a".repeat(default_size_in_bytes - 1024);
2402            tx.post(message.clone());
2403            assert_eq!(rx.recv().await.unwrap(), message);
2404        }
2405        // Bigger than the default size will fail.
2406        {
2407            let (return_channel, return_receiver) = oneshot::channel();
2408            let message = "a".repeat(default_size_in_bytes + 1024);
2409            tx.try_post(message.clone(), return_channel);
2410            let returned = return_receiver.await.unwrap();
2411            assert_eq!(message, returned.message);
2412        }
2413    }
2414
2415    #[async_timed_test(timeout_secs = 30)]
2416    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
2417    #[cfg_attr(not(fbcode_build), ignore)]
2418    async fn test_ack_flush() {
2419        let config = hyperactor_config::global::lock();
2420        // Set a large value to effectively prevent acks from being sent except
2421        // during shutdown flush.
2422        let _guard_message_ack =
2423            config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 100000000);
2424        let _guard_delivery_timeout =
2425            config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(5));
2426
2427        let (addr, mut net_rx) =
2428            server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap()), None).unwrap();
2429        let net_tx = channel::dial::<u64>(addr.clone()).unwrap();
2430        let (tx, rx) = oneshot::channel();
2431        net_tx.try_post(1, tx);
2432        assert_eq!(net_rx.recv().await.unwrap(), 1);
2433        drop(net_rx);
2434        // Using `is_err` to confirm the message is delivered/acked is confusing,
2435        // but is correct. See how send is implemented: https://fburl.com/code/ywt8lip2
2436        assert!(rx.await.is_err());
2437    }
2438
2439    #[async_timed_test(timeout_secs = 60)]
2440    // TODO: OSS: failed to retrieve ipv6 address
2441    #[cfg_attr(not(fbcode_build), ignore)]
2442    async fn test_meta_tls_basic() {
2443        hyperactor_telemetry::initialize_logging_for_test();
2444
2445        let addr = ChannelAddr::any(ChannelTransport::MetaTls(TlsMode::IpV6));
2446        let meta_addr = match addr {
2447            ChannelAddr::MetaTls(meta_addr) => meta_addr,
2448            _ => panic!("expected MetaTls address"),
2449        };
2450        let (local_addr, mut rx) =
2451            server::serve::<u64>(ChannelAddr::MetaTls(meta_addr), None).unwrap();
2452        {
2453            let tx = channel::dial::<u64>(local_addr.clone()).unwrap();
2454            tx.post(123);
2455        }
2456        assert_eq!(rx.recv().await.unwrap(), 123);
2457
2458        {
2459            let tx = channel::dial::<u64>(local_addr.clone()).unwrap();
2460            tx.post(321);
2461            tx.post(111);
2462            tx.post(444);
2463            assert_eq!(rx.recv().await.unwrap(), 321);
2464            assert_eq!(rx.recv().await.unwrap(), 111);
2465            assert_eq!(rx.recv().await.unwrap(), 444);
2466        }
2467
2468        {
2469            let tx = channel::dial::<u64>(local_addr).unwrap();
2470            drop(rx);
2471
2472            let (return_tx, return_rx) = oneshot::channel();
2473            tx.try_post(123, return_tx);
2474            assert_matches!(
2475                return_rx.await,
2476                Ok(SendError {
2477                    error: ChannelError::Closed,
2478                    message: 123,
2479                    ..
2480                })
2481            );
2482        }
2483    }
2484
2485    #[derive(Clone, Debug, Default)]
2486    struct NetworkFlakiness {
2487        // A tuple of:
2488        //   1. the probability of a network failure when sending a message.
2489        //   2. the max number of disconnections allowed.
2490        //   3. the minimum duration between disconnections.
2491        //
2492        //   2 and 3 are useful to prevent frequent disconnections leading to
2493        //   unacked messages being sent repeatedly.
2494        disconnect_params: Option<(f64, u64, Duration)>,
2495        // The max possible latency when sending a message. The actual latency
2496        // is randomly generated between 0 and max_latency.
2497        latency_range: Option<(Duration, Duration)>,
2498    }
2499
2500    impl NetworkFlakiness {
2501        // Calculate whether to disconnect
2502        async fn should_disconnect(
2503            &self,
2504            rng: &mut impl rand::Rng,
2505            disconnected_count: u64,
2506            prev_disconnected_at: &RwLock<Instant>,
2507        ) -> bool {
2508            let Some((prob, max_disconnects, duration)) = &self.disconnect_params else {
2509                return false;
2510            };
2511
2512            let disconnected_at = prev_disconnected_at.read().unwrap();
2513            if disconnected_at.elapsed() > *duration && disconnected_count < *max_disconnects {
2514                rng.random_bool(*prob)
2515            } else {
2516                false
2517            }
2518        }
2519    }
2520
2521    struct MockLink<M> {
2522        buffer_size: usize,
2523        session_id: SessionId,
2524        receiver_storage: Arc<MVar<DuplexStream>>,
2525        // If true, `next()` on this link will always return an error.
2526        fail_connects: Arc<AtomicBool>,
2527        // Used to break the existing connection, if there is one. It still
2528        // allows reconnect.
2529        disconnect_signal: watch::Sender<()>,
2530        network_flakiness: NetworkFlakiness,
2531        disconnected_count: Arc<AtomicU64>,
2532        prev_disconnected_at: Arc<RwLock<Instant>>,
2533        // If set, print logs every `debug_log_sampling_rate` messages. This
2534        // is normally set only when debugging a test failure.
2535        debug_log_sampling_rate: Option<u64>,
2536        _message_type: PhantomData<M>,
2537    }
2538
2539    impl<M> fmt::Debug for MockLink<M> {
2540        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2541            f.debug_struct("MockLink")
2542                .field("buffer_size", &self.buffer_size)
2543                .field("receiver_storage", &"<MVar<DuplexStream>>")
2544                .field("fail_connects", &self.fail_connects)
2545                .field("disconnect_signal", &"<watch::Sender>")
2546                .field("network_flakiness", &self.network_flakiness)
2547                .field("disconnected_count", &self.disconnected_count)
2548                .field("prev_disconnected_at", &"<RwLock<Instant>>")
2549                .field("debug_log_sampling_rate", &self.debug_log_sampling_rate)
2550                .finish()
2551        }
2552    }
2553
2554    impl<M: RemoteMessage> MockLink<M> {
2555        fn new() -> Self {
2556            let (sender, _) = watch::channel(());
2557            Self {
2558                buffer_size: 64,
2559                session_id: SessionId::random(),
2560                receiver_storage: Arc::new(MVar::empty()),
2561                fail_connects: Arc::new(AtomicBool::new(false)),
2562                disconnect_signal: sender,
2563                network_flakiness: NetworkFlakiness::default(),
2564                disconnected_count: Arc::new(AtomicU64::new(0)),
2565                prev_disconnected_at: Arc::new(RwLock::new(tokio::time::Instant::now())),
2566                debug_log_sampling_rate: None,
2567                _message_type: PhantomData,
2568            }
2569        }
2570
2571        // If `fail_connects` is true, `next()` on this link will
2572        // always return an error.
2573        fn fail_connects() -> Self {
2574            Self {
2575                fail_connects: Arc::new(AtomicBool::new(true)),
2576                ..Self::new()
2577            }
2578        }
2579
2580        fn with_network_flakiness(network_flakiness: NetworkFlakiness) -> Self {
2581            if let Some((min, max)) = network_flakiness.latency_range {
2582                assert!(min < max);
2583            }
2584
2585            Self {
2586                network_flakiness,
2587                ..Self::new()
2588            }
2589        }
2590
2591        fn receiver_storage(&self) -> Arc<MVar<DuplexStream>> {
2592            self.receiver_storage.clone()
2593        }
2594
2595        fn disconnected_count(&self) -> Arc<AtomicU64> {
2596            self.disconnected_count.clone()
2597        }
2598
2599        fn disconnect_signal(&self) -> &watch::Sender<()> {
2600            &self.disconnect_signal
2601        }
2602
2603        fn fail_connects_switch(&self) -> Arc<AtomicBool> {
2604            self.fail_connects.clone()
2605        }
2606
2607        fn set_buffer_size(&mut self, size: usize) {
2608            self.buffer_size = size;
2609        }
2610
2611        fn set_sampling_rate(&mut self, sampling_rate: u64) {
2612            self.debug_log_sampling_rate = Some(sampling_rate);
2613        }
2614    }
2615
2616    #[async_trait]
2617    impl<M: RemoteMessage> Link for MockLink<M> {
2618        type Stream = DuplexStream;
2619
2620        fn dest(&self) -> ChannelAddr {
2621            ChannelAddr::Local(u64::MAX)
2622        }
2623
2624        fn link_id(&self) -> SessionId {
2625            self.session_id
2626        }
2627
2628        async fn next(&mut self) -> Result<Self::Stream, ClientError> {
2629            let session_id = self.session_id;
2630            tracing::debug!("MockLink starts to connect.");
2631            if self.fail_connects.load(Ordering::Acquire) {
2632                return Err(ClientError::Connect(
2633                    self.dest(),
2634                    std::io::Error::other("intentional error"),
2635                    "expected failure injected by the mock".to_string(),
2636                ));
2637            }
2638
2639            // Add relays between server and client streams. The
2640            // relays provides the place to inject network flakiness.
2641            // The message flow looks like:
2642            //
2643            // server <-> server relay <-> injection logic <-> client relay <-> client
2644            async fn relay_message<M: RemoteMessage>(
2645                mut disconnect_signal: watch::Receiver<()>,
2646                network_flakiness: NetworkFlakiness,
2647                disconnected_count: Arc<AtomicU64>,
2648                prev_disconnected_at: Arc<RwLock<Instant>>,
2649                mut reader: FrameReader<ReadHalf<DuplexStream>>,
2650                mut writer: WriteHalf<DuplexStream>,
2651                // Used by client and server tokio tasks to coordinate
2652                // stopping together.
2653                task_coordination_token: CancellationToken,
2654                debug_log_sampling_rate: Option<u64>,
2655                // Whether the relayed message is from client to
2656                // server.
2657                is_from_client: bool,
2658            ) {
2659                // Used to simulate latency. Briefly, messages are
2660                // buffered in the queue and wait for the expected
2661                // latency elapse.
2662                async fn wait_for_latency_elapse(
2663                    queue: &VecDeque<(Bytes, Instant)>,
2664                    network_flakiness: &NetworkFlakiness,
2665                    rng: &mut impl rand::Rng,
2666                ) {
2667                    if let Some((min, max)) = network_flakiness.latency_range {
2668                        let diff = max.abs_diff(min);
2669                        let factor = rng.random_range(0.0..=1.0);
2670                        let latency = min + diff.mul_f64(factor);
2671                        tokio::time::sleep_until(queue.front().unwrap().1 + latency).await;
2672                    }
2673                }
2674
2675                let mut rng = rand::rngs::SmallRng::try_from_rng(&mut SysRng).unwrap();
2676                let mut queue: VecDeque<(Bytes, Instant)> = VecDeque::new();
2677                let mut send_count = 0u64;
2678
2679                loop {
2680                    tokio::select! {
2681                        read_res = reader.next() => {
2682                            match read_res {
2683                                Ok(Some((_, data))) => {
2684                                    queue.push_back((data, tokio::time::Instant::now()));
2685                                }
2686                                Ok(None) | Err(_) => {
2687                                        tracing::debug!("The upstream is closed or dropped. MockLink disconnects");
2688                                        break;
2689                                }
2690                            }
2691                        }
2692                        _ = wait_for_latency_elapse(&queue, &network_flakiness, &mut rng), if !queue.is_empty() => {
2693                            let count = disconnected_count.load(Ordering::Relaxed);
2694                            if network_flakiness.should_disconnect(&mut rng, count, &prev_disconnected_at).await {
2695                                tracing::debug!("MockLink disconnects");
2696                                disconnected_count.fetch_add(1, Ordering::Relaxed);
2697
2698                                metrics::CHANNEL_RECONNECTIONS.add(
2699                                    1,
2700                                    hyperactor_telemetry::kv_pairs!(
2701                                        "transport" => "mock",
2702                                        "reason" => "network_flakiness",
2703                                    ),
2704                                );
2705
2706                                let mut w = prev_disconnected_at.write().unwrap();
2707                                *w = tokio::time::Instant::now();
2708                                break;
2709                            }
2710                            let data = queue.pop_front().unwrap().0;
2711                            let is_sampled = debug_log_sampling_rate.is_some_and(|sample_rate| send_count % sample_rate == 1);
2712                            if is_sampled {
2713                                if is_from_client {
2714                                    if let Ok((Frame::Message(_seq, _msg), _)) = bincode::serde::decode_from_slice::<Frame<M>, _>(&data, bincode::config::legacy()) {
2715                                        tracing::debug!("MockLink relays a msg from client. msg type: {}", std::any::type_name::<M>());
2716                                    }
2717                                } else {
2718                                    let result = deserialize_response(data.clone());
2719                                    if let Ok(NetRxResponse::Ack(seq)) = result {
2720                                        tracing::debug!("MockLink relays an ack from server. seq: {}", seq);
2721                                    }
2722                                }
2723                            }
2724                            let mut fw  = FrameWrite::new(writer, data, hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH), 0).unwrap();
2725                            if fw.send().await.is_err() {
2726                                break;
2727                            }
2728                            writer = fw.complete();
2729                            send_count += 1;
2730                        }
2731                        _ = task_coordination_token.cancelled() => break,
2732
2733                        changed = disconnect_signal.changed() => {
2734                            tracing::debug!("MockLink disconnects per disconnect_signal {:?}", changed);
2735                            break;
2736                        }
2737                    }
2738                }
2739
2740                task_coordination_token.cancel();
2741            }
2742
2743            let (server, mut server_relay) = tokio::io::duplex(self.buffer_size);
2744            let (client, client_relay) = tokio::io::duplex(self.buffer_size);
2745
2746            // Write LinkInit on server_relay so it's readable from `server`.
2747            // This simulates the client sending LinkInit over the wire before
2748            // the frame-level relay begins.
2749            write_link_init(&mut server_relay, session_id, 0)
2750                .await
2751                .map_err(|err| ClientError::Io(self.dest(), err))?;
2752
2753            let (server_r, server_writer) = tokio::io::split(server_relay);
2754            let (client_r, client_writer) = tokio::io::split(client_relay);
2755
2756            let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
2757            let server_reader = FrameReader::new(server_r, max_len);
2758            let client_reader = FrameReader::new(client_r, max_len);
2759
2760            let task_coordination_token = CancellationToken::new();
2761            let _server_relay_task_handle = tokio::spawn(relay_message::<M>(
2762                self.disconnect_signal.subscribe(),
2763                self.network_flakiness.clone(),
2764                self.disconnected_count.clone(),
2765                self.prev_disconnected_at.clone(),
2766                server_reader,
2767                client_writer,
2768                task_coordination_token.clone(),
2769                self.debug_log_sampling_rate,
2770                /*is_from_client*/ false,
2771            ));
2772            let _client_relay_task_handle = tokio::spawn(relay_message::<M>(
2773                self.disconnect_signal.subscribe(),
2774                self.network_flakiness.clone(),
2775                self.disconnected_count.clone(),
2776                self.prev_disconnected_at.clone(),
2777                client_reader,
2778                server_writer,
2779                task_coordination_token,
2780                self.debug_log_sampling_rate,
2781                /*is_from_client*/ true,
2782            ));
2783
2784            self.receiver_storage.put(server).await;
2785            Ok(client)
2786        }
2787    }
2788
2789    struct MockLinkListener {
2790        receiver_storage: Arc<MVar<DuplexStream>>,
2791        channel_addr: ChannelAddr,
2792    }
2793
2794    impl MockLinkListener {
2795        fn new(receiver_storage: Arc<MVar<DuplexStream>>, channel_addr: ChannelAddr) -> Self {
2796            Self {
2797                receiver_storage,
2798                channel_addr,
2799            }
2800        }
2801    }
2802
2803    impl fmt::Debug for MockLinkListener {
2804        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2805            f.debug_struct("MockLinkListener")
2806                .field("channel_addr", &self.channel_addr)
2807                .finish()
2808        }
2809    }
2810
2811    #[async_trait]
2812    impl super::Listener for MockLinkListener {
2813        type Stream = DuplexStream;
2814
2815        async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
2816            let stream = self.receiver_storage.take().await;
2817            Ok((stream, self.channel_addr.clone()))
2818        }
2819    }
2820
2821    /// Create an AcceptorLink-based server test rig. Returns the
2822    /// session task handle, the channel sender for dispatching
2823    /// streams, the message receiver, and a cancellation token.
2824    fn serve_acceptor_test<M: RemoteMessage>(
2825        session_id: SessionId,
2826    ) -> (
2827        JoinHandle<()>,
2828        mpsc::UnboundedSender<DuplexStream>,
2829        mpsc::Receiver<M>,
2830        CancellationToken,
2831    ) {
2832        let (acceptor_tx, acceptor_rx) = mpsc::unbounded_channel::<DuplexStream>();
2833        let cancel_token = CancellationToken::new();
2834        let link = AcceptorLink {
2835            dest: ChannelAddr::Local(u64::MAX),
2836            session_id,
2837            stream: acceptor_rx,
2838            cancel: cancel_token.clone(),
2839        };
2840        let (tx, rx) = mpsc::channel::<M>(1024);
2841        let ct = cancel_token.clone();
2842        let handle = tokio::spawn(async move {
2843            let mut session = Session::new(link);
2844            let mut next = session::Next { seq: 0, ack: 0 };
2845
2846            loop {
2847                let connected = match session.connect().await {
2848                    Ok(s) => s,
2849                    Err(_) => break,
2850                };
2851
2852                let result = {
2853                    let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2854                    tokio::select! {
2855                        r = session::recv_connected::<M, _, _>(&stream, &tx, &mut next) => r,
2856                        _ = ct.cancelled() => Err(session::RecvLoopError::Cancelled),
2857                    }
2858                };
2859
2860                // Flush remaining ack if behind.
2861                if next.ack < next.seq {
2862                    let ack = serialize_response(NetRxResponse::Ack(next.seq - 1)).unwrap();
2863                    let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2864                    let mut completion = stream.write(ack);
2865                    match completion.drive().await {
2866                        Ok(()) => {
2867                            next.ack = next.seq;
2868                        }
2869                        Err(e) => {
2870                            tracing::debug!(
2871                                error = %e,
2872                                "failed to flush acks during cleanup"
2873                            );
2874                        }
2875                    }
2876                }
2877
2878                // Send reject or closed response if appropriate.
2879                let terminal_response = match &result {
2880                    Err(session::RecvLoopError::SequenceError(reason)) => {
2881                        Some(NetRxResponse::Reject(reason.clone()))
2882                    }
2883                    Err(session::RecvLoopError::Cancelled) => Some(NetRxResponse::Closed),
2884                    _ => None,
2885                };
2886                if let Some(rsp) = terminal_response {
2887                    let data = serialize_response(rsp).unwrap();
2888                    let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2889                    let mut completion = stream.write(data);
2890                    let _ = completion.drive().await;
2891                }
2892
2893                let recoverable = matches!(&result, Ok(()) | Err(session::RecvLoopError::Io(_)));
2894                session = connected.release();
2895                if recoverable {
2896                    continue;
2897                }
2898                break;
2899            }
2900        });
2901        (handle, acceptor_tx, rx, cancel_token)
2902    }
2903
2904    async fn write_stream<M, W>(
2905        mut writer: W,
2906        _session_id: u64,
2907        messages: &[(u64, M)],
2908        _init: bool,
2909    ) -> W
2910    where
2911        M: RemoteMessage + PartialEq + Clone,
2912        W: AsyncWrite + Unpin,
2913    {
2914        for (seq, message) in messages {
2915            let message =
2916                serde_multipart::serialize_bincode(&Frame::<M>::Message(*seq, message.clone()))
2917                    .unwrap();
2918            let mut fw = FrameWrite::new(
2919                writer,
2920                message.framed(),
2921                hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2922                0,
2923            )
2924            .map_err(|(_w, e)| e)
2925            .unwrap();
2926            fw.send().await.unwrap();
2927            writer = fw.complete();
2928        }
2929
2930        writer
2931    }
2932
2933    #[async_timed_test(timeout_secs = 60)]
2934    async fn test_persistent_server_session() {
2935        let config = hyperactor_config::global::lock();
2936        let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
2937
2938        async fn verify_ack(reader: &mut FrameReader<ReadHalf<DuplexStream>>, expected_last: u64) {
2939            let mut last_acked: i128 = -1;
2940            loop {
2941                let (_, bytes) = reader.next().await.unwrap().unwrap();
2942                let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
2943                assert!(
2944                    acked as i128 > last_acked,
2945                    "acks should be delivered in ascending order"
2946                );
2947                last_acked = acked as i128;
2948                assert!(acked <= expected_last);
2949                if acked == expected_last {
2950                    break;
2951                }
2952            }
2953        }
2954
2955        let session_id = SessionId(123);
2956        let (_handle, acceptor_tx, mut rx, cancel_token) = serve_acceptor_test::<u64>(session_id);
2957
2958        // First connection: send messages, verify delivery and ack.
2959        {
2960            let (sender, receiver) = tokio::io::duplex(5000);
2961            acceptor_tx.send(receiver).unwrap();
2962
2963            let (r, writer) = tokio::io::split(sender);
2964            let mut reader = FrameReader::new(
2965                r,
2966                hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2967            );
2968
2969            let _writer = write_stream(
2970                writer,
2971                123,
2972                &[
2973                    (0u64, 100u64),
2974                    (1u64, 101u64),
2975                    (2u64, 102u64),
2976                    (3u64, 103u64),
2977                ],
2978                true,
2979            )
2980            .await;
2981
2982            assert_eq!(rx.recv().await, Some(100));
2983            assert_eq!(rx.recv().await, Some(101));
2984            assert_eq!(rx.recv().await, Some(102));
2985            assert_eq!(rx.recv().await, Some(103));
2986
2987            verify_ack(&mut reader, 3).await;
2988            // Drop reader and writer to close the connection.
2989        }
2990
2991        // Second connection (reconnection): retransmitted messages are deduped.
2992        {
2993            let (sender2, receiver2) = tokio::io::duplex(5000);
2994            acceptor_tx.send(receiver2).unwrap();
2995
2996            let (r2, writer2) = tokio::io::split(sender2);
2997            let mut reader2 = FrameReader::new(
2998                r2,
2999                hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3000            );
3001
3002            let _ = write_stream(
3003                writer2,
3004                123,
3005                &[
3006                    (2u64, 102u64),
3007                    (3u64, 103u64),
3008                    (4u64, 104u64),
3009                    (5u64, 105u64),
3010                ],
3011                true,
3012            )
3013            .await;
3014
3015            // 102 and 103 are retransmits; only 104 and 105 are new.
3016            assert_eq!(rx.recv().await, Some(104));
3017            assert_eq!(rx.recv().await, Some(105));
3018
3019            verify_ack(&mut reader2, 5).await;
3020
3021            cancel_token.cancel();
3022        }
3023    }
3024
3025    #[async_timed_test(timeout_secs = 60)]
3026    async fn test_ack_from_server_session() {
3027        let config = hyperactor_config::global::lock();
3028        let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
3029        let session_id = SessionId(123);
3030        let (_handle, acceptor_tx, mut rx, cancel_token) = serve_acceptor_test::<u64>(session_id);
3031
3032        let (sender, receiver) = tokio::io::duplex(5000);
3033        acceptor_tx.send(receiver).unwrap();
3034        let (r, mut writer) = tokio::io::split(sender);
3035        let mut reader = FrameReader::new(
3036            r,
3037            hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3038        );
3039
3040        for i in 0u64..100u64 {
3041            writer = write_stream(writer, 123, &[(i, 100u64 + i)], /*init*/ i == 0u64).await;
3042            assert_eq!(rx.recv().await, Some(100u64 + i));
3043            let (_, bytes) = reader.next().await.unwrap().unwrap();
3044            let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
3045            assert_eq!(acked, i);
3046        }
3047
3048        // Wait long enough to ensure server processed everything.
3049        tokio::time::sleep(Duration::from_secs(5)).await;
3050
3051        cancel_token.cancel();
3052
3053        // Should send NetRxResponse::Closed before stopping.
3054        let (_, bytes) = reader.next().await.unwrap().unwrap();
3055        assert!(deserialize_response(bytes).unwrap().is_closed());
3056    }
3057
3058    #[tracing_test::traced_test]
3059    async fn verify_tx_closed(tx_status: &mut watch::Receiver<TxStatus>, expected_log: &str) {
3060        match tokio::time::timeout(Duration::from_secs(5), tx_status.changed()).await {
3061            Ok(Ok(())) => {
3062                let current_status = tx_status.borrow().clone();
3063                assert!(current_status.is_closed());
3064                logs_assert_unscoped(|logs| {
3065                    if logs.iter().any(|log| log.contains(expected_log)) {
3066                        Ok(())
3067                    } else {
3068                        Err("expected log not found".to_string())
3069                    }
3070                });
3071            }
3072            Ok(Err(_)) => panic!("watch::Receiver::changed() failed because sender is dropped."),
3073            Err(_) => panic!("timeout before tx_status changed"),
3074        }
3075    }
3076
3077    #[tracing_test::traced_test]
3078    #[tokio::test]
3079    // TODO: OSS: The logs_assert function returned an error: expected log not found
3080    #[cfg_attr(not(fbcode_build), ignore)]
3081    async fn test_tcp_tx_delivery_timeout() {
3082        // This link always fails to connect.
3083        let link = MockLink::<u64>::fail_connects();
3084        let tx = spawn::<u64>(link);
3085        // Override the default (1m) for the purposes of this test.
3086        let config = hyperactor_config::global::lock();
3087        let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1));
3088        let mut tx_receiver = tx.status().clone();
3089        let (return_channel, _return_receiver) = oneshot::channel();
3090        tx.try_post(123, return_channel);
3091        verify_tx_closed(&mut tx_receiver, "failed to deliver message within timeout").await;
3092    }
3093
3094    async fn take_receiver(
3095        receiver_storage: &MVar<DuplexStream>,
3096    ) -> (FrameReader<ReadHalf<DuplexStream>>, WriteHalf<DuplexStream>) {
3097        let mut receiver = receiver_storage.take().await;
3098        // Read and discard the LinkInit header that MockLink::connect() writes.
3099        let _link_init = read_link_init(&mut receiver).await.expect("read LinkInit");
3100        let (r, writer) = tokio::io::split(receiver);
3101        let reader = FrameReader::new(
3102            r,
3103            hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3104        );
3105        (reader, writer)
3106    }
3107
3108    async fn verify_message<M: RemoteMessage + PartialEq + std::fmt::Debug>(
3109        reader: &mut FrameReader<ReadHalf<DuplexStream>>,
3110        expect: (u64, M),
3111        loc: u32,
3112    ) {
3113        let expected = Frame::Message(expect.0, expect.1);
3114        let (_, bytes) = reader.next().await.unwrap().expect("unexpected EOF");
3115        let message = serde_multipart::Message::from_framed(bytes).unwrap();
3116        let frame: Frame<M> = serde_multipart::deserialize_bincode(message).unwrap();
3117
3118        assert_eq!(frame, expected, "from ln={loc}");
3119    }
3120
3121    async fn verify_stream<M: RemoteMessage + PartialEq + std::fmt::Debug + Clone>(
3122        reader: &mut FrameReader<ReadHalf<DuplexStream>>,
3123        expects: &[(u64, M)],
3124        _expect_session_id: Option<u64>,
3125        loc: u32,
3126    ) {
3127        for expect in expects {
3128            verify_message(reader, expect.clone(), loc).await;
3129        }
3130    }
3131
3132    async fn net_tx_send(tx: &NetTx<u64>, msgs: &[u64]) {
3133        for msg in msgs {
3134            tx.post(*msg);
3135        }
3136    }
3137
3138    // Happy path: all messages are acked.
3139    #[async_timed_test(timeout_secs = 30)]
3140    async fn test_ack_in_net_tx_basic() {
3141        let link = MockLink::<u64>::new();
3142        let receiver_storage = link.receiver_storage();
3143        let tx = spawn::<u64>(link);
3144
3145        // Send some messages, but not acking any of them.
3146        net_tx_send(&tx, &[100, 101, 102, 103, 104]).await;
3147        {
3148            let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
3149            verify_stream(
3150                &mut reader,
3151                &[
3152                    (0u64, 100u64),
3153                    (1u64, 101u64),
3154                    (2u64, 102u64),
3155                    (3u64, 103u64),
3156                    (4u64, 104u64),
3157                ],
3158                None,
3159                line!(),
3160            )
3161            .await;
3162
3163            for i in 0u64..5u64 {
3164                writer = FrameWrite::write_frame(
3165                    writer,
3166                    serialize_response(NetRxResponse::Ack(i)).unwrap(),
3167                    1024,
3168                    0,
3169                )
3170                .await
3171                .map_err(|(_, e)| e)
3172                .unwrap();
3173            }
3174            // Wait for the acks to be processed by NetTx.
3175            tokio::time::sleep(Duration::from_secs(3)).await;
3176            // Drop both halves to break the in-memory connection (parity with old drop of DuplexStream).
3177            drop(reader);
3178            drop(writer);
3179        };
3180
3181        // Sent a new message to verify all sent messages will not be resent.
3182        net_tx_send(&tx, &[105u64]).await;
3183        {
3184            let (mut reader, _writer) = take_receiver(&receiver_storage).await;
3185            verify_stream(&mut reader, &[(5u64, 105u64)], None, line!()).await;
3186            // Reader/writer dropped here. This breaks the connection.
3187        };
3188    }
3189
3190    // Verify unacked message will be resent after reconnection.
3191    #[async_timed_test(timeout_secs = 60)]
3192    async fn test_persistent_net_tx() {
3193        let link = MockLink::<u64>::new();
3194        let receiver_storage = link.receiver_storage();
3195
3196        let tx = spawn::<u64>(link);
3197
3198        // Send some messages, but not acking any of them.
3199        net_tx_send(&tx, &[100, 101, 102, 103, 104]).await;
3200
3201        // How many times to reconnect. Keep this small because the send loop
3202        // applies exponential backoff between reconnections, and mock connections
3203        // are too short-lived to trigger the backoff reset.
3204        let n = 3;
3205
3206        // Reconnect multiple times. The messages should be resent every time
3207        // because none of them is acked.
3208        for i in 0..n {
3209            {
3210                let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
3211                verify_stream(
3212                    &mut reader,
3213                    &[
3214                        (0u64, 100u64),
3215                        (1u64, 101u64),
3216                        (2u64, 102u64),
3217                        (3u64, 103u64),
3218                        (4u64, 104u64),
3219                    ],
3220                    None,
3221                    line!(),
3222                )
3223                .await;
3224
3225                // In the last iteration, ack part of the messages. This should
3226                // prune them from future resent.
3227                if i == n - 1 {
3228                    writer = FrameWrite::write_frame(
3229                        writer,
3230                        serialize_response(NetRxResponse::Ack(1)).unwrap(),
3231                        1024,
3232                        0,
3233                    )
3234                    .await
3235                    .map_err(|(_, e)| e)
3236                    .unwrap();
3237                    // Wait for the acks to be processed by NetTx.
3238                    tokio::time::sleep(Duration::from_secs(3)).await;
3239                }
3240                // client DuplexStream is dropped here. This breaks the connection.
3241                drop(reader);
3242                drop(writer);
3243            };
3244        }
3245
3246        // Verify only unacked are resent.
3247        for _ in 0..n {
3248            {
3249                let (mut reader, mut _writer) = take_receiver(&receiver_storage).await;
3250                verify_stream(
3251                    &mut reader,
3252                    &[(2u64, 102u64), (3u64, 103u64), (4u64, 104u64)],
3253                    None,
3254                    line!(),
3255                )
3256                .await;
3257                // drop(reader/_writer) at scope end
3258            };
3259        }
3260
3261        // Now send more messages.
3262        net_tx_send(&tx, &[105u64, 106u64, 107u64, 108u64, 109u64]).await;
3263        // Verify the unacked messages from the 1st send will be grouped with
3264        // the 2nd send.
3265        for i in 0..n {
3266            {
3267                let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
3268                verify_stream(
3269                    &mut reader,
3270                    &[
3271                        // From the 1st send.
3272                        (2u64, 102u64),
3273                        (3u64, 103u64),
3274                        (4u64, 104u64),
3275                        // From the 2nd send.
3276                        (5u64, 105u64),
3277                        (6u64, 106u64),
3278                        (7u64, 107u64),
3279                        (8u64, 108u64),
3280                        (9u64, 109u64),
3281                    ],
3282                    None,
3283                    line!(),
3284                )
3285                .await;
3286
3287                // In the last iteration, ack part of the messages from the 1st
3288                // sent.
3289                if i == n - 1 {
3290                    // Intentionally ack 1 again to verify it is okay to ack
3291                    // messages that was already acked.
3292                    writer = FrameWrite::write_frame(
3293                        writer,
3294                        serialize_response(NetRxResponse::Ack(1)).unwrap(),
3295                        1024,
3296                        0,
3297                    )
3298                    .await
3299                    .map_err(|(_, e)| e)
3300                    .unwrap();
3301                    writer = FrameWrite::write_frame(
3302                        writer,
3303                        serialize_response(NetRxResponse::Ack(2)).unwrap(),
3304                        1024,
3305                        0,
3306                    )
3307                    .await
3308                    .map_err(|(_, e)| e)
3309                    .unwrap();
3310                    writer = FrameWrite::write_frame(
3311                        writer,
3312                        serialize_response(NetRxResponse::Ack(3)).unwrap(),
3313                        1024,
3314                        0,
3315                    )
3316                    .await
3317                    .map_err(|(_, e)| e)
3318                    .unwrap();
3319                    // Wait for the acks to be processed by NetTx.
3320                    tokio::time::sleep(Duration::from_secs(3)).await;
3321                }
3322                // client DuplexStream is dropped here. This breaks the connection.
3323                drop(reader);
3324                drop(writer);
3325            };
3326        }
3327
3328        for i in 0..n {
3329            {
3330                let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
3331                verify_stream(
3332                    &mut reader,
3333                    &[
3334                        // From the 1st send.
3335                        (4u64, 104),
3336                        // From the 2nd send.
3337                        (5u64, 105u64),
3338                        (6u64, 106u64),
3339                        (7u64, 107u64),
3340                        (8u64, 108u64),
3341                        (9u64, 109u64),
3342                    ],
3343                    None,
3344                    line!(),
3345                )
3346                .await;
3347
3348                // In the last iteration, ack part of the messages from the 2nd send.
3349                if i == n - 1 {
3350                    writer = FrameWrite::write_frame(
3351                        writer,
3352                        serialize_response(NetRxResponse::Ack(7)).unwrap(),
3353                        1024,
3354                        0,
3355                    )
3356                    .await
3357                    .map_err(|(_, e)| e)
3358                    .unwrap();
3359                    // Wait for the acks to be processed by NetTx.
3360                    tokio::time::sleep(Duration::from_secs(3)).await;
3361                }
3362                // client DuplexStream is dropped here. This breaks the connection.
3363                drop(reader);
3364                drop(writer);
3365            };
3366        }
3367
3368        for _ in 0..n {
3369            {
3370                let (mut reader, writer) = take_receiver(&receiver_storage).await;
3371                verify_stream(
3372                    &mut reader,
3373                    &[
3374                        // From the 2nd send.
3375                        (8u64, 108u64),
3376                        (9u64, 109u64),
3377                    ],
3378                    None,
3379                    line!(),
3380                )
3381                .await;
3382                // client DuplexStream is dropped here. This breaks the connection.
3383                drop(reader);
3384                drop(writer);
3385            };
3386        }
3387    }
3388
3389    #[async_timed_test(timeout_secs = 15)]
3390    async fn test_ack_before_redelivery_in_net_tx() {
3391        let link = MockLink::<u64>::new();
3392        let receiver_storage = link.receiver_storage();
3393        let net_tx = spawn::<u64>(link);
3394
3395        // Verify sent-and-ack a message. This is necessary for the test to
3396        // trigger a connection.
3397        let (return_channel_tx, return_channel_rx) = oneshot::channel();
3398        net_tx.try_post(100, return_channel_tx);
3399        let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
3400        verify_stream(&mut reader, &[(0u64, 100u64)], None, line!()).await;
3401        // ack it
3402        writer = FrameWrite::write_frame(
3403            writer,
3404            serialize_response(NetRxResponse::Ack(0)).unwrap(),
3405            1024,
3406            0,
3407        )
3408        .await
3409        .map_err(|(_, e)| e)
3410        .unwrap();
3411        // confirm Tx received ack
3412        //
3413        // Using `is_err` to confirm the message is delivered/acked is confusing,
3414        // but is correct. See how send is implemented: https://fburl.com/code/ywt8lip2
3415        assert!(return_channel_rx.await.is_err());
3416
3417        // Now fake an unknown delivery for Tx:
3418        // Although Tx did not actually send seq=1, we still ack it from Rx to
3419        // pretend Tx already sent it, just it did not know it was sent
3420        // successfully.
3421        let _ = FrameWrite::write_frame(
3422            writer,
3423            serialize_response(NetRxResponse::Ack(1)).unwrap(),
3424            1024,
3425            0,
3426        )
3427        .await
3428        .map_err(|(_, e)| e)
3429        .unwrap();
3430
3431        let (return_channel_tx, return_channel_rx) = oneshot::channel();
3432        net_tx.try_post(101, return_channel_tx);
3433        // Verify the message is sent to Rx.
3434        verify_message(&mut reader, (1u64, 101u64), line!()).await;
3435        // although we did not ack the message after it is sent, since we already
3436        // acked it previously, Tx will treat it as acked, and considered the
3437        // message delivered successfully.
3438        //
3439        // Using `is_err` to confirm the message is delivered/acked is confusing,
3440        // but is correct. See how send is implemented: https://fburl.com/code/ywt8lip2
3441        assert!(return_channel_rx.await.is_err());
3442    }
3443
3444    async fn verify_ack_exceeded_limit(disconnect_before_ack: bool) {
3445        // Use temporary config for this test
3446        let config = hyperactor_config::global::lock();
3447        let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(2));
3448
3449        let link: MockLink<u64> = MockLink::<u64>::new();
3450        let disconnect_signal = link.disconnect_signal().clone();
3451        let fail_connect_switch = link.fail_connects_switch();
3452        let receiver_storage = link.receiver_storage();
3453        let tx = spawn::<u64>(link);
3454        let mut tx_status = tx.status().clone();
3455        // send a message
3456        tx.post(100);
3457        let (mut reader, writer) = take_receiver(&receiver_storage).await;
3458        // Confirm message is sent to rx.
3459        verify_stream(&mut reader, &[(0u64, 100u64)], None, line!()).await;
3460        // ack it
3461        let _ = FrameWrite::write_frame(
3462            writer,
3463            serialize_response(NetRxResponse::Ack(0)).unwrap(),
3464            hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3465            0,
3466        )
3467        .await
3468        .map_err(|(_, e)| e)
3469        .unwrap();
3470        tokio::time::sleep(Duration::from_secs(3)).await;
3471        // Channel should be still alive because ack was sent.
3472        assert!(!tx_status.has_changed().unwrap());
3473        assert_eq!(*tx_status.borrow(), TxStatus::Active);
3474
3475        tx.post(101);
3476        // Confirm message is sent to rx.
3477        verify_message(&mut reader, (1u64, 101u64), line!()).await;
3478
3479        if disconnect_before_ack {
3480            // Prevent link from reconnect
3481            fail_connect_switch.store(true, Ordering::Release);
3482            // Break the existing connection
3483            disconnect_signal.send(()).unwrap();
3484        }
3485
3486        // Verify the channel is closed due to ack timeout based on the log.
3487        let expected_log: &str = if disconnect_before_ack {
3488            "failed to receive ack within timeout 2s; link is currently broken"
3489        } else {
3490            "failed to receive ack within timeout 2s; link is currently connected"
3491        };
3492
3493        verify_tx_closed(&mut tx_status, expected_log).await;
3494    }
3495
3496    #[tracing_test::traced_test]
3497    #[async_timed_test(timeout_secs = 30)]
3498    // TODO: OSS: The logs_assert function returned an error: expected log not found
3499    #[cfg_attr(not(fbcode_build), ignore)]
3500    async fn test_ack_exceeded_limit_with_connected_link() {
3501        verify_ack_exceeded_limit(false).await;
3502    }
3503
3504    #[tracing_test::traced_test]
3505    #[async_timed_test(timeout_secs = 30)]
3506    // TODO: OSS: The logs_assert function returned an error: expected log not found
3507    #[cfg_attr(not(fbcode_build), ignore)]
3508    async fn test_ack_exceeded_limit_with_broken_link() {
3509        verify_ack_exceeded_limit(true).await;
3510    }
3511
3512    // Verify a large number of messages can be delivered and acked with the
3513    // presence of flakiness in the network, i.e. random delay and disconnection.
3514    #[async_timed_test(timeout_secs = 60)]
3515    async fn test_network_flakiness_in_channel() {
3516        hyperactor_telemetry::initialize_logging_for_test();
3517
3518        let sampling_rate = 100;
3519        let mut link = MockLink::<u64>::with_network_flakiness(NetworkFlakiness {
3520            disconnect_params: Some((0.001, 15, Duration::from_millis(400))),
3521            latency_range: Some((Duration::from_millis(100), Duration::from_millis(200))),
3522        });
3523        link.set_sampling_rate(sampling_rate);
3524        // Set a large buffer size to improve throughput.
3525        link.set_buffer_size(1024000);
3526        let disconnected_count = link.disconnected_count();
3527        let receiver_storage = link.receiver_storage();
3528        let listener = MockLinkListener::new(receiver_storage.clone(), link.dest());
3529        let local_addr = listener.channel_addr.clone();
3530        let (_, mut nx): (ChannelAddr, NetRx<u64>) =
3531            super::server::serve_with_listener(listener, local_addr).unwrap();
3532        let tx = spawn::<u64>(link);
3533        let messages: Vec<_> = (0..10001).collect();
3534        let messages_clone = messages.clone();
3535        // Put the sender side in a separate task so we can start the receiver
3536        // side concurrently.
3537        let send_task_handle = tokio::spawn(async move {
3538            for message in messages_clone {
3539                // Add a small delay between messages to give NetRx time to ack.
3540                // Technically, this test still can pass without this delay. But
3541                // the test will need a might larger timeout. The reason is
3542                // fairly convoluted:
3543                //
3544                // MockLink uses the number of delivery to calculate the disconnection
3545                // probability. If NetRx sends messages much faster than NetTx
3546                // can ack them, there is a higher chance that the messages are
3547                // not acked before reconnect. Then those message would be redelivered.
3548                // The repeated redelivery increases the total time of sending
3549                // these messages.
3550                tokio::time::sleep(Duration::from_micros(rand::random::<u64>() % 100)).await;
3551                tx.post(message);
3552            }
3553            tracing::debug!("NetTx sent all messages");
3554            // It is important to return tx instead of dropping it here, because
3555            // Rx might not receive all messages yet.
3556            tx
3557        });
3558
3559        for message in &messages {
3560            if message % sampling_rate == 0 {
3561                tracing::debug!("NetRx received a message: {message}");
3562            }
3563            assert_eq!(nx.recv().await.unwrap(), *message);
3564        }
3565        tracing::debug!("NetRx received all messages");
3566
3567        let send_result = send_task_handle.await;
3568        assert!(send_result.is_ok());
3569
3570        tracing::debug!(
3571            "MockLink disconnected {} times.",
3572            disconnected_count.load(Ordering::SeqCst)
3573        );
3574        // TODO(pzhang) after the return_handle work in NetTx is done, add a
3575        // check here to verify the messages are acked correctly.
3576    }
3577
3578    #[async_timed_test(timeout_secs = 60)]
3579    async fn test_ack_every_n_messages() {
3580        let config = hyperactor_config::global::lock();
3581        let _guard_message_ack = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 600);
3582        let _guard_time_interval =
3583            config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(1000));
3584        sparse_ack().await;
3585    }
3586
3587    #[async_timed_test(timeout_secs = 60)]
3588    async fn test_ack_every_time_interval() {
3589        let config = hyperactor_config::global::lock();
3590        let _guard_message_ack =
3591            config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 100000000);
3592        let _guard_time_interval = config.override_key(
3593            config::MESSAGE_ACK_TIME_INTERVAL,
3594            Duration::from_millis(500),
3595        );
3596        sparse_ack().await;
3597    }
3598
3599    async fn sparse_ack() {
3600        let mut link = MockLink::<u64>::new();
3601        // Set a large buffer size to improve throughput.
3602        link.set_buffer_size(1024000);
3603        let disconnected_count = link.disconnected_count();
3604        let receiver_storage = link.receiver_storage();
3605        let listener = MockLinkListener::new(receiver_storage.clone(), link.dest());
3606        let local_addr = listener.channel_addr.clone();
3607        let (_, mut nx): (ChannelAddr, NetRx<u64>) =
3608            super::server::serve_with_listener(listener, local_addr).unwrap();
3609        let tx = spawn::<u64>(link);
3610        let messages: Vec<_> = (0..20001).collect();
3611        let messages_clone = messages.clone();
3612        // Put the sender side in a separate task so we can start the receiver
3613        // side concurrently.
3614        let send_task_handle = tokio::spawn(async move {
3615            for message in messages_clone {
3616                tokio::time::sleep(Duration::from_micros(rand::random::<u64>() % 100)).await;
3617                tx.post(message);
3618            }
3619            tokio::time::sleep(Duration::from_secs(5)).await;
3620            tracing::debug!("NetTx sent all messages");
3621            tx
3622        });
3623
3624        for message in &messages {
3625            assert_eq!(nx.recv().await.unwrap(), *message);
3626        }
3627        tracing::debug!("NetRx received all messages");
3628
3629        let send_result = send_task_handle.await;
3630        assert!(send_result.is_ok());
3631
3632        tracing::debug!(
3633            "MockLink disconnected {} times.",
3634            disconnected_count.load(Ordering::SeqCst)
3635        );
3636    }
3637
3638    #[test]
3639    fn test_metatls_parsing() {
3640        // host:port
3641        let channel: ChannelAddr = "metatls!localhost:1234".parse().unwrap();
3642        assert_eq!(
3643            channel,
3644            ChannelAddr::MetaTls(TlsAddr::new("localhost", 1234))
3645        );
3646        // ipv4:port - parsed as hostname with ip normalization
3647        let channel: ChannelAddr = "metatls!1.2.3.4:1234".parse().unwrap();
3648        assert_eq!(channel, ChannelAddr::MetaTls(TlsAddr::new("1.2.3.4", 1234)));
3649        // ipv6:port
3650        let channel: ChannelAddr = "metatls!2401:db00:33c:6902:face:0:2a2:0:1234"
3651            .parse()
3652            .unwrap();
3653        assert_eq!(
3654            channel,
3655            ChannelAddr::MetaTls(TlsAddr::new("2401:db00:33c:6902:face:0:2a2:0", 1234))
3656        );
3657
3658        let channel: ChannelAddr = "metatls![::]:1234".parse().unwrap();
3659        assert_eq!(channel, ChannelAddr::MetaTls(TlsAddr::new("::", 1234)));
3660    }
3661
3662    #[async_timed_test(timeout_secs = 300)]
3663    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
3664    #[cfg_attr(not(fbcode_build), ignore)]
3665    async fn test_tcp_throughput() {
3666        let config = hyperactor_config::global::lock();
3667        let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_mins(5));
3668
3669        let socket_addr: SocketAddr = "[::1]:0".parse().unwrap();
3670        let (local_addr, mut rx) =
3671            server::serve::<String>(ChannelAddr::Tcp(socket_addr), None).unwrap();
3672
3673        // Test with 10 connections (senders), each sends 500K messages, 5M messages in total.
3674        let total_num_msgs = 500000;
3675
3676        let receive_handle = tokio::spawn(async move {
3677            let mut num = 0;
3678            for _ in 0..10 * total_num_msgs {
3679                rx.recv().await.unwrap();
3680                num += 1;
3681
3682                if num % 100000 == 0 {
3683                    tracing::info!("total number of received messages: {}", num);
3684                }
3685            }
3686        });
3687
3688        let mut tx_handles = vec![];
3689        let mut txs = vec![];
3690        for _ in 0..10 {
3691            let server_addr = local_addr.clone();
3692            let tx = Arc::new(channel::dial::<String>(server_addr).unwrap());
3693            let tx2 = Arc::clone(&tx);
3694            txs.push(tx);
3695            tx_handles.push(tokio::spawn(async move {
3696                let random_string = rand::rng()
3697                    .sample_iter(&Alphanumeric)
3698                    .take(2048)
3699                    .map(char::from)
3700                    .collect::<String>();
3701                for _ in 0..total_num_msgs {
3702                    tx2.post(random_string.clone());
3703                }
3704            }));
3705        }
3706
3707        receive_handle.await.unwrap();
3708        for handle in tx_handles {
3709            handle.await.unwrap();
3710        }
3711    }
3712
3713    #[tracing_test::traced_test]
3714    #[async_timed_test(timeout_secs = 60)]
3715    // TODO: OSS: The logs_assert function returned an error: expected log not found
3716    #[cfg_attr(not(fbcode_build), ignore)]
3717    async fn test_net_tx_closed_on_server_reject() {
3718        let link = MockLink::<u64>::new();
3719        let receiver_storage = link.receiver_storage();
3720        let mut tx = spawn::<u64>(link);
3721        net_tx_send(&tx, &[100]).await;
3722
3723        {
3724            let (_reader, writer) = take_receiver(&receiver_storage).await;
3725            let _ = FrameWrite::write_frame(
3726                writer,
3727                serialize_response(NetRxResponse::Reject("testing".to_string())).unwrap(),
3728                1024,
3729                0,
3730            )
3731            .await
3732            .map_err(|(_, e)| e);
3733
3734            // Wait for response to be processed by NetTx before dropping reader/writer. Otherwise
3735            // the channel will be closed and we will get the wrong error.
3736            tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
3737        }
3738
3739        verify_tx_closed(&mut tx.status, "server rejected connection").await;
3740    }
3741
3742    #[async_timed_test(timeout_secs = 60)]
3743    async fn test_server_rejects_conn_on_out_of_sequence_message() {
3744        let config = hyperactor_config::global::lock();
3745        let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
3746        let session_id = SessionId(123);
3747        let (_handle, acceptor_tx, mut rx, _cancel_token) = serve_acceptor_test::<u64>(session_id);
3748
3749        let (sender, receiver) = tokio::io::duplex(5000);
3750        acceptor_tx.send(receiver).unwrap();
3751        let (r, writer) = tokio::io::split(sender);
3752        let mut reader = FrameReader::new(
3753            r,
3754            hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3755        );
3756
3757        let _ = write_stream(writer, 123, &[(0, 100u64), (1, 101u64), (3, 103u64)], true).await;
3758        assert_eq!(rx.recv().await, Some(100u64));
3759        assert_eq!(rx.recv().await, Some(101u64));
3760        let (_, bytes) = reader.next().await.unwrap().unwrap();
3761        let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
3762        assert_eq!(acked, 0);
3763        let (_, bytes) = reader.next().await.unwrap().unwrap();
3764        let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
3765        assert_eq!(acked, 1);
3766        let (_, bytes) = reader.next().await.unwrap().unwrap();
3767        assert!(deserialize_response(bytes).unwrap().is_reject());
3768    }
3769
3770    #[async_timed_test(timeout_secs = 60)]
3771    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" })
3772    #[cfg_attr(not(fbcode_build), ignore)]
3773    async fn test_stop_net_tx_after_stopping_net_rx() {
3774        hyperactor_telemetry::initialize_logging_for_test();
3775
3776        let config = hyperactor_config::global::lock();
3777        let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_mins(5));
3778        let (addr, mut rx) =
3779            server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap()), None).unwrap();
3780        let socket_addr = match addr {
3781            ChannelAddr::Tcp(a) => a,
3782            _ => panic!("unexpected channel type"),
3783        };
3784        let tx: NetTx<u64> = spawn(tcp::link(socket_addr, SessionId::random(), 0));
3785        // NetTx will not establish a connection until it sends the 1st message.
3786        // Without a live connection, NetTx cannot received the Closed message
3787        // from NetRx. Therefore, we need to send a message to establish the
3788        //connection.
3789        tx.send(100).await.unwrap();
3790        assert_eq!(rx.recv().await.unwrap(), 100);
3791        // Drop rx will close the NetRx server.
3792        rx.2.stop("testing");
3793        assert!(rx.recv().await.is_err());
3794
3795        // NetTx will only read from the stream when it needs to send a message
3796        // or wait for an ack. Therefore we need to send a message to trigger that.
3797        tx.post(101);
3798        let mut watcher = tx.status().clone();
3799        // When NetRx exits, it should notify NetTx to exit as well.
3800        let _ = watcher.wait_for(|val| val.is_closed()).await;
3801        // wait_for could return Err due to race between when watch's sender was
3802        // dropped and when wait_for was called. So we still need to do an
3803        // equality check.
3804        assert!(watcher.borrow().is_closed());
3805    }
3806
3807    /// Yields pre-built `DuplexStream`s to the accept loop and
3808    /// blocks once drained. Lets the `rx_join_flushes_pending_ack_*`
3809    /// tests inspect the wire from the other end.
3810    struct QueueListener {
3811        streams: std::collections::VecDeque<DuplexStream>,
3812        addr: ChannelAddr,
3813    }
3814
3815    #[async_trait]
3816    impl super::Listener for QueueListener {
3817        type Stream = DuplexStream;
3818
3819        async fn accept(&mut self) -> Result<(DuplexStream, ChannelAddr), ServerError> {
3820            match self.streams.pop_front() {
3821                Some(s) => Ok((s, self.addr.clone())),
3822                None => std::future::pending().await,
3823            }
3824        }
3825    }
3826
3827    /// In-memory connection: server end goes into the listener; the
3828    /// test reads from `client_r`.
3829    struct PreparedConnection {
3830        server_side: DuplexStream,
3831        // Kept alive so the server's recv-loop stays in its `select!`
3832        // on cancellation rather than exiting on EOF. Tests must
3833        // exercise the cancel-flush path.
3834        _client_w: tokio::io::WriteHalf<DuplexStream>,
3835        client_r: ReadHalf<DuplexStream>,
3836    }
3837
3838    /// Write `LinkInit` and the framed `Frame::Message(seq, value)`
3839    /// payloads on the client side; return both halves.
3840    async fn prepare_connection(
3841        session_id: SessionId,
3842        stream_id: u8,
3843        messages: &[(u64, u64)],
3844    ) -> PreparedConnection {
3845        let (client_side, server_side) = tokio::io::duplex(8192);
3846        let (client_r, mut client_w) = tokio::io::split(client_side);
3847
3848        super::write_link_init(&mut client_w, session_id, stream_id)
3849            .await
3850            .unwrap();
3851        let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
3852        for (seq, value) in messages {
3853            let payload =
3854                serde_multipart::serialize_bincode(&Frame::<u64>::Message(*seq, *value)).unwrap();
3855            let mut fw = FrameWrite::new(client_w, payload.framed(), max_len, 0)
3856                .map_err(|(_w, e)| e)
3857                .unwrap();
3858            fw.send().await.unwrap();
3859            client_w = fw.complete();
3860        }
3861
3862        PreparedConnection {
3863            server_side,
3864            _client_w: client_w,
3865            client_r,
3866        }
3867    }
3868
3869    /// Test plan for `run_separate_sessions_flush_test`: each entry
3870    /// describes one connection that lives on its own session.
3871    struct SeparateSessionPlan {
3872        session_id: SessionId,
3873        stream_id: u8,
3874        messages: Vec<(u64, u64)>,
3875    }
3876
3877    /// Drive the rx.join flush test across multiple connections,
3878    /// each on its own session. Verifies:
3879    ///
3880    /// 1. Every message sent reaches the application via `rx.recv()`.
3881    /// 2. After application delivery completes, *no* ack frames have
3882    ///    been emitted on any connection's read side — the policy
3883    ///    thresholds are out of reach.
3884    /// 3. After `rx.join()` returns, every connection has exactly
3885    ///    one `NetRxResponse::Ack(highest_seq)` frame on its read
3886    ///    side, followed by a `NetRxResponse::Closed` terminal frame.
3887    async fn run_separate_sessions_flush_test(
3888        plans: Vec<SeparateSessionPlan>,
3889        stream_id_label: &str,
3890    ) {
3891        let config = hyperactor_config::global::lock();
3892        let _g_msg = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1_000_000);
3893        let _g_time =
3894            config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(3600));
3895
3896        // Build all connections and stage them into a `QueueListener`.
3897        let mut conns: Vec<PreparedConnection> = Vec::with_capacity(plans.len());
3898        let mut expected_messages: std::collections::HashSet<u64> =
3899            std::collections::HashSet::new();
3900        let mut expected_acks: Vec<u64> = Vec::with_capacity(plans.len());
3901        for plan in &plans {
3902            for (_seq, value) in &plan.messages {
3903                expected_messages.insert(*value);
3904            }
3905            expected_acks.push(plan.messages.iter().map(|(s, _)| *s).max().unwrap());
3906            conns.push(prepare_connection(plan.session_id, plan.stream_id, &plan.messages).await);
3907        }
3908
3909        let addr = ChannelAddr::Local(u64::MAX);
3910        let listener = QueueListener {
3911            streams: conns
3912                .iter_mut()
3913                .map(|c| {
3914                    // Move server_side out of each PreparedConnection by replacing it with a placeholder.
3915                    std::mem::replace(&mut c.server_side, tokio::io::duplex(1).0)
3916                })
3917                .collect(),
3918            addr: addr.clone(),
3919        };
3920        let (_addr, mut rx) = super::server::serve_with_listener::<u64, _>(listener, addr).unwrap();
3921
3922        // Drain every expected message off the application channel.
3923        let mut received: std::collections::HashSet<u64> = std::collections::HashSet::new();
3924        for _ in 0..expected_messages.len() {
3925            received.insert(rx.recv().await.unwrap());
3926        }
3927        assert_eq!(
3928            received, expected_messages,
3929            "{stream_id_label}: every produced message should reach the application"
3930        );
3931
3932        // Give any policy-driven ack timer a generous grace period to
3933        // fire. Because `MESSAGE_ACK_EVERY_N_MESSAGES` and
3934        // `MESSAGE_ACK_TIME_INTERVAL` are out of reach, no ack should
3935        // be emitted yet — but if a regression makes them reachable
3936        // (or adds a new spontaneous emission path), this sleep gives
3937        // it room to do so before the assertion below.
3938        tokio::time::sleep(Duration::from_millis(100)).await;
3939
3940        let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
3941        let mut readers: Vec<FrameReader<ReadHalf<DuplexStream>>> = conns
3942            .into_iter()
3943            .map(|c| FrameReader::new(c.client_r, max_len))
3944            .collect();
3945        for (idx, reader) in readers.iter_mut().enumerate() {
3946            match tokio::time::timeout(Duration::from_millis(10), reader.next()).await {
3947                Err(_) => {} // timeout — no frame, expected.
3948                Ok(Err(e)) => panic!(
3949                    "{stream_id_label}: connection {idx} frame reader error before rx.join: {e}"
3950                ),
3951                Ok(Ok(None)) => {
3952                    panic!("{stream_id_label}: connection {idx} closed before rx.join()")
3953                }
3954                Ok(Ok(Some((_, bytes)))) => {
3955                    let resp = super::deserialize_response(bytes).unwrap();
3956                    panic!(
3957                        "{stream_id_label}: connection {idx} unexpectedly received {resp:?} \
3958                         before rx.join()"
3959                    );
3960                }
3961            }
3962        }
3963
3964        rx.join().await;
3965
3966        // After rx.join() returns, every connection must have its
3967        // terminal cleanup frames already written: an `Ack` covering
3968        // the highest seq it sent, then a `Closed`.
3969        for (idx, (reader, expected_ack)) in readers.iter_mut().zip(&expected_acks).enumerate() {
3970            let bytes = tokio::time::timeout(Duration::from_millis(50), reader.next())
3971                .await
3972                .unwrap_or_else(|_| {
3973                    panic!(
3974                        "{stream_id_label}: connection {idx} produced no Ack frame within 50ms \
3975                         after rx.join()"
3976                    )
3977                })
3978                .expect("frame reader error")
3979                .expect("frame reader returned None");
3980            let acked = super::deserialize_response(bytes.1)
3981                .unwrap()
3982                .into_ack()
3983                .unwrap_or_else(|other| {
3984                    panic!("{stream_id_label}: connection {idx} expected Ack, got {other:?}")
3985                });
3986            assert_eq!(
3987                acked, *expected_ack,
3988                "{stream_id_label}: connection {idx} ack mismatch"
3989            );
3990
3991            let bytes = tokio::time::timeout(Duration::from_millis(50), reader.next())
3992                .await
3993                .unwrap_or_else(|_| {
3994                    panic!(
3995                        "{stream_id_label}: connection {idx} produced no Closed frame within 50ms"
3996                    )
3997                })
3998                .expect("frame reader error")
3999                .expect("frame reader returned None");
4000            assert!(
4001                super::deserialize_response(bytes.1).unwrap().is_closed(),
4002                "{stream_id_label}: connection {idx} expected Closed terminal frame"
4003            );
4004        }
4005    }
4006
4007    #[async_timed_test(timeout_secs = 30)]
4008    async fn rx_join_flushes_pending_ack_single_stream() {
4009        // Three independent single-stream sessions, each with three
4010        // framed messages. Verifies every recv-loop's terminal flush
4011        // ran by the time `rx.join()` returns.
4012        let plans = (1u64..=3)
4013            .map(|sid| SeparateSessionPlan {
4014                session_id: SessionId(sid),
4015                stream_id: 0,
4016                messages: (0u64..3).map(|seq| (seq, sid * 100 + seq)).collect(),
4017            })
4018            .collect();
4019        run_separate_sessions_flush_test(plans, "single-stream").await;
4020    }
4021
4022    #[async_timed_test(timeout_secs = 30)]
4023    async fn rx_join_flushes_pending_ack_multi_stream() {
4024        // Three independent multi-stream sessions, each with three
4025        // framed messages, each on its own session_id with a single
4026        // stream_id of 1. Exercises `dispatch_multi_stream`'s terminal
4027        // cleanup once per session.
4028        let plans = (1u64..=3)
4029            .map(|sid| SeparateSessionPlan {
4030                session_id: SessionId(sid),
4031                stream_id: 1,
4032                messages: (0u64..3).map(|seq| (seq, sid * 100 + seq)).collect(),
4033            })
4034            .collect();
4035        run_separate_sessions_flush_test(plans, "multi-stream").await;
4036    }
4037
4038    /// One session, multiple stream_ids — exercises `AckWatermark`'s
4039    /// shared-watermark path. Three streams in the same session each
4040    /// send three messages with disjoint seqs filling the contiguous
4041    /// range 0..=8. Each stream's cleanup reads `highest_uncommitted`
4042    /// and emits `Ack(8)` on its own wire so the peer's per-wire
4043    /// NetTx sees an ack for messages it sent there; the receiver
4044    /// discards duplicates. Every stream also emits its own `Closed`.
4045    #[async_timed_test(timeout_secs = 30)]
4046    async fn rx_join_flushes_pending_ack_shared_multi_stream_session() {
4047        let config = hyperactor_config::global::lock();
4048        let _g_msg = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1_000_000);
4049        let _g_time =
4050            config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(3600));
4051
4052        let session_id = SessionId(99);
4053        let num_streams = 3u8;
4054        let msgs_per_stream = 3u64;
4055        let mut conns: Vec<PreparedConnection> = Vec::with_capacity(num_streams as usize);
4056        let mut expected_messages: std::collections::HashSet<u64> =
4057            std::collections::HashSet::new();
4058        for stream_id in 1..=num_streams {
4059            let messages: Vec<(u64, u64)> = (0u64..msgs_per_stream)
4060                .map(|i| {
4061                    let seq = (stream_id as u64 - 1) * msgs_per_stream + i;
4062                    (seq, 1000 + seq)
4063                })
4064                .collect();
4065            for (_, v) in &messages {
4066                expected_messages.insert(*v);
4067            }
4068            conns.push(prepare_connection(session_id, stream_id, &messages).await);
4069        }
4070        let highest_seq = num_streams as u64 * msgs_per_stream - 1;
4071
4072        let addr = ChannelAddr::Local(u64::MAX);
4073        let listener = QueueListener {
4074            streams: conns
4075                .iter_mut()
4076                .map(|c| std::mem::replace(&mut c.server_side, tokio::io::duplex(1).0))
4077                .collect(),
4078            addr: addr.clone(),
4079        };
4080        let (_addr, mut rx) = super::server::serve_with_listener::<u64, _>(listener, addr).unwrap();
4081
4082        let mut received: std::collections::HashSet<u64> = std::collections::HashSet::new();
4083        for _ in 0..expected_messages.len() {
4084            received.insert(rx.recv().await.unwrap());
4085        }
4086        assert_eq!(
4087            received, expected_messages,
4088            "shared-session multi-stream: every message reaches the application"
4089        );
4090
4091        tokio::time::sleep(Duration::from_millis(100)).await;
4092
4093        let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
4094        let mut readers: Vec<FrameReader<ReadHalf<DuplexStream>>> = conns
4095            .into_iter()
4096            .map(|c| FrameReader::new(c.client_r, max_len))
4097            .collect();
4098        for (idx, reader) in readers.iter_mut().enumerate() {
4099            match tokio::time::timeout(Duration::from_millis(10), reader.next()).await {
4100                Err(_) | Ok(Ok(None)) => {} // timeout / EOF — no early ack, good.
4101                Ok(Err(e)) => {
4102                    panic!("shared-session multi-stream: stream {idx} frame reader error: {e}")
4103                }
4104                Ok(Ok(Some((_, bytes)))) => {
4105                    let resp = super::deserialize_response(bytes).unwrap();
4106                    panic!(
4107                        "shared-session multi-stream: stream {idx} unexpectedly received \
4108                         {resp:?} before rx.join()"
4109                    );
4110                }
4111            }
4112        }
4113
4114        rx.join().await;
4115
4116        // Drain every frame on every reader. Terminal cleanup may emit
4117        // `Ack(highest_seq)` on one or more wires before `Closed`, but
4118        // once one cleanup commits the shared watermark, later streams
4119        // may emit only `Closed`.
4120        let mut ack_count = 0;
4121        let mut closed_count = 0;
4122        for (idx, reader) in readers.iter_mut().enumerate() {
4123            loop {
4124                match tokio::time::timeout(Duration::from_millis(50), reader.next()).await {
4125                    Err(_) => panic!(
4126                        "shared-session multi-stream: stream {idx} did not yield expected \
4127                         frames within 50ms after rx.join()"
4128                    ),
4129                    Ok(Err(e)) => panic!("frame reader error: {e}"),
4130                    Ok(Ok(None)) => break,
4131                    Ok(Ok(Some((_, bytes)))) => {
4132                        let resp = super::deserialize_response(bytes).unwrap();
4133                        match resp {
4134                            NetRxResponse::Ack(seq) => {
4135                                assert_eq!(
4136                                    seq, highest_seq,
4137                                    "shared-session multi-stream: ack should cover the full \
4138                                     contiguous range 0..={highest_seq}"
4139                                );
4140                                ack_count += 1;
4141                            }
4142                            NetRxResponse::Closed => {
4143                                closed_count += 1;
4144                                break;
4145                            }
4146                            other => panic!(
4147                                "shared-session multi-stream: stream {idx} unexpected {other:?}"
4148                            ),
4149                        }
4150                    }
4151                }
4152            }
4153        }
4154        assert!(
4155            ack_count >= 1,
4156            "shared-session multi-stream: expected at least one Ack({highest_seq}); \
4157             got {ack_count}"
4158        );
4159        assert!(
4160            ack_count <= num_streams as usize,
4161            "shared-session multi-stream: expected at most {num_streams} Ack({highest_seq}) \
4162             frames; got {ack_count}"
4163        );
4164        assert_eq!(
4165            closed_count, num_streams as usize,
4166            "shared-session multi-stream: every stream should emit its own Closed frame"
4167        );
4168    }
4169
4170    /// Duplex analog of `rx_join_flushes_pending_ack_single_stream`.
4171    /// Three independent duplex sessions, each with three framed
4172    /// messages. Verifies every `dispatch_duplex_stream`'s terminal
4173    /// flush ran by the time `DuplexServer::join()` returns —
4174    /// structured concurrency makes the listener task await every
4175    /// inline recv/send loop before resolving.
4176    #[async_timed_test(timeout_secs = 30)]
4177    async fn server_join_flushes_pending_ack_duplex_session() {
4178        let config = hyperactor_config::global::lock();
4179        let _g_msg = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1_000_000);
4180        let _g_time =
4181            config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(3600));
4182
4183        let session_count = 3u64;
4184        let msgs_per_session = 3u64;
4185        let mut conns: Vec<PreparedConnection> = Vec::with_capacity(session_count as usize);
4186        let mut expected_messages: std::collections::HashSet<u64> =
4187            std::collections::HashSet::new();
4188        let mut expected_acks: Vec<u64> = Vec::with_capacity(session_count as usize);
4189        for sid in 1..=session_count {
4190            let messages: Vec<(u64, u64)> = (0u64..msgs_per_session)
4191                .map(|seq| (seq, sid * 100 + seq))
4192                .collect();
4193            for (_, v) in &messages {
4194                expected_messages.insert(*v);
4195            }
4196            expected_acks.push(messages.iter().map(|(s, _)| *s).max().unwrap());
4197            conns.push(prepare_connection(SessionId(sid), 0, &messages).await);
4198        }
4199
4200        let addr = ChannelAddr::Local(u64::MAX);
4201        let listener = QueueListener {
4202            streams: conns
4203                .iter_mut()
4204                .map(|c| std::mem::replace(&mut c.server_side, tokio::io::duplex(1).0))
4205                .collect(),
4206            addr: addr.clone(),
4207        };
4208        let mut server = super::duplex::serve_with_listener::<u64, u64, _>(listener, addr).unwrap();
4209
4210        // Accept one (rx, tx) pair per session. Hold the tx halves
4211        // alive so the server's send-loop stays parked in `select!`
4212        // (it never sees an `AppClosed` terminal) — the test must
4213        // exercise the cancel-driven flush path, not an
4214        // app-disconnect one.
4215        let mut all_rx: Vec<super::duplex::DuplexRx<u64>> =
4216            Vec::with_capacity(session_count as usize);
4217        let mut all_tx: Vec<super::duplex::DuplexTx<u64>> =
4218            Vec::with_capacity(session_count as usize);
4219        for _ in 0..session_count {
4220            let (rx, tx) = server.accept().await.unwrap();
4221            all_rx.push(rx);
4222            all_tx.push(tx);
4223        }
4224
4225        // Drain the messages each session sent. Each session's
4226        // dispatch task delivers exactly its own `msgs_per_session`
4227        // values to its own `rx`, but the order in which sessions
4228        // are accepted is non-deterministic — collect into a set.
4229        let mut received: std::collections::HashSet<u64> = std::collections::HashSet::new();
4230        for rx in all_rx.iter_mut() {
4231            for _ in 0..msgs_per_session {
4232                received.insert(rx.recv().await.unwrap());
4233            }
4234        }
4235        assert_eq!(
4236            received, expected_messages,
4237            "duplex: every produced message should reach the application"
4238        );
4239
4240        // Policy thresholds are out of reach, so no ack should fire
4241        // spontaneously. The sleep gives any regression that adds a
4242        // new emission path room to surface.
4243        tokio::time::sleep(Duration::from_millis(100)).await;
4244
4245        let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
4246        let mut readers: Vec<FrameReader<ReadHalf<DuplexStream>>> = conns
4247            .into_iter()
4248            .map(|c| FrameReader::new(c.client_r, max_len))
4249            .collect();
4250        for (idx, reader) in readers.iter_mut().enumerate() {
4251            match tokio::time::timeout(Duration::from_millis(10), reader.next()).await {
4252                Err(_) => {} // timeout — no frame, expected.
4253                Ok(Err(e)) => {
4254                    panic!("duplex: connection {idx} frame reader error before join: {e}")
4255                }
4256                Ok(Ok(None)) => {
4257                    panic!("duplex: connection {idx} closed before server.join()")
4258                }
4259                Ok(Ok(Some((_, bytes)))) => {
4260                    let resp = super::deserialize_response(bytes).unwrap();
4261                    panic!(
4262                        "duplex: connection {idx} unexpectedly received {resp:?} \
4263                         before server.join()"
4264                    );
4265                }
4266            }
4267        }
4268
4269        // Trigger graceful shutdown. Holding `all_tx` / `all_rx`
4270        // alive keeps the send-loop parked and `inbound_tx` valid,
4271        // so the dispatch task only exits via the cancel branch of
4272        // its `select!` — exercising the structured-concurrency
4273        // flush-on-cancel path.
4274        server.join().await;
4275
4276        // After server.join() returns, every connection must have
4277        // its terminal cleanup frames already on the wire: an
4278        // `Ack(highest_seq)` covering the messages it sent, then a
4279        // `Closed`.
4280        for (idx, (reader, expected_ack)) in readers.iter_mut().zip(&expected_acks).enumerate() {
4281            let bytes = tokio::time::timeout(Duration::from_millis(50), reader.next())
4282                .await
4283                .unwrap_or_else(|_| {
4284                    panic!(
4285                        "duplex: connection {idx} produced no Ack frame within 50ms after \
4286                         server.join()"
4287                    )
4288                })
4289                .expect("frame reader error")
4290                .expect("frame reader returned None");
4291            let acked = super::deserialize_response(bytes.1)
4292                .unwrap()
4293                .into_ack()
4294                .unwrap_or_else(|other| {
4295                    panic!("duplex: connection {idx} expected Ack, got {other:?}")
4296                });
4297            assert_eq!(
4298                acked, *expected_ack,
4299                "duplex: connection {idx} ack mismatch"
4300            );
4301
4302            let bytes = tokio::time::timeout(Duration::from_millis(50), reader.next())
4303                .await
4304                .unwrap_or_else(|_| {
4305                    panic!("duplex: connection {idx} produced no Closed frame within 50ms")
4306                })
4307                .expect("frame reader error")
4308                .expect("frame reader returned None");
4309            assert!(
4310                super::deserialize_response(bytes.1).unwrap().is_closed(),
4311                "duplex: connection {idx} expected Closed terminal frame"
4312            );
4313        }
4314    }
4315
4316    /// Test-only [`Link`] that yields each pre-built `DuplexStream`
4317    /// once. Writes `LinkInit` on the stream before returning it so
4318    /// the test (which holds the other end) sees the same wire
4319    /// format a real `TcpLink` produces.
4320    struct DuplexDialMockLink {
4321        session_id: SessionId,
4322        streams: std::collections::VecDeque<DuplexStream>,
4323    }
4324
4325    impl fmt::Debug for DuplexDialMockLink {
4326        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4327            f.debug_struct("DuplexDialMockLink")
4328                .field("session_id", &self.session_id)
4329                .field("remaining_streams", &self.streams.len())
4330                .finish()
4331        }
4332    }
4333
4334    #[async_trait]
4335    impl super::Link for DuplexDialMockLink {
4336        type Stream = DuplexStream;
4337
4338        fn dest(&self) -> ChannelAddr {
4339            ChannelAddr::Local(u64::MAX)
4340        }
4341
4342        fn link_id(&self) -> SessionId {
4343            self.session_id
4344        }
4345
4346        async fn next(&mut self) -> Result<DuplexStream, ClientError> {
4347            match self.streams.pop_front() {
4348                Some(mut stream) => {
4349                    super::write_link_init(&mut stream, self.session_id, 0)
4350                        .await
4351                        .map_err(|err| ClientError::Io(self.dest(), err))?;
4352                    Ok(stream)
4353                }
4354                None => Err(ClientError::Connect(
4355                    self.dest(),
4356                    std::io::Error::other("mock link exhausted"),
4357                    "no more streams".into(),
4358                )),
4359            }
4360        }
4361    }
4362
4363    /// Acceptor-side counterpart of
4364    /// [`duplex_dial_flushes_pending_ack_on_app_closed`]. The
4365    /// application drops its `DuplexTx`, the dispatch task's
4366    /// `send_connected` returns `SendLoopError::AppClosed` —
4367    /// terminal — so `dispatch_duplex_stream`'s loop exits via the
4368    /// flush + `Closed` + break path. Verifies the cumulative ack
4369    /// and the terminal `Closed` are written on
4370    /// `INITIATOR_TO_ACCEPTOR` before the dispatch task ends.
4371    #[async_timed_test(timeout_secs = 30)]
4372    async fn duplex_serve_flushes_pending_ack_on_app_closed() {
4373        let config = hyperactor_config::global::lock();
4374        let _g_msg = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1_000_000);
4375        let _g_time =
4376            config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(3600));
4377
4378        let session_id = SessionId(1);
4379        let messages: Vec<(u64, u64)> = vec![(0, 100), (1, 200), (2, 300)];
4380        let expected_ack: u64 = 2;
4381        let conn = prepare_connection(session_id, 0, &messages).await;
4382
4383        let addr = ChannelAddr::Local(u64::MAX);
4384        let listener = QueueListener {
4385            streams: std::collections::VecDeque::from([conn.server_side]),
4386            addr: addr.clone(),
4387        };
4388
4389        let mut server = super::duplex::serve_with_listener::<u64, u64, _>(listener, addr).unwrap();
4390
4391        let (mut server_rx, server_tx) = server.accept().await.unwrap();
4392
4393        // Drain the messages the test wrote on `INITIATOR_TO_ACCEPTOR`.
4394        let mut received: Vec<u64> = Vec::with_capacity(messages.len());
4395        for _ in &messages {
4396            received.push(server_rx.recv().await.unwrap());
4397        }
4398        let expected_values: Vec<u64> = messages.iter().map(|(_, v)| *v).collect();
4399        assert_eq!(
4400            received, expected_values,
4401            "duplex serve: every message should reach the application"
4402        );
4403
4404        // Policy thresholds are unreachable, so no ack should fire
4405        // spontaneously.
4406        let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
4407        tokio::time::sleep(Duration::from_millis(100)).await;
4408        let mut reader = FrameReader::new(conn.client_r, max_len);
4409        match tokio::time::timeout(Duration::from_millis(10), reader.next()).await {
4410            Err(_) => {} // timeout — expected.
4411            Ok(Err(e)) => panic!("duplex serve: frame reader error before app close: {e}"),
4412            Ok(Ok(None)) => panic!("duplex serve: wire closed before app close"),
4413            Ok(Ok(Some((_, bytes)))) => {
4414                let resp = super::deserialize_response(bytes).unwrap();
4415                panic!("duplex serve: unexpectedly received {resp:?} before app close");
4416            }
4417        }
4418
4419        // Drop the application's `DuplexTx` to trigger `AppClosed`
4420        // in `send_connected` — terminal — so the dispatch task
4421        // exits. The flush logic must write the cumulative ack on
4422        // `INITIATOR_TO_ACCEPTOR` and the terminal `Closed` before
4423        // the loop breaks.
4424        drop(server_tx);
4425
4426        let bytes = tokio::time::timeout(Duration::from_millis(100), reader.next())
4427            .await
4428            .unwrap_or_else(|_| panic!("duplex serve: produced no Ack frame within 100ms"))
4429            .expect("frame reader error")
4430            .expect("frame reader returned None");
4431        let acked = super::deserialize_response(bytes.1)
4432            .unwrap()
4433            .into_ack()
4434            .unwrap_or_else(|other| panic!("duplex serve: expected Ack, got {other:?}"));
4435        assert_eq!(
4436            acked, expected_ack,
4437            "duplex serve: ack should cover the highest seq received"
4438        );
4439
4440        let bytes = tokio::time::timeout(Duration::from_millis(100), reader.next())
4441            .await
4442            .unwrap_or_else(|_| panic!("duplex serve: produced no Closed frame within 100ms"))
4443            .expect("frame reader error")
4444            .expect("frame reader returned None");
4445        assert!(
4446            super::deserialize_response(bytes.1).unwrap().is_closed(),
4447            "duplex serve: expected Closed terminal frame after Ack"
4448        );
4449
4450        drop(server_rx);
4451        // `server.join()` triggers the listener-task cancel so the
4452        // outer `accept_loop` returns. The dispatch task already
4453        // exited via `AppClosed` and was drained from the JoinSet,
4454        // so this just stops the listener.
4455        server.join().await;
4456    }
4457
4458    /// [`DuplexClient::join`] cancels the recv/send loop's
4459    /// cancellation token. The `select!`s observe cancel and the
4460    /// loop exits via the flush + `Closed` + break path. Verifies
4461    /// the cumulative ack and the terminal `Closed` are written on
4462    /// `ACCEPTOR_TO_INITIATOR` before the spawned task ends.
4463    #[async_timed_test(timeout_secs = 30)]
4464    async fn duplex_client_join_flushes_pending_ack() {
4465        let config = hyperactor_config::global::lock();
4466        let _g_msg = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1_000_000);
4467        let _g_time =
4468            config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(3600));
4469
4470        let session_id = SessionId(123);
4471        let (client_side, server_side) = tokio::io::duplex(8192);
4472        let (mut test_r, mut test_w) = tokio::io::split(server_side);
4473
4474        let link = DuplexDialMockLink {
4475            session_id,
4476            streams: std::collections::VecDeque::from([client_side]),
4477        };
4478
4479        let mut dial_client = super::duplex::spawn::<u64, u64>(link);
4480        let _dial_tx = dial_client.tx();
4481        let mut dial_rx = dial_client.take_rx().unwrap();
4482
4483        // Drain the LinkInit the dial-side wrote on connect.
4484        super::read_link_init(&mut test_r).await.unwrap();
4485
4486        // Test acts as the acceptor: write framed `Frame::Message`s
4487        // on `ACCEPTOR_TO_INITIATOR` so the dial-side recv-loop
4488        // reads them and forwards to the application via
4489        // `inbound_tx`.
4490        let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
4491        let messages: Vec<(u64, u64)> = vec![(0, 100), (1, 200), (2, 300)];
4492        let expected_ack: u64 = 2;
4493        for (seq, value) in &messages {
4494            let payload =
4495                serde_multipart::serialize_bincode(&Frame::<u64>::Message(*seq, *value)).unwrap();
4496            let mut fw = FrameWrite::new(
4497                test_w,
4498                payload.framed(),
4499                max_len,
4500                super::ACCEPTOR_TO_INITIATOR,
4501            )
4502            .map_err(|(_w, e)| e)
4503            .unwrap();
4504            fw.send().await.unwrap();
4505            test_w = fw.complete();
4506        }
4507
4508        // Drain the messages on the dial-side rx so the dial-side's
4509        // `recv_next.seq` advances past every seq.
4510        let mut received: Vec<u64> = Vec::with_capacity(messages.len());
4511        for _ in &messages {
4512            received.push(dial_rx.recv().await.unwrap());
4513        }
4514        let expected_values: Vec<u64> = messages.iter().map(|(_, v)| *v).collect();
4515        assert_eq!(
4516            received, expected_values,
4517            "dial: every message should reach the application"
4518        );
4519
4520        // Policy thresholds are unreachable, so no ack should fire
4521        // spontaneously.
4522        tokio::time::sleep(Duration::from_millis(100)).await;
4523        let mut reader = FrameReader::new(test_r, max_len);
4524        match tokio::time::timeout(Duration::from_millis(10), reader.next()).await {
4525            Err(_) => {} // timeout — expected.
4526            Ok(Err(e)) => panic!("dial: frame reader error before join: {e}"),
4527            Ok(Ok(None)) => panic!("dial: wire closed before join"),
4528            Ok(Ok(Some((_, bytes)))) => {
4529                let resp = super::deserialize_response(bytes).unwrap();
4530                panic!("dial: unexpectedly received {resp:?} before join");
4531            }
4532        }
4533
4534        // Trigger graceful shutdown via the structured-concurrency
4535        // join handle. Cancel propagates into the spawn loop's
4536        // `select!`s; the loop exits via Cancelled (terminal) and
4537        // the flush logic writes the cumulative ack and the
4538        // `Closed` terminal frame before the task ends.
4539        dial_client.join().await;
4540
4541        let bytes = tokio::time::timeout(Duration::from_millis(100), reader.next())
4542            .await
4543            .unwrap_or_else(|_| panic!("dial: produced no Ack frame within 100ms after join"))
4544            .expect("frame reader error")
4545            .expect("frame reader returned None");
4546        let acked = super::deserialize_response(bytes.1)
4547            .unwrap()
4548            .into_ack()
4549            .unwrap_or_else(|other| panic!("dial: expected Ack, got {other:?}"));
4550        assert_eq!(
4551            acked, expected_ack,
4552            "dial: ack should cover the highest seq received"
4553        );
4554
4555        // After the ack, the dial-side writes the terminal
4556        // `Closed` response on `ACCEPTOR_TO_INITIATOR` (mirrors
4557        // `dispatch_duplex_stream`), then releases the connection.
4558        let bytes = tokio::time::timeout(Duration::from_millis(100), reader.next())
4559            .await
4560            .unwrap_or_else(|_| panic!("dial: produced no Closed frame within 100ms after join"))
4561            .expect("frame reader error")
4562            .expect("frame reader returned None");
4563        assert!(
4564            super::deserialize_response(bytes.1).unwrap().is_closed(),
4565            "dial: expected Closed terminal frame after Ack"
4566        );
4567
4568        drop(dial_rx);
4569        drop(test_w);
4570    }
4571
4572    /// Verifies that an in-progress [`DuplexRx::recv`] on the
4573    /// receiver returned by [`DuplexClient::take_rx`] resolves with
4574    /// [`ChannelError::Closed`] when [`DuplexClient::join`] is
4575    /// called concurrently. Structured concurrency guarantees the
4576    /// spawned task drops `inbound_tx` before `join` returns, so
4577    /// the receiver observes the close deterministically.
4578    #[async_timed_test(timeout_secs = 30)]
4579    async fn duplex_client_join_terminates_in_progress_recv() {
4580        let session_id = SessionId(123);
4581        let (client_side, server_side) = tokio::io::duplex(8192);
4582        let (mut test_r, _test_w) = tokio::io::split(server_side);
4583
4584        let link = DuplexDialMockLink {
4585            session_id,
4586            streams: std::collections::VecDeque::from([client_side]),
4587        };
4588
4589        let mut dial_client = super::duplex::spawn::<u64, u64>(link);
4590        let mut dial_rx = dial_client.take_rx().unwrap();
4591
4592        // Drain the LinkInit so the dial-side has finished its
4593        // setup before we kick off the recv() under test.
4594        super::read_link_init(&mut test_r).await.unwrap();
4595
4596        // Park a recv() in a separate task; the dial-side hasn't
4597        // forwarded any inbound frames so this future will sit in
4598        // `inbound_rx.recv().await` indefinitely until `join`
4599        // drops the spawned task's `inbound_tx`.
4600        let recv_handle: tokio::task::JoinHandle<Result<u64, ChannelError>> =
4601            tokio::spawn(async move { dial_rx.recv().await });
4602
4603        // `join` cancels the spawned task; on exit the task drops
4604        // its `inbound_tx`, which closes `inbound_rx` and resolves
4605        // the recv() with `ChannelError::Closed`.
4606        dial_client.join().await;
4607
4608        let result = tokio::time::timeout(Duration::from_millis(100), recv_handle)
4609            .await
4610            .expect("parked recv should resolve within 100ms after join")
4611            .expect("recv task should not panic");
4612        assert!(
4613            matches!(result, Err(ChannelError::Closed)),
4614            "in-progress recv should resolve with ChannelError::Closed after join, got {result:?}"
4615        );
4616    }
4617}