hyperactor/
channel.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! One-way, multi-process, typed communication channels. These are used
10//! to send messages between mailboxes residing in different processes.
11
12use core::net::SocketAddr;
13use std::fmt;
14use std::net::IpAddr;
15use std::net::Ipv6Addr;
16#[cfg(target_os = "linux")]
17use std::os::linux::net::SocketAddrExt;
18use std::panic::Location;
19use std::str::FromStr;
20use std::sync::Arc;
21
22use async_trait::async_trait;
23use enum_as_inner::EnumAsInner;
24use hyperactor_config::attrs::AttrValue;
25use serde::Deserialize;
26use serde::Serialize;
27use tokio::sync::mpsc;
28use tokio::sync::oneshot;
29use tokio::sync::watch;
30
31use crate as hyperactor;
32use crate::RemoteMessage;
33pub(crate) mod local;
34pub(crate) mod net;
35
36// Public TLS API for HTTP services (mesh admin, TUI, etc.). The
37// implementation lives in `net` but we re-export here to keep `net`'s
38// internal types out of the public API surface.
39pub use net::try_tls_acceptor;
40pub use net::try_tls_connector;
41pub use net::try_tls_pem_bundle;
42
43/// Duplex channel API: a single connection carries messages in both directions.
44pub mod duplex {
45    pub use super::net::duplex::DuplexRx;
46    pub use super::net::duplex::DuplexServer;
47    pub use super::net::duplex::DuplexTx;
48    pub use super::net::duplex::dial;
49    pub use super::net::duplex::serve;
50}
51
52/// The type of error that can occur on channel operations.
53#[derive(thiserror::Error, Debug)]
54pub enum ChannelError {
55    /// An operation was attempted on a closed channel.
56    #[error("channel closed")]
57    Closed,
58
59    /// An error occurred during send.
60    #[error("send: {0}")]
61    Send(#[source] anyhow::Error),
62
63    /// A network client error.
64    #[error(transparent)]
65    Client(#[from] net::ClientError),
66
67    /// The address was not valid.
68    #[error("invalid address {0:?}")]
69    InvalidAddress(String),
70
71    /// A serving error was encountered.
72    #[error(transparent)]
73    Server(#[from] net::ServerError),
74
75    /// A bincode serialization or deserialization error occurred.
76    #[error(transparent)]
77    Bincode(#[from] Box<bincode::ErrorKind>),
78
79    /// Data encoding errors.
80    #[error(transparent)]
81    Data(#[from] wirevalue::Error),
82
83    /// Some other error.
84    #[error(transparent)]
85    Other(#[from] anyhow::Error),
86
87    /// An operation timeout occurred.
88    #[error("operation timed out after {0:?}")]
89    Timeout(std::time::Duration),
90}
91
92/// An error that occurred during send. Returns the message that failed to send.
93#[derive(thiserror::Error, Debug)]
94#[error("{error} for reason {reason:?}")]
95pub struct SendError<M: RemoteMessage> {
96    /// Inner channel error
97    #[source]
98    pub error: ChannelError,
99    /// Message that couldn't be sent
100    pub message: M,
101    /// Reason that message couldn't be sent, if any.
102    pub reason: Option<String>,
103}
104
105impl<M: RemoteMessage> From<SendError<M>> for ChannelError {
106    fn from(error: SendError<M>) -> Self {
107        error.error
108    }
109}
110
111/// The possible states of a `Tx`.
112#[derive(Debug, Clone, PartialEq, EnumAsInner)]
113pub enum TxStatus {
114    /// The tx is good.
115    Active,
116    /// The tx cannot be used for message delivery.
117    Closed(Arc<str>),
118}
119
120/// The transmit end of an M-typed channel.
121#[async_trait]
122pub trait Tx<M: RemoteMessage> {
123    /// Post a message; returning failed deliveries on the return channel, if provided.
124    /// If provided, the sender is dropped when the message has been
125    /// enqueued at the channel endpoint.
126    ///
127    /// Users should use the `try_post`, and `post` variants directly.
128    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>);
129
130    /// Enqueue a `message` on the local end of the channel. The
131    /// message is either delivered, or we eventually discover that
132    /// the channel has failed and it will be sent back on `return_channel`.
133    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SendError`.
134    #[tracing::instrument(level = "debug", skip_all)]
135    fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
136        self.do_post(message, Some(return_channel));
137    }
138
139    /// Enqueue a message to be sent on the channel.
140    #[hyperactor::instrument_infallible]
141    fn post(&self, message: M) {
142        self.do_post(message, None);
143    }
144
145    /// Send a message synchronously, returning when the message has
146    /// been delivered to the remote end of the channel.
147    async fn send(&self, message: M) -> Result<(), SendError<M>> {
148        let (tx, rx) = oneshot::channel();
149        self.try_post(message, tx);
150        match rx.await {
151            // Channel was closed; the message was not delivered.
152            Ok(err) => Err(err),
153
154            // Channel was dropped; the message was successfully enqueued
155            // on the remote end of the channel.
156            Err(_) => Ok(()),
157        }
158    }
159
160    /// The channel address to which this Tx is sending.
161    fn addr(&self) -> ChannelAddr;
162
163    /// A means to monitor the health of a `Tx`.
164    fn status(&self) -> &watch::Receiver<TxStatus>;
165}
166
167/// The receive end of an M-typed channel.
168#[async_trait]
169pub trait Rx<M: RemoteMessage> {
170    /// Receive the next message from the channel. If the channel returns
171    /// an error it is considered broken and should be discarded.
172    async fn recv(&mut self) -> Result<M, ChannelError>;
173
174    /// The channel address from which this Rx is receiving.
175    fn addr(&self) -> ChannelAddr;
176
177    /// Gracefully shut down the channel receiver, flushing any pending
178    /// acks before returning. Implementations must ensure all pending
179    /// acks are sent before this method returns.
180    async fn join(self)
181    where
182        Self: Sized;
183}
184
185#[allow(dead_code)] // Not used outside tests.
186struct MpscTx<M: RemoteMessage> {
187    tx: mpsc::UnboundedSender<M>,
188    addr: ChannelAddr,
189    status: watch::Receiver<TxStatus>,
190}
191
192impl<M: RemoteMessage> MpscTx<M> {
193    #[allow(dead_code)] // Not used outside tests.
194    pub fn new(tx: mpsc::UnboundedSender<M>, addr: ChannelAddr) -> (Self, watch::Sender<TxStatus>) {
195        let (sender, receiver) = watch::channel(TxStatus::Active);
196        (
197            Self {
198                tx,
199                addr,
200                status: receiver,
201            },
202            sender,
203        )
204    }
205}
206
207#[async_trait]
208impl<M: RemoteMessage> Tx<M> for MpscTx<M> {
209    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
210        if let Err(mpsc::error::SendError(message)) = self.tx.send(message) {
211            if let Some(return_channel) = return_channel {
212                return_channel
213                    .send(SendError {
214                        error: ChannelError::Closed,
215                        message,
216                        reason: None,
217                    })
218                    .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
219            }
220        }
221    }
222
223    fn addr(&self) -> ChannelAddr {
224        self.addr.clone()
225    }
226
227    fn status(&self) -> &watch::Receiver<TxStatus> {
228        &self.status
229    }
230}
231
232#[allow(dead_code)] // Not used outside tests.
233struct MpscRx<M: RemoteMessage> {
234    rx: mpsc::UnboundedReceiver<M>,
235    addr: ChannelAddr,
236    // Used to report the status to the Tx side.
237    status_sender: watch::Sender<TxStatus>,
238}
239
240impl<M: RemoteMessage> MpscRx<M> {
241    #[allow(dead_code)] // Not used outside tests.
242    pub fn new(
243        rx: mpsc::UnboundedReceiver<M>,
244        addr: ChannelAddr,
245        status_sender: watch::Sender<TxStatus>,
246    ) -> Self {
247        Self {
248            rx,
249            addr,
250            status_sender,
251        }
252    }
253}
254
255impl<M: RemoteMessage> Drop for MpscRx<M> {
256    fn drop(&mut self) {
257        let _ = self
258            .status_sender
259            .send(TxStatus::Closed("receiver dropped".into()));
260    }
261}
262
263#[async_trait]
264impl<M: RemoteMessage> Rx<M> for MpscRx<M> {
265    async fn recv(&mut self) -> Result<M, ChannelError> {
266        self.rx.recv().await.ok_or(ChannelError::Closed)
267    }
268
269    fn addr(&self) -> ChannelAddr {
270        self.addr.clone()
271    }
272
273    async fn join(self) {}
274}
275
276/// The hostname to use for TLS connections.
277#[derive(
278    Clone,
279    Debug,
280    PartialEq,
281    Eq,
282    Hash,
283    Serialize,
284    Deserialize,
285    strum::EnumIter,
286    strum::Display,
287    strum::EnumString
288)]
289pub enum TcpMode {
290    /// Use localhost/loopback for the connection.
291    Localhost,
292    /// Use host domain name for the connection.
293    Hostname,
294}
295
296/// The hostname to use for TLS connections.
297#[derive(
298    Clone,
299    Debug,
300    PartialEq,
301    Eq,
302    Hash,
303    Serialize,
304    Deserialize,
305    strum::EnumIter,
306    strum::Display,
307    strum::EnumString
308)]
309pub enum TlsMode {
310    /// Use IpV6 address for TLS connections.
311    IpV6,
312    /// Use host domain name for TLS connections.
313    Hostname,
314    // TODO: consider adding IpV4 support.
315}
316
317/// Address format for TLS channels.
318#[derive(
319    Clone,
320    Debug,
321    PartialEq,
322    Eq,
323    Hash,
324    Serialize,
325    Deserialize,
326    Ord,
327    PartialOrd
328)]
329pub struct TlsAddr {
330    /// The hostname to connect to.
331    pub hostname: Hostname,
332    /// The port to connect to.
333    pub port: Port,
334}
335
336impl TlsAddr {
337    /// Creates a new TLS address with a normalized hostname.
338    pub fn new(hostname: impl Into<Hostname>, port: Port) -> Self {
339        Self {
340            hostname: normalize_host(&hostname.into()),
341            port,
342        }
343    }
344
345    /// Returns the port number for this address.
346    pub fn port(&self) -> Port {
347        self.port
348    }
349
350    /// Returns the hostname for this address.
351    pub fn hostname(&self) -> &str {
352        &self.hostname
353    }
354}
355
356impl fmt::Display for TlsAddr {
357    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358        write!(f, "{}:{}", self.hostname, self.port)
359    }
360}
361
362/// Types of channel transports.
363#[derive(
364    Clone,
365    Debug,
366    PartialEq,
367    Eq,
368    Hash,
369    Serialize,
370    Deserialize,
371    typeuri::Named
372)]
373pub enum ChannelTransport {
374    /// Transport over a TCP connection.
375    Tcp(TcpMode),
376
377    /// Transport over a TCP connection with TLS support within Meta
378    MetaTls(TlsMode),
379
380    /// Transport over a TCP connection with configurable TLS support
381    Tls,
382
383    /// Local transports uses an in-process registry and mpsc channels.
384    Local,
385
386    /// Transport over unix domain socket.
387    Unix,
388}
389
390impl fmt::Display for ChannelTransport {
391    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
392        match self {
393            Self::Tcp(mode) => write!(f, "tcp({:?})", mode),
394            Self::MetaTls(mode) => write!(f, "metatls({:?})", mode),
395            Self::Tls => write!(f, "tls"),
396            Self::Local => write!(f, "local"),
397            Self::Unix => write!(f, "unix"),
398        }
399    }
400}
401
402impl FromStr for ChannelTransport {
403    type Err = anyhow::Error;
404
405    fn from_str(s: &str) -> Result<Self, Self::Err> {
406        match s {
407            // Default to TcpMode::Hostname, if the mode isn't set
408            "tcp" => Ok(ChannelTransport::Tcp(TcpMode::Hostname)),
409            s if s.starts_with("tcp(") => {
410                let inner = &s["tcp(".len()..s.len() - 1];
411                let mode = inner.parse()?;
412                Ok(ChannelTransport::Tcp(mode))
413            }
414            "local" => Ok(ChannelTransport::Local),
415            "unix" => Ok(ChannelTransport::Unix),
416            "tls" => Ok(ChannelTransport::Tls),
417            s if s.starts_with("metatls(") && s.ends_with(")") => {
418                let inner = &s["metatls(".len()..s.len() - 1];
419                let mode = inner.parse()?;
420                Ok(ChannelTransport::MetaTls(mode))
421            }
422            unknown => Err(anyhow::anyhow!("unknown channel transport: {}", unknown)),
423        }
424    }
425}
426
427impl ChannelTransport {
428    /// All known channel transports.
429    pub fn all() -> [ChannelTransport; 3] {
430        [
431            // TODO: @rusch add back once figuring out unspecified override for OSS CI
432            // ChannelTransport::Tcp(TcpMode::Localhost),
433            ChannelTransport::Tcp(TcpMode::Hostname),
434            ChannelTransport::Local,
435            ChannelTransport::Unix,
436            // Tls requires certificate configuration, tested separately in tls::tests
437            // TODO add MetaTls (T208303369)
438        ]
439    }
440
441    /// Return an "any" address for this transport.
442    pub fn any(&self) -> ChannelAddr {
443        ChannelAddr::any(self.clone())
444    }
445
446    /// Returns true if this transport type represents a remote channel.
447    pub fn is_remote(&self) -> bool {
448        match self {
449            ChannelTransport::Tcp(_) => true,
450            ChannelTransport::MetaTls(_) => true,
451            ChannelTransport::Tls => true,
452            ChannelTransport::Local => false,
453            ChannelTransport::Unix => false,
454        }
455    }
456}
457
458impl AttrValue for ChannelTransport {
459    fn display(&self) -> String {
460        self.to_string()
461    }
462
463    fn parse(s: &str) -> Result<Self, anyhow::Error> {
464        s.parse()
465    }
466}
467
468/// Specifies how to bind a channel server.
469#[derive(
470    Clone,
471    Debug,
472    PartialEq,
473    Eq,
474    Hash,
475    Serialize,
476    Deserialize,
477    typeuri::Named
478)]
479pub enum BindSpec {
480    /// Bind to any available address for the given transport.
481    Any(ChannelTransport),
482
483    /// Bind to a specific channel address.
484    Addr(ChannelAddr),
485}
486
487impl BindSpec {
488    /// Return an "any" address for this bind spec.
489    pub fn binding_addr(&self) -> ChannelAddr {
490        match self {
491            BindSpec::Any(transport) => ChannelAddr::any(transport.clone()),
492            BindSpec::Addr(addr) => addr.clone(),
493        }
494    }
495}
496
497impl From<ChannelTransport> for BindSpec {
498    fn from(transport: ChannelTransport) -> Self {
499        BindSpec::Any(transport)
500    }
501}
502
503impl From<ChannelAddr> for BindSpec {
504    fn from(addr: ChannelAddr) -> Self {
505        BindSpec::Addr(addr)
506    }
507}
508
509impl fmt::Display for BindSpec {
510    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
511        match self {
512            Self::Any(transport) => write!(f, "{}", transport),
513            Self::Addr(addr) => write!(f, "{}", addr),
514        }
515    }
516}
517
518impl FromStr for BindSpec {
519    type Err = anyhow::Error;
520
521    fn from_str(s: &str) -> Result<Self, Self::Err> {
522        if let Ok(transport) = ChannelTransport::from_str(s) {
523            Ok(BindSpec::Any(transport))
524        } else if let Ok(addr) = ChannelAddr::from_zmq_url(s) {
525            Ok(BindSpec::Addr(addr))
526        } else if let Ok(addr) = ChannelAddr::from_str(s) {
527            Ok(BindSpec::Addr(addr))
528        } else {
529            Err(anyhow::anyhow!("invalid bind spec: {}", s))
530        }
531    }
532}
533
534impl AttrValue for BindSpec {
535    fn display(&self) -> String {
536        self.to_string()
537    }
538
539    fn parse(s: &str) -> Result<Self, anyhow::Error> {
540        Self::from_str(s)
541    }
542}
543
544/// The type of (TCP) hostnames.
545pub type Hostname = String;
546
547/// The type of (TCP) ports.
548pub type Port = u16;
549
550/// The type of a channel address, used to multiplex different underlying
551/// channel implementations. ChannelAddrs also have a concrete syntax:
552/// the address type (e.g., "tcp" or "local"), followed by ":", and an address
553/// parseable to that type. For example:
554///
555/// - `tcp:127.0.0.1:1234` - localhost port 1234 over TCP
556/// - `tcp:192.168.0.1:1111` - 192.168.0.1 port 1111 over TCP
557/// - `local:123` - the (in-process) local port 123
558/// - `unix:/some/path` - the Unix socket at `/some/path`
559///
560/// Both local and TCP ports 0 are reserved to indicate "any available
561/// port" when serving.
562///
563/// ```
564/// # use hyperactor::channel::ChannelAddr;
565/// let addr: ChannelAddr = "tcp:127.0.0.1:1234".parse().unwrap();
566/// let ChannelAddr::Tcp(socket_addr) = addr else {
567///     panic!()
568/// };
569/// assert_eq!(socket_addr.port(), 1234);
570/// assert_eq!(socket_addr.is_ipv4(), true);
571/// ```
572#[derive(
573    Clone,
574    Debug,
575    PartialEq,
576    Eq,
577    Ord,
578    PartialOrd,
579    Serialize,
580    Deserialize,
581    Hash,
582    typeuri::Named
583)]
584pub enum ChannelAddr {
585    /// A socket address used to establish TCP channels. Supports
586    /// both  IPv4 and IPv6 address / port pairs.
587    Tcp(SocketAddr),
588
589    /// An address to establish TCP channels with TLS support within Meta.
590    /// Uses TlsAddr with hostname and port.
591    MetaTls(TlsAddr),
592
593    /// An address to establish TCP channels with configurable TLS support.
594    /// Uses TlsAddr with hostname and port.
595    Tls(TlsAddr),
596
597    /// Local addresses are registered in-process and given an integral
598    /// index.
599    Local(u64),
600
601    /// A unix domain socket address. Supports both absolute path names as
602    ///  well as "abstract" names per https://manpages.debian.org/unstable/manpages/unix.7.en.html#Abstract_sockets
603    Unix(net::unix::SocketAddr),
604
605    /// A pair of addresses, one for the client and one for the server:
606    ///   - The client should dial to the `dial_to` address.
607    ///   - The server should bind to the `bind_to` address.
608    ///
609    /// The user is responsible for ensuring the traffic to the `dial_to` address
610    /// is routed to the `bind_to` address.
611    ///
612    /// This is useful for scenarios where the network is configured in a way,
613    /// that the bound address is not directly accessible from the client.
614    ///
615    /// For example, in AWS, the client could be provided with the public IP
616    /// address, yet the server is bound to a private IP address or simply
617    /// INADDR_ANY. Traffic to the public IP address is mapped to the private
618    /// IP address through network address translation (NAT).
619    Alias {
620        /// The address to which the client should dial to.
621        dial_to: Box<ChannelAddr>,
622        /// The address to which the server should bind to.
623        bind_to: Box<ChannelAddr>,
624    },
625}
626
627impl From<SocketAddr> for ChannelAddr {
628    fn from(value: SocketAddr) -> Self {
629        Self::Tcp(value)
630    }
631}
632
633impl From<net::unix::SocketAddr> for ChannelAddr {
634    fn from(value: net::unix::SocketAddr) -> Self {
635        Self::Unix(value)
636    }
637}
638
639impl From<std::os::unix::net::SocketAddr> for ChannelAddr {
640    fn from(value: std::os::unix::net::SocketAddr) -> Self {
641        Self::Unix(net::unix::SocketAddr::new(value))
642    }
643}
644
645impl From<tokio::net::unix::SocketAddr> for ChannelAddr {
646    fn from(value: tokio::net::unix::SocketAddr) -> Self {
647        std::os::unix::net::SocketAddr::from(value).into()
648    }
649}
650
651/// Return the first non-link-local address from a list.
652fn find_routable_address(addresses: &[IpAddr]) -> Option<IpAddr> {
653    addresses
654        .iter()
655        .find(|addr| match addr {
656            IpAddr::V6(v6) => !v6.is_unicast_link_local(),
657            IpAddr::V4(v4) => !v4.is_link_local(),
658        })
659        .cloned()
660}
661
662impl ChannelAddr {
663    /// The "any" address for the given transport type. This is used to
664    /// servers to "any" address.
665    pub fn any(transport: ChannelTransport) -> Self {
666        match transport {
667            ChannelTransport::Tcp(mode) => {
668                let ip = match mode {
669                    TcpMode::Localhost => IpAddr::V6(Ipv6Addr::LOCALHOST),
670                    TcpMode::Hostname => {
671                        hostname::get()
672                            .ok()
673                            .and_then(|hostname| {
674                                // TODO: Avoid using DNS directly once we figure out a good extensibility story here
675                                hostname.to_str().and_then(|hostname_str| {
676                                    dns_lookup::lookup_host(hostname_str)
677                                        .ok()
678                                        .and_then(|addresses| find_routable_address(&addresses))
679                                })
680                            })
681                            .expect("failed to resolve hostname to ip address")
682                    }
683                };
684                Self::Tcp(SocketAddr::new(ip, 0))
685            }
686            ChannelTransport::MetaTls(mode) => {
687                let host_address = match mode {
688                    TlsMode::Hostname => hostname::get()
689                        .ok()
690                        .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
691                        .unwrap_or("unknown_host".to_string()),
692                    TlsMode::IpV6 => {
693                        get_host_ipv6_address().expect("failed to retrieve ipv6 address")
694                    }
695                };
696                Self::MetaTls(TlsAddr::new(host_address, 0))
697            }
698            ChannelTransport::Local => Self::Local(0),
699            ChannelTransport::Tls => {
700                let host_address = hostname::get()
701                    .ok()
702                    .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
703                    .unwrap_or("localhost".to_string());
704                Self::Tls(TlsAddr::new(host_address, 0))
705            }
706            // This works because the file will be deleted but we know we have a unique file by this point.
707            ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()),
708        }
709    }
710
711    /// The transport used by this address.
712    pub fn transport(&self) -> ChannelTransport {
713        match self {
714            Self::Tcp(addr) => {
715                if addr.ip().is_loopback() {
716                    ChannelTransport::Tcp(TcpMode::Localhost)
717                } else {
718                    ChannelTransport::Tcp(TcpMode::Hostname)
719                }
720            }
721            Self::MetaTls(addr) => match addr.hostname.parse::<IpAddr>() {
722                Ok(IpAddr::V6(_)) => ChannelTransport::MetaTls(TlsMode::IpV6),
723                Ok(IpAddr::V4(_)) => ChannelTransport::MetaTls(TlsMode::Hostname),
724                Err(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
725            },
726            Self::Tls(_) => ChannelTransport::Tls,
727            Self::Local(_) => ChannelTransport::Local,
728            Self::Unix(_) => ChannelTransport::Unix,
729            // bind_to's transport is what is actually used in communication.
730            // Therefore we use its transport to represent the Alias.
731            Self::Alias { bind_to, .. } => bind_to.transport(),
732        }
733    }
734}
735
736#[cfg(fbcode_build)]
737fn get_host_ipv6_address() -> anyhow::Result<String> {
738    crate::meta::host_ip::host_ipv6_address()
739}
740
741#[cfg(not(fbcode_build))]
742fn get_host_ipv6_address() -> anyhow::Result<String> {
743    Ok(local_ip_address::local_ipv6()?.to_string())
744}
745
746impl fmt::Display for ChannelAddr {
747    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
748        match self {
749            Self::Tcp(addr) => write!(f, "tcp:{}", addr),
750            Self::MetaTls(addr) => write!(f, "metatls:{}", addr),
751            Self::Tls(addr) => write!(f, "tls:{}", addr),
752            Self::Local(index) => write!(f, "local:{}", index),
753            Self::Unix(addr) => write!(f, "unix:{}", addr),
754            Self::Alias { dial_to, bind_to } => {
755                write!(f, "alias:dial_to={};bind_to={}", dial_to, bind_to)
756            }
757        }
758    }
759}
760
761impl FromStr for ChannelAddr {
762    type Err = anyhow::Error;
763
764    fn from_str(addr: &str) -> Result<Self, Self::Err> {
765        match addr.split_once('!').or_else(|| addr.split_once(':')) {
766            Some(("local", rest)) => rest
767                .parse::<u64>()
768                .map(Self::Local)
769                .map_err(anyhow::Error::from),
770            Some(("tcp", rest)) => rest
771                .parse::<SocketAddr>()
772                .map(Self::Tcp)
773                .map_err(anyhow::Error::from),
774            Some(("metatls", rest)) => net::meta::parse(rest).map_err(|e| e.into()),
775            Some(("tls", rest)) => net::tls::parse(rest).map_err(|e| e.into()),
776            Some(("unix", rest)) => Ok(Self::Unix(net::unix::SocketAddr::from_str(rest)?)),
777            Some(("alias", _)) => Err(anyhow::anyhow!(
778                "detect possible alias address, but we currently do not support \
779                parsing alias' string representation since we only want to \
780                support parsing its zmq url format."
781            )),
782            Some((r#type, _)) => Err(anyhow::anyhow!("no such channel type: {type}")),
783            None => Err(anyhow::anyhow!("no channel type specified")),
784        }
785    }
786}
787
788/// Normalize a host string. If the host is an IP address, parse and
789/// re-format it to produce a canonical string representation.
790pub(crate) fn normalize_host(host: &str) -> String {
791    // Strip URI-style brackets (e.g., "[::1]") because IpAddr::from_str
792    // rejects them — it only accepts bare addresses.
793    let host_clean = host
794        .strip_prefix('[')
795        .and_then(|h| h.strip_suffix(']'))
796        .unwrap_or(host);
797
798    if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
799        ip_addr.to_string()
800    } else {
801        host.to_string()
802    }
803}
804
805impl ChannelAddr {
806    /// Parse ZMQ-style URL format: scheme://address
807    /// Supports:
808    /// - tcp://hostname:port or tcp://*:port (wildcard binding)
809    /// - inproc://endpoint-name (equivalent to local)
810    /// - ipc://path (equivalent to unix)
811    /// - metatls://hostname:port or metatls://*:port
812    /// - Alias format: dial_to_url@bind_to_url (e.g., tcp://host:port@tcp://host:port)
813    ///   Note: Alias format is currently only supported for TCP addresses
814    pub fn from_zmq_url(address: &str) -> Result<Self, anyhow::Error> {
815        // Check for Alias format: dial_to_url@bind_to_url
816        // The @ character separates two valid ZMQ URLs
817        if let Some(at_pos) = address.find('@') {
818            let dial_to_str = &address[..at_pos];
819            let bind_to_str = &address[at_pos + 1..];
820
821            // Validate that both addresses use TCP scheme
822            if !dial_to_str.starts_with("tcp://") {
823                return Err(anyhow::anyhow!(
824                    "alias format is only supported for TCP addresses, got dial_to: {}",
825                    dial_to_str
826                ));
827            }
828            if !bind_to_str.starts_with("tcp://") {
829                return Err(anyhow::anyhow!(
830                    "alias format is only supported for TCP addresses, got bind_to: {}",
831                    bind_to_str
832                ));
833            }
834
835            let dial_to = Self::from_zmq_url(dial_to_str)?;
836            let bind_to = Self::from_zmq_url(bind_to_str)?;
837
838            return Ok(Self::Alias {
839                dial_to: Box::new(dial_to),
840                bind_to: Box::new(bind_to),
841            });
842        }
843
844        // Try ZMQ-style URL format first (scheme://...)
845        let (scheme, address) = address.split_once("://").ok_or_else(|| {
846            anyhow::anyhow!("address must be in url form scheme://endppoint {}", address)
847        })?;
848
849        match scheme {
850            "tcp" => {
851                let (host, port) = Self::split_host_port(address)?;
852
853                if host == "*" {
854                    // Wildcard binding - use IPv6 unspecified address
855                    Ok(Self::Tcp(SocketAddr::new("::".parse().unwrap(), port)))
856                } else {
857                    // Resolve hostname to IP address for proper SocketAddr creation
858                    let socket_addr = Self::resolve_hostname_to_socket_addr(host, port)?;
859                    Ok(Self::Tcp(socket_addr))
860                }
861            }
862            "inproc" => {
863                // inproc://port -> local:port
864                // Port must be a valid u64 number
865                let port = address.parse::<u64>().map_err(|_| {
866                    anyhow::anyhow!("inproc endpoint must be a valid port number: {}", address)
867                })?;
868                Ok(Self::Local(port))
869            }
870            "ipc" => {
871                // ipc://path -> unix:path
872                Ok(Self::Unix(net::unix::SocketAddr::from_str(address)?))
873            }
874            "metatls" => {
875                let (host, port) = Self::split_host_port(address)?;
876
877                if host == "*" {
878                    // Wildcard binding - use IPv6 unspecified address directly without hostname resolution
879                    Ok(Self::MetaTls(TlsAddr::new(
880                        std::net::Ipv6Addr::UNSPECIFIED.to_string(),
881                        port,
882                    )))
883                } else {
884                    Ok(Self::MetaTls(TlsAddr::new(host, port)))
885                }
886            }
887            "tls" => {
888                let (host, port) = Self::split_host_port(address)?;
889
890                if host == "*" {
891                    // Wildcard binding - use IPv6 unspecified address directly without hostname resolution
892                    Ok(Self::Tls(TlsAddr::new(
893                        std::net::Ipv6Addr::UNSPECIFIED.to_string(),
894                        port,
895                    )))
896                } else {
897                    Ok(Self::Tls(TlsAddr::new(host, port)))
898                }
899            }
900            scheme => Err(anyhow::anyhow!("unsupported ZMQ scheme: {}", scheme)),
901        }
902    }
903
904    /// Split host:port string, supporting IPv6 addresses
905    fn split_host_port(address: &str) -> Result<(&str, u16), anyhow::Error> {
906        if let Some((host, port_str)) = address.rsplit_once(':') {
907            let port: u16 = port_str
908                .parse()
909                .map_err(|_| anyhow::anyhow!("invalid port: {}", port_str))?;
910            Ok((host, port))
911        } else {
912            Err(anyhow::anyhow!("invalid address format: {}", address))
913        }
914    }
915
916    /// Render as a ZMQ-style URL, the inverse of [`from_zmq_url`](Self::from_zmq_url).
917    pub fn to_zmq_url(&self) -> String {
918        match self {
919            Self::Tcp(addr) => format!("tcp://{}", addr),
920            Self::MetaTls(addr) => format!("metatls://{}:{}", addr.hostname, addr.port),
921            Self::Tls(addr) => format!("tls://{}:{}", addr.hostname, addr.port),
922            Self::Local(index) => format!("inproc://{}", index),
923            Self::Unix(addr) => format!("ipc://{}", addr),
924            Self::Alias { dial_to, bind_to } => {
925                format!("{}@{}", dial_to.to_zmq_url(), bind_to.to_zmq_url())
926            }
927        }
928    }
929
930    /// Resolve hostname to SocketAddr, handling both IP addresses and hostnames
931    fn resolve_hostname_to_socket_addr(host: &str, port: u16) -> Result<SocketAddr, anyhow::Error> {
932        // Handle IPv6 addresses in brackets by stripping the brackets
933        let host_clean = if host.starts_with('[') && host.ends_with(']') {
934            &host[1..host.len() - 1]
935        } else {
936            host
937        };
938
939        // First try to parse as an IP address directly
940        if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
941            return Ok(SocketAddr::new(ip_addr, port));
942        }
943
944        // If not an IP, try hostname resolution
945        use std::net::ToSocketAddrs;
946        let mut addrs = (host_clean, port)
947            .to_socket_addrs()
948            .map_err(|e| anyhow::anyhow!("failed to resolve hostname '{}': {}", host_clean, e))?;
949
950        addrs
951            .next()
952            .ok_or_else(|| anyhow::anyhow!("no addresses found for hostname '{}'", host_clean))
953    }
954}
955
956/// Universal channel transmitter.
957pub struct ChannelTx<M: RemoteMessage> {
958    inner: ChannelTxKind<M>,
959}
960
961impl<M: RemoteMessage> fmt::Debug for ChannelTx<M> {
962    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
963        f.debug_struct("ChannelTx")
964            .field("addr", &self.addr())
965            .finish()
966    }
967}
968
969/// Universal channel transmitter.
970enum ChannelTxKind<M: RemoteMessage> {
971    Local(local::LocalTx<M>),
972    Net(net::NetTx<M>),
973}
974
975#[async_trait]
976impl<M: RemoteMessage> Tx<M> for ChannelTx<M> {
977    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
978        match &self.inner {
979            ChannelTxKind::Local(tx) => tx.do_post(message, return_channel),
980            ChannelTxKind::Net(tx) => tx.do_post(message, return_channel),
981        }
982    }
983
984    fn addr(&self) -> ChannelAddr {
985        match &self.inner {
986            ChannelTxKind::Local(tx) => tx.addr(),
987            ChannelTxKind::Net(tx) => Tx::<M>::addr(tx),
988        }
989    }
990
991    fn status(&self) -> &watch::Receiver<TxStatus> {
992        match &self.inner {
993            ChannelTxKind::Local(tx) => tx.status(),
994            ChannelTxKind::Net(tx) => tx.status(),
995        }
996    }
997}
998
999/// Universal channel receiver.
1000pub struct ChannelRx<M: RemoteMessage> {
1001    inner: ChannelRxKind<M>,
1002}
1003
1004impl<M: RemoteMessage> fmt::Debug for ChannelRx<M> {
1005    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1006        f.debug_struct("ChannelRx")
1007            .field("addr", &self.addr())
1008            .finish()
1009    }
1010}
1011
1012/// Universal channel receiver.
1013enum ChannelRxKind<M: RemoteMessage> {
1014    Local(local::LocalRx<M>),
1015    Net(net::NetRx<M>),
1016}
1017
1018#[async_trait]
1019impl<M: RemoteMessage> Rx<M> for ChannelRx<M> {
1020    #[tracing::instrument(level = "debug", skip_all)]
1021    async fn recv(&mut self) -> Result<M, ChannelError> {
1022        match &mut self.inner {
1023            ChannelRxKind::Local(rx) => rx.recv().await,
1024            ChannelRxKind::Net(rx) => rx.recv().await,
1025        }
1026    }
1027
1028    fn addr(&self) -> ChannelAddr {
1029        match &self.inner {
1030            ChannelRxKind::Local(rx) => rx.addr(),
1031            ChannelRxKind::Net(rx) => rx.addr(),
1032        }
1033    }
1034
1035    async fn join(self) {
1036        match self.inner {
1037            ChannelRxKind::Local(rx) => rx.join().await,
1038            ChannelRxKind::Net(rx) => rx.join().await,
1039        }
1040    }
1041}
1042
1043/// Dial the provided address, returning the corresponding Tx, or error
1044/// if the channel cannot be established. The underlying connection is
1045/// dropped whenever the returned Tx is dropped.
1046#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
1047#[track_caller]
1048pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, ChannelError> {
1049    tracing::debug!(name = "dial", caller = %Location::caller(), %addr, "dialing channel {}", addr);
1050    let inner = match addr {
1051        ChannelAddr::Local(port) => ChannelTxKind::Local(local::dial(port)?),
1052        ChannelAddr::Tcp(_)
1053        | ChannelAddr::Unix(_)
1054        | ChannelAddr::Tls(_)
1055        | ChannelAddr::MetaTls(_) => ChannelTxKind::Net(net::spawn(net::link(addr)?)),
1056        ChannelAddr::Alias { dial_to, .. } => dial(*dial_to)?.inner,
1057    };
1058    Ok(ChannelTx { inner })
1059}
1060
1061/// Serve on the provided channel address. The server is turned down
1062/// when the returned Rx is dropped.
1063#[track_caller]
1064pub fn serve<M: RemoteMessage>(
1065    addr: ChannelAddr,
1066) -> Result<(ChannelAddr, ChannelRx<M>), ChannelError> {
1067    let caller = Location::caller();
1068    serve_inner(addr).map(|(addr, inner)| {
1069        tracing::debug!(
1070            name = "serve",
1071            %addr,
1072            %caller,
1073        );
1074        (addr, ChannelRx { inner })
1075    })
1076}
1077
1078fn serve_inner<M: RemoteMessage>(
1079    addr: ChannelAddr,
1080) -> Result<(ChannelAddr, ChannelRxKind<M>), ChannelError> {
1081    match addr {
1082        ChannelAddr::Tcp(_)
1083        | ChannelAddr::Unix(_)
1084        | ChannelAddr::Tls(_)
1085        | ChannelAddr::MetaTls(_) => {
1086            let (addr, rx) = net::server::serve::<M>(addr)?;
1087            Ok((addr, ChannelRxKind::Net(rx)))
1088        }
1089        ChannelAddr::Local(0) => {
1090            let (port, rx) = local::serve::<M>();
1091            Ok((ChannelAddr::Local(port), ChannelRxKind::Local(rx)))
1092        }
1093        ChannelAddr::Local(a) => Err(ChannelError::InvalidAddress(format!(
1094            "invalid local addr: {}",
1095            a
1096        ))),
1097        ChannelAddr::Alias { dial_to, bind_to } => {
1098            let (bound_addr, rx) = serve_inner::<M>(*bind_to)?;
1099            let alias_addr = ChannelAddr::Alias {
1100                dial_to,
1101                bind_to: Box::new(bound_addr),
1102            };
1103            Ok((alias_addr, rx))
1104        }
1105    }
1106}
1107
1108/// Serve on the local address. The server is turned down
1109/// when the returned Rx is dropped.
1110pub fn serve_local<M: RemoteMessage>() -> (ChannelAddr, ChannelRx<M>) {
1111    let (port, rx) = local::serve::<M>();
1112    (
1113        ChannelAddr::Local(port),
1114        ChannelRx {
1115            inner: ChannelRxKind::Local(rx),
1116        },
1117    )
1118}
1119
1120#[cfg(test)]
1121mod tests {
1122    use std::assert_matches::assert_matches;
1123    use std::collections::HashSet;
1124    use std::net::IpAddr;
1125    use std::net::Ipv4Addr;
1126    use std::net::Ipv6Addr;
1127    use std::time::Duration;
1128
1129    use tokio::task::JoinSet;
1130
1131    use super::net::*;
1132    use super::*;
1133    #[test]
1134    fn test_channel_addr() {
1135        let cases_ok = vec![
1136            (
1137                "tcp<DELIM>[::1]:1234",
1138                ChannelAddr::Tcp(SocketAddr::new(
1139                    IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
1140                    1234,
1141                )),
1142            ),
1143            (
1144                "tcp<DELIM>127.0.0.1:8080",
1145                ChannelAddr::Tcp(SocketAddr::new(
1146                    IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
1147                    8080,
1148                )),
1149            ),
1150            #[cfg(target_os = "linux")]
1151            ("local<DELIM>123", ChannelAddr::Local(123)),
1152            (
1153                "unix<DELIM>@yolo",
1154                ChannelAddr::Unix(
1155                    unix::SocketAddr::from_abstract_name("yolo")
1156                        .expect("can't make socket from abstract name"),
1157                ),
1158            ),
1159            (
1160                "unix<DELIM>/cool/socket-path",
1161                ChannelAddr::Unix(
1162                    unix::SocketAddr::from_pathname("/cool/socket-path")
1163                        .expect("can't make socket from path"),
1164                ),
1165            ),
1166        ];
1167
1168        for (raw, parsed) in cases_ok {
1169            for delim in ["!", ":"] {
1170                let raw = raw.replace("<DELIM>", delim);
1171                assert_eq!(raw.parse::<ChannelAddr>().unwrap(), parsed);
1172            }
1173        }
1174
1175        let cases_err = vec![
1176            ("tcp:abcdef..123124", "invalid socket address syntax"),
1177            ("xxx:foo", "no such channel type: xxx"),
1178            ("127.0.0.1", "no channel type specified"),
1179            ("local:abc", "invalid digit found in string"),
1180        ];
1181
1182        for (raw, error) in cases_err {
1183            let Err(err) = raw.parse::<ChannelAddr>() else {
1184                panic!("expected error parsing: {}", &raw)
1185            };
1186            assert_eq!(format!("{}", err), error);
1187        }
1188    }
1189
1190    #[test]
1191    fn test_zmq_style_channel_addr() {
1192        // Test TCP addresses
1193        assert_eq!(
1194            ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080").unwrap(),
1195            ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap())
1196        );
1197
1198        // Test TCP wildcard binding
1199        assert_eq!(
1200            ChannelAddr::from_zmq_url("tcp://*:5555").unwrap(),
1201            ChannelAddr::Tcp("[::]:5555".parse().unwrap())
1202        );
1203
1204        // Test inproc (maps to local with numeric endpoint)
1205        assert_eq!(
1206            ChannelAddr::from_zmq_url("inproc://12345").unwrap(),
1207            ChannelAddr::Local(12345)
1208        );
1209
1210        // Test ipc (maps to unix)
1211        assert_eq!(
1212            ChannelAddr::from_zmq_url("ipc:///tmp/my-socket").unwrap(),
1213            ChannelAddr::Unix(unix::SocketAddr::from_pathname("/tmp/my-socket").unwrap())
1214        );
1215
1216        // Test metatls with hostname
1217        assert_eq!(
1218            ChannelAddr::from_zmq_url("metatls://example.com:443").unwrap(),
1219            ChannelAddr::MetaTls(TlsAddr::new("example.com", 443))
1220        );
1221
1222        // Test metatls with IP address (should be normalized)
1223        assert_eq!(
1224            ChannelAddr::from_zmq_url("metatls://192.168.1.1:443").unwrap(),
1225            ChannelAddr::MetaTls(TlsAddr::new("192.168.1.1", 443))
1226        );
1227
1228        // Test metatls with wildcard (should use IPv6 unspecified address)
1229        assert_eq!(
1230            ChannelAddr::from_zmq_url("metatls://*:8443").unwrap(),
1231            ChannelAddr::MetaTls(TlsAddr::new("::", 8443))
1232        );
1233
1234        // Test TCP hostname resolution (should resolve hostname to IP)
1235        // Note: This test may fail in environments without proper DNS resolution
1236        // We test that it at least doesn't fail to parse
1237        let tcp_hostname_result = ChannelAddr::from_zmq_url("tcp://localhost:8080");
1238        assert!(tcp_hostname_result.is_ok());
1239
1240        // Test IPv6 address
1241        assert_eq!(
1242            ChannelAddr::from_zmq_url("tcp://[::1]:1234").unwrap(),
1243            ChannelAddr::Tcp("[::1]:1234".parse().unwrap())
1244        );
1245
1246        // Test error cases
1247        assert!(ChannelAddr::from_zmq_url("invalid://scheme").is_err());
1248        assert!(ChannelAddr::from_zmq_url("tcp://invalid-port").is_err());
1249        assert!(ChannelAddr::from_zmq_url("metatls://no-port").is_err());
1250        assert!(ChannelAddr::from_zmq_url("inproc://not-a-number").is_err());
1251
1252        // IPv6 normalization: leading zeros are stripped
1253        assert_eq!(
1254            ChannelAddr::from_zmq_url("metatls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443")
1255                .unwrap(),
1256            ChannelAddr::MetaTls(TlsAddr::new("2a03:83e4:5000:c000:56d7:cf:75ce:144a", 443))
1257        );
1258
1259        // Short and long forms of the same IPv6 produce equal ChannelAddr values
1260        assert_eq!(
1261            ChannelAddr::from_zmq_url("metatls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443")
1262                .unwrap(),
1263            ChannelAddr::from_zmq_url("metatls://2a03:83e4:5000:c000:56d7:cf:75ce:144a:443")
1264                .unwrap(),
1265        );
1266
1267        // Bracketed IPv6 is normalized
1268        assert_eq!(
1269            ChannelAddr::from_zmq_url("metatls://[::1]:443").unwrap(),
1270            ChannelAddr::MetaTls(TlsAddr::new("::1", 443))
1271        );
1272
1273        // Same tests for tls://
1274        assert_eq!(
1275            ChannelAddr::from_zmq_url("tls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443").unwrap(),
1276            ChannelAddr::Tls(TlsAddr::new("2a03:83e4:5000:c000:56d7:cf:75ce:144a", 443))
1277        );
1278        assert_eq!(
1279            ChannelAddr::from_zmq_url("tls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443").unwrap(),
1280            ChannelAddr::from_zmq_url("tls://2a03:83e4:5000:c000:56d7:cf:75ce:144a:443").unwrap(),
1281        );
1282        assert_eq!(
1283            ChannelAddr::from_zmq_url("tls://[::1]:443").unwrap(),
1284            ChannelAddr::Tls(TlsAddr::new("::1", 443))
1285        );
1286    }
1287
1288    #[test]
1289    fn test_normalize_host() {
1290        // Plain IPv4 passes through
1291        assert_eq!(normalize_host("192.168.1.1"), "192.168.1.1");
1292
1293        // Plain hostname passes through
1294        assert_eq!(normalize_host("example.com"), "example.com");
1295
1296        // IPv6 with leading zeros gets normalized
1297        assert_eq!(
1298            normalize_host("2a03:83e4:5000:c000:56d7:00cf:75ce:144a"),
1299            "2a03:83e4:5000:c000:56d7:cf:75ce:144a"
1300        );
1301
1302        // Bracketed IPv6 is stripped and normalized
1303        assert_eq!(normalize_host("[::1]"), "::1");
1304
1305        // Without bracket stripping, IpAddr::from_str rejects bracketed
1306        // addresses. This demonstrates that the bracket stripping in
1307        // normalize_host is necessary.
1308        assert!("[::1]".parse::<IpAddr>().is_err());
1309    }
1310
1311    #[test]
1312    fn test_zmq_style_alias_channel_addr() {
1313        // Test Alias format: dial_to_url@bind_to_url
1314        // The format is: dial_to_url@bind_to_url where both are valid ZMQ URLs
1315        // Note: Alias format is only supported for TCP addresses
1316
1317        // Test Alias with tcp on both sides
1318        let alias_addr = ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::]:8800").unwrap();
1319        match alias_addr {
1320            ChannelAddr::Alias { dial_to, bind_to } => {
1321                assert_eq!(
1322                    *dial_to,
1323                    ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap())
1324                );
1325                assert_eq!(*bind_to, ChannelAddr::Tcp("[::]:8800".parse().unwrap()));
1326            }
1327            _ => panic!("Expected Alias"),
1328        }
1329
1330        // Test error: alias with non-tcp dial_to (not supported)
1331        assert!(
1332            ChannelAddr::from_zmq_url("metatls://example.com:443@tcp://127.0.0.1:8080").is_err()
1333        );
1334
1335        // Test error: alias with non-tcp bind_to (not supported)
1336        assert!(
1337            ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@metatls://example.com:443").is_err()
1338        );
1339
1340        // Test error: invalid dial_to URL in Alias
1341        assert!(ChannelAddr::from_zmq_url("invalid://scheme@tcp://127.0.0.1:8080").is_err());
1342
1343        // Test error: invalid bind_to URL in Alias
1344        assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@invalid://scheme").is_err());
1345
1346        // Test error: missing port in dial_to
1347        assert!(ChannelAddr::from_zmq_url("tcp://host@tcp://127.0.0.1:8080").is_err());
1348
1349        // Test error: missing port in bind_to
1350        assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@tcp://example.com").is_err());
1351    }
1352
1353    #[tokio::test]
1354    async fn test_multiple_connections() {
1355        for addr in ChannelTransport::all().map(ChannelAddr::any) {
1356            let (listen_addr, mut rx) = crate::channel::serve::<u64>(addr).unwrap();
1357
1358            let mut sends: JoinSet<()> = JoinSet::new();
1359            for message in 0u64..100u64 {
1360                let addr = listen_addr.clone();
1361                sends.spawn(async move {
1362                    let tx = dial::<u64>(addr).unwrap();
1363                    tx.post(message);
1364                });
1365            }
1366
1367            let mut received: HashSet<u64> = HashSet::new();
1368            while received.len() < 100 {
1369                received.insert(rx.recv().await.unwrap());
1370            }
1371
1372            for message in 0u64..100u64 {
1373                assert!(received.contains(&message));
1374            }
1375
1376            loop {
1377                match sends.join_next().await {
1378                    Some(Ok(())) => (),
1379                    Some(Err(err)) => panic!("{}", err),
1380                    None => break,
1381                }
1382            }
1383        }
1384    }
1385
1386    #[tokio::test]
1387    async fn test_server_close() {
1388        for addr in ChannelTransport::all().map(ChannelAddr::any) {
1389            if net::is_net_addr(&addr) {
1390                // Net has store-and-forward semantics. We don't expect failures
1391                // on closure.
1392                continue;
1393            }
1394
1395            let (listen_addr, rx) = crate::channel::serve::<u64>(addr).unwrap();
1396
1397            let tx = dial::<u64>(listen_addr).unwrap();
1398            tx.post(123);
1399            drop(rx);
1400
1401            // New transmits should fail... but there is buffering, etc.,
1402            // which can cause the failure to be delayed. We give it
1403            // a deadline, but it can still technically fail -- the test
1404            // should be considered a kind of integration test.
1405            let start = tokio::time::Instant::now();
1406
1407            let result = loop {
1408                let (return_tx, return_rx) = oneshot::channel();
1409                tx.try_post(123, return_tx);
1410                let result = return_rx.await;
1411
1412                if result.is_ok() || start.elapsed() > Duration::from_secs(10) {
1413                    break result;
1414                }
1415            };
1416            assert_matches!(
1417                result,
1418                Ok(SendError {
1419                    error: ChannelError::Closed,
1420                    message: 123,
1421                    reason: None
1422                })
1423            );
1424        }
1425    }
1426
1427    fn addrs() -> Vec<ChannelAddr> {
1428        use rand::Rng;
1429        use rand::distr::Uniform;
1430
1431        let rng = rand::rng();
1432        let uniform = Uniform::new_inclusive('a', 'z').unwrap();
1433        vec![
1434            "tcp:[::1]:0".parse().unwrap(),
1435            "local:0".parse().unwrap(),
1436            #[cfg(target_os = "linux")]
1437            "unix:".parse().unwrap(),
1438            #[cfg(target_os = "linux")]
1439            format!(
1440                "unix:@{}",
1441                rng.sample_iter(uniform).take(10).collect::<String>()
1442            )
1443            .parse()
1444            .unwrap(),
1445        ]
1446    }
1447
1448    #[test]
1449    fn test_bind_spec_from_str() {
1450        // Test parsing ChannelTransport strings -> BindSpec::Any
1451        assert_eq!(
1452            BindSpec::from_str("tcp").unwrap(),
1453            BindSpec::Any(ChannelTransport::Tcp(TcpMode::Hostname))
1454        );
1455        assert_eq!(
1456            BindSpec::from_str("metatls(Hostname)").unwrap(),
1457            BindSpec::Any(ChannelTransport::MetaTls(TlsMode::Hostname))
1458        );
1459
1460        // Test parsing ChannelAddr strings -> BindSpec::Addr
1461        assert_eq!(
1462            BindSpec::from_str("tcp:127.0.0.1:8080").unwrap(),
1463            BindSpec::Addr(ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap()))
1464        );
1465
1466        // Test parsing ZMQ URL format -> BindSpec::Addr
1467        assert_eq!(
1468            BindSpec::from_str("tcp://127.0.0.1:9000").unwrap(),
1469            BindSpec::Addr(ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap()))
1470        );
1471        assert_eq!(
1472            BindSpec::from_str("tcp://127.0.0.1:9000@tcp://[::1]:7200").unwrap(),
1473            BindSpec::Addr(
1474                ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::1]:7200").unwrap()
1475            )
1476        );
1477
1478        // Test error cases
1479        assert!(BindSpec::from_str("invalid_spec").is_err());
1480        assert!(BindSpec::from_str("unknown://scheme").is_err());
1481        assert!(BindSpec::from_str("").is_err());
1482    }
1483
1484    #[tokio::test]
1485    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Server(Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" }))
1486    #[cfg_attr(not(fbcode_build), ignore)]
1487    async fn test_dial_serve() {
1488        for addr in addrs() {
1489            let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1490            let tx = crate::channel::dial(listen_addr).unwrap();
1491            tx.post(123);
1492            assert_eq!(rx.recv().await.unwrap(), 123);
1493        }
1494    }
1495
1496    #[tokio::test]
1497    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Server(Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" }))
1498    #[cfg_attr(not(fbcode_build), ignore)]
1499    async fn test_send() {
1500        let config = hyperactor_config::global::lock();
1501
1502        // Use temporary config for this test
1503        let _guard1 = config.override_key(
1504            crate::config::MESSAGE_DELIVERY_TIMEOUT,
1505            Duration::from_secs(1),
1506        );
1507        let _guard2 = config.override_key(crate::config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
1508        for addr in addrs() {
1509            let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1510            let tx = crate::channel::dial(listen_addr).unwrap();
1511            tx.send(123).await.unwrap();
1512            assert_eq!(rx.recv().await.unwrap(), 123);
1513
1514            drop(rx);
1515            assert_matches!(
1516                tx.send(123).await.unwrap_err(),
1517                SendError {
1518                    error: ChannelError::Closed,
1519                    message: 123,
1520                    ..
1521                }
1522            );
1523        }
1524    }
1525
1526    #[test]
1527    fn test_find_routable_address_skips_link_local_ipv6() {
1528        let link_local_v6: IpAddr = "fe80::1".parse().unwrap();
1529        let routable_v6: IpAddr = "2001:db8::1".parse().unwrap();
1530        let addrs = vec![link_local_v6, routable_v6];
1531        assert_eq!(find_routable_address(&addrs), Some(routable_v6));
1532    }
1533
1534    #[test]
1535    fn test_find_routable_address_skips_link_local_ipv4() {
1536        let link_local_v4: IpAddr = "169.254.1.1".parse().unwrap();
1537        let routable_v4: IpAddr = "192.168.1.1".parse().unwrap();
1538        let addrs = vec![link_local_v4, routable_v4];
1539        assert_eq!(find_routable_address(&addrs), Some(routable_v4));
1540    }
1541
1542    #[test]
1543    fn test_find_routable_address_returns_none_when_all_link_local() {
1544        let link_local_v6: IpAddr = "fe80::1".parse().unwrap();
1545        let link_local_v4: IpAddr = "169.254.1.1".parse().unwrap();
1546        let addrs = vec![link_local_v6, link_local_v4];
1547        assert_eq!(find_routable_address(&addrs), None);
1548    }
1549
1550    #[test]
1551    fn test_find_routable_address_mixed() {
1552        let link_local_v6: IpAddr = "fe80::1".parse().unwrap();
1553        let link_local_v4: IpAddr = "169.254.0.1".parse().unwrap();
1554        let routable_v4: IpAddr = "10.0.0.1".parse().unwrap();
1555        let routable_v6: IpAddr = "2001:db8::2".parse().unwrap();
1556
1557        // First routable address in list order should be returned.
1558        let addrs = vec![link_local_v6, link_local_v4, routable_v4, routable_v6];
1559        assert_eq!(find_routable_address(&addrs), Some(routable_v4));
1560    }
1561}