Skip to main content

hyperactor/
channel.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! One-way, multi-process, typed communication channels. These are used
10//! to send messages between mailboxes residing in different processes.
11
12use core::net::SocketAddr;
13use std::fmt;
14use std::net::IpAddr;
15use std::net::Ipv6Addr;
16#[cfg(target_os = "linux")]
17use std::os::linux::net::SocketAddrExt;
18use std::os::unix::io::FromRawFd;
19use std::os::unix::io::RawFd;
20use std::panic::Location;
21use std::str::FromStr;
22use std::sync::Arc;
23
24use async_trait::async_trait;
25use enum_as_inner::EnumAsInner;
26use hyperactor_config::attrs::AttrValue;
27use serde::Deserialize;
28use serde::Serialize;
29use tokio::sync::mpsc;
30use tokio::sync::oneshot;
31use tokio::sync::watch;
32
33use crate as hyperactor;
34use crate::RemoteMessage;
35pub(crate) mod local;
36pub(crate) mod net;
37
38// Public TLS API for HTTP services (mesh admin, TUI, etc.). The
39// implementation lives in `net` but we re-export here to keep `net`'s
40// internal types out of the public API surface.
41pub use net::ServerError;
42pub use net::try_tls_acceptor;
43pub use net::try_tls_connector;
44pub use net::try_tls_pem_bundle;
45
46/// Duplex channel API: a single connection carries messages in both directions.
47pub mod duplex {
48    pub use super::net::duplex::DuplexClient;
49    pub use super::net::duplex::DuplexRx;
50    pub use super::net::duplex::DuplexServer;
51    pub use super::net::duplex::DuplexTx;
52    pub use super::net::duplex::dial;
53    pub use super::net::duplex::serve;
54}
55
56/// The type of error that can occur on channel operations.
57#[derive(thiserror::Error, Debug)]
58pub enum ChannelError {
59    /// An operation was attempted on a closed channel.
60    #[error("channel closed")]
61    Closed,
62
63    /// An error occurred during send.
64    #[error("send: {0}")]
65    Send(#[source] anyhow::Error),
66
67    /// A network client error.
68    #[error(transparent)]
69    Client(#[from] net::ClientError),
70
71    /// The address was not valid.
72    #[error("invalid address {0:?}")]
73    InvalidAddress(String),
74
75    /// A serving error was encountered.
76    #[error(transparent)]
77    Server(#[from] net::ServerError),
78
79    /// A bincode encoding error occurred.
80    #[error(transparent)]
81    BincodeEncode(#[from] bincode::error::EncodeError),
82
83    /// A bincode decoding error occurred.
84    #[error(transparent)]
85    BincodeDecode(#[from] bincode::error::DecodeError),
86
87    /// Data encoding errors.
88    #[error(transparent)]
89    Data(#[from] wirevalue::Error),
90
91    /// Some other error.
92    #[error(transparent)]
93    Other(#[from] anyhow::Error),
94
95    /// An operation timeout occurred.
96    #[error("operation timed out after {0:?}")]
97    Timeout(std::time::Duration),
98}
99
100/// An error that occurred during send. Returns the message that failed to send.
101#[derive(thiserror::Error, Debug)]
102#[error("{error} for reason {reason:?}")]
103pub struct SendError<M: RemoteMessage> {
104    /// Inner channel error
105    #[source]
106    pub error: ChannelError,
107    /// Message that couldn't be sent
108    pub message: M,
109    /// Reason that message couldn't be sent, if any.
110    pub reason: Option<String>,
111}
112
113impl<M: RemoteMessage> From<SendError<M>> for ChannelError {
114    fn from(error: SendError<M>) -> Self {
115        error.error
116    }
117}
118
119/// The possible states of a `Tx`.
120#[derive(Debug, Clone, PartialEq, EnumAsInner)]
121pub enum TxStatus {
122    /// The tx is good.
123    Active,
124    /// The tx cannot be used for message delivery.
125    Closed(Arc<str>),
126}
127
128/// The transmit end of an M-typed channel.
129#[async_trait]
130pub trait Tx<M: RemoteMessage> {
131    /// Post a message; returning failed deliveries on the return channel, if provided.
132    /// If provided, the sender is dropped when the message has been
133    /// enqueued at the channel endpoint.
134    ///
135    /// Users should use the `try_post`, and `post` variants directly.
136    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>);
137
138    /// Enqueue a `message` on the local end of the channel. The
139    /// message is either delivered, or we eventually discover that
140    /// the channel has failed and it will be sent back on `return_channel`.
141    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SendError`.
142    #[tracing::instrument(level = "debug", skip_all)]
143    fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
144        self.do_post(message, Some(return_channel));
145    }
146
147    /// Enqueue a message to be sent on the channel.
148    #[hyperactor::instrument_infallible]
149    fn post(&self, message: M) {
150        self.do_post(message, None);
151    }
152
153    /// Send a message synchronously, returning when the message has
154    /// been delivered to the remote end of the channel.
155    async fn send(&self, message: M) -> Result<(), SendError<M>> {
156        let (tx, rx) = oneshot::channel();
157        self.try_post(message, tx);
158        match rx.await {
159            // Channel was closed; the message was not delivered.
160            Ok(err) => Err(err),
161
162            // Channel was dropped; the message was successfully enqueued
163            // on the remote end of the channel.
164            Err(_) => Ok(()),
165        }
166    }
167
168    /// The channel address to which this Tx is sending.
169    fn addr(&self) -> ChannelAddr;
170
171    /// A means to monitor the health of a `Tx`.
172    fn status(&self) -> &watch::Receiver<TxStatus>;
173}
174
175/// The receive end of an M-typed channel.
176#[async_trait]
177pub trait Rx<M: RemoteMessage> {
178    /// Receive the next message from the channel. If the channel returns
179    /// an error it is considered broken and should be discarded.
180    async fn recv(&mut self) -> Result<M, ChannelError>;
181
182    /// The channel address from which this Rx is receiving.
183    fn addr(&self) -> ChannelAddr;
184
185    /// Gracefully shut down the channel receiver, flushing any pending
186    /// acks before returning. Implementations must ensure all pending
187    /// acks are sent before this method returns.
188    async fn join(self)
189    where
190        Self: Sized;
191}
192
193#[allow(dead_code)] // Not used outside tests.
194struct MpscTx<M: RemoteMessage> {
195    tx: mpsc::UnboundedSender<M>,
196    addr: ChannelAddr,
197    status: watch::Receiver<TxStatus>,
198}
199
200impl<M: RemoteMessage> MpscTx<M> {
201    #[allow(dead_code)] // Not used outside tests.
202    pub fn new(tx: mpsc::UnboundedSender<M>, addr: ChannelAddr) -> (Self, watch::Sender<TxStatus>) {
203        let (sender, receiver) = watch::channel(TxStatus::Active);
204        (
205            Self {
206                tx,
207                addr,
208                status: receiver,
209            },
210            sender,
211        )
212    }
213}
214
215#[async_trait]
216impl<M: RemoteMessage> Tx<M> for MpscTx<M> {
217    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
218        if let Err(mpsc::error::SendError(message)) = self.tx.send(message)
219            && let Some(return_channel) = return_channel
220        {
221            return_channel
222                .send(SendError {
223                    error: ChannelError::Closed,
224                    message,
225                    reason: None,
226                })
227                .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
228        }
229    }
230
231    fn addr(&self) -> ChannelAddr {
232        self.addr.clone()
233    }
234
235    fn status(&self) -> &watch::Receiver<TxStatus> {
236        &self.status
237    }
238}
239
240#[allow(dead_code)] // Not used outside tests.
241struct MpscRx<M: RemoteMessage> {
242    rx: mpsc::UnboundedReceiver<M>,
243    addr: ChannelAddr,
244    // Used to report the status to the Tx side.
245    status_sender: watch::Sender<TxStatus>,
246}
247
248impl<M: RemoteMessage> MpscRx<M> {
249    #[allow(dead_code)] // Not used outside tests.
250    pub fn new(
251        rx: mpsc::UnboundedReceiver<M>,
252        addr: ChannelAddr,
253        status_sender: watch::Sender<TxStatus>,
254    ) -> Self {
255        Self {
256            rx,
257            addr,
258            status_sender,
259        }
260    }
261}
262
263impl<M: RemoteMessage> Drop for MpscRx<M> {
264    fn drop(&mut self) {
265        let _ = self
266            .status_sender
267            .send(TxStatus::Closed("receiver dropped".into()));
268    }
269}
270
271#[async_trait]
272impl<M: RemoteMessage> Rx<M> for MpscRx<M> {
273    async fn recv(&mut self) -> Result<M, ChannelError> {
274        self.rx.recv().await.ok_or(ChannelError::Closed)
275    }
276
277    fn addr(&self) -> ChannelAddr {
278        self.addr.clone()
279    }
280
281    async fn join(self) {}
282}
283
284/// The hostname to use for TLS connections.
285#[derive(
286    Clone,
287    Debug,
288    PartialEq,
289    Eq,
290    Hash,
291    Serialize,
292    Deserialize,
293    strum::EnumIter,
294    strum::Display,
295    strum::EnumString
296)]
297pub enum TcpMode {
298    /// Use localhost/loopback for the connection.
299    Localhost,
300    /// Use host domain name for the connection.
301    Hostname,
302}
303
304/// The hostname to use for TLS connections.
305#[derive(
306    Clone,
307    Debug,
308    PartialEq,
309    Eq,
310    Hash,
311    Serialize,
312    Deserialize,
313    strum::EnumIter,
314    strum::Display,
315    strum::EnumString
316)]
317pub enum TlsMode {
318    /// Use IpV6 address for TLS connections.
319    IpV6,
320    /// Use host domain name for TLS connections.
321    Hostname,
322    // TODO: consider adding IpV4 support.
323}
324
325/// Address format for TLS channels.
326#[derive(
327    Clone,
328    Debug,
329    PartialEq,
330    Eq,
331    Hash,
332    Serialize,
333    Deserialize,
334    Ord,
335    PartialOrd
336)]
337pub struct TlsAddr {
338    /// The hostname to connect to.
339    pub hostname: Hostname,
340    /// The port to connect to.
341    pub port: Port,
342}
343
344impl TlsAddr {
345    /// Creates a new TLS address with a normalized hostname.
346    pub fn new(hostname: impl Into<Hostname>, port: Port) -> Self {
347        Self {
348            hostname: normalize_host(&hostname.into()),
349            port,
350        }
351    }
352
353    /// Returns the port number for this address.
354    pub fn port(&self) -> Port {
355        self.port
356    }
357
358    /// Returns the hostname for this address.
359    pub fn hostname(&self) -> &str {
360        &self.hostname
361    }
362}
363
364impl fmt::Display for TlsAddr {
365    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
366        write!(f, "{}:{}", self.hostname, self.port)
367    }
368}
369
370/// Types of channel transports.
371#[derive(
372    Clone,
373    Debug,
374    PartialEq,
375    Eq,
376    Hash,
377    Serialize,
378    Deserialize,
379    typeuri::Named
380)]
381pub enum ChannelTransport {
382    /// Transport over a TCP connection.
383    Tcp(TcpMode),
384
385    /// Transport over a TCP connection with TLS support within Meta
386    MetaTls(TlsMode),
387
388    /// Transport over a TCP connection with configurable TLS support
389    Tls,
390
391    /// Local transports uses an in-process registry and mpsc channels.
392    Local,
393
394    /// Transport over unix domain socket.
395    Unix,
396}
397
398impl fmt::Display for ChannelTransport {
399    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
400        match self {
401            Self::Tcp(mode) => write!(f, "tcp({:?})", mode),
402            Self::MetaTls(mode) => write!(f, "metatls({:?})", mode),
403            Self::Tls => write!(f, "tls"),
404            Self::Local => write!(f, "local"),
405            Self::Unix => write!(f, "unix"),
406        }
407    }
408}
409
410impl FromStr for ChannelTransport {
411    type Err = anyhow::Error;
412
413    fn from_str(s: &str) -> Result<Self, Self::Err> {
414        match s {
415            // Default to TcpMode::Hostname, if the mode isn't set
416            "tcp" => Ok(ChannelTransport::Tcp(TcpMode::Hostname)),
417            s if s.starts_with("tcp(") => {
418                let inner = &s["tcp(".len()..s.len() - 1];
419                let mode = inner.parse()?;
420                Ok(ChannelTransport::Tcp(mode))
421            }
422            "local" => Ok(ChannelTransport::Local),
423            "unix" => Ok(ChannelTransport::Unix),
424            "tls" => Ok(ChannelTransport::Tls),
425            s if s.starts_with("metatls(") && s.ends_with(")") => {
426                let inner = &s["metatls(".len()..s.len() - 1];
427                let mode = inner.parse()?;
428                Ok(ChannelTransport::MetaTls(mode))
429            }
430            unknown => Err(anyhow::anyhow!("unknown channel transport: {}", unknown)),
431        }
432    }
433}
434
435impl ChannelTransport {
436    /// All known channel transports.
437    pub fn all() -> [ChannelTransport; 3] {
438        [
439            // TODO: @rusch add back once figuring out unspecified override for OSS CI
440            // ChannelTransport::Tcp(TcpMode::Localhost),
441            ChannelTransport::Tcp(TcpMode::Hostname),
442            ChannelTransport::Local,
443            ChannelTransport::Unix,
444            // Tls requires certificate configuration, tested separately in tls::tests
445            // TODO add MetaTls (T208303369)
446        ]
447    }
448
449    /// Return an "any" address for this transport.
450    pub fn any(&self) -> ChannelAddr {
451        ChannelAddr::any(self.clone())
452    }
453
454    /// Returns true if this transport type represents a remote channel.
455    pub fn is_remote(&self) -> bool {
456        match self {
457            ChannelTransport::Tcp(_) => true,
458            ChannelTransport::MetaTls(_) => true,
459            ChannelTransport::Tls => true,
460            ChannelTransport::Local => false,
461            ChannelTransport::Unix => false,
462        }
463    }
464
465    /// Returns true if this transport can carry the duplex byte-stream
466    /// protocol (see [`crate::channel::net::duplex`]). In-process
467    /// transports cannot carry a duplex wire protocol and must fall
468    /// back to a simplex channel.
469    pub fn supports_duplex(&self) -> bool {
470        match self {
471            ChannelTransport::Tcp(_) => true,
472            ChannelTransport::MetaTls(_) => true,
473            ChannelTransport::Tls => true,
474            ChannelTransport::Unix => true,
475            ChannelTransport::Local => false,
476        }
477    }
478}
479
480impl AttrValue for ChannelTransport {
481    fn display(&self) -> String {
482        self.to_string()
483    }
484
485    fn parse(s: &str) -> Result<Self, anyhow::Error> {
486        s.parse()
487    }
488}
489
490/// Specifies how to bind a channel server.
491#[derive(
492    Clone,
493    Debug,
494    PartialEq,
495    Eq,
496    Hash,
497    Serialize,
498    Deserialize,
499    typeuri::Named
500)]
501pub enum BindSpec {
502    /// Bind to any available address for the given transport.
503    Any(ChannelTransport),
504
505    /// Bind to a specific channel address.
506    Addr(ChannelAddr),
507}
508
509impl BindSpec {
510    /// Return an "any" address for this bind spec.
511    pub fn binding_addr(&self) -> ChannelAddr {
512        match self {
513            BindSpec::Any(transport) => ChannelAddr::any(transport.clone()),
514            BindSpec::Addr(addr) => addr.clone(),
515        }
516    }
517}
518
519impl From<ChannelTransport> for BindSpec {
520    fn from(transport: ChannelTransport) -> Self {
521        BindSpec::Any(transport)
522    }
523}
524
525impl From<ChannelAddr> for BindSpec {
526    fn from(addr: ChannelAddr) -> Self {
527        BindSpec::Addr(addr)
528    }
529}
530
531impl fmt::Display for BindSpec {
532    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
533        match self {
534            Self::Any(transport) => write!(f, "{}", transport),
535            Self::Addr(addr) => write!(f, "{}", addr),
536        }
537    }
538}
539
540impl FromStr for BindSpec {
541    type Err = anyhow::Error;
542
543    fn from_str(s: &str) -> Result<Self, Self::Err> {
544        if let Ok(transport) = ChannelTransport::from_str(s) {
545            Ok(BindSpec::Any(transport))
546        } else if let Ok(addr) = ChannelAddr::from_zmq_url(s) {
547            Ok(BindSpec::Addr(addr))
548        } else if let Ok(addr) = ChannelAddr::from_str(s) {
549            Ok(BindSpec::Addr(addr))
550        } else {
551            Err(anyhow::anyhow!("invalid bind spec: {}", s))
552        }
553    }
554}
555
556impl AttrValue for BindSpec {
557    fn display(&self) -> String {
558        self.to_string()
559    }
560
561    fn parse(s: &str) -> Result<Self, anyhow::Error> {
562        Self::from_str(s)
563    }
564}
565
566/// The type of (TCP) hostnames.
567pub type Hostname = String;
568
569/// The type of (TCP) ports.
570pub type Port = u16;
571
572/// The type of a channel address, used to multiplex different underlying
573/// channel implementations. ChannelAddrs also have a concrete syntax:
574/// the address type (e.g., "tcp" or "local"), followed by ":", and an address
575/// parseable to that type. For example:
576///
577/// - `tcp:127.0.0.1:1234` - localhost port 1234 over TCP
578/// - `tcp:192.168.0.1:1111` - 192.168.0.1 port 1111 over TCP
579/// - `local:123` - the (in-process) local port 123
580/// - `unix:/some/path` - the Unix socket at `/some/path`
581///
582/// Both local and TCP ports 0 are reserved to indicate "any available
583/// port" when serving.
584///
585/// ```
586/// # use hyperactor::channel::ChannelAddr;
587/// let addr: ChannelAddr = "tcp:127.0.0.1:1234".parse().unwrap();
588/// let ChannelAddr::Tcp(socket_addr) = addr else {
589///     panic!()
590/// };
591/// assert_eq!(socket_addr.port(), 1234);
592/// assert_eq!(socket_addr.is_ipv4(), true);
593/// ```
594#[derive(
595    Clone,
596    Debug,
597    PartialEq,
598    Eq,
599    Ord,
600    PartialOrd,
601    Serialize,
602    Deserialize,
603    Hash,
604    typeuri::Named
605)]
606pub enum ChannelAddr {
607    /// A socket address used to establish TCP channels. Supports
608    /// both  IPv4 and IPv6 address / port pairs.
609    Tcp(SocketAddr),
610
611    /// An address to establish TCP channels with TLS support within Meta.
612    /// Uses TlsAddr with hostname and port.
613    MetaTls(TlsAddr),
614
615    /// An address to establish TCP channels with configurable TLS support.
616    /// Uses TlsAddr with hostname and port.
617    Tls(TlsAddr),
618
619    /// Local addresses are registered in-process and given an integral
620    /// index.
621    Local(u64),
622
623    /// A unix domain socket address. Supports both absolute path names as
624    ///  well as "abstract" names per https://manpages.debian.org/unstable/manpages/unix.7.en.html#Abstract_sockets
625    Unix(net::unix::SocketAddr),
626
627    /// A pair of addresses, one for the client and one for the server:
628    ///   - The client should dial to the `dial_to` address.
629    ///   - The server should bind to the `bind_to` address.
630    ///
631    /// The user is responsible for ensuring the traffic to the `dial_to` address
632    /// is routed to the `bind_to` address.
633    ///
634    /// This is useful for scenarios where the network is configured in a way,
635    /// that the bound address is not directly accessible from the client.
636    ///
637    /// For example, in AWS, the client could be provided with the public IP
638    /// address, yet the server is bound to a private IP address or simply
639    /// INADDR_ANY. Traffic to the public IP address is mapped to the private
640    /// IP address through network address translation (NAT).
641    Alias {
642        /// The address to which the client should dial to.
643        dial_to: Box<ChannelAddr>,
644        /// The address to which the server should bind to.
645        bind_to: Box<ChannelAddr>,
646    },
647}
648
649impl From<SocketAddr> for ChannelAddr {
650    fn from(value: SocketAddr) -> Self {
651        Self::Tcp(value)
652    }
653}
654
655impl From<net::unix::SocketAddr> for ChannelAddr {
656    fn from(value: net::unix::SocketAddr) -> Self {
657        Self::Unix(value)
658    }
659}
660
661impl From<std::os::unix::net::SocketAddr> for ChannelAddr {
662    fn from(value: std::os::unix::net::SocketAddr) -> Self {
663        Self::Unix(net::unix::SocketAddr::new(value))
664    }
665}
666
667impl From<tokio::net::unix::SocketAddr> for ChannelAddr {
668    fn from(value: tokio::net::unix::SocketAddr) -> Self {
669        std::os::unix::net::SocketAddr::from(value).into()
670    }
671}
672
673/// Return the first non-link-local address from a list.
674fn find_routable_address(addresses: &[IpAddr]) -> Option<IpAddr> {
675    addresses
676        .iter()
677        .find(|addr| match addr {
678            IpAddr::V6(v6) => !v6.is_unicast_link_local(),
679            IpAddr::V4(v4) => !v4.is_link_local(),
680        })
681        .cloned()
682}
683
684impl ChannelAddr {
685    /// The "any" address for the given transport type. This is used to
686    /// servers to "any" address.
687    pub fn any(transport: ChannelTransport) -> Self {
688        match transport {
689            ChannelTransport::Tcp(mode) => {
690                let ip = match mode {
691                    TcpMode::Localhost => IpAddr::V6(Ipv6Addr::LOCALHOST),
692                    TcpMode::Hostname => {
693                        hostname::get()
694                            .ok()
695                            .and_then(|hostname| {
696                                // TODO: Avoid using DNS directly once we figure out a good extensibility story here
697                                hostname.to_str().and_then(|hostname_str| {
698                                    dns_lookup::lookup_host(hostname_str)
699                                        .ok()
700                                        .and_then(|addresses| find_routable_address(&addresses))
701                                })
702                            })
703                            .unwrap_or(IpAddr::V6(Ipv6Addr::LOCALHOST))
704                    }
705                };
706                Self::Tcp(SocketAddr::new(ip, 0))
707            }
708            ChannelTransport::MetaTls(mode) => {
709                let host_address = match mode {
710                    TlsMode::Hostname => hostname::get()
711                        .ok()
712                        .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
713                        .unwrap_or("unknown_host".to_string()),
714                    TlsMode::IpV6 => {
715                        get_host_ipv6_address().expect("failed to retrieve ipv6 address")
716                    }
717                };
718                Self::MetaTls(TlsAddr::new(host_address, 0))
719            }
720            ChannelTransport::Local => Self::Local(0),
721            ChannelTransport::Tls => {
722                let host_address = hostname::get()
723                    .ok()
724                    .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
725                    .unwrap_or("localhost".to_string());
726                Self::Tls(TlsAddr::new(host_address, 0))
727            }
728            // This works because the file will be deleted but we know we have a unique file by this point.
729            ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()),
730        }
731    }
732
733    /// The transport used by this address.
734    pub fn transport(&self) -> ChannelTransport {
735        match self {
736            Self::Tcp(addr) => {
737                if addr.ip().is_loopback() {
738                    ChannelTransport::Tcp(TcpMode::Localhost)
739                } else {
740                    ChannelTransport::Tcp(TcpMode::Hostname)
741                }
742            }
743            Self::MetaTls(addr) => match addr.hostname.parse::<IpAddr>() {
744                Ok(IpAddr::V6(_)) => ChannelTransport::MetaTls(TlsMode::IpV6),
745                Ok(IpAddr::V4(_)) => ChannelTransport::MetaTls(TlsMode::Hostname),
746                Err(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
747            },
748            Self::Tls(_) => ChannelTransport::Tls,
749            Self::Local(_) => ChannelTransport::Local,
750            Self::Unix(_) => ChannelTransport::Unix,
751            // bind_to's transport is what is actually used in communication.
752            // Therefore we use its transport to represent the Alias.
753            Self::Alias { bind_to, .. } => bind_to.transport(),
754        }
755    }
756}
757
758#[cfg(fbcode_build)]
759fn get_host_ipv6_address() -> anyhow::Result<String> {
760    crate::meta::host_ip::host_ipv6_address()
761}
762
763#[cfg(not(fbcode_build))]
764fn get_host_ipv6_address() -> anyhow::Result<String> {
765    Ok(local_ip_address::local_ipv6()?.to_string())
766}
767
768impl fmt::Display for ChannelAddr {
769    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
770        match self {
771            Self::Tcp(addr) => write!(f, "tcp:{}", addr),
772            Self::MetaTls(addr) => write!(f, "metatls:{}", addr),
773            Self::Tls(addr) => write!(f, "tls:{}", addr),
774            Self::Local(index) => write!(f, "local:{}", index),
775            Self::Unix(addr) => write!(f, "unix:{}", addr),
776            Self::Alias { dial_to, bind_to } => {
777                write!(f, "alias:dial_to={};bind_to={}", dial_to, bind_to)
778            }
779        }
780    }
781}
782
783impl FromStr for ChannelAddr {
784    type Err = anyhow::Error;
785
786    fn from_str(addr: &str) -> Result<Self, Self::Err> {
787        match addr.split_once('!').or_else(|| addr.split_once(':')) {
788            Some(("local", rest)) => rest
789                .parse::<u64>()
790                .map(Self::Local)
791                .map_err(anyhow::Error::from),
792            Some(("tcp", rest)) => rest
793                .parse::<SocketAddr>()
794                .map(Self::Tcp)
795                .map_err(anyhow::Error::from),
796            Some(("metatls", rest)) => net::meta::parse(rest).map_err(|e| e.into()),
797            Some(("tls", rest)) => net::tls::parse(rest).map_err(|e| e.into()),
798            Some(("unix", rest)) => Ok(Self::Unix(net::unix::SocketAddr::from_str(rest)?)),
799            Some(("alias", _)) => Err(anyhow::anyhow!(
800                "detect possible alias address, but we currently do not support \
801                parsing alias' string representation since we only want to \
802                support parsing its zmq url format."
803            )),
804            Some((r#type, _)) => Err(anyhow::anyhow!("no such channel type: {type}")),
805            None => Err(anyhow::anyhow!("no channel type specified")),
806        }
807    }
808}
809
810/// Normalize a host string. If the host is an IP address, parse and
811/// re-format it to produce a canonical string representation.
812pub(crate) fn normalize_host(host: &str) -> String {
813    // Strip URI-style brackets (e.g., "[::1]") because IpAddr::from_str
814    // rejects them — it only accepts bare addresses.
815    let host_clean = host
816        .strip_prefix('[')
817        .and_then(|h| h.strip_suffix(']'))
818        .unwrap_or(host);
819
820    if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
821        ip_addr.to_string()
822    } else {
823        host.to_string()
824    }
825}
826
827impl ChannelAddr {
828    /// Parse ZMQ-style URL format: scheme://address
829    /// Supports:
830    /// - tcp://hostname:port or tcp://*:port (wildcard binding)
831    /// - inproc://endpoint-name (equivalent to local)
832    /// - ipc://path (equivalent to unix)
833    /// - metatls://hostname:port or metatls://*:port
834    /// - Alias format: dial_to_url@bind_to_url (e.g., tcp://host:port@tcp://host:port)
835    ///   Note: Alias format is currently only supported for TCP addresses
836    pub fn from_zmq_url(address: &str) -> Result<Self, anyhow::Error> {
837        let (addr, _listener) = Self::from_zmq_url_with_listener(address)?;
838        Ok(addr)
839    }
840
841    /// Parse ZMQ-style URL format, with support for pre-opened file descriptors.
842    ///
843    /// When the port portion of a URL is `fdNNN` (e.g. `tcp://myhost:fd5`),
844    /// the file descriptor is adopted as a pre-bound `TcpListener`. The
845    /// returned `ChannelAddr` will contain the real port that the fd is bound
846    /// to, and the `Option<TcpListener>` will be `Some`.
847    ///
848    /// # Safety
849    /// When using the `fd` syntax, the caller must ensure the file descriptor
850    /// is a valid, bound TCP socket that is not used elsewhere. The socket
851    /// does not need to be in a listening state — `listen()` will be called
852    /// automatically.
853    pub fn from_zmq_url_with_listener(
854        address: &str,
855    ) -> Result<(Self, Option<std::net::TcpListener>), anyhow::Error> {
856        // Check for Alias format: dial_to_url@bind_to_url
857        // The @ character separates two valid ZMQ URLs.
858        if let Some(at_pos) = address
859            .find('@')
860            .filter(|&pos| address[..pos].starts_with("tcp://"))
861        {
862            let dial_to_str = &address[..at_pos];
863            let bind_to_str = &address[at_pos + 1..];
864
865            // Validate that both addresses use TCP scheme
866            if !dial_to_str.starts_with("tcp://") {
867                return Err(anyhow::anyhow!(
868                    "alias format is only supported for TCP addresses, got dial_to: {}",
869                    dial_to_str
870                ));
871            }
872            if !bind_to_str.starts_with("tcp://") {
873                return Err(anyhow::anyhow!(
874                    "alias format is only supported for TCP addresses, got bind_to: {}",
875                    bind_to_str
876                ));
877            }
878
879            let dial_to = Self::from_zmq_url(dial_to_str)?;
880            let bind_to = Self::from_zmq_url(bind_to_str)?;
881
882            return Ok((
883                Self::Alias {
884                    dial_to: Box::new(dial_to),
885                    bind_to: Box::new(bind_to),
886                },
887                None,
888            ));
889        }
890
891        // Try ZMQ-style URL format first (scheme://...)
892        let (scheme, address) = address.split_once("://").ok_or_else(|| {
893            anyhow::anyhow!("address must be in url form scheme://endppoint {}", address)
894        })?;
895
896        match scheme {
897            "tcp" => {
898                let (host, port, listener) = Self::parse_host_port_or_fd(address)?;
899                let socket_addr = if host == "*" {
900                    SocketAddr::new("::".parse().unwrap(), port)
901                } else {
902                    Self::resolve_hostname_to_socket_addr(host, port)?
903                };
904                Ok((Self::Tcp(socket_addr), listener))
905            }
906            "inproc" => {
907                let port = address.parse::<u64>().map_err(|_| {
908                    anyhow::anyhow!("inproc endpoint must be a valid port number: {}", address)
909                })?;
910                Ok((Self::Local(port), None))
911            }
912            "ipc" => Ok((Self::Unix(net::unix::SocketAddr::from_str(address)?), None)),
913            "metatls" | "tls" => {
914                let (host, port, listener) = Self::parse_host_port_or_fd(address)?;
915                let hostname = if host == "*" {
916                    std::net::Ipv6Addr::UNSPECIFIED.to_string()
917                } else {
918                    host.to_string()
919                };
920                let addr = match scheme {
921                    "metatls" => Self::MetaTls(TlsAddr::new(hostname, port)),
922                    _ => Self::Tls(TlsAddr::new(hostname, port)),
923                };
924                Ok((addr, listener))
925            }
926            scheme => Err(anyhow::anyhow!("unsupported ZMQ scheme: {}", scheme)),
927        }
928    }
929
930    /// Parse host:port where the port may be either a numeric port or `fdNNN`
931    /// referencing a pre-opened file descriptor. Returns (host, resolved_port, optional_listener).
932    fn parse_host_port_or_fd(
933        address: &str,
934    ) -> Result<(&str, u16, Option<std::net::TcpListener>), anyhow::Error> {
935        let (host, port_str) = address
936            .rsplit_once(':')
937            .ok_or_else(|| anyhow::anyhow!("invalid address format: {}", address))?;
938
939        if let Some(fd_str) = port_str.strip_prefix("fd") {
940            let fd_num: RawFd = fd_str
941                .parse()
942                .map_err(|_| anyhow::anyhow!("invalid file descriptor number: {}", port_str))?;
943            // Ensure the socket is in listening state. This is a no-op if
944            // listen() was already called, and required if only bind() was done.
945            // Safety: fd_num is valid and we are about to take ownership of it.
946            let borrowed = unsafe { std::os::unix::io::BorrowedFd::borrow_raw(fd_num) };
947            nix::sys::socket::listen(&borrowed, nix::sys::socket::Backlog::new(128)?)?;
948            // Safety: caller guarantees the fd is a valid bound TCP socket.
949            let std_listener = unsafe { std::net::TcpListener::from_raw_fd(fd_num) };
950            let local_addr = std_listener.local_addr()?;
951            Ok((host, local_addr.port(), Some(std_listener)))
952        } else {
953            let port: u16 = port_str
954                .parse()
955                .map_err(|_| anyhow::anyhow!("invalid port: {}", port_str))?;
956            Ok((host, port, None))
957        }
958    }
959
960    /// Render as a ZMQ-style URL, the inverse of [`from_zmq_url`](Self::from_zmq_url).
961    pub fn to_zmq_url(&self) -> String {
962        match self {
963            Self::Tcp(addr) => format!("tcp://{}", addr),
964            Self::MetaTls(addr) => format!("metatls://{}:{}", addr.hostname, addr.port),
965            Self::Tls(addr) => format!("tls://{}:{}", addr.hostname, addr.port),
966            Self::Local(index) => format!("inproc://{}", index),
967            Self::Unix(addr) => format!("ipc://{}", addr),
968            Self::Alias { dial_to, bind_to } => {
969                format!("{}@{}", dial_to.to_zmq_url(), bind_to.to_zmq_url())
970            }
971        }
972    }
973
974    /// Resolve hostname to SocketAddr, handling both IP addresses and hostnames
975    fn resolve_hostname_to_socket_addr(host: &str, port: u16) -> Result<SocketAddr, anyhow::Error> {
976        // Handle IPv6 addresses in brackets by stripping the brackets
977        let host_clean = if host.starts_with('[') && host.ends_with(']') {
978            &host[1..host.len() - 1]
979        } else {
980            host
981        };
982
983        // First try to parse as an IP address directly
984        if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
985            return Ok(SocketAddr::new(ip_addr, port));
986        }
987
988        // If not an IP, try hostname resolution
989        use std::net::ToSocketAddrs;
990        let mut addrs = (host_clean, port)
991            .to_socket_addrs()
992            .map_err(|e| anyhow::anyhow!("failed to resolve hostname '{}': {}", host_clean, e))?;
993
994        addrs
995            .next()
996            .ok_or_else(|| anyhow::anyhow!("no addresses found for hostname '{}'", host_clean))
997    }
998}
999
1000/// Universal channel transmitter.
1001pub struct ChannelTx<M: RemoteMessage> {
1002    inner: ChannelTxKind<M>,
1003}
1004
1005impl<M: RemoteMessage> fmt::Debug for ChannelTx<M> {
1006    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1007        f.debug_struct("ChannelTx")
1008            .field("addr", &self.addr())
1009            .finish()
1010    }
1011}
1012
1013/// Universal channel transmitter.
1014enum ChannelTxKind<M: RemoteMessage> {
1015    Local(local::LocalTx<M>),
1016    Net(net::NetTx<M>),
1017}
1018
1019#[async_trait]
1020impl<M: RemoteMessage> Tx<M> for ChannelTx<M> {
1021    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
1022        match &self.inner {
1023            ChannelTxKind::Local(tx) => tx.do_post(message, return_channel),
1024            ChannelTxKind::Net(tx) => tx.do_post(message, return_channel),
1025        }
1026    }
1027
1028    fn addr(&self) -> ChannelAddr {
1029        match &self.inner {
1030            ChannelTxKind::Local(tx) => tx.addr(),
1031            ChannelTxKind::Net(tx) => Tx::<M>::addr(tx),
1032        }
1033    }
1034
1035    fn status(&self) -> &watch::Receiver<TxStatus> {
1036        match &self.inner {
1037            ChannelTxKind::Local(tx) => tx.status(),
1038            ChannelTxKind::Net(tx) => tx.status(),
1039        }
1040    }
1041}
1042
1043/// Universal channel receiver.
1044pub struct ChannelRx<M: RemoteMessage> {
1045    inner: ChannelRxKind<M>,
1046}
1047
1048impl<M: RemoteMessage> fmt::Debug for ChannelRx<M> {
1049    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1050        f.debug_struct("ChannelRx")
1051            .field("addr", &self.addr())
1052            .finish()
1053    }
1054}
1055
1056/// Universal channel receiver.
1057enum ChannelRxKind<M: RemoteMessage> {
1058    Local(local::LocalRx<M>),
1059    Net(net::NetRx<M>),
1060}
1061
1062#[async_trait]
1063impl<M: RemoteMessage> Rx<M> for ChannelRx<M> {
1064    #[tracing::instrument(level = "debug", skip_all)]
1065    async fn recv(&mut self) -> Result<M, ChannelError> {
1066        match &mut self.inner {
1067            ChannelRxKind::Local(rx) => rx.recv().await,
1068            ChannelRxKind::Net(rx) => rx.recv().await,
1069        }
1070    }
1071
1072    fn addr(&self) -> ChannelAddr {
1073        match &self.inner {
1074            ChannelRxKind::Local(rx) => rx.addr(),
1075            ChannelRxKind::Net(rx) => rx.addr(),
1076        }
1077    }
1078
1079    async fn join(self) {
1080        match self.inner {
1081            ChannelRxKind::Local(rx) => rx.join().await,
1082            ChannelRxKind::Net(rx) => rx.join().await,
1083        }
1084    }
1085}
1086
1087/// Dial the provided address, returning the corresponding Tx, or error
1088/// if the channel cannot be established. The underlying connection is
1089/// dropped whenever the returned Tx is dropped.
1090#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
1091#[track_caller]
1092pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, ChannelError> {
1093    tracing::debug!(name = "dial", caller = %Location::caller(), %addr, "dialing channel {}", addr);
1094    let inner = match addr {
1095        ChannelAddr::Local(port) => ChannelTxKind::Local(local::dial(port)?),
1096        ChannelAddr::Tcp(_)
1097        | ChannelAddr::Unix(_)
1098        | ChannelAddr::Tls(_)
1099        | ChannelAddr::MetaTls(_) => {
1100            ChannelTxKind::Net(net::spawn(net::link(addr, net::SessionId::random(), 0)?))
1101        }
1102        ChannelAddr::Alias { dial_to, .. } => dial(*dial_to)?.inner,
1103    };
1104    Ok(ChannelTx { inner })
1105}
1106
1107/// Experimental: dial with out-of-order delivery and N parallel streams.
1108///
1109/// Opens N TCP connections sharing a single `SessionId` (distinct
1110/// `stream_id` in `1..=num_streams` so the server routes them through
1111/// the multi-stream receive path). Frames are load-balanced across
1112/// streams via a shared MPMC work queue — idle writers pull next.
1113/// API may change.
1114///
1115/// # Semantics (how this differs from [`dial`])
1116///
1117/// Multi-stream trades several of [`dial`]'s delivery guarantees for
1118/// aggregate bandwidth. Use [`dial`] when any of these matter:
1119///
1120/// - **Ordering.** Messages are delivered to the receiver in *arrival
1121///   order across streams*, not send order. Two messages posted back
1122///   to back on the sender may reach the receiver out of order when
1123///   they're carried on different TCP streams. [`dial`] is strictly
1124///   in-order.
1125///
1126/// - **Retransmission on reconnect.** If a writer's TCP connection
1127///   drops after the bytes of a message have been written but before
1128///   the peer acks, that message is **not retransmitted** on the new
1129///   connection. It may be silently lost. [`dial`] re-sends all
1130///   unacked messages on reconnect.
1131///
1132/// - **Delivery timeouts.** No delivery timeout is enforced. On
1133///   sustained peer outage, writers reconnect indefinitely (backoff
1134///   capped at 5s); senders block in `send().await` with no bound.
1135///   [`dial`] fails unacked sends after
1136///   `MESSAGE_DELIVERY_TIMEOUT`.
1137///
1138/// - **`Tx::send` return semantics.** On session shutdown, messages
1139///   still in the unacked buffer have their `return_channel` dropped.
1140///   Per the [`Tx::send`] contract, this makes `send().await` return
1141///   `Ok(())` for messages that may never have been delivered — i.e.
1142///   a "success" return does not actually confirm delivery here.
1143///   [`dial`] delivers a structured `SendError` instead.
1144///
1145/// # Ack semantics
1146///
1147/// Receivers ack a cumulative watermark: the highest `N` such that
1148/// all of `0..=N` have been observed across all streams. Acks may
1149/// stall behind a single missing seq on a slow stream.
1150pub fn exp_dial_unordered<M: RemoteMessage>(
1151    addr: ChannelAddr,
1152    num_streams: usize,
1153) -> Result<ChannelTx<M>, ChannelError> {
1154    assert!(num_streams > 0);
1155    let session_id = net::SessionId::random();
1156    let links: Vec<net::NetLink> = (1..=num_streams)
1157        .map(|i| net::link(addr.clone(), session_id, i as u8))
1158        .collect::<Result<_, _>>()?;
1159    let inner = ChannelTxKind::Net(net::spawn_unordered(links));
1160    Ok(ChannelTx { inner })
1161}
1162
1163/// Serve on the provided channel address. The server is turned down
1164/// when the returned Rx is dropped.
1165#[track_caller]
1166pub fn serve<M: RemoteMessage>(
1167    addr: ChannelAddr,
1168) -> Result<(ChannelAddr, ChannelRx<M>), ChannelError> {
1169    serve_with_listener(addr, None)
1170}
1171
1172/// Serve on the provided channel address, optionally using a pre-opened TCP listener.
1173/// When `listener` is `Some`, the provided listener is used instead of binding a new socket.
1174/// The server is turned down when the returned Rx is dropped.
1175#[track_caller]
1176pub fn serve_with_listener<M: RemoteMessage>(
1177    addr: ChannelAddr,
1178    listener: Option<std::net::TcpListener>,
1179) -> Result<(ChannelAddr, ChannelRx<M>), ChannelError> {
1180    let caller = Location::caller();
1181    serve_inner(addr, listener).map(|(addr, inner)| {
1182        tracing::debug!(
1183            name = "serve",
1184            %addr,
1185            %caller,
1186        );
1187        (addr, ChannelRx { inner })
1188    })
1189}
1190
1191fn serve_inner<M: RemoteMessage>(
1192    addr: ChannelAddr,
1193    listener: Option<std::net::TcpListener>,
1194) -> Result<(ChannelAddr, ChannelRxKind<M>), ChannelError> {
1195    match addr {
1196        ChannelAddr::Unix(_) => {
1197            assert!(
1198                listener.is_none(),
1199                "pre-opened listener not supported for Unix transport"
1200            );
1201            let (addr, rx) = net::server::serve::<M>(addr, listener)?;
1202            Ok((addr, ChannelRxKind::Net(rx)))
1203        }
1204        ChannelAddr::Tcp(_) | ChannelAddr::Tls(_) | ChannelAddr::MetaTls(_) => {
1205            let (addr, rx) = net::server::serve::<M>(addr, listener)?;
1206            Ok((addr, ChannelRxKind::Net(rx)))
1207        }
1208        ChannelAddr::Local(0) => {
1209            assert!(
1210                listener.is_none(),
1211                "pre-opened listener not supported for Local transport"
1212            );
1213            let (port, rx) = local::serve::<M>();
1214            Ok((ChannelAddr::Local(port), ChannelRxKind::Local(rx)))
1215        }
1216        ChannelAddr::Local(a) => Ok((
1217            ChannelAddr::Local(a),
1218            ChannelRxKind::Local(local::bind::<M>(a)?),
1219        )),
1220        ChannelAddr::Alias { dial_to, bind_to } => {
1221            let (bound_addr, rx) = serve_inner::<M>(*bind_to, listener)?;
1222            let alias_addr = ChannelAddr::Alias {
1223                dial_to,
1224                bind_to: Box::new(bound_addr),
1225            };
1226            Ok((alias_addr, rx))
1227        }
1228    }
1229}
1230
1231/// Serve on the local address. The server is turned down
1232/// when the returned Rx is dropped.
1233pub fn serve_local<M: RemoteMessage>() -> (ChannelAddr, ChannelRx<M>) {
1234    let (port, rx) = local::serve::<M>();
1235    (
1236        ChannelAddr::Local(port),
1237        ChannelRx {
1238            inner: ChannelRxKind::Local(rx),
1239        },
1240    )
1241}
1242
1243/// Reserve a local channel address that can be served later.
1244///
1245/// Local channels are backed by an in-process port registry, so reserving a
1246/// concrete address is a synchronous allocation that does not require a Tokio
1247/// runtime or an OS listener. Gateways use this to have a stable advertised
1248/// local location immediately, including when the process-wide gateway is
1249/// initialized from a [`std::sync::OnceLock`]. Serving is a separate step that
1250/// binds the reserved port to a receiver.
1251///
1252/// Network transports do not have an equivalent reservation API here: their
1253/// concrete addresses come from binding sockets and starting the corresponding
1254/// channel server.
1255pub fn reserve_local_addr() -> ChannelAddr {
1256    ChannelAddr::Local(local::reserve())
1257}
1258
1259#[cfg(test)]
1260mod tests {
1261    use std::assert_matches;
1262    use std::collections::HashSet;
1263    use std::net::IpAddr;
1264    use std::net::Ipv4Addr;
1265    use std::net::Ipv6Addr;
1266    use std::time::Duration;
1267
1268    use rand::RngExt as _;
1269    use rand::distr::Uniform;
1270    use tokio::task::JoinSet;
1271
1272    use super::net::*;
1273    use super::*;
1274    #[test]
1275    fn test_channel_addr() {
1276        let cases_ok = vec![
1277            (
1278                "tcp<DELIM>[::1]:1234",
1279                ChannelAddr::Tcp(SocketAddr::new(
1280                    IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
1281                    1234,
1282                )),
1283            ),
1284            (
1285                "tcp<DELIM>127.0.0.1:8080",
1286                ChannelAddr::Tcp(SocketAddr::new(
1287                    IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
1288                    8080,
1289                )),
1290            ),
1291            #[cfg(target_os = "linux")]
1292            ("local<DELIM>123", ChannelAddr::Local(123)),
1293            (
1294                "unix<DELIM>@yolo",
1295                ChannelAddr::Unix(
1296                    unix::SocketAddr::from_abstract_name("yolo")
1297                        .expect("can't make socket from abstract name"),
1298                ),
1299            ),
1300            (
1301                "unix<DELIM>/cool/socket-path",
1302                ChannelAddr::Unix(
1303                    unix::SocketAddr::from_pathname("/cool/socket-path")
1304                        .expect("can't make socket from path"),
1305                ),
1306            ),
1307        ];
1308
1309        for (raw, parsed) in cases_ok {
1310            for delim in ["!", ":"] {
1311                let raw = raw.replace("<DELIM>", delim);
1312                assert_eq!(raw.parse::<ChannelAddr>().unwrap(), parsed);
1313            }
1314        }
1315
1316        let cases_err = vec![
1317            ("tcp:abcdef..123124", "invalid socket address syntax"),
1318            ("xxx:foo", "no such channel type: xxx"),
1319            ("127.0.0.1", "no channel type specified"),
1320            ("local:abc", "invalid digit found in string"),
1321        ];
1322
1323        for (raw, error) in cases_err {
1324            let Err(err) = raw.parse::<ChannelAddr>() else {
1325                panic!("expected error parsing: {}", &raw)
1326            };
1327            assert_eq!(format!("{}", err), error);
1328        }
1329    }
1330
1331    #[test]
1332    fn test_zmq_style_channel_addr() {
1333        // Test TCP addresses
1334        assert_eq!(
1335            ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080").unwrap(),
1336            ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap())
1337        );
1338
1339        // Test TCP wildcard binding
1340        assert_eq!(
1341            ChannelAddr::from_zmq_url("tcp://*:5555").unwrap(),
1342            ChannelAddr::Tcp("[::]:5555".parse().unwrap())
1343        );
1344
1345        // Test inproc (maps to local with numeric endpoint)
1346        assert_eq!(
1347            ChannelAddr::from_zmq_url("inproc://12345").unwrap(),
1348            ChannelAddr::Local(12345)
1349        );
1350
1351        // Test ipc (maps to unix)
1352        assert_eq!(
1353            ChannelAddr::from_zmq_url("ipc:///tmp/my-socket").unwrap(),
1354            ChannelAddr::Unix(unix::SocketAddr::from_pathname("/tmp/my-socket").unwrap())
1355        );
1356
1357        // Test metatls with hostname
1358        assert_eq!(
1359            ChannelAddr::from_zmq_url("metatls://example.com:443").unwrap(),
1360            ChannelAddr::MetaTls(TlsAddr::new("example.com", 443))
1361        );
1362
1363        // Test metatls with IP address (should be normalized)
1364        assert_eq!(
1365            ChannelAddr::from_zmq_url("metatls://192.168.1.1:443").unwrap(),
1366            ChannelAddr::MetaTls(TlsAddr::new("192.168.1.1", 443))
1367        );
1368
1369        // Test metatls with wildcard (should use IPv6 unspecified address)
1370        assert_eq!(
1371            ChannelAddr::from_zmq_url("metatls://*:8443").unwrap(),
1372            ChannelAddr::MetaTls(TlsAddr::new("::", 8443))
1373        );
1374
1375        // Test TCP hostname resolution (should resolve hostname to IP)
1376        // Note: This test may fail in environments without proper DNS resolution
1377        // We test that it at least doesn't fail to parse
1378        let tcp_hostname_result = ChannelAddr::from_zmq_url("tcp://localhost:8080");
1379        assert!(tcp_hostname_result.is_ok());
1380
1381        // Test IPv6 address
1382        assert_eq!(
1383            ChannelAddr::from_zmq_url("tcp://[::1]:1234").unwrap(),
1384            ChannelAddr::Tcp("[::1]:1234".parse().unwrap())
1385        );
1386
1387        // Test error cases
1388        assert!(ChannelAddr::from_zmq_url("invalid://scheme").is_err());
1389        assert!(ChannelAddr::from_zmq_url("tcp://invalid-port").is_err());
1390        assert!(ChannelAddr::from_zmq_url("metatls://no-port").is_err());
1391        assert!(ChannelAddr::from_zmq_url("inproc://not-a-number").is_err());
1392
1393        // IPv6 normalization: leading zeros are stripped
1394        assert_eq!(
1395            ChannelAddr::from_zmq_url("metatls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443")
1396                .unwrap(),
1397            ChannelAddr::MetaTls(TlsAddr::new("2a03:83e4:5000:c000:56d7:cf:75ce:144a", 443))
1398        );
1399
1400        // Short and long forms of the same IPv6 produce equal ChannelAddr values
1401        assert_eq!(
1402            ChannelAddr::from_zmq_url("metatls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443")
1403                .unwrap(),
1404            ChannelAddr::from_zmq_url("metatls://2a03:83e4:5000:c000:56d7:cf:75ce:144a:443")
1405                .unwrap(),
1406        );
1407
1408        // Bracketed IPv6 is normalized
1409        assert_eq!(
1410            ChannelAddr::from_zmq_url("metatls://[::1]:443").unwrap(),
1411            ChannelAddr::MetaTls(TlsAddr::new("::1", 443))
1412        );
1413
1414        // Same tests for tls://
1415        assert_eq!(
1416            ChannelAddr::from_zmq_url("tls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443").unwrap(),
1417            ChannelAddr::Tls(TlsAddr::new("2a03:83e4:5000:c000:56d7:cf:75ce:144a", 443))
1418        );
1419        assert_eq!(
1420            ChannelAddr::from_zmq_url("tls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443").unwrap(),
1421            ChannelAddr::from_zmq_url("tls://2a03:83e4:5000:c000:56d7:cf:75ce:144a:443").unwrap(),
1422        );
1423        assert_eq!(
1424            ChannelAddr::from_zmq_url("tls://[::1]:443").unwrap(),
1425            ChannelAddr::Tls(TlsAddr::new("::1", 443))
1426        );
1427    }
1428
1429    #[tokio::test]
1430    async fn test_reserved_local_addr_can_be_served() {
1431        let addr = reserve_local_addr();
1432        assert!(dial::<u64>(addr.clone()).is_err());
1433
1434        let (bound_addr, mut rx) = serve::<u64>(addr.clone()).unwrap();
1435        assert_eq!(bound_addr, addr);
1436
1437        let tx = dial::<u64>(addr.clone()).unwrap();
1438        tx.post(123);
1439        assert_eq!(rx.recv().await.unwrap(), 123);
1440        drop(rx);
1441
1442        let (rebound_addr, _rx) = serve::<u64>(addr.clone()).unwrap();
1443        assert_eq!(rebound_addr, addr);
1444    }
1445
1446    #[test]
1447    fn test_normalize_host() {
1448        // Plain IPv4 passes through
1449        assert_eq!(normalize_host("192.168.1.1"), "192.168.1.1");
1450
1451        // Plain hostname passes through
1452        assert_eq!(normalize_host("example.com"), "example.com");
1453
1454        // IPv6 with leading zeros gets normalized
1455        assert_eq!(
1456            normalize_host("2a03:83e4:5000:c000:56d7:00cf:75ce:144a"),
1457            "2a03:83e4:5000:c000:56d7:cf:75ce:144a"
1458        );
1459
1460        // Bracketed IPv6 is stripped and normalized
1461        assert_eq!(normalize_host("[::1]"), "::1");
1462
1463        // Without bracket stripping, IpAddr::from_str rejects bracketed
1464        // addresses. This demonstrates that the bracket stripping in
1465        // normalize_host is necessary.
1466        assert!("[::1]".parse::<IpAddr>().is_err());
1467    }
1468
1469    #[test]
1470    fn test_zmq_style_alias_channel_addr() {
1471        // Test Alias format: dial_to_url@bind_to_url
1472        // The format is: dial_to_url@bind_to_url where both are valid ZMQ URLs
1473        // Note: Alias format is only supported for TCP addresses
1474
1475        // Test Alias with tcp on both sides
1476        let alias_addr = ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::]:8800").unwrap();
1477        match alias_addr {
1478            ChannelAddr::Alias { dial_to, bind_to } => {
1479                assert_eq!(
1480                    *dial_to,
1481                    ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap())
1482                );
1483                assert_eq!(*bind_to, ChannelAddr::Tcp("[::]:8800".parse().unwrap()));
1484            }
1485            _ => panic!("Expected Alias"),
1486        }
1487
1488        // Non-tcp left side: alias branch is skipped, parsed as regular address.
1489        // metatls:// with garbage host is not an alias.
1490        let non_alias = ChannelAddr::from_zmq_url("metatls://example.com:443@tcp://127.0.0.1:8080");
1491        assert!(
1492            !matches!(non_alias, Ok(ChannelAddr::Alias { .. })),
1493            "non-tcp left side must not produce Alias"
1494        );
1495
1496        // Test error: alias with non-tcp bind_to (not supported)
1497        assert!(
1498            ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@metatls://example.com:443").is_err()
1499        );
1500
1501        // Test error: invalid scheme falls through to scheme parsing, errors there
1502        assert!(ChannelAddr::from_zmq_url("invalid://scheme@tcp://127.0.0.1:8080").is_err());
1503
1504        // Test error: invalid bind_to URL in Alias
1505        assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@invalid://scheme").is_err());
1506
1507        // Test error: missing port in dial_to
1508        assert!(ChannelAddr::from_zmq_url("tcp://host@tcp://127.0.0.1:8080").is_err());
1509
1510        // Test error: missing port in bind_to
1511        assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@tcp://example.com").is_err());
1512    }
1513
1514    #[tokio::test]
1515    async fn test_multiple_connections() {
1516        for addr in ChannelTransport::all().map(ChannelAddr::any) {
1517            let (listen_addr, mut rx) = crate::channel::serve::<u64>(addr).unwrap();
1518
1519            let mut sends: JoinSet<()> = JoinSet::new();
1520            for message in 0u64..100u64 {
1521                let addr = listen_addr.clone();
1522                sends.spawn(async move {
1523                    let tx = dial::<u64>(addr).unwrap();
1524                    tx.post(message);
1525                });
1526            }
1527
1528            let mut received: HashSet<u64> = HashSet::new();
1529            while received.len() < 100 {
1530                received.insert(rx.recv().await.unwrap());
1531            }
1532
1533            for message in 0u64..100u64 {
1534                assert!(received.contains(&message));
1535            }
1536
1537            loop {
1538                match sends.join_next().await {
1539                    Some(Ok(())) => (),
1540                    Some(Err(err)) => panic!("{}", err),
1541                    None => break,
1542                }
1543            }
1544        }
1545    }
1546
1547    #[tokio::test]
1548    async fn test_server_close() {
1549        for addr in ChannelTransport::all().map(ChannelAddr::any) {
1550            if net::is_net_addr(&addr) {
1551                // Net has store-and-forward semantics. We don't expect failures
1552                // on closure.
1553                continue;
1554            }
1555
1556            let (listen_addr, rx) = crate::channel::serve::<u64>(addr).unwrap();
1557
1558            let tx = dial::<u64>(listen_addr).unwrap();
1559            tx.post(123);
1560            drop(rx);
1561
1562            // New transmits should fail... but there is buffering, etc.,
1563            // which can cause the failure to be delayed. We give it
1564            // a deadline, but it can still technically fail -- the test
1565            // should be considered a kind of integration test.
1566            let start = tokio::time::Instant::now();
1567
1568            let result = loop {
1569                let (return_tx, return_rx) = oneshot::channel();
1570                tx.try_post(123, return_tx);
1571                let result = return_rx.await;
1572
1573                if result.is_ok() || start.elapsed() > Duration::from_secs(10) {
1574                    break result;
1575                }
1576            };
1577            assert_matches!(
1578                result,
1579                Ok(SendError {
1580                    error: ChannelError::Closed,
1581                    message: 123,
1582                    reason: None
1583                })
1584            );
1585        }
1586    }
1587
1588    fn addrs() -> Vec<ChannelAddr> {
1589        let rng = rand::rng();
1590        let uniform = Uniform::new_inclusive('a', 'z').unwrap();
1591        vec![
1592            "tcp:[::1]:0".parse().unwrap(),
1593            "local:0".parse().unwrap(),
1594            #[cfg(target_os = "linux")]
1595            "unix:".parse().unwrap(),
1596            #[cfg(target_os = "linux")]
1597            format!(
1598                "unix:@{}",
1599                rng.sample_iter(uniform).take(10).collect::<String>()
1600            )
1601            .parse()
1602            .unwrap(),
1603        ]
1604    }
1605
1606    #[test]
1607    fn test_bind_spec_from_str() {
1608        // Test parsing ChannelTransport strings -> BindSpec::Any
1609        assert_eq!(
1610            BindSpec::from_str("tcp").unwrap(),
1611            BindSpec::Any(ChannelTransport::Tcp(TcpMode::Hostname))
1612        );
1613        assert_eq!(
1614            BindSpec::from_str("metatls(Hostname)").unwrap(),
1615            BindSpec::Any(ChannelTransport::MetaTls(TlsMode::Hostname))
1616        );
1617
1618        // Test parsing ChannelAddr strings -> BindSpec::Addr
1619        assert_eq!(
1620            BindSpec::from_str("tcp:127.0.0.1:8080").unwrap(),
1621            BindSpec::Addr(ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap()))
1622        );
1623
1624        // Test parsing ZMQ URL format -> BindSpec::Addr
1625        assert_eq!(
1626            BindSpec::from_str("tcp://127.0.0.1:9000").unwrap(),
1627            BindSpec::Addr(ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap()))
1628        );
1629        assert_eq!(
1630            BindSpec::from_str("tcp://127.0.0.1:9000@tcp://[::1]:7200").unwrap(),
1631            BindSpec::Addr(
1632                ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::1]:7200").unwrap()
1633            )
1634        );
1635
1636        // Test error cases
1637        assert!(BindSpec::from_str("invalid_spec").is_err());
1638        assert!(BindSpec::from_str("unknown://scheme").is_err());
1639        assert!(BindSpec::from_str("").is_err());
1640    }
1641
1642    #[tokio::test]
1643    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Server(Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" }))
1644    #[cfg_attr(not(fbcode_build), ignore)]
1645    async fn test_dial_serve() {
1646        for addr in addrs() {
1647            let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1648            let tx = crate::channel::dial(listen_addr).unwrap();
1649            tx.post(123);
1650            assert_eq!(rx.recv().await.unwrap(), 123);
1651        }
1652    }
1653
1654    #[tokio::test]
1655    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Server(Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" }))
1656    #[cfg_attr(not(fbcode_build), ignore)]
1657    async fn test_send() {
1658        let config = hyperactor_config::global::lock();
1659
1660        // Use temporary config for this test
1661        let _guard1 = config.override_key(
1662            crate::config::MESSAGE_DELIVERY_TIMEOUT,
1663            Duration::from_secs(1),
1664        );
1665        let _guard2 = config.override_key(crate::config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
1666        for addr in addrs() {
1667            let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1668            let tx = crate::channel::dial(listen_addr).unwrap();
1669            tx.send(123).await.unwrap();
1670            assert_eq!(rx.recv().await.unwrap(), 123);
1671
1672            drop(rx);
1673            assert_matches!(
1674                tx.send(123).await.unwrap_err(),
1675                SendError {
1676                    error: ChannelError::Closed,
1677                    message: 123,
1678                    ..
1679                }
1680            );
1681        }
1682    }
1683
1684    #[test]
1685    fn test_find_routable_address_skips_link_local_ipv6() {
1686        let link_local_v6: IpAddr = "fe80::1".parse().unwrap();
1687        let routable_v6: IpAddr = "2001:db8::1".parse().unwrap();
1688        let addrs = vec![link_local_v6, routable_v6];
1689        assert_eq!(find_routable_address(&addrs), Some(routable_v6));
1690    }
1691
1692    #[test]
1693    fn test_find_routable_address_skips_link_local_ipv4() {
1694        let link_local_v4: IpAddr = "169.254.1.1".parse().unwrap();
1695        let routable_v4: IpAddr = "192.168.1.1".parse().unwrap();
1696        let addrs = vec![link_local_v4, routable_v4];
1697        assert_eq!(find_routable_address(&addrs), Some(routable_v4));
1698    }
1699
1700    #[test]
1701    fn test_find_routable_address_returns_none_when_all_link_local() {
1702        let link_local_v6: IpAddr = "fe80::1".parse().unwrap();
1703        let link_local_v4: IpAddr = "169.254.1.1".parse().unwrap();
1704        let addrs = vec![link_local_v6, link_local_v4];
1705        assert_eq!(find_routable_address(&addrs), None);
1706    }
1707
1708    #[test]
1709    fn test_find_routable_address_mixed() {
1710        let link_local_v6: IpAddr = "fe80::1".parse().unwrap();
1711        let link_local_v4: IpAddr = "169.254.0.1".parse().unwrap();
1712        let routable_v4: IpAddr = "10.0.0.1".parse().unwrap();
1713        let routable_v6: IpAddr = "2001:db8::2".parse().unwrap();
1714
1715        // First routable address in list order should be returned.
1716        let addrs = vec![link_local_v6, link_local_v4, routable_v4, routable_v6];
1717        assert_eq!(find_routable_address(&addrs), Some(routable_v4));
1718    }
1719}