hyperactor/
channel.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! One-way, multi-process, typed communication channels. These are used
10//! to send messages between mailboxes residing in different processes.
11
12#![allow(dead_code)] // Allow until this is used outside of tests.
13
14use core::net::SocketAddr;
15use std::fmt;
16use std::net::IpAddr;
17#[cfg(target_os = "linux")]
18use std::os::linux::net::SocketAddrExt;
19use std::str::FromStr;
20
21use async_trait::async_trait;
22use enum_as_inner::EnumAsInner;
23use lazy_static::lazy_static;
24use local_ip_address::local_ipv6;
25use serde::Deserialize;
26use serde::Serialize;
27use tokio::sync::mpsc;
28use tokio::sync::oneshot;
29use tokio::sync::watch;
30
31use crate as hyperactor;
32use crate::Named;
33use crate::RemoteMessage;
34use crate::attrs::AttrValue;
35use crate::channel::sim::SimAddr;
36use crate::simnet::SimNetError;
37
38pub(crate) mod local;
39pub(crate) mod net;
40pub mod sim;
41
42/// The type of error that can occur on channel operations.
43#[derive(thiserror::Error, Debug)]
44pub enum ChannelError {
45    /// An operation was attempted on a closed channel.
46    #[error("channel closed")]
47    Closed,
48
49    /// An error occurred during send.
50    #[error("send: {0}")]
51    Send(#[source] anyhow::Error),
52
53    /// A network client error.
54    #[error(transparent)]
55    Client(#[from] net::ClientError),
56
57    /// The address was not valid.
58    #[error("invalid address {0:?}")]
59    InvalidAddress(String),
60
61    /// A serving error was encountered.
62    #[error(transparent)]
63    Server(#[from] net::ServerError),
64
65    /// A bincode serialization or deserialization error occurred.
66    #[error(transparent)]
67    Bincode(#[from] Box<bincode::ErrorKind>),
68
69    /// Data encoding errors.
70    #[error(transparent)]
71    Data(#[from] crate::data::Error),
72
73    /// Some other error.
74    #[error(transparent)]
75    Other(#[from] anyhow::Error),
76
77    /// An operation timeout occurred.
78    #[error("operation timed out after {0:?}")]
79    Timeout(std::time::Duration),
80
81    /// A simulator error occurred.
82    #[error(transparent)]
83    SimNetError(#[from] SimNetError),
84}
85
86/// An error that occurred during send. Returns the message that failed to send.
87#[derive(thiserror::Error, Debug)]
88#[error("{0}")]
89pub struct SendError<M: RemoteMessage>(#[source] pub ChannelError, pub M);
90
91impl<M: RemoteMessage> From<SendError<M>> for ChannelError {
92    fn from(error: SendError<M>) -> Self {
93        error.0
94    }
95}
96
97/// The possible states of a `Tx`.
98#[derive(Debug, Copy, Clone, PartialEq)]
99pub enum TxStatus {
100    /// The tx is good.
101    Active,
102    /// The tx cannot be used for message delivery.
103    Closed,
104}
105
106/// The transmit end of an M-typed channel.
107#[async_trait]
108pub trait Tx<M: RemoteMessage>: std::fmt::Debug {
109    /// Enqueue a `message` on the local end of the channel. The
110    /// message is either delivered, or we eventually discover that
111    /// the channel has failed and it will be sent back on `return_handle`.
112    // TODO: the return channel should be SendError<M> directly, and we should drop
113    // the returned result.
114    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SendError`.
115    fn try_post(&self, message: M, return_channel: oneshot::Sender<M>) -> Result<(), SendError<M>>;
116
117    /// Enqueue a message to be sent on the channel. The caller is expected to monitor
118    /// the channel status for failures.
119    fn post(&self, message: M) {
120        // We ignore errors here because the caller is meant to monitor the channel's
121        // status, rather than rely on this function to report errors.
122        let _ignore = self.try_post(message, oneshot::channel().0);
123    }
124
125    /// Send a message synchronously, returning when the messsage has
126    /// been delivered to the remote end of the channel.
127    async fn send(&self, message: M) -> Result<(), SendError<M>> {
128        let (tx, rx) = oneshot::channel();
129        self.try_post(message, tx)?;
130        match rx.await {
131            // Channel was closed; the message was not delivered.
132            Ok(m) => Err(SendError(ChannelError::Closed, m)),
133
134            // Channel was dropped; the message was successfully enqueued
135            // on the remote end of the channel.
136            Err(_) => Ok(()),
137        }
138    }
139
140    /// The channel address to which this Tx is sending.
141    fn addr(&self) -> ChannelAddr;
142
143    /// A means to monitor the health of a `Tx`.
144    fn status(&self) -> &watch::Receiver<TxStatus>;
145}
146
147/// The receive end of an M-typed channel.
148#[async_trait]
149pub trait Rx<M: RemoteMessage>: std::fmt::Debug {
150    /// Receive the next message from the channel. If the channel returns
151    /// an error it is considered broken and should be discarded.
152    async fn recv(&mut self) -> Result<M, ChannelError>;
153
154    /// The channel address from which this Rx is receiving.
155    fn addr(&self) -> ChannelAddr;
156}
157
158#[derive(Debug)]
159struct MpscTx<M: RemoteMessage> {
160    tx: mpsc::UnboundedSender<M>,
161    addr: ChannelAddr,
162    status: watch::Receiver<TxStatus>,
163}
164
165impl<M: RemoteMessage> MpscTx<M> {
166    pub fn new(tx: mpsc::UnboundedSender<M>, addr: ChannelAddr) -> (Self, watch::Sender<TxStatus>) {
167        let (sender, receiver) = watch::channel(TxStatus::Active);
168        (
169            Self {
170                tx,
171                addr,
172                status: receiver,
173            },
174            sender,
175        )
176    }
177}
178
179#[async_trait]
180impl<M: RemoteMessage> Tx<M> for MpscTx<M> {
181    fn try_post(
182        &self,
183        message: M,
184        _return_channel: oneshot::Sender<M>,
185    ) -> Result<(), SendError<M>> {
186        self.tx
187            .send(message)
188            .map_err(|mpsc::error::SendError(message)| SendError(ChannelError::Closed, message))
189    }
190
191    fn addr(&self) -> ChannelAddr {
192        self.addr.clone()
193    }
194
195    fn status(&self) -> &watch::Receiver<TxStatus> {
196        &self.status
197    }
198}
199
200#[derive(Debug)]
201struct MpscRx<M: RemoteMessage> {
202    rx: mpsc::UnboundedReceiver<M>,
203    addr: ChannelAddr,
204    // Used to report the status to the Tx side.
205    status_sender: watch::Sender<TxStatus>,
206}
207
208impl<M: RemoteMessage> MpscRx<M> {
209    pub fn new(
210        rx: mpsc::UnboundedReceiver<M>,
211        addr: ChannelAddr,
212        status_sender: watch::Sender<TxStatus>,
213    ) -> Self {
214        Self {
215            rx,
216            addr,
217            status_sender,
218        }
219    }
220}
221
222impl<M: RemoteMessage> Drop for MpscRx<M> {
223    fn drop(&mut self) {
224        let _ = self.status_sender.send(TxStatus::Closed);
225    }
226}
227
228#[async_trait]
229impl<M: RemoteMessage> Rx<M> for MpscRx<M> {
230    async fn recv(&mut self) -> Result<M, ChannelError> {
231        self.rx.recv().await.ok_or(ChannelError::Closed)
232    }
233
234    fn addr(&self) -> ChannelAddr {
235        self.addr.clone()
236    }
237}
238
239/// The hostname to use for TLS connections.
240#[derive(
241    Clone,
242    Debug,
243    PartialEq,
244    Eq,
245    Hash,
246    Serialize,
247    Deserialize,
248    strum::EnumIter,
249    strum::Display,
250    strum::EnumString
251)]
252pub enum TlsMode {
253    /// Use IpV6 address for TLS connections.
254    IpV6,
255    /// Use host domain name for TLS connections.
256    Hostname,
257    // TODO: consider adding IpV4 support.
258}
259
260/// Address format for MetaTls channels. Supports both hostname/port pairs
261/// (required for clients for host identity) and direct socket addresses
262/// (allowed for servers).
263#[derive(
264    Clone,
265    Debug,
266    PartialEq,
267    Eq,
268    Hash,
269    Serialize,
270    Deserialize,
271    Ord,
272    PartialOrd,
273    EnumAsInner
274)]
275pub enum MetaTlsAddr {
276    /// Hostname and port pair. Required for clients to establish host identity.
277    Host {
278        /// The hostname to connect to.
279        hostname: Hostname,
280        /// The port to connect to.
281        port: Port,
282    },
283    /// Direct socket address. Allowed for servers.
284    Socket(SocketAddr),
285}
286
287impl MetaTlsAddr {
288    /// Returns the port number for this address.
289    pub fn port(&self) -> Port {
290        match self {
291            Self::Host { port, .. } => *port,
292            Self::Socket(addr) => addr.port(),
293        }
294    }
295
296    /// Returns the hostname if this is a Host variant, None otherwise.
297    pub fn hostname(&self) -> Option<&str> {
298        match self {
299            Self::Host { hostname, .. } => Some(hostname),
300            Self::Socket(_) => None,
301        }
302    }
303}
304
305impl fmt::Display for MetaTlsAddr {
306    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307        match self {
308            Self::Host { hostname, port } => write!(f, "{}:{}", hostname, port),
309            Self::Socket(addr) => write!(f, "{}", addr),
310        }
311    }
312}
313
314/// Types of channel transports.
315#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Named)]
316pub enum ChannelTransport {
317    /// Transport over a TCP connection.
318    Tcp,
319
320    /// Transport over a TCP connection with TLS support within Meta
321    MetaTls(TlsMode),
322
323    /// Local transports uses an in-process registry and mpsc channels.
324    Local,
325
326    /// Sim is a simulated channel for testing.
327    Sim(/*simulated transport:*/ Box<ChannelTransport>),
328
329    /// Transport over unix domain socket.
330    Unix,
331}
332
333impl fmt::Display for ChannelTransport {
334    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
335        match self {
336            Self::Tcp => write!(f, "tcp"),
337            Self::MetaTls(mode) => write!(f, "metatls({:?})", mode),
338            Self::Local => write!(f, "local"),
339            Self::Sim(transport) => write!(f, "sim({})", transport),
340            Self::Unix => write!(f, "unix"),
341        }
342    }
343}
344
345impl FromStr for ChannelTransport {
346    type Err = anyhow::Error;
347
348    fn from_str(s: &str) -> Result<Self, Self::Err> {
349        // Hacky parsing; can't recurse (e.g., sim(sim(..)))
350        if let Some(rest) = s.strip_prefix("sim(") {
351            if let Some(end) = rest.rfind(')') {
352                let inner = &rest[..end];
353                let inner_transport = ChannelTransport::from_str(inner)?;
354                return Ok(ChannelTransport::Sim(Box::new(inner_transport)));
355            } else {
356                return Err(anyhow::anyhow!("invalid sim transport"));
357            }
358        }
359
360        match s {
361            "tcp" => Ok(ChannelTransport::Tcp),
362            "local" => Ok(ChannelTransport::Local),
363            "unix" => Ok(ChannelTransport::Unix),
364            s if s.starts_with("metatls(") && s.ends_with(")") => {
365                let inner = &s["metatls(".len()..s.len() - 1];
366                let mode = inner.parse()?;
367                Ok(ChannelTransport::MetaTls(mode))
368            }
369            unknown => Err(anyhow::anyhow!("unknown channel transport: {}", unknown)),
370        }
371    }
372}
373
374impl ChannelTransport {
375    /// All known channel transports.
376    pub fn all() -> [ChannelTransport; 3] {
377        [
378            ChannelTransport::Tcp,
379            ChannelTransport::Local,
380            ChannelTransport::Unix,
381            // TODO add MetaTls (T208303369)
382            // TODO ChannelTransport::Sim(Box::new(ChannelTransport::Tcp)),
383            // TODO ChannelTransport::Sim(Box::new(ChannelTransport::Local)),
384        ]
385    }
386
387    /// Return an "any" address for this transport.
388    pub fn any(&self) -> ChannelAddr {
389        ChannelAddr::any(self.clone())
390    }
391
392    /// Returns true if this transport type represents a remote channel.
393    pub fn is_remote(&self) -> bool {
394        match self {
395            ChannelTransport::Tcp => true,
396            ChannelTransport::MetaTls(_) => true,
397            ChannelTransport::Local => false,
398            ChannelTransport::Sim(_) => false,
399            ChannelTransport::Unix => false,
400        }
401    }
402}
403
404impl AttrValue for ChannelTransport {
405    fn display(&self) -> String {
406        self.to_string()
407    }
408
409    fn parse(s: &str) -> Result<Self, anyhow::Error> {
410        s.parse()
411    }
412}
413
414/// The type of (TCP) hostnames.
415pub type Hostname = String;
416
417/// The type of (TCP) ports.
418pub type Port = u16;
419
420/// The type of a channel address, used to multiplex different underlying
421/// channel implementations. ChannelAddrs also have a concrete syntax:
422/// the address type (e.g., "tcp" or "local"), followed by ":", and an address
423/// parseable to that type. For example:
424///
425/// - `tcp:127.0.0.1:1234` - localhost port 1234 over TCP
426/// - `tcp:192.168.0.1:1111` - 192.168.0.1 port 1111 over TCP
427/// - `local:123` - the (in-process) local port 123
428/// - `unix:/some/path` - the Unix socket at `/some/path`
429///
430/// Both local and TCP ports 0 are reserved to indicate "any available
431/// port" when serving.
432///
433/// ```
434/// # use hyperactor::channel::ChannelAddr;
435/// let addr: ChannelAddr = "tcp:127.0.0.1:1234".parse().unwrap();
436/// let ChannelAddr::Tcp(socket_addr) = addr else {
437///     panic!()
438/// };
439/// assert_eq!(socket_addr.port(), 1234);
440/// assert_eq!(socket_addr.is_ipv4(), true);
441/// ```
442#[derive(
443    Clone,
444    Debug,
445    PartialEq,
446    Eq,
447    Ord,
448    PartialOrd,
449    Serialize,
450    Deserialize,
451    Hash,
452    Named
453)]
454pub enum ChannelAddr {
455    /// A socket address used to establish TCP channels. Supports
456    /// both  IPv4 and IPv6 address / port pairs.
457    Tcp(SocketAddr),
458
459    /// An address to establish TCP channels with TLS support within Meta.
460    /// Supports both hostname/port pairs (required for clients) and
461    /// socket addresses (allowed for servers).
462    MetaTls(MetaTlsAddr),
463
464    /// Local addresses are registered in-process and given an integral
465    /// index.
466    Local(u64),
467
468    /// Sim is a simulated channel for testing.
469    Sim(SimAddr),
470
471    /// A unix domain socket address. Supports both absolute path names as
472    ///  well as "abstract" names per https://manpages.debian.org/unstable/manpages/unix.7.en.html#Abstract_sockets
473    Unix(net::unix::SocketAddr),
474}
475
476impl From<SocketAddr> for ChannelAddr {
477    fn from(value: SocketAddr) -> Self {
478        Self::Tcp(value)
479    }
480}
481
482impl From<net::unix::SocketAddr> for ChannelAddr {
483    fn from(value: net::unix::SocketAddr) -> Self {
484        Self::Unix(value)
485    }
486}
487
488impl From<std::os::unix::net::SocketAddr> for ChannelAddr {
489    fn from(value: std::os::unix::net::SocketAddr) -> Self {
490        Self::Unix(net::unix::SocketAddr::new(value))
491    }
492}
493
494impl From<tokio::net::unix::SocketAddr> for ChannelAddr {
495    fn from(value: tokio::net::unix::SocketAddr) -> Self {
496        std::os::unix::net::SocketAddr::from(value).into()
497    }
498}
499
500impl ChannelAddr {
501    /// The "any" address for the given transport type. This is used to
502    /// servers to "any" address.
503    pub fn any(transport: ChannelTransport) -> Self {
504        match transport {
505            ChannelTransport::Tcp => {
506                let ip = hostname::get()
507                    .ok()
508                    .and_then(|hostname| {
509                        // TODO: Avoid using DNS directly once we figure out a good extensibility story here
510                        hostname.to_str().and_then(|hostname_str| {
511                            dns_lookup::lookup_host(hostname_str)
512                                .ok()
513                                .and_then(|addresses| addresses.first().cloned())
514                        })
515                    })
516                    .unwrap_or_else(|| IpAddr::from_str("::1").unwrap());
517                Self::Tcp(SocketAddr::new(ip, 0))
518            }
519            ChannelTransport::MetaTls(mode) => {
520                let host_address = match mode {
521                    TlsMode::Hostname => hostname::get()
522                        .ok()
523                        .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
524                        .unwrap_or("unknown_host".to_string()),
525                    TlsMode::IpV6 => local_ipv6()
526                        .ok()
527                        .and_then(|addr| addr.to_string().parse().ok())
528                        .expect("failed to retrieve ipv6 address"),
529                };
530                Self::MetaTls(MetaTlsAddr::Host {
531                    hostname: host_address,
532                    port: 0,
533                })
534            }
535            ChannelTransport::Local => Self::Local(0),
536            ChannelTransport::Sim(transport) => sim::any(*transport),
537            // This works because the file will be deleted but we know we have a unique file by this point.
538            ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()),
539        }
540    }
541
542    /// The transport used by this address.
543    pub fn transport(&self) -> ChannelTransport {
544        match self {
545            Self::Tcp(_) => ChannelTransport::Tcp,
546            Self::MetaTls(addr) => match addr {
547                MetaTlsAddr::Host { hostname, .. } => match hostname.parse::<IpAddr>() {
548                    Ok(IpAddr::V6(_)) => ChannelTransport::MetaTls(TlsMode::IpV6),
549                    Ok(IpAddr::V4(_)) => ChannelTransport::MetaTls(TlsMode::Hostname),
550                    Err(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
551                },
552                MetaTlsAddr::Socket(socket_addr) => match socket_addr.ip() {
553                    IpAddr::V6(_) => ChannelTransport::MetaTls(TlsMode::IpV6),
554                    IpAddr::V4(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
555                },
556            },
557            Self::Local(_) => ChannelTransport::Local,
558            Self::Sim(addr) => ChannelTransport::Sim(Box::new(addr.transport())),
559            Self::Unix(_) => ChannelTransport::Unix,
560        }
561    }
562}
563
564impl fmt::Display for ChannelAddr {
565    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
566        match self {
567            Self::Tcp(addr) => write!(f, "tcp:{}", addr),
568            Self::MetaTls(addr) => write!(f, "metatls:{}", addr),
569            Self::Local(index) => write!(f, "local:{}", index),
570            Self::Sim(sim_addr) => write!(f, "sim:{}", sim_addr),
571            Self::Unix(addr) => write!(f, "unix:{}", addr),
572        }
573    }
574}
575
576impl FromStr for ChannelAddr {
577    type Err = anyhow::Error;
578
579    fn from_str(addr: &str) -> Result<Self, Self::Err> {
580        match addr.split_once('!').or_else(|| addr.split_once(':')) {
581            Some(("local", rest)) => rest
582                .parse::<u64>()
583                .map(Self::Local)
584                .map_err(anyhow::Error::from),
585            Some(("tcp", rest)) => rest
586                .parse::<SocketAddr>()
587                .map(Self::Tcp)
588                .map_err(anyhow::Error::from),
589            Some(("metatls", rest)) => net::meta::parse(rest).map_err(|e| e.into()),
590            Some(("sim", rest)) => sim::parse(rest).map_err(|e| e.into()),
591            Some(("unix", rest)) => Ok(Self::Unix(net::unix::SocketAddr::from_str(rest)?)),
592            Some((r#type, _)) => Err(anyhow::anyhow!("no such channel type: {type}")),
593            None => Err(anyhow::anyhow!("no channel type specified")),
594        }
595    }
596}
597
598impl ChannelAddr {
599    /// Parse ZMQ-style URL format: scheme://address
600    /// Supports:
601    /// - tcp://hostname:port or tcp://*:port (wildcard binding)
602    /// - inproc://endpoint-name (equivalent to local)
603    /// - ipc://path (equivalent to unix)
604    /// - metatls://hostname:port or metatls://*:port
605    pub fn from_zmq_url(address: &str) -> Result<Self, anyhow::Error> {
606        // Try ZMQ-style URL format first (scheme://...)
607        let (scheme, address) = address.split_once("://").ok_or_else(|| {
608            anyhow::anyhow!("address must be in url form scheme://endppoint {}", address)
609        })?;
610
611        match scheme {
612            "tcp" => {
613                let (host, port) = Self::split_host_port(address)?;
614
615                if host == "*" {
616                    // Wildcard binding - use IPv6 unspecified address
617                    Ok(Self::Tcp(SocketAddr::new("::".parse().unwrap(), port)))
618                } else {
619                    // Resolve hostname to IP address for proper SocketAddr creation
620                    let socket_addr = Self::resolve_hostname_to_socket_addr(host, port)?;
621                    Ok(Self::Tcp(socket_addr))
622                }
623            }
624            "inproc" => {
625                // inproc://port -> local:port
626                // Port must be a valid u64 number
627                let port = address.parse::<u64>().map_err(|_| {
628                    anyhow::anyhow!("inproc endpoint must be a valid port number: {}", address)
629                })?;
630                Ok(Self::Local(port))
631            }
632            "ipc" => {
633                // ipc://path -> unix:path
634                Ok(Self::Unix(net::unix::SocketAddr::from_str(address)?))
635            }
636            "metatls" => {
637                let (host, port) = Self::split_host_port(address)?;
638
639                if host == "*" {
640                    // Wildcard binding - use IPv6 unspecified address directly without hostname resolution
641                    Ok(Self::MetaTls(MetaTlsAddr::Host {
642                        hostname: std::net::Ipv6Addr::UNSPECIFIED.to_string(),
643                        port,
644                    }))
645                } else {
646                    Ok(Self::MetaTls(MetaTlsAddr::Host {
647                        hostname: host.to_string(),
648                        port,
649                    }))
650                }
651            }
652            scheme => Err(anyhow::anyhow!("unsupported ZMQ scheme: {}", scheme)),
653        }
654    }
655
656    /// Split host:port string, supporting IPv6 addresses
657    fn split_host_port(address: &str) -> Result<(&str, u16), anyhow::Error> {
658        if let Some((host, port_str)) = address.rsplit_once(':') {
659            let port: u16 = port_str
660                .parse()
661                .map_err(|_| anyhow::anyhow!("invalid port: {}", port_str))?;
662            Ok((host, port))
663        } else {
664            Err(anyhow::anyhow!("invalid address format: {}", address))
665        }
666    }
667
668    /// Resolve hostname to SocketAddr, handling both IP addresses and hostnames
669    fn resolve_hostname_to_socket_addr(host: &str, port: u16) -> Result<SocketAddr, anyhow::Error> {
670        // Handle IPv6 addresses in brackets by stripping the brackets
671        let host_clean = if host.starts_with('[') && host.ends_with(']') {
672            &host[1..host.len() - 1]
673        } else {
674            host
675        };
676
677        // First try to parse as an IP address directly
678        if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
679            return Ok(SocketAddr::new(ip_addr, port));
680        }
681
682        // If not an IP, try hostname resolution
683        use std::net::ToSocketAddrs;
684        let mut addrs = (host_clean, port)
685            .to_socket_addrs()
686            .map_err(|e| anyhow::anyhow!("failed to resolve hostname '{}': {}", host_clean, e))?;
687
688        addrs
689            .next()
690            .ok_or_else(|| anyhow::anyhow!("no addresses found for hostname '{}'", host_clean))
691    }
692}
693
694/// Universal channel transmitter.
695#[derive(Debug)]
696pub struct ChannelTx<M: RemoteMessage> {
697    inner: ChannelTxKind<M>,
698}
699
700/// Universal channel transmitter.
701#[derive(Debug)]
702enum ChannelTxKind<M: RemoteMessage> {
703    Local(local::LocalTx<M>),
704    Tcp(net::NetTx<M>),
705    MetaTls(net::NetTx<M>),
706    Unix(net::NetTx<M>),
707    Sim(sim::SimTx<M>),
708}
709
710#[async_trait]
711impl<M: RemoteMessage> Tx<M> for ChannelTx<M> {
712    fn try_post(&self, message: M, return_channel: oneshot::Sender<M>) -> Result<(), SendError<M>> {
713        match &self.inner {
714            ChannelTxKind::Local(tx) => tx.try_post(message, return_channel),
715            ChannelTxKind::Tcp(tx) => tx.try_post(message, return_channel),
716            ChannelTxKind::MetaTls(tx) => tx.try_post(message, return_channel),
717            ChannelTxKind::Sim(tx) => tx.try_post(message, return_channel),
718            ChannelTxKind::Unix(tx) => tx.try_post(message, return_channel),
719        }
720    }
721
722    fn addr(&self) -> ChannelAddr {
723        match &self.inner {
724            ChannelTxKind::Local(tx) => tx.addr(),
725            ChannelTxKind::Tcp(tx) => Tx::<M>::addr(tx),
726            ChannelTxKind::MetaTls(tx) => Tx::<M>::addr(tx),
727            ChannelTxKind::Sim(tx) => tx.addr(),
728            ChannelTxKind::Unix(tx) => Tx::<M>::addr(tx),
729        }
730    }
731
732    fn status(&self) -> &watch::Receiver<TxStatus> {
733        match &self.inner {
734            ChannelTxKind::Local(tx) => tx.status(),
735            ChannelTxKind::Tcp(tx) => tx.status(),
736            ChannelTxKind::MetaTls(tx) => tx.status(),
737            ChannelTxKind::Sim(tx) => tx.status(),
738            ChannelTxKind::Unix(tx) => tx.status(),
739        }
740    }
741}
742
743/// Universal channel receiver.
744#[derive(Debug)]
745pub struct ChannelRx<M: RemoteMessage> {
746    inner: ChannelRxKind<M>,
747}
748
749/// Universal channel receiver.
750#[derive(Debug)]
751enum ChannelRxKind<M: RemoteMessage> {
752    Local(local::LocalRx<M>),
753    Tcp(net::NetRx<M>),
754    MetaTls(net::NetRx<M>),
755    Unix(net::NetRx<M>),
756    Sim(sim::SimRx<M>),
757}
758
759#[async_trait]
760impl<M: RemoteMessage> Rx<M> for ChannelRx<M> {
761    async fn recv(&mut self) -> Result<M, ChannelError> {
762        match &mut self.inner {
763            ChannelRxKind::Local(rx) => rx.recv().await,
764            ChannelRxKind::Tcp(rx) => rx.recv().await,
765            ChannelRxKind::MetaTls(rx) => rx.recv().await,
766            ChannelRxKind::Sim(rx) => rx.recv().await,
767            ChannelRxKind::Unix(rx) => rx.recv().await,
768        }
769    }
770
771    fn addr(&self) -> ChannelAddr {
772        match &self.inner {
773            ChannelRxKind::Local(rx) => rx.addr(),
774            ChannelRxKind::Tcp(rx) => rx.addr(),
775            ChannelRxKind::MetaTls(rx) => rx.addr(),
776            ChannelRxKind::Sim(rx) => rx.addr(),
777            ChannelRxKind::Unix(rx) => rx.addr(),
778        }
779    }
780}
781
782/// Dial the provided address, returning the corresponding Tx, or error
783/// if the channel cannot be established. The underlying connection is
784/// dropped whenever the returned Tx is dropped.
785#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`.
786pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, ChannelError> {
787    tracing::debug!(name = "dial", "dialing channel {}", addr);
788    let inner = match addr {
789        ChannelAddr::Local(port) => ChannelTxKind::Local(local::dial(port)?),
790        ChannelAddr::Tcp(addr) => ChannelTxKind::Tcp(net::tcp::dial(addr)),
791        ChannelAddr::MetaTls(meta_addr) => ChannelTxKind::MetaTls(net::meta::dial(meta_addr)?),
792        ChannelAddr::Sim(sim_addr) => ChannelTxKind::Sim(sim::dial::<M>(sim_addr)?),
793        ChannelAddr::Unix(path) => ChannelTxKind::Unix(net::unix::dial(path)),
794    };
795    Ok(ChannelTx { inner })
796}
797
798/// Serve on the provided channel address. The server is turned down
799/// when the returned Rx is dropped.
800#[crate::instrument]
801pub fn serve<M: RemoteMessage>(
802    addr: ChannelAddr,
803) -> Result<(ChannelAddr, ChannelRx<M>), ChannelError> {
804    tracing::debug!(name = "serve", "serving channel address {}", addr);
805    match addr {
806        ChannelAddr::Tcp(addr) => {
807            let (addr, rx) = net::tcp::serve::<M>(addr)?;
808            Ok((addr, ChannelRxKind::Tcp(rx)))
809        }
810        ChannelAddr::MetaTls(meta_addr) => {
811            let (addr, rx) = net::meta::serve::<M>(meta_addr)?;
812            Ok((addr, ChannelRxKind::MetaTls(rx)))
813        }
814        ChannelAddr::Unix(path) => {
815            let (addr, rx) = net::unix::serve::<M>(path)?;
816            Ok((addr, ChannelRxKind::Unix(rx)))
817        }
818        ChannelAddr::Local(0) => {
819            let (port, rx) = local::serve::<M>();
820            Ok((ChannelAddr::Local(port), ChannelRxKind::Local(rx)))
821        }
822        ChannelAddr::Sim(sim_addr) => {
823            let (addr, rx) = sim::serve::<M>(sim_addr)?;
824            Ok((addr, ChannelRxKind::Sim(rx)))
825        }
826        ChannelAddr::Local(a) => Err(ChannelError::InvalidAddress(format!(
827            "invalid local addr: {}",
828            a
829        ))),
830    }
831    .map(|(addr, inner)| (addr, ChannelRx { inner }))
832}
833
834/// Serve on the local address. The server is turned down
835/// when the returned Rx is dropped.
836pub fn serve_local<M: RemoteMessage>() -> (ChannelAddr, ChannelRx<M>) {
837    let (port, rx) = local::serve::<M>();
838    (
839        ChannelAddr::Local(port),
840        ChannelRx {
841            inner: ChannelRxKind::Local(rx),
842        },
843    )
844}
845
846#[cfg(test)]
847mod tests {
848    use std::assert_matches::assert_matches;
849    use std::collections::HashSet;
850    use std::net::IpAddr;
851    use std::net::Ipv4Addr;
852    use std::net::Ipv6Addr;
853    use std::time::Duration;
854
855    use tokio::task::JoinSet;
856
857    use super::net::*;
858    use super::*;
859    use crate::clock::Clock;
860    use crate::clock::RealClock;
861
862    #[test]
863    fn test_channel_addr() {
864        let cases_ok = vec![
865            (
866                "tcp<DELIM>[::1]:1234",
867                ChannelAddr::Tcp(SocketAddr::new(
868                    IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
869                    1234,
870                )),
871            ),
872            (
873                "tcp<DELIM>127.0.0.1:8080",
874                ChannelAddr::Tcp(SocketAddr::new(
875                    IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
876                    8080,
877                )),
878            ),
879            #[cfg(target_os = "linux")]
880            ("local<DELIM>123", ChannelAddr::Local(123)),
881            (
882                "unix<DELIM>@yolo",
883                ChannelAddr::Unix(
884                    unix::SocketAddr::from_abstract_name("yolo")
885                        .expect("can't make socket from abstract name"),
886                ),
887            ),
888            (
889                "unix<DELIM>/cool/socket-path",
890                ChannelAddr::Unix(
891                    unix::SocketAddr::from_pathname("/cool/socket-path")
892                        .expect("can't make socket from path"),
893                ),
894            ),
895        ];
896
897        for (raw, parsed) in cases_ok.clone() {
898            for delim in ["!", ":"] {
899                let raw = raw.replace("<DELIM>", delim);
900                assert_eq!(raw.parse::<ChannelAddr>().unwrap(), parsed);
901            }
902        }
903
904        for (raw, parsed) in cases_ok {
905            for delim in ["!", ":"] {
906                // We don't allow mixing and matching delims
907                let raw = format!("sim{}{}", delim, raw.replace("<DELIM>", delim));
908                assert_eq!(
909                    raw.parse::<ChannelAddr>().unwrap(),
910                    ChannelAddr::Sim(SimAddr::new(parsed.clone()).unwrap())
911                );
912            }
913        }
914
915        let cases_err = vec![
916            ("tcp:abcdef..123124", "invalid socket address syntax"),
917            ("xxx:foo", "no such channel type: xxx"),
918            ("127.0.0.1", "no channel type specified"),
919            ("local:abc", "invalid digit found in string"),
920        ];
921
922        for (raw, error) in cases_err {
923            let Err(err) = raw.parse::<ChannelAddr>() else {
924                panic!("expected error parsing: {}", &raw)
925            };
926            assert_eq!(format!("{}", err), error);
927        }
928    }
929
930    #[test]
931    fn test_zmq_style_channel_addr() {
932        // Test TCP addresses
933        assert_eq!(
934            ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080").unwrap(),
935            ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap())
936        );
937
938        // Test TCP wildcard binding
939        assert_eq!(
940            ChannelAddr::from_zmq_url("tcp://*:5555").unwrap(),
941            ChannelAddr::Tcp("[::]:5555".parse().unwrap())
942        );
943
944        // Test inproc (maps to local with numeric endpoint)
945        assert_eq!(
946            ChannelAddr::from_zmq_url("inproc://12345").unwrap(),
947            ChannelAddr::Local(12345)
948        );
949
950        // Test ipc (maps to unix)
951        assert_eq!(
952            ChannelAddr::from_zmq_url("ipc:///tmp/my-socket").unwrap(),
953            ChannelAddr::Unix(unix::SocketAddr::from_pathname("/tmp/my-socket").unwrap())
954        );
955
956        // Test metatls with hostname
957        assert_eq!(
958            ChannelAddr::from_zmq_url("metatls://example.com:443").unwrap(),
959            ChannelAddr::MetaTls(MetaTlsAddr::Host {
960                hostname: "example.com".to_string(),
961                port: 443
962            })
963        );
964
965        // Test metatls with IP address (should be normalized)
966        assert_eq!(
967            ChannelAddr::from_zmq_url("metatls://192.168.1.1:443").unwrap(),
968            ChannelAddr::MetaTls(MetaTlsAddr::Host {
969                hostname: "192.168.1.1".to_string(),
970                port: 443
971            })
972        );
973
974        // Test metatls with wildcard (should use IPv6 unspecified address)
975        assert_eq!(
976            ChannelAddr::from_zmq_url("metatls://*:8443").unwrap(),
977            ChannelAddr::MetaTls(MetaTlsAddr::Host {
978                hostname: "::".to_string(),
979                port: 8443
980            })
981        );
982
983        // Test TCP hostname resolution (should resolve hostname to IP)
984        // Note: This test may fail in environments without proper DNS resolution
985        // We test that it at least doesn't fail to parse
986        let tcp_hostname_result = ChannelAddr::from_zmq_url("tcp://localhost:8080");
987        assert!(tcp_hostname_result.is_ok());
988
989        // Test IPv6 address
990        assert_eq!(
991            ChannelAddr::from_zmq_url("tcp://[::1]:1234").unwrap(),
992            ChannelAddr::Tcp("[::1]:1234".parse().unwrap())
993        );
994
995        // Test error cases
996        assert!(ChannelAddr::from_zmq_url("invalid://scheme").is_err());
997        assert!(ChannelAddr::from_zmq_url("tcp://invalid-port").is_err());
998        assert!(ChannelAddr::from_zmq_url("metatls://no-port").is_err());
999        assert!(ChannelAddr::from_zmq_url("inproc://not-a-number").is_err());
1000    }
1001
1002    #[tokio::test]
1003    async fn test_multiple_connections() {
1004        for addr in ChannelTransport::all().map(ChannelAddr::any) {
1005            let (listen_addr, mut rx) = crate::channel::serve::<u64>(addr).unwrap();
1006
1007            let mut sends: JoinSet<()> = JoinSet::new();
1008            for message in 0u64..100u64 {
1009                let addr = listen_addr.clone();
1010                sends.spawn(async move {
1011                    let tx = dial::<u64>(addr).unwrap();
1012                    tx.try_post(message, oneshot::channel().0).unwrap();
1013                });
1014            }
1015
1016            let mut received: HashSet<u64> = HashSet::new();
1017            while received.len() < 100 {
1018                received.insert(rx.recv().await.unwrap());
1019            }
1020
1021            for message in 0u64..100u64 {
1022                assert!(received.contains(&message));
1023            }
1024
1025            loop {
1026                match sends.join_next().await {
1027                    Some(Ok(())) => (),
1028                    Some(Err(err)) => panic!("{}", err),
1029                    None => break,
1030                }
1031            }
1032        }
1033    }
1034
1035    #[tokio::test]
1036    async fn test_server_close() {
1037        for addr in ChannelTransport::all().map(ChannelAddr::any) {
1038            if net::is_net_addr(&addr) {
1039                // Net has store-and-forward semantics. We don't expect failures
1040                // on closure.
1041                continue;
1042            }
1043
1044            let (listen_addr, rx) = crate::channel::serve::<u64>(addr).unwrap();
1045
1046            let tx = dial::<u64>(listen_addr).unwrap();
1047            tx.try_post(123, oneshot::channel().0).unwrap();
1048            drop(rx);
1049
1050            // New transmits should fail... but there is buffering, etc.,
1051            // which can cause the failure to be delayed. We give it
1052            // a deadline, but it can still technically fail -- the test
1053            // should be considered a kind of integration test.
1054            let start = RealClock.now();
1055
1056            let result = loop {
1057                let result = tx.try_post(123, oneshot::channel().0);
1058                if result.is_err() || start.elapsed() > Duration::from_secs(10) {
1059                    break result;
1060                }
1061            };
1062            assert_matches!(result, Err(SendError(ChannelError::Closed, 123)));
1063        }
1064    }
1065
1066    fn addrs() -> Vec<ChannelAddr> {
1067        use rand::Rng;
1068        use rand::distributions::Uniform;
1069
1070        let rng = rand::thread_rng();
1071        vec![
1072            "tcp:[::1]:0".parse().unwrap(),
1073            "local:0".parse().unwrap(),
1074            #[cfg(target_os = "linux")]
1075            "unix:".parse().unwrap(),
1076            #[cfg(target_os = "linux")]
1077            format!(
1078                "unix:@{}",
1079                rng.sample_iter(Uniform::new_inclusive('a', 'z'))
1080                    .take(10)
1081                    .collect::<String>()
1082            )
1083            .parse()
1084            .unwrap(),
1085        ]
1086    }
1087
1088    #[tokio::test]
1089    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Server(Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" }))
1090    #[cfg_attr(not(feature = "fb"), ignore)]
1091    async fn test_dial_serve() {
1092        for addr in addrs() {
1093            let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1094            let tx = crate::channel::dial(listen_addr).unwrap();
1095            tx.try_post(123, oneshot::channel().0).unwrap();
1096            assert_eq!(rx.recv().await.unwrap(), 123);
1097        }
1098    }
1099
1100    #[tokio::test]
1101    // TODO: OSS: called `Result::unwrap()` on an `Err` value: Server(Listen(Tcp([::1]:0), Os { code: 99, kind: AddrNotAvailable, message: "Cannot assign requested address" }))
1102    #[cfg_attr(not(feature = "fb"), ignore)]
1103    async fn test_send() {
1104        let config = crate::config::global::lock();
1105
1106        // Use temporary config for this test
1107        let _guard1 = config.override_key(
1108            crate::config::MESSAGE_DELIVERY_TIMEOUT,
1109            Duration::from_secs(1),
1110        );
1111        let _guard2 = config.override_key(crate::config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
1112        for addr in addrs() {
1113            let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1114            let tx = crate::channel::dial(listen_addr).unwrap();
1115            tx.send(123).await.unwrap();
1116            assert_eq!(rx.recv().await.unwrap(), 123);
1117
1118            drop(rx);
1119            assert_matches!(
1120                tx.send(123).await.unwrap_err(),
1121                SendError(ChannelError::Closed, 123)
1122            );
1123        }
1124    }
1125}