1use core::net::SocketAddr;
13use std::fmt;
14use std::net::IpAddr;
15use std::net::Ipv6Addr;
16#[cfg(target_os = "linux")]
17use std::os::linux::net::SocketAddrExt;
18use std::panic::Location;
19use std::str::FromStr;
20
21use async_trait::async_trait;
22use hyperactor_config::attrs::AttrValue;
23use serde::Deserialize;
24use serde::Serialize;
25use tokio::sync::mpsc;
26use tokio::sync::oneshot;
27use tokio::sync::watch;
28
29use crate as hyperactor;
30use crate::RemoteMessage;
31pub(crate) mod local;
32pub(crate) mod net;
33
34pub use net::try_tls_acceptor;
38pub use net::try_tls_connector;
39pub use net::try_tls_pem_bundle;
40
41pub mod duplex {
43 pub use super::net::duplex::DuplexRx;
44 pub use super::net::duplex::DuplexServer;
45 pub use super::net::duplex::DuplexTx;
46 pub use super::net::duplex::dial;
47 pub use super::net::duplex::serve;
48}
49
50#[derive(thiserror::Error, Debug)]
52pub enum ChannelError {
53 #[error("channel closed")]
55 Closed,
56
57 #[error("send: {0}")]
59 Send(#[source] anyhow::Error),
60
61 #[error(transparent)]
63 Client(#[from] net::ClientError),
64
65 #[error("invalid address {0:?}")]
67 InvalidAddress(String),
68
69 #[error(transparent)]
71 Server(#[from] net::ServerError),
72
73 #[error(transparent)]
75 Bincode(#[from] Box<bincode::ErrorKind>),
76
77 #[error(transparent)]
79 Data(#[from] wirevalue::Error),
80
81 #[error(transparent)]
83 Other(#[from] anyhow::Error),
84
85 #[error("operation timed out after {0:?}")]
87 Timeout(std::time::Duration),
88}
89
90#[derive(thiserror::Error, Debug)]
92#[error("{error} for reason {reason:?}")]
93pub struct SendError<M: RemoteMessage> {
94 #[source]
96 pub error: ChannelError,
97 pub message: M,
99 pub reason: Option<String>,
101}
102
103impl<M: RemoteMessage> From<SendError<M>> for ChannelError {
104 fn from(error: SendError<M>) -> Self {
105 error.error
106 }
107}
108
109#[derive(Debug, Copy, Clone, PartialEq)]
111pub enum TxStatus {
112 Active,
114 Closed,
116}
117
118#[async_trait]
120pub trait Tx<M: RemoteMessage> {
121 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>);
127
128 #[allow(clippy::result_large_err)] #[tracing::instrument(level = "debug", skip_all)]
133 fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
134 self.do_post(message, Some(return_channel));
135 }
136
137 #[hyperactor::instrument_infallible]
139 fn post(&self, message: M) {
140 self.do_post(message, None);
141 }
142
143 async fn send(&self, message: M) -> Result<(), SendError<M>> {
146 let (tx, rx) = oneshot::channel();
147 self.try_post(message, tx);
148 match rx.await {
149 Ok(err) => Err(err),
151
152 Err(_) => Ok(()),
155 }
156 }
157
158 fn addr(&self) -> ChannelAddr;
160
161 fn status(&self) -> &watch::Receiver<TxStatus>;
163}
164
165#[async_trait]
167pub trait Rx<M: RemoteMessage> {
168 async fn recv(&mut self) -> Result<M, ChannelError>;
171
172 fn addr(&self) -> ChannelAddr;
174
175 async fn join(self)
179 where
180 Self: Sized;
181}
182
183#[allow(dead_code)] struct MpscTx<M: RemoteMessage> {
185 tx: mpsc::UnboundedSender<M>,
186 addr: ChannelAddr,
187 status: watch::Receiver<TxStatus>,
188}
189
190impl<M: RemoteMessage> MpscTx<M> {
191 #[allow(dead_code)] pub fn new(tx: mpsc::UnboundedSender<M>, addr: ChannelAddr) -> (Self, watch::Sender<TxStatus>) {
193 let (sender, receiver) = watch::channel(TxStatus::Active);
194 (
195 Self {
196 tx,
197 addr,
198 status: receiver,
199 },
200 sender,
201 )
202 }
203}
204
205#[async_trait]
206impl<M: RemoteMessage> Tx<M> for MpscTx<M> {
207 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
208 if let Err(mpsc::error::SendError(message)) = self.tx.send(message) {
209 if let Some(return_channel) = return_channel {
210 return_channel
211 .send(SendError {
212 error: ChannelError::Closed,
213 message,
214 reason: None,
215 })
216 .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
217 }
218 }
219 }
220
221 fn addr(&self) -> ChannelAddr {
222 self.addr.clone()
223 }
224
225 fn status(&self) -> &watch::Receiver<TxStatus> {
226 &self.status
227 }
228}
229
230#[allow(dead_code)] struct MpscRx<M: RemoteMessage> {
232 rx: mpsc::UnboundedReceiver<M>,
233 addr: ChannelAddr,
234 status_sender: watch::Sender<TxStatus>,
236}
237
238impl<M: RemoteMessage> MpscRx<M> {
239 #[allow(dead_code)] pub fn new(
241 rx: mpsc::UnboundedReceiver<M>,
242 addr: ChannelAddr,
243 status_sender: watch::Sender<TxStatus>,
244 ) -> Self {
245 Self {
246 rx,
247 addr,
248 status_sender,
249 }
250 }
251}
252
253impl<M: RemoteMessage> Drop for MpscRx<M> {
254 fn drop(&mut self) {
255 let _ = self.status_sender.send(TxStatus::Closed);
256 }
257}
258
259#[async_trait]
260impl<M: RemoteMessage> Rx<M> for MpscRx<M> {
261 async fn recv(&mut self) -> Result<M, ChannelError> {
262 self.rx.recv().await.ok_or(ChannelError::Closed)
263 }
264
265 fn addr(&self) -> ChannelAddr {
266 self.addr.clone()
267 }
268
269 async fn join(self) {}
270}
271
272#[derive(
274 Clone,
275 Debug,
276 PartialEq,
277 Eq,
278 Hash,
279 Serialize,
280 Deserialize,
281 strum::EnumIter,
282 strum::Display,
283 strum::EnumString
284)]
285pub enum TcpMode {
286 Localhost,
288 Hostname,
290}
291
292#[derive(
294 Clone,
295 Debug,
296 PartialEq,
297 Eq,
298 Hash,
299 Serialize,
300 Deserialize,
301 strum::EnumIter,
302 strum::Display,
303 strum::EnumString
304)]
305pub enum TlsMode {
306 IpV6,
308 Hostname,
310 }
312
313#[derive(
315 Clone,
316 Debug,
317 PartialEq,
318 Eq,
319 Hash,
320 Serialize,
321 Deserialize,
322 Ord,
323 PartialOrd
324)]
325pub struct TlsAddr {
326 pub hostname: Hostname,
328 pub port: Port,
330}
331
332impl TlsAddr {
333 pub fn new(hostname: impl Into<Hostname>, port: Port) -> Self {
335 Self {
336 hostname: normalize_host(&hostname.into()),
337 port,
338 }
339 }
340
341 pub fn port(&self) -> Port {
343 self.port
344 }
345
346 pub fn hostname(&self) -> &str {
348 &self.hostname
349 }
350}
351
352impl fmt::Display for TlsAddr {
353 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354 write!(f, "{}:{}", self.hostname, self.port)
355 }
356}
357
358#[derive(
360 Clone,
361 Debug,
362 PartialEq,
363 Eq,
364 Hash,
365 Serialize,
366 Deserialize,
367 typeuri::Named
368)]
369pub enum ChannelTransport {
370 Tcp(TcpMode),
372
373 MetaTls(TlsMode),
375
376 Tls,
378
379 Local,
381
382 Unix,
384}
385
386impl fmt::Display for ChannelTransport {
387 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
388 match self {
389 Self::Tcp(mode) => write!(f, "tcp({:?})", mode),
390 Self::MetaTls(mode) => write!(f, "metatls({:?})", mode),
391 Self::Tls => write!(f, "tls"),
392 Self::Local => write!(f, "local"),
393 Self::Unix => write!(f, "unix"),
394 }
395 }
396}
397
398impl FromStr for ChannelTransport {
399 type Err = anyhow::Error;
400
401 fn from_str(s: &str) -> Result<Self, Self::Err> {
402 match s {
403 "tcp" => Ok(ChannelTransport::Tcp(TcpMode::Hostname)),
405 s if s.starts_with("tcp(") => {
406 let inner = &s["tcp(".len()..s.len() - 1];
407 let mode = inner.parse()?;
408 Ok(ChannelTransport::Tcp(mode))
409 }
410 "local" => Ok(ChannelTransport::Local),
411 "unix" => Ok(ChannelTransport::Unix),
412 "tls" => Ok(ChannelTransport::Tls),
413 s if s.starts_with("metatls(") && s.ends_with(")") => {
414 let inner = &s["metatls(".len()..s.len() - 1];
415 let mode = inner.parse()?;
416 Ok(ChannelTransport::MetaTls(mode))
417 }
418 unknown => Err(anyhow::anyhow!("unknown channel transport: {}", unknown)),
419 }
420 }
421}
422
423impl ChannelTransport {
424 pub fn all() -> [ChannelTransport; 3] {
426 [
427 ChannelTransport::Tcp(TcpMode::Hostname),
430 ChannelTransport::Local,
431 ChannelTransport::Unix,
432 ]
435 }
436
437 pub fn any(&self) -> ChannelAddr {
439 ChannelAddr::any(self.clone())
440 }
441
442 pub fn is_remote(&self) -> bool {
444 match self {
445 ChannelTransport::Tcp(_) => true,
446 ChannelTransport::MetaTls(_) => true,
447 ChannelTransport::Tls => true,
448 ChannelTransport::Local => false,
449 ChannelTransport::Unix => false,
450 }
451 }
452}
453
454impl AttrValue for ChannelTransport {
455 fn display(&self) -> String {
456 self.to_string()
457 }
458
459 fn parse(s: &str) -> Result<Self, anyhow::Error> {
460 s.parse()
461 }
462}
463
464#[derive(
466 Clone,
467 Debug,
468 PartialEq,
469 Eq,
470 Hash,
471 Serialize,
472 Deserialize,
473 typeuri::Named
474)]
475pub enum BindSpec {
476 Any(ChannelTransport),
478
479 Addr(ChannelAddr),
481}
482
483impl BindSpec {
484 pub fn binding_addr(&self) -> ChannelAddr {
486 match self {
487 BindSpec::Any(transport) => ChannelAddr::any(transport.clone()),
488 BindSpec::Addr(addr) => addr.clone(),
489 }
490 }
491}
492
493impl From<ChannelTransport> for BindSpec {
494 fn from(transport: ChannelTransport) -> Self {
495 BindSpec::Any(transport)
496 }
497}
498
499impl From<ChannelAddr> for BindSpec {
500 fn from(addr: ChannelAddr) -> Self {
501 BindSpec::Addr(addr)
502 }
503}
504
505impl fmt::Display for BindSpec {
506 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
507 match self {
508 Self::Any(transport) => write!(f, "{}", transport),
509 Self::Addr(addr) => write!(f, "{}", addr),
510 }
511 }
512}
513
514impl FromStr for BindSpec {
515 type Err = anyhow::Error;
516
517 fn from_str(s: &str) -> Result<Self, Self::Err> {
518 if let Ok(transport) = ChannelTransport::from_str(s) {
519 Ok(BindSpec::Any(transport))
520 } else if let Ok(addr) = ChannelAddr::from_zmq_url(s) {
521 Ok(BindSpec::Addr(addr))
522 } else if let Ok(addr) = ChannelAddr::from_str(s) {
523 Ok(BindSpec::Addr(addr))
524 } else {
525 Err(anyhow::anyhow!("invalid bind spec: {}", s))
526 }
527 }
528}
529
530impl AttrValue for BindSpec {
531 fn display(&self) -> String {
532 self.to_string()
533 }
534
535 fn parse(s: &str) -> Result<Self, anyhow::Error> {
536 Self::from_str(s)
537 }
538}
539
540pub type Hostname = String;
542
543pub type Port = u16;
545
546#[derive(
569 Clone,
570 Debug,
571 PartialEq,
572 Eq,
573 Ord,
574 PartialOrd,
575 Serialize,
576 Deserialize,
577 Hash,
578 typeuri::Named
579)]
580pub enum ChannelAddr {
581 Tcp(SocketAddr),
584
585 MetaTls(TlsAddr),
588
589 Tls(TlsAddr),
592
593 Local(u64),
596
597 Unix(net::unix::SocketAddr),
600
601 Alias {
616 dial_to: Box<ChannelAddr>,
618 bind_to: Box<ChannelAddr>,
620 },
621}
622
623impl From<SocketAddr> for ChannelAddr {
624 fn from(value: SocketAddr) -> Self {
625 Self::Tcp(value)
626 }
627}
628
629impl From<net::unix::SocketAddr> for ChannelAddr {
630 fn from(value: net::unix::SocketAddr) -> Self {
631 Self::Unix(value)
632 }
633}
634
635impl From<std::os::unix::net::SocketAddr> for ChannelAddr {
636 fn from(value: std::os::unix::net::SocketAddr) -> Self {
637 Self::Unix(net::unix::SocketAddr::new(value))
638 }
639}
640
641impl From<tokio::net::unix::SocketAddr> for ChannelAddr {
642 fn from(value: tokio::net::unix::SocketAddr) -> Self {
643 std::os::unix::net::SocketAddr::from(value).into()
644 }
645}
646
647fn find_routable_address(addresses: &[IpAddr]) -> Option<IpAddr> {
649 addresses
650 .iter()
651 .find(|addr| match addr {
652 IpAddr::V6(v6) => !v6.is_unicast_link_local(),
653 IpAddr::V4(v4) => !v4.is_link_local(),
654 })
655 .cloned()
656}
657
658impl ChannelAddr {
659 pub fn any(transport: ChannelTransport) -> Self {
662 match transport {
663 ChannelTransport::Tcp(mode) => {
664 let ip = match mode {
665 TcpMode::Localhost => IpAddr::V6(Ipv6Addr::LOCALHOST),
666 TcpMode::Hostname => {
667 hostname::get()
668 .ok()
669 .and_then(|hostname| {
670 hostname.to_str().and_then(|hostname_str| {
672 dns_lookup::lookup_host(hostname_str)
673 .ok()
674 .and_then(|addresses| find_routable_address(&addresses))
675 })
676 })
677 .expect("failed to resolve hostname to ip address")
678 }
679 };
680 Self::Tcp(SocketAddr::new(ip, 0))
681 }
682 ChannelTransport::MetaTls(mode) => {
683 let host_address = match mode {
684 TlsMode::Hostname => hostname::get()
685 .ok()
686 .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
687 .unwrap_or("unknown_host".to_string()),
688 TlsMode::IpV6 => {
689 get_host_ipv6_address().expect("failed to retrieve ipv6 address")
690 }
691 };
692 Self::MetaTls(TlsAddr::new(host_address, 0))
693 }
694 ChannelTransport::Local => Self::Local(0),
695 ChannelTransport::Tls => {
696 let host_address = hostname::get()
697 .ok()
698 .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
699 .unwrap_or("localhost".to_string());
700 Self::Tls(TlsAddr::new(host_address, 0))
701 }
702 ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()),
704 }
705 }
706
707 pub fn transport(&self) -> ChannelTransport {
709 match self {
710 Self::Tcp(addr) => {
711 if addr.ip().is_loopback() {
712 ChannelTransport::Tcp(TcpMode::Localhost)
713 } else {
714 ChannelTransport::Tcp(TcpMode::Hostname)
715 }
716 }
717 Self::MetaTls(addr) => match addr.hostname.parse::<IpAddr>() {
718 Ok(IpAddr::V6(_)) => ChannelTransport::MetaTls(TlsMode::IpV6),
719 Ok(IpAddr::V4(_)) => ChannelTransport::MetaTls(TlsMode::Hostname),
720 Err(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
721 },
722 Self::Tls(_) => ChannelTransport::Tls,
723 Self::Local(_) => ChannelTransport::Local,
724 Self::Unix(_) => ChannelTransport::Unix,
725 Self::Alias { bind_to, .. } => bind_to.transport(),
728 }
729 }
730}
731
732#[cfg(fbcode_build)]
733fn get_host_ipv6_address() -> anyhow::Result<String> {
734 crate::meta::host_ip::host_ipv6_address()
735}
736
737#[cfg(not(fbcode_build))]
738fn get_host_ipv6_address() -> anyhow::Result<String> {
739 Ok(local_ip_address::local_ipv6()?.to_string())
740}
741
742impl fmt::Display for ChannelAddr {
743 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
744 match self {
745 Self::Tcp(addr) => write!(f, "tcp:{}", addr),
746 Self::MetaTls(addr) => write!(f, "metatls:{}", addr),
747 Self::Tls(addr) => write!(f, "tls:{}", addr),
748 Self::Local(index) => write!(f, "local:{}", index),
749 Self::Unix(addr) => write!(f, "unix:{}", addr),
750 Self::Alias { dial_to, bind_to } => {
751 write!(f, "alias:dial_to={};bind_to={}", dial_to, bind_to)
752 }
753 }
754 }
755}
756
757impl FromStr for ChannelAddr {
758 type Err = anyhow::Error;
759
760 fn from_str(addr: &str) -> Result<Self, Self::Err> {
761 match addr.split_once('!').or_else(|| addr.split_once(':')) {
762 Some(("local", rest)) => rest
763 .parse::<u64>()
764 .map(Self::Local)
765 .map_err(anyhow::Error::from),
766 Some(("tcp", rest)) => rest
767 .parse::<SocketAddr>()
768 .map(Self::Tcp)
769 .map_err(anyhow::Error::from),
770 Some(("metatls", rest)) => net::meta::parse(rest).map_err(|e| e.into()),
771 Some(("tls", rest)) => net::tls::parse(rest).map_err(|e| e.into()),
772 Some(("unix", rest)) => Ok(Self::Unix(net::unix::SocketAddr::from_str(rest)?)),
773 Some(("alias", _)) => Err(anyhow::anyhow!(
774 "detect possible alias address, but we currently do not support \
775 parsing alias' string representation since we only want to \
776 support parsing its zmq url format."
777 )),
778 Some((r#type, _)) => Err(anyhow::anyhow!("no such channel type: {type}")),
779 None => Err(anyhow::anyhow!("no channel type specified")),
780 }
781 }
782}
783
784pub(crate) fn normalize_host(host: &str) -> String {
787 let host_clean = host
790 .strip_prefix('[')
791 .and_then(|h| h.strip_suffix(']'))
792 .unwrap_or(host);
793
794 if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
795 ip_addr.to_string()
796 } else {
797 host.to_string()
798 }
799}
800
801impl ChannelAddr {
802 pub fn from_zmq_url(address: &str) -> Result<Self, anyhow::Error> {
811 if let Some(at_pos) = address.find('@') {
814 let dial_to_str = &address[..at_pos];
815 let bind_to_str = &address[at_pos + 1..];
816
817 if !dial_to_str.starts_with("tcp://") {
819 return Err(anyhow::anyhow!(
820 "alias format is only supported for TCP addresses, got dial_to: {}",
821 dial_to_str
822 ));
823 }
824 if !bind_to_str.starts_with("tcp://") {
825 return Err(anyhow::anyhow!(
826 "alias format is only supported for TCP addresses, got bind_to: {}",
827 bind_to_str
828 ));
829 }
830
831 let dial_to = Self::from_zmq_url(dial_to_str)?;
832 let bind_to = Self::from_zmq_url(bind_to_str)?;
833
834 return Ok(Self::Alias {
835 dial_to: Box::new(dial_to),
836 bind_to: Box::new(bind_to),
837 });
838 }
839
840 let (scheme, address) = address.split_once("://").ok_or_else(|| {
842 anyhow::anyhow!("address must be in url form scheme://endppoint {}", address)
843 })?;
844
845 match scheme {
846 "tcp" => {
847 let (host, port) = Self::split_host_port(address)?;
848
849 if host == "*" {
850 Ok(Self::Tcp(SocketAddr::new("::".parse().unwrap(), port)))
852 } else {
853 let socket_addr = Self::resolve_hostname_to_socket_addr(host, port)?;
855 Ok(Self::Tcp(socket_addr))
856 }
857 }
858 "inproc" => {
859 let port = address.parse::<u64>().map_err(|_| {
862 anyhow::anyhow!("inproc endpoint must be a valid port number: {}", address)
863 })?;
864 Ok(Self::Local(port))
865 }
866 "ipc" => {
867 Ok(Self::Unix(net::unix::SocketAddr::from_str(address)?))
869 }
870 "metatls" => {
871 let (host, port) = Self::split_host_port(address)?;
872
873 if host == "*" {
874 Ok(Self::MetaTls(TlsAddr::new(
876 std::net::Ipv6Addr::UNSPECIFIED.to_string(),
877 port,
878 )))
879 } else {
880 Ok(Self::MetaTls(TlsAddr::new(host, port)))
881 }
882 }
883 "tls" => {
884 let (host, port) = Self::split_host_port(address)?;
885
886 if host == "*" {
887 Ok(Self::Tls(TlsAddr::new(
889 std::net::Ipv6Addr::UNSPECIFIED.to_string(),
890 port,
891 )))
892 } else {
893 Ok(Self::Tls(TlsAddr::new(host, port)))
894 }
895 }
896 scheme => Err(anyhow::anyhow!("unsupported ZMQ scheme: {}", scheme)),
897 }
898 }
899
900 fn split_host_port(address: &str) -> Result<(&str, u16), anyhow::Error> {
902 if let Some((host, port_str)) = address.rsplit_once(':') {
903 let port: u16 = port_str
904 .parse()
905 .map_err(|_| anyhow::anyhow!("invalid port: {}", port_str))?;
906 Ok((host, port))
907 } else {
908 Err(anyhow::anyhow!("invalid address format: {}", address))
909 }
910 }
911
912 pub fn to_zmq_url(&self) -> String {
914 match self {
915 Self::Tcp(addr) => format!("tcp://{}", addr),
916 Self::MetaTls(addr) => format!("metatls://{}:{}", addr.hostname, addr.port),
917 Self::Tls(addr) => format!("tls://{}:{}", addr.hostname, addr.port),
918 Self::Local(index) => format!("inproc://{}", index),
919 Self::Unix(addr) => format!("ipc://{}", addr),
920 Self::Alias { dial_to, bind_to } => {
921 format!("{}@{}", dial_to.to_zmq_url(), bind_to.to_zmq_url())
922 }
923 }
924 }
925
926 fn resolve_hostname_to_socket_addr(host: &str, port: u16) -> Result<SocketAddr, anyhow::Error> {
928 let host_clean = if host.starts_with('[') && host.ends_with(']') {
930 &host[1..host.len() - 1]
931 } else {
932 host
933 };
934
935 if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
937 return Ok(SocketAddr::new(ip_addr, port));
938 }
939
940 use std::net::ToSocketAddrs;
942 let mut addrs = (host_clean, port)
943 .to_socket_addrs()
944 .map_err(|e| anyhow::anyhow!("failed to resolve hostname '{}': {}", host_clean, e))?;
945
946 addrs
947 .next()
948 .ok_or_else(|| anyhow::anyhow!("no addresses found for hostname '{}'", host_clean))
949 }
950}
951
952pub struct ChannelTx<M: RemoteMessage> {
954 inner: ChannelTxKind<M>,
955}
956
957impl<M: RemoteMessage> fmt::Debug for ChannelTx<M> {
958 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
959 f.debug_struct("ChannelTx")
960 .field("addr", &self.addr())
961 .finish()
962 }
963}
964
965enum ChannelTxKind<M: RemoteMessage> {
967 Local(local::LocalTx<M>),
968 Net(net::NetTx<M>),
969}
970
971#[async_trait]
972impl<M: RemoteMessage> Tx<M> for ChannelTx<M> {
973 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
974 match &self.inner {
975 ChannelTxKind::Local(tx) => tx.do_post(message, return_channel),
976 ChannelTxKind::Net(tx) => tx.do_post(message, return_channel),
977 }
978 }
979
980 fn addr(&self) -> ChannelAddr {
981 match &self.inner {
982 ChannelTxKind::Local(tx) => tx.addr(),
983 ChannelTxKind::Net(tx) => Tx::<M>::addr(tx),
984 }
985 }
986
987 fn status(&self) -> &watch::Receiver<TxStatus> {
988 match &self.inner {
989 ChannelTxKind::Local(tx) => tx.status(),
990 ChannelTxKind::Net(tx) => tx.status(),
991 }
992 }
993}
994
995pub struct ChannelRx<M: RemoteMessage> {
997 inner: ChannelRxKind<M>,
998}
999
1000impl<M: RemoteMessage> fmt::Debug for ChannelRx<M> {
1001 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1002 f.debug_struct("ChannelRx")
1003 .field("addr", &self.addr())
1004 .finish()
1005 }
1006}
1007
1008enum ChannelRxKind<M: RemoteMessage> {
1010 Local(local::LocalRx<M>),
1011 Net(net::NetRx<M>),
1012}
1013
1014#[async_trait]
1015impl<M: RemoteMessage> Rx<M> for ChannelRx<M> {
1016 #[tracing::instrument(level = "debug", skip_all)]
1017 async fn recv(&mut self) -> Result<M, ChannelError> {
1018 match &mut self.inner {
1019 ChannelRxKind::Local(rx) => rx.recv().await,
1020 ChannelRxKind::Net(rx) => rx.recv().await,
1021 }
1022 }
1023
1024 fn addr(&self) -> ChannelAddr {
1025 match &self.inner {
1026 ChannelRxKind::Local(rx) => rx.addr(),
1027 ChannelRxKind::Net(rx) => rx.addr(),
1028 }
1029 }
1030
1031 async fn join(self) {
1032 match self.inner {
1033 ChannelRxKind::Local(rx) => rx.join().await,
1034 ChannelRxKind::Net(rx) => rx.join().await,
1035 }
1036 }
1037}
1038
1039#[allow(clippy::result_large_err)] #[track_caller]
1044pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, ChannelError> {
1045 tracing::debug!(name = "dial", caller = %Location::caller(), %addr, "dialing channel {}", addr);
1046 let inner = match addr {
1047 ChannelAddr::Local(port) => ChannelTxKind::Local(local::dial(port)?),
1048 ChannelAddr::Tcp(_)
1049 | ChannelAddr::Unix(_)
1050 | ChannelAddr::Tls(_)
1051 | ChannelAddr::MetaTls(_) => ChannelTxKind::Net(net::spawn(net::link(addr)?)),
1052 ChannelAddr::Alias { dial_to, .. } => dial(*dial_to)?.inner,
1053 };
1054 Ok(ChannelTx { inner })
1055}
1056
1057#[track_caller]
1060pub fn serve<M: RemoteMessage>(
1061 addr: ChannelAddr,
1062) -> Result<(ChannelAddr, ChannelRx<M>), ChannelError> {
1063 let caller = Location::caller();
1064 serve_inner(addr).map(|(addr, inner)| {
1065 tracing::debug!(
1066 name = "serve",
1067 %addr,
1068 %caller,
1069 );
1070 (addr, ChannelRx { inner })
1071 })
1072}
1073
1074fn serve_inner<M: RemoteMessage>(
1075 addr: ChannelAddr,
1076) -> Result<(ChannelAddr, ChannelRxKind<M>), ChannelError> {
1077 match addr {
1078 ChannelAddr::Tcp(_)
1079 | ChannelAddr::Unix(_)
1080 | ChannelAddr::Tls(_)
1081 | ChannelAddr::MetaTls(_) => {
1082 let (addr, rx) = net::server::serve::<M>(addr)?;
1083 Ok((addr, ChannelRxKind::Net(rx)))
1084 }
1085 ChannelAddr::Local(0) => {
1086 let (port, rx) = local::serve::<M>();
1087 Ok((ChannelAddr::Local(port), ChannelRxKind::Local(rx)))
1088 }
1089 ChannelAddr::Local(a) => Err(ChannelError::InvalidAddress(format!(
1090 "invalid local addr: {}",
1091 a
1092 ))),
1093 ChannelAddr::Alias { dial_to, bind_to } => {
1094 let (bound_addr, rx) = serve_inner::<M>(*bind_to)?;
1095 let alias_addr = ChannelAddr::Alias {
1096 dial_to,
1097 bind_to: Box::new(bound_addr),
1098 };
1099 Ok((alias_addr, rx))
1100 }
1101 }
1102}
1103
1104pub fn serve_local<M: RemoteMessage>() -> (ChannelAddr, ChannelRx<M>) {
1107 let (port, rx) = local::serve::<M>();
1108 (
1109 ChannelAddr::Local(port),
1110 ChannelRx {
1111 inner: ChannelRxKind::Local(rx),
1112 },
1113 )
1114}
1115
1116#[cfg(test)]
1117mod tests {
1118 use std::assert_matches::assert_matches;
1119 use std::collections::HashSet;
1120 use std::net::IpAddr;
1121 use std::net::Ipv4Addr;
1122 use std::net::Ipv6Addr;
1123 use std::time::Duration;
1124
1125 use tokio::task::JoinSet;
1126
1127 use super::net::*;
1128 use super::*;
1129 #[test]
1130 fn test_channel_addr() {
1131 let cases_ok = vec![
1132 (
1133 "tcp<DELIM>[::1]:1234",
1134 ChannelAddr::Tcp(SocketAddr::new(
1135 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
1136 1234,
1137 )),
1138 ),
1139 (
1140 "tcp<DELIM>127.0.0.1:8080",
1141 ChannelAddr::Tcp(SocketAddr::new(
1142 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
1143 8080,
1144 )),
1145 ),
1146 #[cfg(target_os = "linux")]
1147 ("local<DELIM>123", ChannelAddr::Local(123)),
1148 (
1149 "unix<DELIM>@yolo",
1150 ChannelAddr::Unix(
1151 unix::SocketAddr::from_abstract_name("yolo")
1152 .expect("can't make socket from abstract name"),
1153 ),
1154 ),
1155 (
1156 "unix<DELIM>/cool/socket-path",
1157 ChannelAddr::Unix(
1158 unix::SocketAddr::from_pathname("/cool/socket-path")
1159 .expect("can't make socket from path"),
1160 ),
1161 ),
1162 ];
1163
1164 for (raw, parsed) in cases_ok {
1165 for delim in ["!", ":"] {
1166 let raw = raw.replace("<DELIM>", delim);
1167 assert_eq!(raw.parse::<ChannelAddr>().unwrap(), parsed);
1168 }
1169 }
1170
1171 let cases_err = vec![
1172 ("tcp:abcdef..123124", "invalid socket address syntax"),
1173 ("xxx:foo", "no such channel type: xxx"),
1174 ("127.0.0.1", "no channel type specified"),
1175 ("local:abc", "invalid digit found in string"),
1176 ];
1177
1178 for (raw, error) in cases_err {
1179 let Err(err) = raw.parse::<ChannelAddr>() else {
1180 panic!("expected error parsing: {}", &raw)
1181 };
1182 assert_eq!(format!("{}", err), error);
1183 }
1184 }
1185
1186 #[test]
1187 fn test_zmq_style_channel_addr() {
1188 assert_eq!(
1190 ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080").unwrap(),
1191 ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap())
1192 );
1193
1194 assert_eq!(
1196 ChannelAddr::from_zmq_url("tcp://*:5555").unwrap(),
1197 ChannelAddr::Tcp("[::]:5555".parse().unwrap())
1198 );
1199
1200 assert_eq!(
1202 ChannelAddr::from_zmq_url("inproc://12345").unwrap(),
1203 ChannelAddr::Local(12345)
1204 );
1205
1206 assert_eq!(
1208 ChannelAddr::from_zmq_url("ipc:///tmp/my-socket").unwrap(),
1209 ChannelAddr::Unix(unix::SocketAddr::from_pathname("/tmp/my-socket").unwrap())
1210 );
1211
1212 assert_eq!(
1214 ChannelAddr::from_zmq_url("metatls://example.com:443").unwrap(),
1215 ChannelAddr::MetaTls(TlsAddr::new("example.com", 443))
1216 );
1217
1218 assert_eq!(
1220 ChannelAddr::from_zmq_url("metatls://192.168.1.1:443").unwrap(),
1221 ChannelAddr::MetaTls(TlsAddr::new("192.168.1.1", 443))
1222 );
1223
1224 assert_eq!(
1226 ChannelAddr::from_zmq_url("metatls://*:8443").unwrap(),
1227 ChannelAddr::MetaTls(TlsAddr::new("::", 8443))
1228 );
1229
1230 let tcp_hostname_result = ChannelAddr::from_zmq_url("tcp://localhost:8080");
1234 assert!(tcp_hostname_result.is_ok());
1235
1236 assert_eq!(
1238 ChannelAddr::from_zmq_url("tcp://[::1]:1234").unwrap(),
1239 ChannelAddr::Tcp("[::1]:1234".parse().unwrap())
1240 );
1241
1242 assert!(ChannelAddr::from_zmq_url("invalid://scheme").is_err());
1244 assert!(ChannelAddr::from_zmq_url("tcp://invalid-port").is_err());
1245 assert!(ChannelAddr::from_zmq_url("metatls://no-port").is_err());
1246 assert!(ChannelAddr::from_zmq_url("inproc://not-a-number").is_err());
1247
1248 assert_eq!(
1250 ChannelAddr::from_zmq_url("metatls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443")
1251 .unwrap(),
1252 ChannelAddr::MetaTls(TlsAddr::new("2a03:83e4:5000:c000:56d7:cf:75ce:144a", 443))
1253 );
1254
1255 assert_eq!(
1257 ChannelAddr::from_zmq_url("metatls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443")
1258 .unwrap(),
1259 ChannelAddr::from_zmq_url("metatls://2a03:83e4:5000:c000:56d7:cf:75ce:144a:443")
1260 .unwrap(),
1261 );
1262
1263 assert_eq!(
1265 ChannelAddr::from_zmq_url("metatls://[::1]:443").unwrap(),
1266 ChannelAddr::MetaTls(TlsAddr::new("::1", 443))
1267 );
1268
1269 assert_eq!(
1271 ChannelAddr::from_zmq_url("tls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443").unwrap(),
1272 ChannelAddr::Tls(TlsAddr::new("2a03:83e4:5000:c000:56d7:cf:75ce:144a", 443))
1273 );
1274 assert_eq!(
1275 ChannelAddr::from_zmq_url("tls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443").unwrap(),
1276 ChannelAddr::from_zmq_url("tls://2a03:83e4:5000:c000:56d7:cf:75ce:144a:443").unwrap(),
1277 );
1278 assert_eq!(
1279 ChannelAddr::from_zmq_url("tls://[::1]:443").unwrap(),
1280 ChannelAddr::Tls(TlsAddr::new("::1", 443))
1281 );
1282 }
1283
1284 #[test]
1285 fn test_normalize_host() {
1286 assert_eq!(normalize_host("192.168.1.1"), "192.168.1.1");
1288
1289 assert_eq!(normalize_host("example.com"), "example.com");
1291
1292 assert_eq!(
1294 normalize_host("2a03:83e4:5000:c000:56d7:00cf:75ce:144a"),
1295 "2a03:83e4:5000:c000:56d7:cf:75ce:144a"
1296 );
1297
1298 assert_eq!(normalize_host("[::1]"), "::1");
1300
1301 assert!("[::1]".parse::<IpAddr>().is_err());
1305 }
1306
1307 #[test]
1308 fn test_zmq_style_alias_channel_addr() {
1309 let alias_addr = ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::]:8800").unwrap();
1315 match alias_addr {
1316 ChannelAddr::Alias { dial_to, bind_to } => {
1317 assert_eq!(
1318 *dial_to,
1319 ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap())
1320 );
1321 assert_eq!(*bind_to, ChannelAddr::Tcp("[::]:8800".parse().unwrap()));
1322 }
1323 _ => panic!("Expected Alias"),
1324 }
1325
1326 assert!(
1328 ChannelAddr::from_zmq_url("metatls://example.com:443@tcp://127.0.0.1:8080").is_err()
1329 );
1330
1331 assert!(
1333 ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@metatls://example.com:443").is_err()
1334 );
1335
1336 assert!(ChannelAddr::from_zmq_url("invalid://scheme@tcp://127.0.0.1:8080").is_err());
1338
1339 assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@invalid://scheme").is_err());
1341
1342 assert!(ChannelAddr::from_zmq_url("tcp://host@tcp://127.0.0.1:8080").is_err());
1344
1345 assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@tcp://example.com").is_err());
1347 }
1348
1349 #[tokio::test]
1350 async fn test_multiple_connections() {
1351 for addr in ChannelTransport::all().map(ChannelAddr::any) {
1352 let (listen_addr, mut rx) = crate::channel::serve::<u64>(addr).unwrap();
1353
1354 let mut sends: JoinSet<()> = JoinSet::new();
1355 for message in 0u64..100u64 {
1356 let addr = listen_addr.clone();
1357 sends.spawn(async move {
1358 let tx = dial::<u64>(addr).unwrap();
1359 tx.post(message);
1360 });
1361 }
1362
1363 let mut received: HashSet<u64> = HashSet::new();
1364 while received.len() < 100 {
1365 received.insert(rx.recv().await.unwrap());
1366 }
1367
1368 for message in 0u64..100u64 {
1369 assert!(received.contains(&message));
1370 }
1371
1372 loop {
1373 match sends.join_next().await {
1374 Some(Ok(())) => (),
1375 Some(Err(err)) => panic!("{}", err),
1376 None => break,
1377 }
1378 }
1379 }
1380 }
1381
1382 #[tokio::test]
1383 async fn test_server_close() {
1384 for addr in ChannelTransport::all().map(ChannelAddr::any) {
1385 if net::is_net_addr(&addr) {
1386 continue;
1389 }
1390
1391 let (listen_addr, rx) = crate::channel::serve::<u64>(addr).unwrap();
1392
1393 let tx = dial::<u64>(listen_addr).unwrap();
1394 tx.post(123);
1395 drop(rx);
1396
1397 let start = tokio::time::Instant::now();
1402
1403 let result = loop {
1404 let (return_tx, return_rx) = oneshot::channel();
1405 tx.try_post(123, return_tx);
1406 let result = return_rx.await;
1407
1408 if result.is_ok() || start.elapsed() > Duration::from_secs(10) {
1409 break result;
1410 }
1411 };
1412 assert_matches!(
1413 result,
1414 Ok(SendError {
1415 error: ChannelError::Closed,
1416 message: 123,
1417 reason: None
1418 })
1419 );
1420 }
1421 }
1422
1423 fn addrs() -> Vec<ChannelAddr> {
1424 use rand::Rng;
1425 use rand::distributions::Uniform;
1426
1427 let rng = rand::thread_rng();
1428 vec![
1429 "tcp:[::1]:0".parse().unwrap(),
1430 "local:0".parse().unwrap(),
1431 #[cfg(target_os = "linux")]
1432 "unix:".parse().unwrap(),
1433 #[cfg(target_os = "linux")]
1434 format!(
1435 "unix:@{}",
1436 rng.sample_iter(Uniform::new_inclusive('a', 'z'))
1437 .take(10)
1438 .collect::<String>()
1439 )
1440 .parse()
1441 .unwrap(),
1442 ]
1443 }
1444
1445 #[test]
1446 fn test_bind_spec_from_str() {
1447 assert_eq!(
1449 BindSpec::from_str("tcp").unwrap(),
1450 BindSpec::Any(ChannelTransport::Tcp(TcpMode::Hostname))
1451 );
1452 assert_eq!(
1453 BindSpec::from_str("metatls(Hostname)").unwrap(),
1454 BindSpec::Any(ChannelTransport::MetaTls(TlsMode::Hostname))
1455 );
1456
1457 assert_eq!(
1459 BindSpec::from_str("tcp:127.0.0.1:8080").unwrap(),
1460 BindSpec::Addr(ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap()))
1461 );
1462
1463 assert_eq!(
1465 BindSpec::from_str("tcp://127.0.0.1:9000").unwrap(),
1466 BindSpec::Addr(ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap()))
1467 );
1468 assert_eq!(
1469 BindSpec::from_str("tcp://127.0.0.1:9000@tcp://[::1]:7200").unwrap(),
1470 BindSpec::Addr(
1471 ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::1]:7200").unwrap()
1472 )
1473 );
1474
1475 assert!(BindSpec::from_str("invalid_spec").is_err());
1477 assert!(BindSpec::from_str("unknown://scheme").is_err());
1478 assert!(BindSpec::from_str("").is_err());
1479 }
1480
1481 #[tokio::test]
1482 #[cfg_attr(not(fbcode_build), ignore)]
1484 async fn test_dial_serve() {
1485 for addr in addrs() {
1486 let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1487 let tx = crate::channel::dial(listen_addr).unwrap();
1488 tx.post(123);
1489 assert_eq!(rx.recv().await.unwrap(), 123);
1490 }
1491 }
1492
1493 #[tokio::test]
1494 #[cfg_attr(not(fbcode_build), ignore)]
1496 async fn test_send() {
1497 let config = hyperactor_config::global::lock();
1498
1499 let _guard1 = config.override_key(
1501 crate::config::MESSAGE_DELIVERY_TIMEOUT,
1502 Duration::from_secs(1),
1503 );
1504 let _guard2 = config.override_key(crate::config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
1505 for addr in addrs() {
1506 let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1507 let tx = crate::channel::dial(listen_addr).unwrap();
1508 tx.send(123).await.unwrap();
1509 assert_eq!(rx.recv().await.unwrap(), 123);
1510
1511 drop(rx);
1512 assert_matches!(
1513 tx.send(123).await.unwrap_err(),
1514 SendError {
1515 error: ChannelError::Closed,
1516 message: 123,
1517 ..
1518 }
1519 );
1520 }
1521 }
1522
1523 #[test]
1524 fn test_find_routable_address_skips_link_local_ipv6() {
1525 let link_local_v6: IpAddr = "fe80::1".parse().unwrap();
1526 let routable_v6: IpAddr = "2001:db8::1".parse().unwrap();
1527 let addrs = vec![link_local_v6, routable_v6];
1528 assert_eq!(find_routable_address(&addrs), Some(routable_v6));
1529 }
1530
1531 #[test]
1532 fn test_find_routable_address_skips_link_local_ipv4() {
1533 let link_local_v4: IpAddr = "169.254.1.1".parse().unwrap();
1534 let routable_v4: IpAddr = "192.168.1.1".parse().unwrap();
1535 let addrs = vec![link_local_v4, routable_v4];
1536 assert_eq!(find_routable_address(&addrs), Some(routable_v4));
1537 }
1538
1539 #[test]
1540 fn test_find_routable_address_returns_none_when_all_link_local() {
1541 let link_local_v6: IpAddr = "fe80::1".parse().unwrap();
1542 let link_local_v4: IpAddr = "169.254.1.1".parse().unwrap();
1543 let addrs = vec![link_local_v6, link_local_v4];
1544 assert_eq!(find_routable_address(&addrs), None);
1545 }
1546
1547 #[test]
1548 fn test_find_routable_address_mixed() {
1549 let link_local_v6: IpAddr = "fe80::1".parse().unwrap();
1550 let link_local_v4: IpAddr = "169.254.0.1".parse().unwrap();
1551 let routable_v4: IpAddr = "10.0.0.1".parse().unwrap();
1552 let routable_v6: IpAddr = "2001:db8::2".parse().unwrap();
1553
1554 let addrs = vec![link_local_v6, link_local_v4, routable_v4, routable_v6];
1556 assert_eq!(find_routable_address(&addrs), Some(routable_v4));
1557 }
1558}