hyperactor/
channel.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! One-way, multi-process, typed communication channels. These are used
10//! to send messages between mailboxes residing in different processes.
11
12use core::net::SocketAddr;
13use std::fmt;
14use std::net::IpAddr;
15use std::net::Ipv6Addr;
16#[cfg(target_os = "linux")]
17use std::os::linux::net::SocketAddrExt;
18use std::panic::Location;
19use std::str::FromStr;
20
21use async_trait::async_trait;
22use hyperactor_config::attrs::AttrValue;
23use serde::Deserialize;
24use serde::Serialize;
25use tokio::sync::mpsc;
26use tokio::sync::oneshot;
27use tokio::sync::watch;
28
29use crate as hyperactor;
30use crate::RemoteMessage;
31pub(crate) mod local;
32pub(crate) mod net;
33
34// Public TLS API for HTTP services (mesh admin, TUI, etc.). The
35// implementation lives in `net` but we re-export here to keep `net`'s
36// internal types out of the public API surface.
37pub use net::try_tls_acceptor;
38pub use net::try_tls_connector;
39pub use net::try_tls_pem_bundle;
40
41/// Duplex channel API: a single connection carries messages in both directions.
42pub mod duplex {
43    pub use super::net::duplex::DuplexRx;
44    pub use super::net::duplex::DuplexServer;
45    pub use super::net::duplex::DuplexTx;
46    pub use super::net::duplex::dial;
47    pub use super::net::duplex::serve;
48}
49
50/// The type of error that can occur on channel operations.
51#[derive(thiserror::Error, Debug)]
52pub enum ChannelError {
53    /// An operation was attempted on a closed channel.
54    #[error("channel closed")]
55    Closed,
56
57    /// An error occurred during send.
58    #[error("send: {0}")]
59    Send(#[source] anyhow::Error),
60
61    /// A network client error.
62    #[error(transparent)]
63    Client(#[from] net::ClientError),
64
65    /// The address was not valid.
66    #[error("invalid address {0:?}")]
67    InvalidAddress(String),
68
69    /// A serving error was encountered.
70    #[error(transparent)]
71    Server(#[from] net::ServerError),
72
73    /// A bincode serialization or deserialization error occurred.
74    #[error(transparent)]
75    Bincode(#[from] Box<bincode::ErrorKind>),
76
77    /// Data encoding errors.
78    #[error(transparent)]
79    Data(#[from] wirevalue::Error),
80
81    /// Some other error.
82    #[error(transparent)]
83    Other(#[from] anyhow::Error),
84
85    /// An operation timeout occurred.
86    #[error("operation timed out after {0:?}")]
87    Timeout(std::time::Duration),
88}
89
90/// An error that occurred during send. Returns the message that failed to send.
91#[derive(thiserror::Error, Debug)]
92#[error("{error} for reason {reason:?}")]
93pub struct SendError<M: RemoteMessage> {
94    /// Inner channel error
95    #[source]
96    pub error: ChannelError,
97    /// Message that couldn't be sent
98    pub message: M,
99    /// Reason that message couldn't be sent, if any.
100    pub reason: Option<String>,
101}
102
103impl<M: RemoteMessage> From<SendError<M>> for ChannelError {
104    fn from(error: SendError<M>) -> Self {
105        error.error
106    }
107}
108
109/// The possible states of a `Tx`.
110#[derive(Debug, Copy, Clone, PartialEq)]
111pub enum TxStatus {
112    /// The tx is good.
113    Active,
114    /// The tx cannot be used for message delivery.
115    Closed,
116}
117
118/// The transmit end of an M-typed channel.
119#[async_trait]
120pub trait Tx<M: RemoteMessage> {
121    /// Post a message; returning failed deliveries on the return channel, if provided.
122    /// If provided, the sender is dropped when the message has been
123    /// enqueued at the channel endpoint.
124    ///
125    /// Users should use the `try_post`, and `post` variants directly.
126    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>);
127
128    /// Enqueue a `message` on the local end of the channel. The
129    /// message is either delivered, or we eventually discover that
130    /// the channel has failed and it will be sent back on `return_channel`.
131    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SendError`.
132    #[tracing::instrument(level = "debug", skip_all)]
133    fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
134        self.do_post(message, Some(return_channel));
135    }
136
137    /// Enqueue a message to be sent on the channel.
138    #[hyperactor::instrument_infallible]
139    fn post(&self, message: M) {
140        self.do_post(message, None);
141    }
142
143    /// Send a message synchronously, returning when the message has
144    /// been delivered to the remote end of the channel.
145    async fn send(&self, message: M) -> Result<(), SendError<M>> {
146        let (tx, rx) = oneshot::channel();
147        self.try_post(message, tx);
148        match rx.await {
149            // Channel was closed; the message was not delivered.
150            Ok(err) => Err(err),
151
152            // Channel was dropped; the message was successfully enqueued
153            // on the remote end of the channel.
154            Err(_) => Ok(()),
155        }
156    }
157
158    /// The channel address to which this Tx is sending.
159    fn addr(&self) -> ChannelAddr;
160
161    /// A means to monitor the health of a `Tx`.
162    fn status(&self) -> &watch::Receiver<TxStatus>;
163}
164
165/// The receive end of an M-typed channel.
166#[async_trait]
167pub trait Rx<M: RemoteMessage> {
168    /// Receive the next message from the channel. If the channel returns
169    /// an error it is considered broken and should be discarded.
170    async fn recv(&mut self) -> Result<M, ChannelError>;
171
172    /// The channel address from which this Rx is receiving.
173    fn addr(&self) -> ChannelAddr;
174
175    /// Gracefully shut down the channel receiver, flushing any pending
176    /// acks before returning. Implementations must ensure all pending
177    /// acks are sent before this method returns.
178    async fn join(self)
179    where
180        Self: Sized;
181}
182
183#[allow(dead_code)] // Not used outside tests.
184struct MpscTx<M: RemoteMessage> {
185    tx: mpsc::UnboundedSender<M>,
186    addr: ChannelAddr,
187    status: watch::Receiver<TxStatus>,
188}
189
190impl<M: RemoteMessage> MpscTx<M> {
191    #[allow(dead_code)] // Not used outside tests.
192    pub fn new(tx: mpsc::UnboundedSender<M>, addr: ChannelAddr) -> (Self, watch::Sender<TxStatus>) {
193        let (sender, receiver) = watch::channel(TxStatus::Active);
194        (
195            Self {
196                tx,
197                addr,
198                status: receiver,
199            },
200            sender,
201        )
202    }
203}
204
205#[async_trait]
206impl<M: RemoteMessage> Tx<M> for MpscTx<M> {
207    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
208        if let Err(mpsc::error::SendError(message)) = self.tx.send(message) {
209            if let Some(return_channel) = return_channel {
210                return_channel
211                    .send(SendError {
212                        error: ChannelError::Closed,
213                        message,
214                        reason: None,
215                    })
216                    .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
217            }
218        }
219    }
220
221    fn addr(&self) -> ChannelAddr {
222        self.addr.clone()
223    }
224
225    fn status(&self) -> &watch::Receiver<TxStatus> {
226        &self.status
227    }
228}
229
230#[allow(dead_code)] // Not used outside tests.
231struct MpscRx<M: RemoteMessage> {
232    rx: mpsc::UnboundedReceiver<M>,
233    addr: ChannelAddr,
234    // Used to report the status to the Tx side.
235    status_sender: watch::Sender<TxStatus>,
236}
237
238impl<M: RemoteMessage> MpscRx<M> {
239    #[allow(dead_code)] // Not used outside tests.
240    pub fn new(
241        rx: mpsc::UnboundedReceiver<M>,
242        addr: ChannelAddr,
243        status_sender: watch::Sender<TxStatus>,
244    ) -> Self {
245        Self {
246            rx,
247            addr,
248            status_sender,
249        }
250    }
251}
252
253impl<M: RemoteMessage> Drop for MpscRx<M> {
254    fn drop(&mut self) {
255        let _ = self.status_sender.send(TxStatus::Closed);
256    }
257}
258
259#[async_trait]
260impl<M: RemoteMessage> Rx<M> for MpscRx<M> {
261    async fn recv(&mut self) -> Result<M, ChannelError> {
262        self.rx.recv().await.ok_or(ChannelError::Closed)
263    }
264
265    fn addr(&self) -> ChannelAddr {
266        self.addr.clone()
267    }
268
269    async fn join(self) {}
270}
271
272/// The hostname to use for TLS connections.
273#[derive(
274    Clone,
275    Debug,
276    PartialEq,
277    Eq,
278    Hash,
279    Serialize,
280    Deserialize,
281    strum::EnumIter,
282    strum::Display,
283    strum::EnumString
284)]
285pub enum TcpMode {
286    /// Use localhost/loopback for the connection.
287    Localhost,
288    /// Use host domain name for the connection.
289    Hostname,
290}
291
292/// The hostname to use for TLS connections.
293#[derive(
294    Clone,
295    Debug,
296    PartialEq,
297    Eq,
298    Hash,
299    Serialize,
300    Deserialize,
301    strum::EnumIter,
302    strum::Display,
303    strum::EnumString
304)]
305pub enum TlsMode {
306    /// Use IpV6 address for TLS connections.
307    IpV6,
308    /// Use host domain name for TLS connections.
309    Hostname,
310    // TODO: consider adding IpV4 support.
311}
312
313/// Address format for TLS channels.
314#[derive(
315    Clone,
316    Debug,
317    PartialEq,
318    Eq,
319    Hash,
320    Serialize,
321    Deserialize,
322    Ord,
323    PartialOrd
324)]
325pub struct TlsAddr {
326    /// The hostname to connect to.
327    pub hostname: Hostname,
328    /// The port to connect to.
329    pub port: Port,
330}
331
332impl TlsAddr {
333    /// Creates a new TLS address with a normalized hostname.
334    pub fn new(hostname: impl Into<Hostname>, port: Port) -> Self {
335        Self {
336            hostname: normalize_host(&hostname.into()),
337            port,
338        }
339    }
340
341    /// Returns the port number for this address.
342    pub fn port(&self) -> Port {
343        self.port
344    }
345
346    /// Returns the hostname for this address.
347    pub fn hostname(&self) -> &str {
348        &self.hostname
349    }
350}
351
352impl fmt::Display for TlsAddr {
353    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354        write!(f, "{}:{}", self.hostname, self.port)
355    }
356}
357
358/// Types of channel transports.
359#[derive(
360    Clone,
361    Debug,
362    PartialEq,
363    Eq,
364    Hash,
365    Serialize,
366    Deserialize,
367    typeuri::Named
368)]
369pub enum ChannelTransport {
370    /// Transport over a TCP connection.
371    Tcp(TcpMode),
372
373    /// Transport over a TCP connection with TLS support within Meta
374    MetaTls(TlsMode),
375
376    /// Transport over a TCP connection with configurable TLS support
377    Tls,
378
379    /// Local transports uses an in-process registry and mpsc channels.
380    Local,
381
382    /// Transport over unix domain socket.
383    Unix,
384}
385
386impl fmt::Display for ChannelTransport {
387    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
388        match self {
389            Self::Tcp(mode) => write!(f, "tcp({:?})", mode),
390            Self::MetaTls(mode) => write!(f, "metatls({:?})", mode),
391            Self::Tls => write!(f, "tls"),
392            Self::Local => write!(f, "local"),
393            Self::Unix => write!(f, "unix"),
394        }
395    }
396}
397
398impl FromStr for ChannelTransport {
399    type Err = anyhow::Error;
400
401    fn from_str(s: &str) -> Result<Self, Self::Err> {
402        match s {
403            // Default to TcpMode::Hostname, if the mode isn't set
404            "tcp" => Ok(ChannelTransport::Tcp(TcpMode::Hostname)),
405            s if s.starts_with("tcp(") => {
406                let inner = &s["tcp(".len()..s.len() - 1];
407                let mode = inner.parse()?;
408                Ok(ChannelTransport::Tcp(mode))
409            }
410            "local" => Ok(ChannelTransport::Local),
411            "unix" => Ok(ChannelTransport::Unix),
412            "tls" => Ok(ChannelTransport::Tls),
413            s if s.starts_with("metatls(") && s.ends_with(")") => {
414                let inner = &s["metatls(".len()..s.len() - 1];
415                let mode = inner.parse()?;
416                Ok(ChannelTransport::MetaTls(mode))
417            }
418            unknown => Err(anyhow::anyhow!("unknown channel transport: {}", unknown)),
419        }
420    }
421}
422
423impl ChannelTransport {
424    /// All known channel transports.
425    pub fn all() -> [ChannelTransport; 3] {
426        [
427            // TODO: @rusch add back once figuring out unspecified override for OSS CI
428            // ChannelTransport::Tcp(TcpMode::Localhost),
429            ChannelTransport::Tcp(TcpMode::Hostname),
430            ChannelTransport::Local,
431            ChannelTransport::Unix,
432            // Tls requires certificate configuration, tested separately in tls::tests
433            // TODO add MetaTls (T208303369)
434        ]
435    }
436
437    /// Return an "any" address for this transport.
438    pub fn any(&self) -> ChannelAddr {
439        ChannelAddr::any(self.clone())
440    }
441
442    /// Returns true if this transport type represents a remote channel.
443    pub fn is_remote(&self) -> bool {
444        match self {
445            ChannelTransport::Tcp(_) => true,
446            ChannelTransport::MetaTls(_) => true,
447            ChannelTransport::Tls => true,
448            ChannelTransport::Local => false,
449            ChannelTransport::Unix => false,
450        }
451    }
452}
453
454impl AttrValue for ChannelTransport {
455    fn display(&self) -> String {
456        self.to_string()
457    }
458
459    fn parse(s: &str) -> Result<Self, anyhow::Error> {
460        s.parse()
461    }
462}
463
464/// Specifies how to bind a channel server.
465#[derive(
466    Clone,
467    Debug,
468    PartialEq,
469    Eq,
470    Hash,
471    Serialize,
472    Deserialize,
473    typeuri::Named
474)]
475pub enum BindSpec {
476    /// Bind to any available address for the given transport.
477    Any(ChannelTransport),
478
479    /// Bind to a specific channel address.
480    Addr(ChannelAddr),
481}
482
483impl BindSpec {
484    /// Return an "any" address for this bind spec.
485    pub fn binding_addr(&self) -> ChannelAddr {
486        match self {
487            BindSpec::Any(transport) => ChannelAddr::any(transport.clone()),
488            BindSpec::Addr(addr) => addr.clone(),
489        }
490    }
491}
492
493impl From<ChannelTransport> for BindSpec {
494    fn from(transport: ChannelTransport) -> Self {
495        BindSpec::Any(transport)
496    }
497}
498
499impl From<ChannelAddr> for BindSpec {
500    fn from(addr: ChannelAddr) -> Self {
501        BindSpec::Addr(addr)
502    }
503}
504
505impl fmt::Display for BindSpec {
506    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
507        match self {
508            Self::Any(transport) => write!(f, "{}", transport),
509            Self::Addr(addr) => write!(f, "{}", addr),
510        }
511    }
512}
513
514impl FromStr for BindSpec {
515    type Err = anyhow::Error;
516
517    fn from_str(s: &str) -> Result<Self, Self::Err> {
518        if let Ok(transport) = ChannelTransport::from_str(s) {
519            Ok(BindSpec::Any(transport))
520        } else if let Ok(addr) = ChannelAddr::from_zmq_url(s) {
521            Ok(BindSpec::Addr(addr))
522        } else if let Ok(addr) = ChannelAddr::from_str(s) {
523            Ok(BindSpec::Addr(addr))
524        } else {
525            Err(anyhow::anyhow!("invalid bind spec: {}", s))
526        }
527    }
528}
529
530impl AttrValue for BindSpec {
531    fn display(&self) -> String {
532        self.to_string()
533    }
534
535    fn parse(s: &str) -> Result<Self, anyhow::Error> {
536        Self::from_str(s)
537    }
538}
539
540/// The type of (TCP) hostnames.
541pub type Hostname = String;
542
543/// The type of (TCP) ports.
544pub type Port = u16;
545
546/// The type of a channel address, used to multiplex different underlying
547/// channel implementations. ChannelAddrs also have a concrete syntax:
548/// the address type (e.g., "tcp" or "local"), followed by ":", and an address
549/// parseable to that type. For example:
550///
551/// - `tcp:127.0.0.1:1234` - localhost port 1234 over TCP
552/// - `tcp:192.168.0.1:1111` - 192.168.0.1 port 1111 over TCP
553/// - `local:123` - the (in-process) local port 123
554/// - `unix:/some/path` - the Unix socket at `/some/path`
555///
556/// Both local and TCP ports 0 are reserved to indicate "any available
557/// port" when serving.
558///
559/// ```
560/// # use hyperactor::channel::ChannelAddr;
561/// let addr: ChannelAddr = "tcp:127.0.0.1:1234".parse().unwrap();
562/// let ChannelAddr::Tcp(socket_addr) = addr else {
563///     panic!()
564/// };
565/// assert_eq!(socket_addr.port(), 1234);
566/// assert_eq!(socket_addr.is_ipv4(), true);
567/// ```
568#[derive(
569    Clone,
570    Debug,
571    PartialEq,
572    Eq,
573    Ord,
574    PartialOrd,
575    Serialize,
576    Deserialize,
577    Hash,
578    typeuri::Named
579)]
580pub enum ChannelAddr {
581    /// A socket address used to establish TCP channels. Supports
582    /// both  IPv4 and IPv6 address / port pairs.
583    Tcp(SocketAddr),
584
585    /// An address to establish TCP channels with TLS support within Meta.
586    /// Uses TlsAddr with hostname and port.
587    MetaTls(TlsAddr),
588
589    /// An address to establish TCP channels with configurable TLS support.
590    /// Uses TlsAddr with hostname and port.
591    Tls(TlsAddr),
592
593    /// Local addresses are registered in-process and given an integral
594    /// index.
595    Local(u64),
596
597    /// A unix domain socket address. Supports both absolute path names as
598    ///  well as "abstract" names per https://manpages.debian.org/unstable/manpages/unix.7.en.html#Abstract_sockets
599    Unix(net::unix::SocketAddr),
600
601    /// A pair of addresses, one for the client and one for the server:
602    ///   - The client should dial to the `dial_to` address.
603    ///   - The server should bind to the `bind_to` address.
604    ///
605    /// The user is responsible for ensuring the traffic to the `dial_to` address
606    /// is routed to the `bind_to` address.
607    ///
608    /// This is useful for scenarios where the network is configured in a way,
609    /// that the bound address is not directly accessible from the client.
610    ///
611    /// For example, in AWS, the client could be provided with the public IP
612    /// address, yet the server is bound to a private IP address or simply
613    /// INADDR_ANY. Traffic to the public IP address is mapped to the private
614    /// IP address through network address translation (NAT).
615    Alias {
616        /// The address to which the client should dial to.
617        dial_to: Box<ChannelAddr>,
618        /// The address to which the server should bind to.
619        bind_to: Box<ChannelAddr>,
620    },
621}
622
623impl From<SocketAddr> for ChannelAddr {
624    fn from(value: SocketAddr) -> Self {
625        Self::Tcp(value)
626    }
627}
628
629impl From<net::unix::SocketAddr> for ChannelAddr {
630    fn from(value: net::unix::SocketAddr) -> Self {
631        Self::Unix(value)
632    }
633}
634
635impl From<std::os::unix::net::SocketAddr> for ChannelAddr {
636    fn from(value: std::os::unix::net::SocketAddr) -> Self {
637        Self::Unix(net::unix::SocketAddr::new(value))
638    }
639}
640
641impl From<tokio::net::unix::SocketAddr> for ChannelAddr {
642    fn from(value: tokio::net::unix::SocketAddr) -> Self {
643        std::os::unix::net::SocketAddr::from(value).into()
644    }
645}
646
647/// Return the first non-link-local address from a list.
648fn find_routable_address(addresses: &[IpAddr]) -> Option<IpAddr> {
649    addresses
650        .iter()
651        .find(|addr| match addr {
652            IpAddr::V6(v6) => !v6.is_unicast_link_local(),
653            IpAddr::V4(v4) => !v4.is_link_local(),
654        })
655        .cloned()
656}
657
658impl ChannelAddr {
659    /// The "any" address for the given transport type. This is used to
660    /// servers to "any" address.
661    pub fn any(transport: ChannelTransport) -> Self {
662        match transport {
663            ChannelTransport::Tcp(mode) => {
664                let ip = match mode {
665                    TcpMode::Localhost => IpAddr::V6(Ipv6Addr::LOCALHOST),
666                    TcpMode::Hostname => {
667                        hostname::get()
668                            .ok()
669                            .and_then(|hostname| {
670                                // TODO: Avoid using DNS directly once we figure out a good extensibility story here
671                                hostname.to_str().and_then(|hostname_str| {
672                                    dns_lookup::lookup_host(hostname_str)
673                                        .ok()
674                                        .and_then(|addresses| find_routable_address(&addresses))
675                                })
676                            })
677                            .expect("failed to resolve hostname to ip address")
678                    }
679                };
680                Self::Tcp(SocketAddr::new(ip, 0))
681            }
682            ChannelTransport::MetaTls(mode) => {
683                let host_address = match mode {
684                    TlsMode::Hostname => hostname::get()
685                        .ok()
686                        .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
687                        .unwrap_or("unknown_host".to_string()),
688                    TlsMode::IpV6 => {
689                        get_host_ipv6_address().expect("failed to retrieve ipv6 address")
690                    }
691                };
692                Self::MetaTls(TlsAddr::new(host_address, 0))
693            }
694            ChannelTransport::Local => Self::Local(0),
695            ChannelTransport::Tls => {
696                let host_address = hostname::get()
697                    .ok()
698                    .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
699                    .unwrap_or("localhost".to_string());
700                Self::Tls(TlsAddr::new(host_address, 0))
701            }
702            // This works because the file will be deleted but we know we have a unique file by this point.
703            ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()),
704        }
705    }
706
707    /// The transport used by this address.
708    pub fn transport(&self) -> ChannelTransport {
709        match self {
710            Self::Tcp(addr) => {
711                if addr.ip().is_loopback() {
712                    ChannelTransport::Tcp(TcpMode::Localhost)
713                } else {
714                    ChannelTransport::Tcp(TcpMode::Hostname)
715                }
716            }
717            Self::MetaTls(addr) => match addr.hostname.parse::<IpAddr>() {
718                Ok(IpAddr::V6(_)) => ChannelTransport::MetaTls(TlsMode::IpV6),
719                Ok(IpAddr::V4(_)) => ChannelTransport::MetaTls(TlsMode::Hostname),
720                Err(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
721            },
722            Self::Tls(_) => ChannelTransport::Tls,
723            Self::Local(_) => ChannelTransport::Local,
724            Self::Unix(_) => ChannelTransport::Unix,
725            // bind_to's transport is what is actually used in communication.
726            // Therefore we use its transport to represent the Alias.
727            Self::Alias { bind_to, .. } => bind_to.transport(),
728        }
729    }
730}
731
732#[cfg(fbcode_build)]
733fn get_host_ipv6_address() -> anyhow::Result<String> {
734    crate::meta::host_ip::host_ipv6_address()
735}
736
737#[cfg(not(fbcode_build))]
738fn get_host_ipv6_address() -> anyhow::Result<String> {
739    Ok(local_ip_address::local_ipv6()?.to_string())
740}
741
742impl fmt::Display for ChannelAddr {
743    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
744        match self {
745            Self::Tcp(addr) => write!(f, "tcp:{}", addr),
746            Self::MetaTls(addr) => write!(f, "metatls:{}", addr),
747            Self::Tls(addr) => write!(f, "tls:{}", addr),
748            Self::Local(index) => write!(f, "local:{}", index),
749            Self::Unix(addr) => write!(f, "unix:{}", addr),
750            Self::Alias { dial_to, bind_to } => {
751                write!(f, "alias:dial_to={};bind_to={}", dial_to, bind_to)
752            }
753        }
754    }
755}
756
757impl FromStr for ChannelAddr {
758    type Err = anyhow::Error;
759
760    fn from_str(addr: &str) -> Result<Self, Self::Err> {
761        match addr.split_once('!').or_else(|| addr.split_once(':')) {
762            Some(("local", rest)) => rest
763                .parse::<u64>()
764                .map(Self::Local)
765                .map_err(anyhow::Error::from),
766            Some(("tcp", rest)) => rest
767                .parse::<SocketAddr>()
768                .map(Self::Tcp)
769                .map_err(anyhow::Error::from),
770            Some(("metatls", rest)) => net::meta::parse(rest).map_err(|e| e.into()),
771            Some(("tls", rest)) => net::tls::parse(rest).map_err(|e| e.into()),
772            Some(("unix", rest)) => Ok(Self::Unix(net::unix::SocketAddr::from_str(rest)?)),
773            Some(("alias", _)) => Err(anyhow::anyhow!(
774                "detect possible alias address, but we currently do not support \
775                parsing alias' string representation since we only want to \
776                support parsing its zmq url format."
777            )),
778            Some((r#type, _)) => Err(anyhow::anyhow!("no such channel type: {type}")),
779            None => Err(anyhow::anyhow!("no channel type specified")),
780        }
781    }
782}
783
784/// Normalize a host string. If the host is an IP address, parse and
785/// re-format it to produce a canonical string representation.
786pub(crate) fn normalize_host(host: &str) -> String {
787    // Strip URI-style brackets (e.g., "[::1]") because IpAddr::from_str
788    // rejects them — it only accepts bare addresses.
789    let host_clean = host
790        .strip_prefix('[')
791        .and_then(|h| h.strip_suffix(']'))
792        .unwrap_or(host);
793
794    if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
795        ip_addr.to_string()
796    } else {
797        host.to_string()
798    }
799}
800
801impl ChannelAddr {
802    /// Parse ZMQ-style URL format: scheme://address
803    /// Supports:
804    /// - tcp://hostname:port or tcp://*:port (wildcard binding)
805    /// - inproc://endpoint-name (equivalent to local)
806    /// - ipc://path (equivalent to unix)
807    /// - metatls://hostname:port or metatls://*:port
808    /// - Alias format: dial_to_url@bind_to_url (e.g., tcp://host:port@tcp://host:port)
809    ///   Note: Alias format is currently only supported for TCP addresses
810    pub fn from_zmq_url(address: &str) -> Result<Self, anyhow::Error> {
811        // Check for Alias format: dial_to_url@bind_to_url
812        // The @ character separates two valid ZMQ URLs
813        if let Some(at_pos) = address.find('@') {
814            let dial_to_str = &address[..at_pos];
815            let bind_to_str = &address[at_pos + 1..];
816
817            // Validate that both addresses use TCP scheme
818            if !dial_to_str.starts_with("tcp://") {
819                return Err(anyhow::anyhow!(
820                    "alias format is only supported for TCP addresses, got dial_to: {}",
821                    dial_to_str
822                ));
823            }
824            if !bind_to_str.starts_with("tcp://") {
825                return Err(anyhow::anyhow!(
826                    "alias format is only supported for TCP addresses, got bind_to: {}",
827                    bind_to_str
828                ));
829            }
830
831            let dial_to = Self::from_zmq_url(dial_to_str)?;
832            let bind_to = Self::from_zmq_url(bind_to_str)?;
833
834            return Ok(Self::Alias {
835                dial_to: Box::new(dial_to),
836                bind_to: Box::new(bind_to),
837            });
838        }
839
840        // Try ZMQ-style URL format first (scheme://...)
841        let (scheme, address) = address.split_once("://").ok_or_else(|| {
842            anyhow::anyhow!("address must be in url form scheme://endppoint {}", address)
843        })?;
844
845        match scheme {
846            "tcp" => {
847                let (host, port) = Self::split_host_port(address)?;
848
849                if host == "*" {
850                    // Wildcard binding - use IPv6 unspecified address
851                    Ok(Self::Tcp(SocketAddr::new("::".parse().unwrap(), port)))
852                } else {
853                    // Resolve hostname to IP address for proper SocketAddr creation
854                    let socket_addr = Self::resolve_hostname_to_socket_addr(host, port)?;
855                    Ok(Self::Tcp(socket_addr))
856                }
857            }
858            "inproc" => {
859                // inproc://port -> local:port
860                // Port must be a valid u64 number
861                let port = address.parse::<u64>().map_err(|_| {
862                    anyhow::anyhow!("inproc endpoint must be a valid port number: {}", address)
863                })?;
864                Ok(Self::Local(port))
865            }
866            "ipc" => {
867                // ipc://path -> unix:path
868                Ok(Self::Unix(net::unix::SocketAddr::from_str(address)?))
869            }
870            "metatls" => {
871                let (host, port) = Self::split_host_port(address)?;
872
873                if host == "*" {
874                    // Wildcard binding - use IPv6 unspecified address directly without hostname resolution
875                    Ok(Self::MetaTls(TlsAddr::new(
876                        std::net::Ipv6Addr::UNSPECIFIED.to_string(),
877                        port,
878                    )))
879                } else {
880                    Ok(Self::MetaTls(TlsAddr::new(host, port)))
881                }
882            }
883            "tls" => {
884                let (host, port) = Self::split_host_port(address)?;
885
886                if host == "*" {
887                    // Wildcard binding - use IPv6 unspecified address directly without hostname resolution
888                    Ok(Self::Tls(TlsAddr::new(
889                        std::net::Ipv6Addr::UNSPECIFIED.to_string(),
890                        port,
891                    )))
892                } else {
893                    Ok(Self::Tls(TlsAddr::new(host, port)))
894                }
895            }
896            scheme => Err(anyhow::anyhow!("unsupported ZMQ scheme: {}", scheme)),
897        }
898    }
899
900    /// Split host:port string, supporting IPv6 addresses
901    fn split_host_port(address: &str) -> Result<(&str, u16), anyhow::Error> {
902        if let Some((host, port_str)) = address.rsplit_once(':') {
903            let port: u16 = port_str
904                .parse()
905                .map_err(|_| anyhow::anyhow!("invalid port: {}", port_str))?;
906            Ok((host, port))
907        } else {
908            Err(anyhow::anyhow!("invalid address format: {}", address))
909        }
910    }
911
912    /// Render as a ZMQ-style URL, the inverse of [`from_zmq_url`](Self::from_zmq_url).
913    pub fn to_zmq_url(&self) -> String {
914        match self {
915            Self::Tcp(addr) => format!("tcp://{}", addr),
916            Self::MetaTls(addr) => format!("metatls://{}:{}", addr.hostname, addr.port),
917            Self::Tls(addr) => format!("tls://{}:{}", addr.hostname, addr.port),
918            Self::Local(index) => format!("inproc://{}", index),
919            Self::Unix(addr) => format!("ipc://{}", addr),
920            Self::Alias { dial_to, bind_to } => {
921                format!("{}@{}", dial_to.to_zmq_url(), bind_to.to_zmq_url())
922            }
923        }
924    }
925
926    /// Resolve hostname to SocketAddr, handling both IP addresses and hostnames
927    fn resolve_hostname_to_socket_addr(host: &str, port: u16) -> Result<SocketAddr, anyhow::Error> {
928        // Handle IPv6 addresses in brackets by stripping the brackets
929        let host_clean = if host.starts_with('[') && host.ends_with(']') {
930            &host[1..host.len() - 1]
931        } else {
932            host
933        };
934
935        // First try to parse as an IP address directly
936        if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
937            return Ok(SocketAddr::new(ip_addr, port));
938        }
939
940        // If not an IP, try hostname resolution
941        use std::net::ToSocketAddrs;
942        let mut addrs = (host_clean, port)
943            .to_socket_addrs()
944            .map_err(|e| anyhow::anyhow!("failed to resolve hostname '{}': {}", host_clean, e))?;
945
946        addrs
947            .next()
948            .ok_or_else(|| anyhow::anyhow!("no addresses found for hostname '{}'", host_clean))
949    }
950}
951
952/// Universal channel transmitter.
953pub struct ChannelTx<M: RemoteMessage> {
954    inner: ChannelTxKind<M>,
955}
956
957impl<M: RemoteMessage> fmt::Debug for ChannelTx<M> {
958    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
959        f.debug_struct("ChannelTx")
960            .field("addr", &self.addr())
961            .finish()
962    }
963}
964
965/// Universal channel transmitter.
966enum ChannelTxKind<M: RemoteMessage> {
967    Local(local::LocalTx<M>),
968    Net(net::NetTx<M>),
969}
970
971#[async_trait]
972impl<M: RemoteMessage> Tx<M> for ChannelTx<M> {
973    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
974        match &self.inner {
975            ChannelTxKind::Local(tx) => tx.do_post(message, return_channel),
976            ChannelTxKind::Net(tx) => tx.do_post(message, return_channel),
977        }
978    }
979
980    fn addr(&self) -> ChannelAddr {
981        match &self.inner {
982            ChannelTxKind::Local(tx) => tx.addr(),
983            ChannelTxKind::Net(tx) => Tx::<M>::addr(tx),
984        }
985    }
986
987    fn status(&self) -> &watch::Receiver<TxStatus> {
988        match &self.inner {
989            ChannelTxKind::Local(tx) => tx.status(),
990            ChannelTxKind::Net(tx) => tx.status(),
991        }
992    }
993}
994
995/// Universal channel receiver.
996pub struct ChannelRx<M: RemoteMessage> {
997    inner: ChannelRxKind<M>,
998}
999
1000impl<M: RemoteMessage> fmt::Debug for ChannelRx<M> {
1001    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1002        f.debug_struct("ChannelRx")
1003            .field("addr", &self.addr())
1004            .finish()
1005    }
1006}
1007
1008/// Universal channel receiver.
1009enum ChannelRxKind<M: RemoteMessage> {
1010    Local(local::LocalRx<M>),
1011    Net(net::NetRx<M>),
1012}
1013
1014#[async_trait]
1015impl<M: RemoteMessage> Rx<M> for ChannelRx<M> {
1016    #[tracing::instrument(level = "debug", skip_all)]
1017    async fn recv(&mut self) -> Result<M, ChannelError> {
1018        match &mut self.inner {
1019            ChannelRxKind::Local(rx) => rx.recv().await,
1020            ChannelRxKind::Net(rx) => rx.recv().await,
1021        }
1022    }
1023
1024    fn addr(&self) -> ChannelAddr {
1025        match &self.inner {
1026            ChannelRxKind::Local(rx) => rx.addr(),
1027            ChannelRxKind::Net(rx) => rx.addr(),
1028        }
1029    }
1030
1031    async fn join(self) {
1032        match self.inner {
1033            ChannelRxKind::Local(rx) => rx.join().await,
1034            ChannelRxKind::Net(rx) => rx.join().await,
1035        }
1036    }
1037}
1038
1039/// Dial the provided address, returning the corresponding Tx, or error
1040/// if the channel cannot be established. The underlying connection is
1041/// dropped whenever the returned Tx is dropped.
1042#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
1043#[track_caller]
1044pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, ChannelError> {
1045    tracing::debug!(name = "dial", caller = %Location::caller(), %addr, "dialing channel {}", addr);
1046    let inner = match addr {
1047        ChannelAddr::Local(port) => ChannelTxKind::Local(local::dial(port)?),
1048        ChannelAddr::Tcp(_)
1049        | ChannelAddr::Unix(_)
1050        | ChannelAddr::Tls(_)
1051        | ChannelAddr::MetaTls(_) => ChannelTxKind::Net(net::spawn(net::link(addr)?)),
1052        ChannelAddr::Alias { dial_to, .. } => dial(*dial_to)?.inner,
1053    };
1054    Ok(ChannelTx { inner })
1055}
1056
1057/// Serve on the provided channel address. The server is turned down
1058/// when the returned Rx is dropped.
1059#[track_caller]
1060pub fn serve<M: RemoteMessage>(
1061    addr: ChannelAddr,
1062) -> Result<(ChannelAddr, ChannelRx<M>), ChannelError> {
1063    let caller = Location::caller();
1064    serve_inner(addr).map(|(addr, inner)| {
1065        tracing::debug!(
1066            name = "serve",
1067            %addr,
1068            %caller,
1069        );
1070        (addr, ChannelRx { inner })
1071    })
1072}
1073
1074fn serve_inner<M: RemoteMessage>(
1075    addr: ChannelAddr,
1076) -> Result<(ChannelAddr, ChannelRxKind<M>), ChannelError> {
1077    match addr {
1078        ChannelAddr::Tcp(_)
1079        | ChannelAddr::Unix(_)
1080        | ChannelAddr::Tls(_)
1081        | ChannelAddr::MetaTls(_) => {
1082            let (addr, rx) = net::server::serve::<M>(addr)?;
1083            Ok((addr, ChannelRxKind::Net(rx)))
1084        }
1085        ChannelAddr::Local(0) => {
1086            let (port, rx) = local::serve::<M>();
1087            Ok((ChannelAddr::Local(port), ChannelRxKind::Local(rx)))
1088        }
1089        ChannelAddr::Local(a) => Err(ChannelError::InvalidAddress(format!(
1090            "invalid local addr: {}",
1091            a
1092        ))),
1093        ChannelAddr::Alias { dial_to, bind_to } => {
1094            let (bound_addr, rx) = serve_inner::<M>(*bind_to)?;
1095            let alias_addr = ChannelAddr::Alias {
1096                dial_to,
1097                bind_to: Box::new(bound_addr),
1098            };
1099            Ok((alias_addr, rx))
1100        }
1101    }
1102}
1103
1104/// Serve on the local address. The server is turned down
1105/// when the returned Rx is dropped.
1106pub fn serve_local<M: RemoteMessage>() -> (ChannelAddr, ChannelRx<M>) {
1107    let (port, rx) = local::serve::<M>();
1108    (
1109        ChannelAddr::Local(port),
1110        ChannelRx {
1111            inner: ChannelRxKind::Local(rx),
1112        },
1113    )
1114}
1115
1116#[cfg(test)]
1117mod tests {
1118    use std::assert_matches::assert_matches;
1119    use std::collections::HashSet;
1120    use std::net::IpAddr;
1121    use std::net::Ipv4Addr;
1122    use std::net::Ipv6Addr;
1123    use std::time::Duration;
1124
1125    use tokio::task::JoinSet;
1126
1127    use super::net::*;
1128    use super::*;
1129    #[test]
1130    fn test_channel_addr() {
1131        let cases_ok = vec![
1132            (
1133                "tcp<DELIM>[::1]:1234",
1134                ChannelAddr::Tcp(SocketAddr::new(
1135                    IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
1136                    1234,
1137                )),
1138            ),
1139            (
1140                "tcp<DELIM>127.0.0.1:8080",
1141                ChannelAddr::Tcp(SocketAddr::new(
1142                    IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
1143                    8080,
1144                )),
1145            ),
1146            #[cfg(target_os = "linux")]
1147            ("local<DELIM>123", ChannelAddr::Local(123)),
1148            (
1149                "unix<DELIM>@yolo",
1150                ChannelAddr::Unix(
1151                    unix::SocketAddr::from_abstract_name("yolo")
1152                        .expect("can't make socket from abstract name"),
1153                ),
1154            ),
1155            (
1156                "unix<DELIM>/cool/socket-path",
1157                ChannelAddr::Unix(
1158                    unix::SocketAddr::from_pathname("/cool/socket-path")
1159                        .expect("can't make socket from path"),
1160                ),
1161            ),
1162        ];
1163
1164        for (raw, parsed) in cases_ok {
1165            for delim in ["!", ":"] {
1166                let raw = raw.replace("<DELIM>", delim);
1167                assert_eq!(raw.parse::<ChannelAddr>().unwrap(), parsed);
1168            }
1169        }
1170
1171        let cases_err = vec![
1172            ("tcp:abcdef..123124", "invalid socket address syntax"),
1173            ("xxx:foo", "no such channel type: xxx"),
1174            ("127.0.0.1", "no channel type specified"),
1175            ("local:abc", "invalid digit found in string"),
1176        ];
1177
1178        for (raw, error) in cases_err {
1179            let Err(err) = raw.parse::<ChannelAddr>() else {
1180                panic!("expected error parsing: {}", &raw)
1181            };
1182            assert_eq!(format!("{}", err), error);
1183        }
1184    }
1185
1186    #[test]
1187    fn test_zmq_style_channel_addr() {
1188        // Test TCP addresses
1189        assert_eq!(
1190            ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080").unwrap(),
1191            ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap())
1192        );
1193
1194        // Test TCP wildcard binding
1195        assert_eq!(
1196            ChannelAddr::from_zmq_url("tcp://*:5555").unwrap(),
1197            ChannelAddr::Tcp("[::]:5555".parse().unwrap())
1198        );
1199
1200        // Test inproc (maps to local with numeric endpoint)
1201        assert_eq!(
1202            ChannelAddr::from_zmq_url("inproc://12345").unwrap(),
1203            ChannelAddr::Local(12345)
1204        );
1205
1206        // Test ipc (maps to unix)
1207        assert_eq!(
1208            ChannelAddr::from_zmq_url("ipc:///tmp/my-socket").unwrap(),
1209            ChannelAddr::Unix(unix::SocketAddr::from_pathname("/tmp/my-socket").unwrap())
1210        );
1211
1212        // Test metatls with hostname
1213        assert_eq!(
1214            ChannelAddr::from_zmq_url("metatls://example.com:443").unwrap(),
1215            ChannelAddr::MetaTls(TlsAddr::new("example.com", 443))
1216        );
1217
1218        // Test metatls with IP address (should be normalized)
1219        assert_eq!(
1220            ChannelAddr::from_zmq_url("metatls://192.168.1.1:443").unwrap(),
1221            ChannelAddr::MetaTls(TlsAddr::new("192.168.1.1", 443))
1222        );
1223
1224        // Test metatls with wildcard (should use IPv6 unspecified address)
1225        assert_eq!(
1226            ChannelAddr::from_zmq_url("metatls://*:8443").unwrap(),
1227            ChannelAddr::MetaTls(TlsAddr::new("::", 8443))
1228        );
1229
1230        // Test TCP hostname resolution (should resolve hostname to IP)
1231        // Note: This test may fail in environments without proper DNS resolution
1232        // We test that it at least doesn't fail to parse
1233        let tcp_hostname_result = ChannelAddr::from_zmq_url("tcp://localhost:8080");
1234        assert!(tcp_hostname_result.is_ok());
1235
1236        // Test IPv6 address
1237        assert_eq!(
1238            ChannelAddr::from_zmq_url("tcp://[::1]:1234").unwrap(),
1239            ChannelAddr::Tcp("[::1]:1234".parse().unwrap())
1240        );
1241
1242        // Test error cases
1243        assert!(ChannelAddr::from_zmq_url("invalid://scheme").is_err());
1244        assert!(ChannelAddr::from_zmq_url("tcp://invalid-port").is_err());
1245        assert!(ChannelAddr::from_zmq_url("metatls://no-port").is_err());
1246        assert!(ChannelAddr::from_zmq_url("inproc://not-a-number").is_err());
1247
1248        // IPv6 normalization: leading zeros are stripped
1249        assert_eq!(
1250            ChannelAddr::from_zmq_url("metatls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443")
1251                .unwrap(),
1252            ChannelAddr::MetaTls(TlsAddr::new("2a03:83e4:5000:c000:56d7:cf:75ce:144a", 443))
1253        );
1254
1255        // Short and long forms of the same IPv6 produce equal ChannelAddr values
1256        assert_eq!(
1257            ChannelAddr::from_zmq_url("metatls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443")
1258                .unwrap(),
1259            ChannelAddr::from_zmq_url("metatls://2a03:83e4:5000:c000:56d7:cf:75ce:144a:443")
1260                .unwrap(),
1261        );
1262
1263        // Bracketed IPv6 is normalized
1264        assert_eq!(
1265            ChannelAddr::from_zmq_url("metatls://[::1]:443").unwrap(),
1266            ChannelAddr::MetaTls(TlsAddr::new("::1", 443))
1267        );
1268
1269        // Same tests for tls://
1270        assert_eq!(
1271            ChannelAddr::from_zmq_url("tls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443").unwrap(),
1272            ChannelAddr::Tls(TlsAddr::new("2a03:83e4:5000:c000:56d7:cf:75ce:144a", 443))
1273        );
1274        assert_eq!(
1275            ChannelAddr::from_zmq_url("tls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443").unwrap(),
1276            ChannelAddr::from_zmq_url("tls://2a03:83e4:5000:c000:56d7:cf:75ce:144a:443").unwrap(),
1277        );
1278        assert_eq!(
1279            ChannelAddr::from_zmq_url("tls://[::1]:443").unwrap(),
1280            ChannelAddr::Tls(TlsAddr::new("::1", 443))
1281        );
1282    }
1283
1284    #[test]
1285    fn test_normalize_host() {
1286        // Plain IPv4 passes through
1287        assert_eq!(normalize_host("192.168.1.1"), "192.168.1.1");
1288
1289        // Plain hostname passes through
1290        assert_eq!(normalize_host("example.com"), "example.com");
1291
1292        // IPv6 with leading zeros gets normalized
1293        assert_eq!(
1294            normalize_host("2a03:83e4:5000:c000:56d7:00cf:75ce:144a"),
1295            "2a03:83e4:5000:c000:56d7:cf:75ce:144a"
1296        );
1297
1298        // Bracketed IPv6 is stripped and normalized
1299        assert_eq!(normalize_host("[::1]"), "::1");
1300
1301        // Without bracket stripping, IpAddr::from_str rejects bracketed
1302        // addresses. This demonstrates that the bracket stripping in
1303        // normalize_host is necessary.
1304        assert!("[::1]".parse::<IpAddr>().is_err());
1305    }
1306
1307    #[test]
1308    fn test_zmq_style_alias_channel_addr() {
1309        // Test Alias format: dial_to_url@bind_to_url
1310        // The format is: dial_to_url@bind_to_url where both are valid ZMQ URLs
1311        // Note: Alias format is only supported for TCP addresses
1312
1313        // Test Alias with tcp on both sides
1314        let alias_addr = ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::]:8800").unwrap();
1315        match alias_addr {
1316            ChannelAddr::Alias { dial_to, bind_to } => {
1317                assert_eq!(
1318                    *dial_to,
1319                    ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap())
1320                );
1321                assert_eq!(*bind_to, ChannelAddr::Tcp("[::]:8800".parse().unwrap()));
1322            }
1323            _ => panic!("Expected Alias"),
1324        }
1325
1326        // Test error: alias with non-tcp dial_to (not supported)
1327        assert!(
1328            ChannelAddr::from_zmq_url("metatls://example.com:443@tcp://127.0.0.1:8080").is_err()
1329        );
1330
1331        // Test error: alias with non-tcp bind_to (not supported)
1332        assert!(
1333            ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@metatls://example.com:443").is_err()
1334        );
1335
1336        // Test error: invalid dial_to URL in Alias
1337        assert!(ChannelAddr::from_zmq_url("invalid://scheme@tcp://127.0.0.1:8080").is_err());
1338
1339        // Test error: invalid bind_to URL in Alias
1340        assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@invalid://scheme").is_err());
1341
1342        // Test error: missing port in dial_to
1343        assert!(ChannelAddr::from_zmq_url("tcp://host@tcp://127.0.0.1:8080").is_err());
1344
1345        // Test error: missing port in bind_to
1346        assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@tcp://example.com").is_err());
1347    }
1348
1349    #[tokio::test]
1350    async fn test_multiple_connections() {
1351        for addr in ChannelTransport::all().map(ChannelAddr::any) {
1352            let (listen_addr, mut rx) = crate::channel::serve::<u64>(addr).unwrap();
1353
1354            let mut sends: JoinSet<()> = JoinSet::new();
1355            for message in 0u64..100u64 {
1356                let addr = listen_addr.clone();
1357                sends.spawn(async move {
1358                    let tx = dial::<u64>(addr).unwrap();
1359                    tx.post(message);
1360                });
1361            }
1362
1363            let mut received: HashSet<u64> = HashSet::new();
1364            while received.len() < 100 {
1365                received.insert(rx.recv().await.unwrap());
1366            }
1367
1368            for message in 0u64..100u64 {
1369                assert!(received.contains(&message));
1370            }
1371
1372            loop {
1373                match sends.join_next().await {
1374                    Some(Ok(())) => (),
1375                    Some(Err(err)) => panic!("{}", err),
1376                    None => break,
1377                }
1378            }
1379        }
1380    }
1381
1382    #[tokio::test]
1383    async fn test_server_close() {
1384        for addr in ChannelTransport::all().map(ChannelAddr::any) {
1385            if net::is_net_addr(&addr) {
1386                // Net has store-and-forward semantics. We don't expect failures
1387                // on closure.
1388                continue;
1389            }
1390
1391            let (listen_addr, rx) = crate::channel::serve::<u64>(addr).unwrap();
1392
1393            let tx = dial::<u64>(listen_addr).unwrap();
1394            tx.post(123);
1395            drop(rx);
1396
1397            // New transmits should fail... but there is buffering, etc.,
1398            // which can cause the failure to be delayed. We give it
1399            // a deadline, but it can still technically fail -- the test
1400            // should be considered a kind of integration test.
1401            let start = tokio::time::Instant::now();
1402
1403            let result = loop {
1404                let (return_tx, return_rx) = oneshot::channel();
1405                tx.try_post(123, return_tx);
1406                let result = return_rx.await;
1407
1408                if result.is_ok() || start.elapsed() > Duration::from_secs(10) {
1409                    break result;
1410                }
1411            };
1412            assert_matches!(
1413                result,
1414                Ok(SendError {
1415                    error: ChannelError::Closed,
1416                    message: 123,
1417                    reason: None
1418                })
1419            );
1420        }
1421    }
1422
1423    fn addrs() -> Vec<ChannelAddr> {
1424        use rand::Rng;
1425        use rand::distributions::Uniform;
1426
1427        let rng = rand::thread_rng();
1428        vec![
1429            "tcp:[::1]:0".parse().unwrap(),
1430            "local:0".parse().unwrap(),
1431            #[cfg(target_os = "linux")]
1432            "unix:".parse().unwrap(),
1433            #[cfg(target_os = "linux")]
1434            format!(
1435                "unix:@{}",
1436                rng.sample_iter(Uniform::new_inclusive('a', 'z'))
1437                    .take(10)
1438                    .collect::<String>()
1439            )
1440            .parse()
1441            .unwrap(),
1442        ]
1443    }
1444
1445    #[test]
1446    fn test_bind_spec_from_str() {
1447        // Test parsing ChannelTransport strings -> BindSpec::Any
1448        assert_eq!(
1449            BindSpec::from_str("tcp").unwrap(),
1450            BindSpec::Any(ChannelTransport::Tcp(TcpMode::Hostname))
1451        );
1452        assert_eq!(
1453            BindSpec::from_str("metatls(Hostname)").unwrap(),
1454            BindSpec::Any(ChannelTransport::MetaTls(TlsMode::Hostname))
1455        );
1456
1457        // Test parsing ChannelAddr strings -> BindSpec::Addr
1458        assert_eq!(
1459            BindSpec::from_str("tcp:127.0.0.1:8080").unwrap(),
1460            BindSpec::Addr(ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap()))
1461        );
1462
1463        // Test parsing ZMQ URL format -> BindSpec::Addr
1464        assert_eq!(
1465            BindSpec::from_str("tcp://127.0.0.1:9000").unwrap(),
1466            BindSpec::Addr(ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap()))
1467        );
1468        assert_eq!(
1469            BindSpec::from_str("tcp://127.0.0.1:9000@tcp://[::1]:7200").unwrap(),
1470            BindSpec::Addr(
1471                ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::1]:7200").unwrap()
1472            )
1473        );
1474
1475        // Test error cases
1476        assert!(BindSpec::from_str("invalid_spec").is_err());
1477        assert!(BindSpec::from_str("unknown://scheme").is_err());
1478        assert!(BindSpec::from_str("").is_err());
1479    }
1480
1481    #[tokio::test]
1482    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Server(Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" }))
1483    #[cfg_attr(not(fbcode_build), ignore)]
1484    async fn test_dial_serve() {
1485        for addr in addrs() {
1486            let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1487            let tx = crate::channel::dial(listen_addr).unwrap();
1488            tx.post(123);
1489            assert_eq!(rx.recv().await.unwrap(), 123);
1490        }
1491    }
1492
1493    #[tokio::test]
1494    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Server(Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" }))
1495    #[cfg_attr(not(fbcode_build), ignore)]
1496    async fn test_send() {
1497        let config = hyperactor_config::global::lock();
1498
1499        // Use temporary config for this test
1500        let _guard1 = config.override_key(
1501            crate::config::MESSAGE_DELIVERY_TIMEOUT,
1502            Duration::from_secs(1),
1503        );
1504        let _guard2 = config.override_key(crate::config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
1505        for addr in addrs() {
1506            let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1507            let tx = crate::channel::dial(listen_addr).unwrap();
1508            tx.send(123).await.unwrap();
1509            assert_eq!(rx.recv().await.unwrap(), 123);
1510
1511            drop(rx);
1512            assert_matches!(
1513                tx.send(123).await.unwrap_err(),
1514                SendError {
1515                    error: ChannelError::Closed,
1516                    message: 123,
1517                    ..
1518                }
1519            );
1520        }
1521    }
1522
1523    #[test]
1524    fn test_find_routable_address_skips_link_local_ipv6() {
1525        let link_local_v6: IpAddr = "fe80::1".parse().unwrap();
1526        let routable_v6: IpAddr = "2001:db8::1".parse().unwrap();
1527        let addrs = vec![link_local_v6, routable_v6];
1528        assert_eq!(find_routable_address(&addrs), Some(routable_v6));
1529    }
1530
1531    #[test]
1532    fn test_find_routable_address_skips_link_local_ipv4() {
1533        let link_local_v4: IpAddr = "169.254.1.1".parse().unwrap();
1534        let routable_v4: IpAddr = "192.168.1.1".parse().unwrap();
1535        let addrs = vec![link_local_v4, routable_v4];
1536        assert_eq!(find_routable_address(&addrs), Some(routable_v4));
1537    }
1538
1539    #[test]
1540    fn test_find_routable_address_returns_none_when_all_link_local() {
1541        let link_local_v6: IpAddr = "fe80::1".parse().unwrap();
1542        let link_local_v4: IpAddr = "169.254.1.1".parse().unwrap();
1543        let addrs = vec![link_local_v6, link_local_v4];
1544        assert_eq!(find_routable_address(&addrs), None);
1545    }
1546
1547    #[test]
1548    fn test_find_routable_address_mixed() {
1549        let link_local_v6: IpAddr = "fe80::1".parse().unwrap();
1550        let link_local_v4: IpAddr = "169.254.0.1".parse().unwrap();
1551        let routable_v4: IpAddr = "10.0.0.1".parse().unwrap();
1552        let routable_v6: IpAddr = "2001:db8::2".parse().unwrap();
1553
1554        // First routable address in list order should be returned.
1555        let addrs = vec![link_local_v6, link_local_v4, routable_v4, routable_v6];
1556        assert_eq!(find_routable_address(&addrs), Some(routable_v4));
1557    }
1558}