hyperactor/
channel.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! One-way, multi-process, typed communication channels. These are used
10//! to send messages between mailboxes residing in different processes.
11
12#![allow(dead_code)] // Allow until this is used outside of tests.
13
14use core::net::SocketAddr;
15use std::fmt;
16use std::net::IpAddr;
17use std::net::Ipv6Addr;
18#[cfg(target_os = "linux")]
19use std::os::linux::net::SocketAddrExt;
20use std::panic::Location;
21use std::str::FromStr;
22
23use async_trait::async_trait;
24use enum_as_inner::EnumAsInner;
25use lazy_static::lazy_static;
26use local_ip_address::local_ipv6;
27use serde::Deserialize;
28use serde::Serialize;
29use tokio::sync::mpsc;
30use tokio::sync::oneshot;
31use tokio::sync::watch;
32
33use crate as hyperactor;
34use crate::Named;
35use crate::RemoteMessage;
36use crate::attrs::AttrValue;
37use crate::channel::sim::SimAddr;
38use crate::simnet::SimNetError;
39
40pub(crate) mod local;
41pub(crate) mod net;
42pub mod sim;
43
44/// The type of error that can occur on channel operations.
45#[derive(thiserror::Error, Debug)]
46pub enum ChannelError {
47    /// An operation was attempted on a closed channel.
48    #[error("channel closed")]
49    Closed,
50
51    /// An error occurred during send.
52    #[error("send: {0}")]
53    Send(#[source] anyhow::Error),
54
55    /// A network client error.
56    #[error(transparent)]
57    Client(#[from] net::ClientError),
58
59    /// The address was not valid.
60    #[error("invalid address {0:?}")]
61    InvalidAddress(String),
62
63    /// A serving error was encountered.
64    #[error(transparent)]
65    Server(#[from] net::ServerError),
66
67    /// A bincode serialization or deserialization error occurred.
68    #[error(transparent)]
69    Bincode(#[from] Box<bincode::ErrorKind>),
70
71    /// Data encoding errors.
72    #[error(transparent)]
73    Data(#[from] crate::data::Error),
74
75    /// Some other error.
76    #[error(transparent)]
77    Other(#[from] anyhow::Error),
78
79    /// An operation timeout occurred.
80    #[error("operation timed out after {0:?}")]
81    Timeout(std::time::Duration),
82
83    /// A simulator error occurred.
84    #[error(transparent)]
85    SimNetError(#[from] SimNetError),
86}
87
88/// An error that occurred during send. Returns the message that failed to send.
89#[derive(thiserror::Error, Debug)]
90#[error("{0}")]
91pub struct SendError<M: RemoteMessage>(#[source] pub ChannelError, pub M);
92
93impl<M: RemoteMessage> From<SendError<M>> for ChannelError {
94    fn from(error: SendError<M>) -> Self {
95        error.0
96    }
97}
98
99/// The possible states of a `Tx`.
100#[derive(Debug, Copy, Clone, PartialEq)]
101pub enum TxStatus {
102    /// The tx is good.
103    Active,
104    /// The tx cannot be used for message delivery.
105    Closed,
106}
107
108/// The transmit end of an M-typed channel.
109#[async_trait]
110pub trait Tx<M: RemoteMessage>: std::fmt::Debug {
111    /// Post a message; returning failed deliveries on the return channel, if provided.
112    /// If provided, the sender is dropped when the message has been
113    /// enqueued at the channel endpoint.
114    ///
115    /// Users should use the `try_post`, and `post` variants directly.
116    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>);
117
118    /// Enqueue a `message` on the local end of the channel. The
119    /// message is either delivered, or we eventually discover that
120    /// the channel has failed and it will be sent back on `return_channel`.
121    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SendError`.
122    fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
123        self.do_post(message, Some(return_channel));
124    }
125
126    /// Enqueue a message to be sent on the channel.
127    fn post(&self, message: M) {
128        self.do_post(message, None);
129    }
130
131    /// Send a message synchronously, returning when the messsage has
132    /// been delivered to the remote end of the channel.
133    async fn send(&self, message: M) -> Result<(), SendError<M>> {
134        let (tx, rx) = oneshot::channel();
135        self.try_post(message, tx);
136        match rx.await {
137            // Channel was closed; the message was not delivered.
138            Ok(err) => Err(err),
139
140            // Channel was dropped; the message was successfully enqueued
141            // on the remote end of the channel.
142            Err(_) => Ok(()),
143        }
144    }
145
146    /// The channel address to which this Tx is sending.
147    fn addr(&self) -> ChannelAddr;
148
149    /// A means to monitor the health of a `Tx`.
150    fn status(&self) -> &watch::Receiver<TxStatus>;
151}
152
153/// The receive end of an M-typed channel.
154#[async_trait]
155pub trait Rx<M: RemoteMessage>: std::fmt::Debug {
156    /// Receive the next message from the channel. If the channel returns
157    /// an error it is considered broken and should be discarded.
158    async fn recv(&mut self) -> Result<M, ChannelError>;
159
160    /// The channel address from which this Rx is receiving.
161    fn addr(&self) -> ChannelAddr;
162}
163
164#[derive(Debug)]
165struct MpscTx<M: RemoteMessage> {
166    tx: mpsc::UnboundedSender<M>,
167    addr: ChannelAddr,
168    status: watch::Receiver<TxStatus>,
169}
170
171impl<M: RemoteMessage> MpscTx<M> {
172    pub fn new(tx: mpsc::UnboundedSender<M>, addr: ChannelAddr) -> (Self, watch::Sender<TxStatus>) {
173        let (sender, receiver) = watch::channel(TxStatus::Active);
174        (
175            Self {
176                tx,
177                addr,
178                status: receiver,
179            },
180            sender,
181        )
182    }
183}
184
185#[async_trait]
186impl<M: RemoteMessage> Tx<M> for MpscTx<M> {
187    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
188        if let Err(mpsc::error::SendError(message)) = self.tx.send(message) {
189            if let Some(return_channel) = return_channel {
190                return_channel
191                    .send(SendError(ChannelError::Closed, message))
192                    .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
193            }
194        }
195    }
196
197    fn addr(&self) -> ChannelAddr {
198        self.addr.clone()
199    }
200
201    fn status(&self) -> &watch::Receiver<TxStatus> {
202        &self.status
203    }
204}
205
206#[derive(Debug)]
207struct MpscRx<M: RemoteMessage> {
208    rx: mpsc::UnboundedReceiver<M>,
209    addr: ChannelAddr,
210    // Used to report the status to the Tx side.
211    status_sender: watch::Sender<TxStatus>,
212}
213
214impl<M: RemoteMessage> MpscRx<M> {
215    pub fn new(
216        rx: mpsc::UnboundedReceiver<M>,
217        addr: ChannelAddr,
218        status_sender: watch::Sender<TxStatus>,
219    ) -> Self {
220        Self {
221            rx,
222            addr,
223            status_sender,
224        }
225    }
226}
227
228impl<M: RemoteMessage> Drop for MpscRx<M> {
229    fn drop(&mut self) {
230        let _ = self.status_sender.send(TxStatus::Closed);
231    }
232}
233
234#[async_trait]
235impl<M: RemoteMessage> Rx<M> for MpscRx<M> {
236    async fn recv(&mut self) -> Result<M, ChannelError> {
237        self.rx.recv().await.ok_or(ChannelError::Closed)
238    }
239
240    fn addr(&self) -> ChannelAddr {
241        self.addr.clone()
242    }
243}
244
245/// The hostname to use for TLS connections.
246#[derive(
247    Clone,
248    Debug,
249    PartialEq,
250    Eq,
251    Hash,
252    Serialize,
253    Deserialize,
254    strum::EnumIter,
255    strum::Display,
256    strum::EnumString
257)]
258pub enum TcpMode {
259    /// Use localhost/loopback for the connection.
260    Localhost,
261    /// Use host domain name for the connection.
262    Hostname,
263}
264
265/// The hostname to use for TLS connections.
266#[derive(
267    Clone,
268    Debug,
269    PartialEq,
270    Eq,
271    Hash,
272    Serialize,
273    Deserialize,
274    strum::EnumIter,
275    strum::Display,
276    strum::EnumString
277)]
278pub enum TlsMode {
279    /// Use IpV6 address for TLS connections.
280    IpV6,
281    /// Use host domain name for TLS connections.
282    Hostname,
283    // TODO: consider adding IpV4 support.
284}
285
286/// Address format for MetaTls channels. Supports both hostname/port pairs
287/// (required for clients for host identity) and direct socket addresses
288/// (allowed for servers).
289#[derive(
290    Clone,
291    Debug,
292    PartialEq,
293    Eq,
294    Hash,
295    Serialize,
296    Deserialize,
297    Ord,
298    PartialOrd,
299    EnumAsInner
300)]
301pub enum MetaTlsAddr {
302    /// Hostname and port pair. Required for clients to establish host identity.
303    Host {
304        /// The hostname to connect to.
305        hostname: Hostname,
306        /// The port to connect to.
307        port: Port,
308    },
309    /// Direct socket address. Allowed for servers.
310    Socket(SocketAddr),
311}
312
313impl MetaTlsAddr {
314    /// Returns the port number for this address.
315    pub fn port(&self) -> Port {
316        match self {
317            Self::Host { port, .. } => *port,
318            Self::Socket(addr) => addr.port(),
319        }
320    }
321
322    /// Returns the hostname if this is a Host variant, None otherwise.
323    pub fn hostname(&self) -> Option<&str> {
324        match self {
325            Self::Host { hostname, .. } => Some(hostname),
326            Self::Socket(_) => None,
327        }
328    }
329}
330
331impl fmt::Display for MetaTlsAddr {
332    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
333        match self {
334            Self::Host { hostname, port } => write!(f, "{}:{}", hostname, port),
335            Self::Socket(addr) => write!(f, "{}", addr),
336        }
337    }
338}
339
340/// Types of channel transports.
341#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Named)]
342pub enum ChannelTransport {
343    /// Transport over a TCP connection.
344    Tcp(TcpMode),
345
346    /// Transport over a TCP connection with TLS support within Meta
347    MetaTls(TlsMode),
348
349    /// Local transports uses an in-process registry and mpsc channels.
350    Local,
351
352    /// Sim is a simulated channel for testing.
353    Sim(/*simulated transport:*/ Box<ChannelTransport>),
354
355    /// Transport over unix domain socket.
356    Unix,
357}
358
359impl fmt::Display for ChannelTransport {
360    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361        match self {
362            Self::Tcp(mode) => write!(f, "tcp({:?})", mode),
363            Self::MetaTls(mode) => write!(f, "metatls({:?})", mode),
364            Self::Local => write!(f, "local"),
365            Self::Sim(transport) => write!(f, "sim({})", transport),
366            Self::Unix => write!(f, "unix"),
367        }
368    }
369}
370
371impl FromStr for ChannelTransport {
372    type Err = anyhow::Error;
373
374    fn from_str(s: &str) -> Result<Self, Self::Err> {
375        // Hacky parsing; can't recurse (e.g., sim(sim(..)))
376        if let Some(rest) = s.strip_prefix("sim(") {
377            if let Some(end) = rest.rfind(')') {
378                let inner = &rest[..end];
379                let inner_transport = ChannelTransport::from_str(inner)?;
380                return Ok(ChannelTransport::Sim(Box::new(inner_transport)));
381            } else {
382                return Err(anyhow::anyhow!("invalid sim transport"));
383            }
384        }
385
386        match s {
387            // Default to TcpMode::Hostname, if the mode isn't set
388            "tcp" => Ok(ChannelTransport::Tcp(TcpMode::Hostname)),
389            s if s.starts_with("tcp(") => {
390                let inner = &s["tcp(".len()..s.len() - 1];
391                let mode = inner.parse()?;
392                Ok(ChannelTransport::Tcp(mode))
393            }
394            "local" => Ok(ChannelTransport::Local),
395            "unix" => Ok(ChannelTransport::Unix),
396            s if s.starts_with("metatls(") && s.ends_with(")") => {
397                let inner = &s["metatls(".len()..s.len() - 1];
398                let mode = inner.parse()?;
399                Ok(ChannelTransport::MetaTls(mode))
400            }
401            unknown => Err(anyhow::anyhow!("unknown channel transport: {}", unknown)),
402        }
403    }
404}
405
406impl ChannelTransport {
407    /// All known channel transports.
408    pub fn all() -> [ChannelTransport; 3] {
409        [
410            // TODO: @rusch add back once figuring out unspecified override for OSS CI
411            // ChannelTransport::Tcp(TcpMode::Localhost),
412            ChannelTransport::Tcp(TcpMode::Hostname),
413            ChannelTransport::Local,
414            ChannelTransport::Unix,
415            // TODO add MetaTls (T208303369)
416            // TODO ChannelTransport::Sim(Box::new(ChannelTransport::Tcp)),
417            // TODO ChannelTransport::Sim(Box::new(ChannelTransport::Local)),
418        ]
419    }
420
421    /// Return an "any" address for this transport.
422    pub fn any(&self) -> ChannelAddr {
423        ChannelAddr::any(self.clone())
424    }
425
426    /// Returns true if this transport type represents a remote channel.
427    pub fn is_remote(&self) -> bool {
428        match self {
429            ChannelTransport::Tcp(_) => true,
430            ChannelTransport::MetaTls(_) => true,
431            ChannelTransport::Local => false,
432            ChannelTransport::Sim(_) => false,
433            ChannelTransport::Unix => false,
434        }
435    }
436}
437
438impl AttrValue for ChannelTransport {
439    fn display(&self) -> String {
440        self.to_string()
441    }
442
443    fn parse(s: &str) -> Result<Self, anyhow::Error> {
444        s.parse()
445    }
446}
447
448/// The type of (TCP) hostnames.
449pub type Hostname = String;
450
451/// The type of (TCP) ports.
452pub type Port = u16;
453
454/// The type of a channel address, used to multiplex different underlying
455/// channel implementations. ChannelAddrs also have a concrete syntax:
456/// the address type (e.g., "tcp" or "local"), followed by ":", and an address
457/// parseable to that type. For example:
458///
459/// - `tcp:127.0.0.1:1234` - localhost port 1234 over TCP
460/// - `tcp:192.168.0.1:1111` - 192.168.0.1 port 1111 over TCP
461/// - `local:123` - the (in-process) local port 123
462/// - `unix:/some/path` - the Unix socket at `/some/path`
463///
464/// Both local and TCP ports 0 are reserved to indicate "any available
465/// port" when serving.
466///
467/// ```
468/// # use hyperactor::channel::ChannelAddr;
469/// let addr: ChannelAddr = "tcp:127.0.0.1:1234".parse().unwrap();
470/// let ChannelAddr::Tcp(socket_addr) = addr else {
471///     panic!()
472/// };
473/// assert_eq!(socket_addr.port(), 1234);
474/// assert_eq!(socket_addr.is_ipv4(), true);
475/// ```
476#[derive(
477    Clone,
478    Debug,
479    PartialEq,
480    Eq,
481    Ord,
482    PartialOrd,
483    Serialize,
484    Deserialize,
485    Hash,
486    Named
487)]
488pub enum ChannelAddr {
489    /// A socket address used to establish TCP channels. Supports
490    /// both  IPv4 and IPv6 address / port pairs.
491    Tcp(SocketAddr),
492
493    /// An address to establish TCP channels with TLS support within Meta.
494    /// Supports both hostname/port pairs (required for clients) and
495    /// socket addresses (allowed for servers).
496    MetaTls(MetaTlsAddr),
497
498    /// Local addresses are registered in-process and given an integral
499    /// index.
500    Local(u64),
501
502    /// Sim is a simulated channel for testing.
503    Sim(SimAddr),
504
505    /// A unix domain socket address. Supports both absolute path names as
506    ///  well as "abstract" names per https://manpages.debian.org/unstable/manpages/unix.7.en.html#Abstract_sockets
507    Unix(net::unix::SocketAddr),
508}
509
510impl From<SocketAddr> for ChannelAddr {
511    fn from(value: SocketAddr) -> Self {
512        Self::Tcp(value)
513    }
514}
515
516impl From<net::unix::SocketAddr> for ChannelAddr {
517    fn from(value: net::unix::SocketAddr) -> Self {
518        Self::Unix(value)
519    }
520}
521
522impl From<std::os::unix::net::SocketAddr> for ChannelAddr {
523    fn from(value: std::os::unix::net::SocketAddr) -> Self {
524        Self::Unix(net::unix::SocketAddr::new(value))
525    }
526}
527
528impl From<tokio::net::unix::SocketAddr> for ChannelAddr {
529    fn from(value: tokio::net::unix::SocketAddr) -> Self {
530        std::os::unix::net::SocketAddr::from(value).into()
531    }
532}
533
534impl ChannelAddr {
535    /// The "any" address for the given transport type. This is used to
536    /// servers to "any" address.
537    pub fn any(transport: ChannelTransport) -> Self {
538        match transport {
539            ChannelTransport::Tcp(mode) => {
540                let ip = match mode {
541                    TcpMode::Localhost => IpAddr::V6(Ipv6Addr::LOCALHOST),
542                    TcpMode::Hostname => {
543                        hostname::get()
544                            .ok()
545                            .and_then(|hostname| {
546                                // TODO: Avoid using DNS directly once we figure out a good extensibility story here
547                                hostname.to_str().and_then(|hostname_str| {
548                                    dns_lookup::lookup_host(hostname_str)
549                                        .ok()
550                                        .and_then(|addresses| addresses.first().cloned())
551                                })
552                            })
553                            .expect("failed to resolve hostname to ip address")
554                    }
555                };
556                Self::Tcp(SocketAddr::new(ip, 0))
557            }
558            ChannelTransport::MetaTls(mode) => {
559                let host_address = match mode {
560                    TlsMode::Hostname => hostname::get()
561                        .ok()
562                        .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
563                        .unwrap_or("unknown_host".to_string()),
564                    TlsMode::IpV6 => local_ipv6()
565                        .ok()
566                        .and_then(|addr| addr.to_string().parse().ok())
567                        .expect("failed to retrieve ipv6 address"),
568                };
569                Self::MetaTls(MetaTlsAddr::Host {
570                    hostname: host_address,
571                    port: 0,
572                })
573            }
574            ChannelTransport::Local => Self::Local(0),
575            ChannelTransport::Sim(transport) => sim::any(*transport),
576            // This works because the file will be deleted but we know we have a unique file by this point.
577            ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()),
578        }
579    }
580
581    /// The transport used by this address.
582    pub fn transport(&self) -> ChannelTransport {
583        match self {
584            Self::Tcp(addr) => {
585                if addr.ip().is_loopback() {
586                    ChannelTransport::Tcp(TcpMode::Localhost)
587                } else {
588                    ChannelTransport::Tcp(TcpMode::Hostname)
589                }
590            }
591            Self::MetaTls(addr) => match addr {
592                MetaTlsAddr::Host { hostname, .. } => match hostname.parse::<IpAddr>() {
593                    Ok(IpAddr::V6(_)) => ChannelTransport::MetaTls(TlsMode::IpV6),
594                    Ok(IpAddr::V4(_)) => ChannelTransport::MetaTls(TlsMode::Hostname),
595                    Err(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
596                },
597                MetaTlsAddr::Socket(socket_addr) => match socket_addr.ip() {
598                    IpAddr::V6(_) => ChannelTransport::MetaTls(TlsMode::IpV6),
599                    IpAddr::V4(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
600                },
601            },
602            Self::Local(_) => ChannelTransport::Local,
603            Self::Sim(addr) => ChannelTransport::Sim(Box::new(addr.transport())),
604            Self::Unix(_) => ChannelTransport::Unix,
605        }
606    }
607}
608
609impl fmt::Display for ChannelAddr {
610    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
611        match self {
612            Self::Tcp(addr) => write!(f, "tcp:{}", addr),
613            Self::MetaTls(addr) => write!(f, "metatls:{}", addr),
614            Self::Local(index) => write!(f, "local:{}", index),
615            Self::Sim(sim_addr) => write!(f, "sim:{}", sim_addr),
616            Self::Unix(addr) => write!(f, "unix:{}", addr),
617        }
618    }
619}
620
621impl FromStr for ChannelAddr {
622    type Err = anyhow::Error;
623
624    fn from_str(addr: &str) -> Result<Self, Self::Err> {
625        match addr.split_once('!').or_else(|| addr.split_once(':')) {
626            Some(("local", rest)) => rest
627                .parse::<u64>()
628                .map(Self::Local)
629                .map_err(anyhow::Error::from),
630            Some(("tcp", rest)) => rest
631                .parse::<SocketAddr>()
632                .map(Self::Tcp)
633                .map_err(anyhow::Error::from),
634            Some(("metatls", rest)) => net::meta::parse(rest).map_err(|e| e.into()),
635            Some(("sim", rest)) => sim::parse(rest).map_err(|e| e.into()),
636            Some(("unix", rest)) => Ok(Self::Unix(net::unix::SocketAddr::from_str(rest)?)),
637            Some((r#type, _)) => Err(anyhow::anyhow!("no such channel type: {type}")),
638            None => Err(anyhow::anyhow!("no channel type specified")),
639        }
640    }
641}
642
643impl ChannelAddr {
644    /// Parse ZMQ-style URL format: scheme://address
645    /// Supports:
646    /// - tcp://hostname:port or tcp://*:port (wildcard binding)
647    /// - inproc://endpoint-name (equivalent to local)
648    /// - ipc://path (equivalent to unix)
649    /// - metatls://hostname:port or metatls://*:port
650    pub fn from_zmq_url(address: &str) -> Result<Self, anyhow::Error> {
651        // Try ZMQ-style URL format first (scheme://...)
652        let (scheme, address) = address.split_once("://").ok_or_else(|| {
653            anyhow::anyhow!("address must be in url form scheme://endppoint {}", address)
654        })?;
655
656        match scheme {
657            "tcp" => {
658                let (host, port) = Self::split_host_port(address)?;
659
660                if host == "*" {
661                    // Wildcard binding - use IPv6 unspecified address
662                    Ok(Self::Tcp(SocketAddr::new("::".parse().unwrap(), port)))
663                } else {
664                    // Resolve hostname to IP address for proper SocketAddr creation
665                    let socket_addr = Self::resolve_hostname_to_socket_addr(host, port)?;
666                    Ok(Self::Tcp(socket_addr))
667                }
668            }
669            "inproc" => {
670                // inproc://port -> local:port
671                // Port must be a valid u64 number
672                let port = address.parse::<u64>().map_err(|_| {
673                    anyhow::anyhow!("inproc endpoint must be a valid port number: {}", address)
674                })?;
675                Ok(Self::Local(port))
676            }
677            "ipc" => {
678                // ipc://path -> unix:path
679                Ok(Self::Unix(net::unix::SocketAddr::from_str(address)?))
680            }
681            "metatls" => {
682                let (host, port) = Self::split_host_port(address)?;
683
684                if host == "*" {
685                    // Wildcard binding - use IPv6 unspecified address directly without hostname resolution
686                    Ok(Self::MetaTls(MetaTlsAddr::Host {
687                        hostname: std::net::Ipv6Addr::UNSPECIFIED.to_string(),
688                        port,
689                    }))
690                } else {
691                    Ok(Self::MetaTls(MetaTlsAddr::Host {
692                        hostname: host.to_string(),
693                        port,
694                    }))
695                }
696            }
697            scheme => Err(anyhow::anyhow!("unsupported ZMQ scheme: {}", scheme)),
698        }
699    }
700
701    /// Split host:port string, supporting IPv6 addresses
702    fn split_host_port(address: &str) -> Result<(&str, u16), anyhow::Error> {
703        if let Some((host, port_str)) = address.rsplit_once(':') {
704            let port: u16 = port_str
705                .parse()
706                .map_err(|_| anyhow::anyhow!("invalid port: {}", port_str))?;
707            Ok((host, port))
708        } else {
709            Err(anyhow::anyhow!("invalid address format: {}", address))
710        }
711    }
712
713    /// Resolve hostname to SocketAddr, handling both IP addresses and hostnames
714    fn resolve_hostname_to_socket_addr(host: &str, port: u16) -> Result<SocketAddr, anyhow::Error> {
715        // Handle IPv6 addresses in brackets by stripping the brackets
716        let host_clean = if host.starts_with('[') && host.ends_with(']') {
717            &host[1..host.len() - 1]
718        } else {
719            host
720        };
721
722        // First try to parse as an IP address directly
723        if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
724            return Ok(SocketAddr::new(ip_addr, port));
725        }
726
727        // If not an IP, try hostname resolution
728        use std::net::ToSocketAddrs;
729        let mut addrs = (host_clean, port)
730            .to_socket_addrs()
731            .map_err(|e| anyhow::anyhow!("failed to resolve hostname '{}': {}", host_clean, e))?;
732
733        addrs
734            .next()
735            .ok_or_else(|| anyhow::anyhow!("no addresses found for hostname '{}'", host_clean))
736    }
737}
738
739/// Universal channel transmitter.
740#[derive(Debug)]
741pub struct ChannelTx<M: RemoteMessage> {
742    inner: ChannelTxKind<M>,
743}
744
745/// Universal channel transmitter.
746#[derive(Debug)]
747enum ChannelTxKind<M: RemoteMessage> {
748    Local(local::LocalTx<M>),
749    Tcp(net::NetTx<M>),
750    MetaTls(net::NetTx<M>),
751    Unix(net::NetTx<M>),
752    Sim(sim::SimTx<M>),
753}
754
755#[async_trait]
756impl<M: RemoteMessage> Tx<M> for ChannelTx<M> {
757    fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
758        match &self.inner {
759            ChannelTxKind::Local(tx) => tx.do_post(message, return_channel),
760            ChannelTxKind::Tcp(tx) => tx.do_post(message, return_channel),
761            ChannelTxKind::MetaTls(tx) => tx.do_post(message, return_channel),
762            ChannelTxKind::Sim(tx) => tx.do_post(message, return_channel),
763            ChannelTxKind::Unix(tx) => tx.do_post(message, return_channel),
764        }
765    }
766
767    fn addr(&self) -> ChannelAddr {
768        match &self.inner {
769            ChannelTxKind::Local(tx) => tx.addr(),
770            ChannelTxKind::Tcp(tx) => Tx::<M>::addr(tx),
771            ChannelTxKind::MetaTls(tx) => Tx::<M>::addr(tx),
772            ChannelTxKind::Sim(tx) => tx.addr(),
773            ChannelTxKind::Unix(tx) => Tx::<M>::addr(tx),
774        }
775    }
776
777    fn status(&self) -> &watch::Receiver<TxStatus> {
778        match &self.inner {
779            ChannelTxKind::Local(tx) => tx.status(),
780            ChannelTxKind::Tcp(tx) => tx.status(),
781            ChannelTxKind::MetaTls(tx) => tx.status(),
782            ChannelTxKind::Sim(tx) => tx.status(),
783            ChannelTxKind::Unix(tx) => tx.status(),
784        }
785    }
786}
787
788/// Universal channel receiver.
789#[derive(Debug)]
790pub struct ChannelRx<M: RemoteMessage> {
791    inner: ChannelRxKind<M>,
792}
793
794/// Universal channel receiver.
795#[derive(Debug)]
796enum ChannelRxKind<M: RemoteMessage> {
797    Local(local::LocalRx<M>),
798    Tcp(net::NetRx<M>),
799    MetaTls(net::NetRx<M>),
800    Unix(net::NetRx<M>),
801    Sim(sim::SimRx<M>),
802}
803
804#[async_trait]
805impl<M: RemoteMessage> Rx<M> for ChannelRx<M> {
806    async fn recv(&mut self) -> Result<M, ChannelError> {
807        match &mut self.inner {
808            ChannelRxKind::Local(rx) => rx.recv().await,
809            ChannelRxKind::Tcp(rx) => rx.recv().await,
810            ChannelRxKind::MetaTls(rx) => rx.recv().await,
811            ChannelRxKind::Sim(rx) => rx.recv().await,
812            ChannelRxKind::Unix(rx) => rx.recv().await,
813        }
814    }
815
816    fn addr(&self) -> ChannelAddr {
817        match &self.inner {
818            ChannelRxKind::Local(rx) => rx.addr(),
819            ChannelRxKind::Tcp(rx) => rx.addr(),
820            ChannelRxKind::MetaTls(rx) => rx.addr(),
821            ChannelRxKind::Sim(rx) => rx.addr(),
822            ChannelRxKind::Unix(rx) => rx.addr(),
823        }
824    }
825}
826
827/// Dial the provided address, returning the corresponding Tx, or error
828/// if the channel cannot be established. The underlying connection is
829/// dropped whenever the returned Tx is dropped.
830#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
831#[track_caller]
832pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, ChannelError> {
833    tracing::debug!(name = "dial", caller = %Location::caller(), %addr, "dialing channel {}", addr);
834    let inner = match addr {
835        ChannelAddr::Local(port) => ChannelTxKind::Local(local::dial(port)?),
836        ChannelAddr::Tcp(addr) => ChannelTxKind::Tcp(net::tcp::dial(addr)),
837        ChannelAddr::MetaTls(meta_addr) => ChannelTxKind::MetaTls(net::meta::dial(meta_addr)?),
838        ChannelAddr::Sim(sim_addr) => ChannelTxKind::Sim(sim::dial::<M>(sim_addr)?),
839        ChannelAddr::Unix(path) => ChannelTxKind::Unix(net::unix::dial(path)),
840    };
841    Ok(ChannelTx { inner })
842}
843
844/// Serve on the provided channel address. The server is turned down
845/// when the returned Rx is dropped.
846#[crate::instrument]
847#[track_caller]
848pub fn serve<M: RemoteMessage>(
849    addr: ChannelAddr,
850) -> Result<(ChannelAddr, ChannelRx<M>), ChannelError> {
851    let caller = Location::caller();
852    match addr {
853        ChannelAddr::Tcp(addr) => {
854            let (addr, rx) = net::tcp::serve::<M>(addr)?;
855            Ok((addr, ChannelRxKind::Tcp(rx)))
856        }
857        ChannelAddr::MetaTls(meta_addr) => {
858            let (addr, rx) = net::meta::serve::<M>(meta_addr)?;
859            Ok((addr, ChannelRxKind::MetaTls(rx)))
860        }
861        ChannelAddr::Unix(path) => {
862            let (addr, rx) = net::unix::serve::<M>(path)?;
863            Ok((addr, ChannelRxKind::Unix(rx)))
864        }
865        ChannelAddr::Local(0) => {
866            let (port, rx) = local::serve::<M>();
867            Ok((ChannelAddr::Local(port), ChannelRxKind::Local(rx)))
868        }
869        ChannelAddr::Sim(sim_addr) => {
870            let (addr, rx) = sim::serve::<M>(sim_addr)?;
871            Ok((addr, ChannelRxKind::Sim(rx)))
872        }
873        ChannelAddr::Local(a) => Err(ChannelError::InvalidAddress(format!(
874            "invalid local addr: {}",
875            a
876        ))),
877    }
878    .map(|(addr, inner)| {
879        tracing::debug!(
880            name = "serve",
881            %addr,
882            %caller,
883        );
884        (addr, ChannelRx { inner })
885    })
886}
887
888/// Serve on the local address. The server is turned down
889/// when the returned Rx is dropped.
890pub fn serve_local<M: RemoteMessage>() -> (ChannelAddr, ChannelRx<M>) {
891    let (port, rx) = local::serve::<M>();
892    (
893        ChannelAddr::Local(port),
894        ChannelRx {
895            inner: ChannelRxKind::Local(rx),
896        },
897    )
898}
899
900#[cfg(test)]
901mod tests {
902    use std::assert_matches::assert_matches;
903    use std::collections::HashSet;
904    use std::net::IpAddr;
905    use std::net::Ipv4Addr;
906    use std::net::Ipv6Addr;
907    use std::time::Duration;
908
909    use tokio::task::JoinSet;
910
911    use super::net::*;
912    use super::*;
913    use crate::clock::Clock;
914    use crate::clock::RealClock;
915
916    #[test]
917    fn test_channel_addr() {
918        let cases_ok = vec![
919            (
920                "tcp<DELIM>[::1]:1234",
921                ChannelAddr::Tcp(SocketAddr::new(
922                    IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
923                    1234,
924                )),
925            ),
926            (
927                "tcp<DELIM>127.0.0.1:8080",
928                ChannelAddr::Tcp(SocketAddr::new(
929                    IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
930                    8080,
931                )),
932            ),
933            #[cfg(target_os = "linux")]
934            ("local<DELIM>123", ChannelAddr::Local(123)),
935            (
936                "unix<DELIM>@yolo",
937                ChannelAddr::Unix(
938                    unix::SocketAddr::from_abstract_name("yolo")
939                        .expect("can't make socket from abstract name"),
940                ),
941            ),
942            (
943                "unix<DELIM>/cool/socket-path",
944                ChannelAddr::Unix(
945                    unix::SocketAddr::from_pathname("/cool/socket-path")
946                        .expect("can't make socket from path"),
947                ),
948            ),
949        ];
950
951        for (raw, parsed) in cases_ok.clone() {
952            for delim in ["!", ":"] {
953                let raw = raw.replace("<DELIM>", delim);
954                assert_eq!(raw.parse::<ChannelAddr>().unwrap(), parsed);
955            }
956        }
957
958        for (raw, parsed) in cases_ok {
959            for delim in ["!", ":"] {
960                // We don't allow mixing and matching delims
961                let raw = format!("sim{}{}", delim, raw.replace("<DELIM>", delim));
962                assert_eq!(
963                    raw.parse::<ChannelAddr>().unwrap(),
964                    ChannelAddr::Sim(SimAddr::new(parsed.clone()).unwrap())
965                );
966            }
967        }
968
969        let cases_err = vec![
970            ("tcp:abcdef..123124", "invalid socket address syntax"),
971            ("xxx:foo", "no such channel type: xxx"),
972            ("127.0.0.1", "no channel type specified"),
973            ("local:abc", "invalid digit found in string"),
974        ];
975
976        for (raw, error) in cases_err {
977            let Err(err) = raw.parse::<ChannelAddr>() else {
978                panic!("expected error parsing: {}", &raw)
979            };
980            assert_eq!(format!("{}", err), error);
981        }
982    }
983
984    #[test]
985    fn test_zmq_style_channel_addr() {
986        // Test TCP addresses
987        assert_eq!(
988            ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080").unwrap(),
989            ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap())
990        );
991
992        // Test TCP wildcard binding
993        assert_eq!(
994            ChannelAddr::from_zmq_url("tcp://*:5555").unwrap(),
995            ChannelAddr::Tcp("[::]:5555".parse().unwrap())
996        );
997
998        // Test inproc (maps to local with numeric endpoint)
999        assert_eq!(
1000            ChannelAddr::from_zmq_url("inproc://12345").unwrap(),
1001            ChannelAddr::Local(12345)
1002        );
1003
1004        // Test ipc (maps to unix)
1005        assert_eq!(
1006            ChannelAddr::from_zmq_url("ipc:///tmp/my-socket").unwrap(),
1007            ChannelAddr::Unix(unix::SocketAddr::from_pathname("/tmp/my-socket").unwrap())
1008        );
1009
1010        // Test metatls with hostname
1011        assert_eq!(
1012            ChannelAddr::from_zmq_url("metatls://example.com:443").unwrap(),
1013            ChannelAddr::MetaTls(MetaTlsAddr::Host {
1014                hostname: "example.com".to_string(),
1015                port: 443
1016            })
1017        );
1018
1019        // Test metatls with IP address (should be normalized)
1020        assert_eq!(
1021            ChannelAddr::from_zmq_url("metatls://192.168.1.1:443").unwrap(),
1022            ChannelAddr::MetaTls(MetaTlsAddr::Host {
1023                hostname: "192.168.1.1".to_string(),
1024                port: 443
1025            })
1026        );
1027
1028        // Test metatls with wildcard (should use IPv6 unspecified address)
1029        assert_eq!(
1030            ChannelAddr::from_zmq_url("metatls://*:8443").unwrap(),
1031            ChannelAddr::MetaTls(MetaTlsAddr::Host {
1032                hostname: "::".to_string(),
1033                port: 8443
1034            })
1035        );
1036
1037        // Test TCP hostname resolution (should resolve hostname to IP)
1038        // Note: This test may fail in environments without proper DNS resolution
1039        // We test that it at least doesn't fail to parse
1040        let tcp_hostname_result = ChannelAddr::from_zmq_url("tcp://localhost:8080");
1041        assert!(tcp_hostname_result.is_ok());
1042
1043        // Test IPv6 address
1044        assert_eq!(
1045            ChannelAddr::from_zmq_url("tcp://[::1]:1234").unwrap(),
1046            ChannelAddr::Tcp("[::1]:1234".parse().unwrap())
1047        );
1048
1049        // Test error cases
1050        assert!(ChannelAddr::from_zmq_url("invalid://scheme").is_err());
1051        assert!(ChannelAddr::from_zmq_url("tcp://invalid-port").is_err());
1052        assert!(ChannelAddr::from_zmq_url("metatls://no-port").is_err());
1053        assert!(ChannelAddr::from_zmq_url("inproc://not-a-number").is_err());
1054    }
1055
1056    #[tokio::test]
1057    async fn test_multiple_connections() {
1058        for addr in ChannelTransport::all().map(ChannelAddr::any) {
1059            let (listen_addr, mut rx) = crate::channel::serve::<u64>(addr).unwrap();
1060
1061            let mut sends: JoinSet<()> = JoinSet::new();
1062            for message in 0u64..100u64 {
1063                let addr = listen_addr.clone();
1064                sends.spawn(async move {
1065                    let tx = dial::<u64>(addr).unwrap();
1066                    tx.post(message);
1067                });
1068            }
1069
1070            let mut received: HashSet<u64> = HashSet::new();
1071            while received.len() < 100 {
1072                received.insert(rx.recv().await.unwrap());
1073            }
1074
1075            for message in 0u64..100u64 {
1076                assert!(received.contains(&message));
1077            }
1078
1079            loop {
1080                match sends.join_next().await {
1081                    Some(Ok(())) => (),
1082                    Some(Err(err)) => panic!("{}", err),
1083                    None => break,
1084                }
1085            }
1086        }
1087    }
1088
1089    #[tokio::test]
1090    async fn test_server_close() {
1091        for addr in ChannelTransport::all().map(ChannelAddr::any) {
1092            if net::is_net_addr(&addr) {
1093                // Net has store-and-forward semantics. We don't expect failures
1094                // on closure.
1095                continue;
1096            }
1097
1098            let (listen_addr, rx) = crate::channel::serve::<u64>(addr).unwrap();
1099
1100            let tx = dial::<u64>(listen_addr).unwrap();
1101            tx.post(123);
1102            drop(rx);
1103
1104            // New transmits should fail... but there is buffering, etc.,
1105            // which can cause the failure to be delayed. We give it
1106            // a deadline, but it can still technically fail -- the test
1107            // should be considered a kind of integration test.
1108            let start = RealClock.now();
1109
1110            let result = loop {
1111                let (return_tx, return_rx) = oneshot::channel();
1112                tx.try_post(123, return_tx);
1113                let result = return_rx.await;
1114
1115                if result.is_ok() || start.elapsed() > Duration::from_secs(10) {
1116                    break result;
1117                }
1118            };
1119            assert_matches!(result, Ok(SendError(ChannelError::Closed, 123)));
1120        }
1121    }
1122
1123    fn addrs() -> Vec<ChannelAddr> {
1124        use rand::Rng;
1125        use rand::distributions::Uniform;
1126
1127        let rng = rand::thread_rng();
1128        vec![
1129            "tcp:[::1]:0".parse().unwrap(),
1130            "local:0".parse().unwrap(),
1131            #[cfg(target_os = "linux")]
1132            "unix:".parse().unwrap(),
1133            #[cfg(target_os = "linux")]
1134            format!(
1135                "unix:@{}",
1136                rng.sample_iter(Uniform::new_inclusive('a', 'z'))
1137                    .take(10)
1138                    .collect::<String>()
1139            )
1140            .parse()
1141            .unwrap(),
1142        ]
1143    }
1144
1145    #[tokio::test]
1146    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Server(Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" }))
1147    #[cfg_attr(not(fbcode_build), ignore)]
1148    async fn test_dial_serve() {
1149        for addr in addrs() {
1150            let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1151            let tx = crate::channel::dial(listen_addr).unwrap();
1152            tx.post(123);
1153            assert_eq!(rx.recv().await.unwrap(), 123);
1154        }
1155    }
1156
1157    #[tokio::test]
1158    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Server(Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" }))
1159    #[cfg_attr(not(fbcode_build), ignore)]
1160    async fn test_send() {
1161        let config = crate::config::global::lock();
1162
1163        // Use temporary config for this test
1164        let _guard1 = config.override_key(
1165            crate::config::MESSAGE_DELIVERY_TIMEOUT,
1166            Duration::from_secs(1),
1167        );
1168        let _guard2 = config.override_key(crate::config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
1169        for addr in addrs() {
1170            let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1171            let tx = crate::channel::dial(listen_addr).unwrap();
1172            tx.send(123).await.unwrap();
1173            assert_eq!(rx.recv().await.unwrap(), 123);
1174
1175            drop(rx);
1176            assert_matches!(
1177                tx.send(123).await.unwrap_err(),
1178                SendError(ChannelError::Closed, 123)
1179            );
1180        }
1181    }
1182}