1use core::net::SocketAddr;
13use std::fmt;
14use std::net::IpAddr;
15use std::net::Ipv6Addr;
16#[cfg(target_os = "linux")]
17use std::os::linux::net::SocketAddrExt;
18use std::panic::Location;
19use std::str::FromStr;
20use std::sync::Arc;
21
22use async_trait::async_trait;
23use enum_as_inner::EnumAsInner;
24use hyperactor_config::attrs::AttrValue;
25use serde::Deserialize;
26use serde::Serialize;
27use tokio::sync::mpsc;
28use tokio::sync::oneshot;
29use tokio::sync::watch;
30
31use crate as hyperactor;
32use crate::RemoteMessage;
33pub(crate) mod local;
34pub(crate) mod net;
35
36pub use net::try_tls_acceptor;
40pub use net::try_tls_connector;
41pub use net::try_tls_pem_bundle;
42
43pub mod duplex {
45 pub use super::net::duplex::DuplexRx;
46 pub use super::net::duplex::DuplexServer;
47 pub use super::net::duplex::DuplexTx;
48 pub use super::net::duplex::dial;
49 pub use super::net::duplex::serve;
50}
51
52#[derive(thiserror::Error, Debug)]
54pub enum ChannelError {
55 #[error("channel closed")]
57 Closed,
58
59 #[error("send: {0}")]
61 Send(#[source] anyhow::Error),
62
63 #[error(transparent)]
65 Client(#[from] net::ClientError),
66
67 #[error("invalid address {0:?}")]
69 InvalidAddress(String),
70
71 #[error(transparent)]
73 Server(#[from] net::ServerError),
74
75 #[error(transparent)]
77 Bincode(#[from] Box<bincode::ErrorKind>),
78
79 #[error(transparent)]
81 Data(#[from] wirevalue::Error),
82
83 #[error(transparent)]
85 Other(#[from] anyhow::Error),
86
87 #[error("operation timed out after {0:?}")]
89 Timeout(std::time::Duration),
90}
91
92#[derive(thiserror::Error, Debug)]
94#[error("{error} for reason {reason:?}")]
95pub struct SendError<M: RemoteMessage> {
96 #[source]
98 pub error: ChannelError,
99 pub message: M,
101 pub reason: Option<String>,
103}
104
105impl<M: RemoteMessage> From<SendError<M>> for ChannelError {
106 fn from(error: SendError<M>) -> Self {
107 error.error
108 }
109}
110
111#[derive(Debug, Clone, PartialEq, EnumAsInner)]
113pub enum TxStatus {
114 Active,
116 Closed(Arc<str>),
118}
119
120#[async_trait]
122pub trait Tx<M: RemoteMessage> {
123 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>);
129
130 #[allow(clippy::result_large_err)] #[tracing::instrument(level = "debug", skip_all)]
135 fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
136 self.do_post(message, Some(return_channel));
137 }
138
139 #[hyperactor::instrument_infallible]
141 fn post(&self, message: M) {
142 self.do_post(message, None);
143 }
144
145 async fn send(&self, message: M) -> Result<(), SendError<M>> {
148 let (tx, rx) = oneshot::channel();
149 self.try_post(message, tx);
150 match rx.await {
151 Ok(err) => Err(err),
153
154 Err(_) => Ok(()),
157 }
158 }
159
160 fn addr(&self) -> ChannelAddr;
162
163 fn status(&self) -> &watch::Receiver<TxStatus>;
165}
166
167#[async_trait]
169pub trait Rx<M: RemoteMessage> {
170 async fn recv(&mut self) -> Result<M, ChannelError>;
173
174 fn addr(&self) -> ChannelAddr;
176
177 async fn join(self)
181 where
182 Self: Sized;
183}
184
185#[allow(dead_code)] struct MpscTx<M: RemoteMessage> {
187 tx: mpsc::UnboundedSender<M>,
188 addr: ChannelAddr,
189 status: watch::Receiver<TxStatus>,
190}
191
192impl<M: RemoteMessage> MpscTx<M> {
193 #[allow(dead_code)] pub fn new(tx: mpsc::UnboundedSender<M>, addr: ChannelAddr) -> (Self, watch::Sender<TxStatus>) {
195 let (sender, receiver) = watch::channel(TxStatus::Active);
196 (
197 Self {
198 tx,
199 addr,
200 status: receiver,
201 },
202 sender,
203 )
204 }
205}
206
207#[async_trait]
208impl<M: RemoteMessage> Tx<M> for MpscTx<M> {
209 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
210 if let Err(mpsc::error::SendError(message)) = self.tx.send(message) {
211 if let Some(return_channel) = return_channel {
212 return_channel
213 .send(SendError {
214 error: ChannelError::Closed,
215 message,
216 reason: None,
217 })
218 .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
219 }
220 }
221 }
222
223 fn addr(&self) -> ChannelAddr {
224 self.addr.clone()
225 }
226
227 fn status(&self) -> &watch::Receiver<TxStatus> {
228 &self.status
229 }
230}
231
232#[allow(dead_code)] struct MpscRx<M: RemoteMessage> {
234 rx: mpsc::UnboundedReceiver<M>,
235 addr: ChannelAddr,
236 status_sender: watch::Sender<TxStatus>,
238}
239
240impl<M: RemoteMessage> MpscRx<M> {
241 #[allow(dead_code)] pub fn new(
243 rx: mpsc::UnboundedReceiver<M>,
244 addr: ChannelAddr,
245 status_sender: watch::Sender<TxStatus>,
246 ) -> Self {
247 Self {
248 rx,
249 addr,
250 status_sender,
251 }
252 }
253}
254
255impl<M: RemoteMessage> Drop for MpscRx<M> {
256 fn drop(&mut self) {
257 let _ = self
258 .status_sender
259 .send(TxStatus::Closed("receiver dropped".into()));
260 }
261}
262
263#[async_trait]
264impl<M: RemoteMessage> Rx<M> for MpscRx<M> {
265 async fn recv(&mut self) -> Result<M, ChannelError> {
266 self.rx.recv().await.ok_or(ChannelError::Closed)
267 }
268
269 fn addr(&self) -> ChannelAddr {
270 self.addr.clone()
271 }
272
273 async fn join(self) {}
274}
275
276#[derive(
278 Clone,
279 Debug,
280 PartialEq,
281 Eq,
282 Hash,
283 Serialize,
284 Deserialize,
285 strum::EnumIter,
286 strum::Display,
287 strum::EnumString
288)]
289pub enum TcpMode {
290 Localhost,
292 Hostname,
294}
295
296#[derive(
298 Clone,
299 Debug,
300 PartialEq,
301 Eq,
302 Hash,
303 Serialize,
304 Deserialize,
305 strum::EnumIter,
306 strum::Display,
307 strum::EnumString
308)]
309pub enum TlsMode {
310 IpV6,
312 Hostname,
314 }
316
317#[derive(
319 Clone,
320 Debug,
321 PartialEq,
322 Eq,
323 Hash,
324 Serialize,
325 Deserialize,
326 Ord,
327 PartialOrd
328)]
329pub struct TlsAddr {
330 pub hostname: Hostname,
332 pub port: Port,
334}
335
336impl TlsAddr {
337 pub fn new(hostname: impl Into<Hostname>, port: Port) -> Self {
339 Self {
340 hostname: normalize_host(&hostname.into()),
341 port,
342 }
343 }
344
345 pub fn port(&self) -> Port {
347 self.port
348 }
349
350 pub fn hostname(&self) -> &str {
352 &self.hostname
353 }
354}
355
356impl fmt::Display for TlsAddr {
357 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358 write!(f, "{}:{}", self.hostname, self.port)
359 }
360}
361
362#[derive(
364 Clone,
365 Debug,
366 PartialEq,
367 Eq,
368 Hash,
369 Serialize,
370 Deserialize,
371 typeuri::Named
372)]
373pub enum ChannelTransport {
374 Tcp(TcpMode),
376
377 MetaTls(TlsMode),
379
380 Tls,
382
383 Local,
385
386 Unix,
388}
389
390impl fmt::Display for ChannelTransport {
391 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
392 match self {
393 Self::Tcp(mode) => write!(f, "tcp({:?})", mode),
394 Self::MetaTls(mode) => write!(f, "metatls({:?})", mode),
395 Self::Tls => write!(f, "tls"),
396 Self::Local => write!(f, "local"),
397 Self::Unix => write!(f, "unix"),
398 }
399 }
400}
401
402impl FromStr for ChannelTransport {
403 type Err = anyhow::Error;
404
405 fn from_str(s: &str) -> Result<Self, Self::Err> {
406 match s {
407 "tcp" => Ok(ChannelTransport::Tcp(TcpMode::Hostname)),
409 s if s.starts_with("tcp(") => {
410 let inner = &s["tcp(".len()..s.len() - 1];
411 let mode = inner.parse()?;
412 Ok(ChannelTransport::Tcp(mode))
413 }
414 "local" => Ok(ChannelTransport::Local),
415 "unix" => Ok(ChannelTransport::Unix),
416 "tls" => Ok(ChannelTransport::Tls),
417 s if s.starts_with("metatls(") && s.ends_with(")") => {
418 let inner = &s["metatls(".len()..s.len() - 1];
419 let mode = inner.parse()?;
420 Ok(ChannelTransport::MetaTls(mode))
421 }
422 unknown => Err(anyhow::anyhow!("unknown channel transport: {}", unknown)),
423 }
424 }
425}
426
427impl ChannelTransport {
428 pub fn all() -> [ChannelTransport; 3] {
430 [
431 ChannelTransport::Tcp(TcpMode::Hostname),
434 ChannelTransport::Local,
435 ChannelTransport::Unix,
436 ]
439 }
440
441 pub fn any(&self) -> ChannelAddr {
443 ChannelAddr::any(self.clone())
444 }
445
446 pub fn is_remote(&self) -> bool {
448 match self {
449 ChannelTransport::Tcp(_) => true,
450 ChannelTransport::MetaTls(_) => true,
451 ChannelTransport::Tls => true,
452 ChannelTransport::Local => false,
453 ChannelTransport::Unix => false,
454 }
455 }
456}
457
458impl AttrValue for ChannelTransport {
459 fn display(&self) -> String {
460 self.to_string()
461 }
462
463 fn parse(s: &str) -> Result<Self, anyhow::Error> {
464 s.parse()
465 }
466}
467
468#[derive(
470 Clone,
471 Debug,
472 PartialEq,
473 Eq,
474 Hash,
475 Serialize,
476 Deserialize,
477 typeuri::Named
478)]
479pub enum BindSpec {
480 Any(ChannelTransport),
482
483 Addr(ChannelAddr),
485}
486
487impl BindSpec {
488 pub fn binding_addr(&self) -> ChannelAddr {
490 match self {
491 BindSpec::Any(transport) => ChannelAddr::any(transport.clone()),
492 BindSpec::Addr(addr) => addr.clone(),
493 }
494 }
495}
496
497impl From<ChannelTransport> for BindSpec {
498 fn from(transport: ChannelTransport) -> Self {
499 BindSpec::Any(transport)
500 }
501}
502
503impl From<ChannelAddr> for BindSpec {
504 fn from(addr: ChannelAddr) -> Self {
505 BindSpec::Addr(addr)
506 }
507}
508
509impl fmt::Display for BindSpec {
510 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
511 match self {
512 Self::Any(transport) => write!(f, "{}", transport),
513 Self::Addr(addr) => write!(f, "{}", addr),
514 }
515 }
516}
517
518impl FromStr for BindSpec {
519 type Err = anyhow::Error;
520
521 fn from_str(s: &str) -> Result<Self, Self::Err> {
522 if let Ok(transport) = ChannelTransport::from_str(s) {
523 Ok(BindSpec::Any(transport))
524 } else if let Ok(addr) = ChannelAddr::from_zmq_url(s) {
525 Ok(BindSpec::Addr(addr))
526 } else if let Ok(addr) = ChannelAddr::from_str(s) {
527 Ok(BindSpec::Addr(addr))
528 } else {
529 Err(anyhow::anyhow!("invalid bind spec: {}", s))
530 }
531 }
532}
533
534impl AttrValue for BindSpec {
535 fn display(&self) -> String {
536 self.to_string()
537 }
538
539 fn parse(s: &str) -> Result<Self, anyhow::Error> {
540 Self::from_str(s)
541 }
542}
543
544pub type Hostname = String;
546
547pub type Port = u16;
549
550#[derive(
573 Clone,
574 Debug,
575 PartialEq,
576 Eq,
577 Ord,
578 PartialOrd,
579 Serialize,
580 Deserialize,
581 Hash,
582 typeuri::Named
583)]
584pub enum ChannelAddr {
585 Tcp(SocketAddr),
588
589 MetaTls(TlsAddr),
592
593 Tls(TlsAddr),
596
597 Local(u64),
600
601 Unix(net::unix::SocketAddr),
604
605 Alias {
620 dial_to: Box<ChannelAddr>,
622 bind_to: Box<ChannelAddr>,
624 },
625}
626
627impl From<SocketAddr> for ChannelAddr {
628 fn from(value: SocketAddr) -> Self {
629 Self::Tcp(value)
630 }
631}
632
633impl From<net::unix::SocketAddr> for ChannelAddr {
634 fn from(value: net::unix::SocketAddr) -> Self {
635 Self::Unix(value)
636 }
637}
638
639impl From<std::os::unix::net::SocketAddr> for ChannelAddr {
640 fn from(value: std::os::unix::net::SocketAddr) -> Self {
641 Self::Unix(net::unix::SocketAddr::new(value))
642 }
643}
644
645impl From<tokio::net::unix::SocketAddr> for ChannelAddr {
646 fn from(value: tokio::net::unix::SocketAddr) -> Self {
647 std::os::unix::net::SocketAddr::from(value).into()
648 }
649}
650
651fn find_routable_address(addresses: &[IpAddr]) -> Option<IpAddr> {
653 addresses
654 .iter()
655 .find(|addr| match addr {
656 IpAddr::V6(v6) => !v6.is_unicast_link_local(),
657 IpAddr::V4(v4) => !v4.is_link_local(),
658 })
659 .cloned()
660}
661
662impl ChannelAddr {
663 pub fn any(transport: ChannelTransport) -> Self {
666 match transport {
667 ChannelTransport::Tcp(mode) => {
668 let ip = match mode {
669 TcpMode::Localhost => IpAddr::V6(Ipv6Addr::LOCALHOST),
670 TcpMode::Hostname => {
671 hostname::get()
672 .ok()
673 .and_then(|hostname| {
674 hostname.to_str().and_then(|hostname_str| {
676 dns_lookup::lookup_host(hostname_str)
677 .ok()
678 .and_then(|addresses| find_routable_address(&addresses))
679 })
680 })
681 .expect("failed to resolve hostname to ip address")
682 }
683 };
684 Self::Tcp(SocketAddr::new(ip, 0))
685 }
686 ChannelTransport::MetaTls(mode) => {
687 let host_address = match mode {
688 TlsMode::Hostname => hostname::get()
689 .ok()
690 .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
691 .unwrap_or("unknown_host".to_string()),
692 TlsMode::IpV6 => {
693 get_host_ipv6_address().expect("failed to retrieve ipv6 address")
694 }
695 };
696 Self::MetaTls(TlsAddr::new(host_address, 0))
697 }
698 ChannelTransport::Local => Self::Local(0),
699 ChannelTransport::Tls => {
700 let host_address = hostname::get()
701 .ok()
702 .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
703 .unwrap_or("localhost".to_string());
704 Self::Tls(TlsAddr::new(host_address, 0))
705 }
706 ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()),
708 }
709 }
710
711 pub fn transport(&self) -> ChannelTransport {
713 match self {
714 Self::Tcp(addr) => {
715 if addr.ip().is_loopback() {
716 ChannelTransport::Tcp(TcpMode::Localhost)
717 } else {
718 ChannelTransport::Tcp(TcpMode::Hostname)
719 }
720 }
721 Self::MetaTls(addr) => match addr.hostname.parse::<IpAddr>() {
722 Ok(IpAddr::V6(_)) => ChannelTransport::MetaTls(TlsMode::IpV6),
723 Ok(IpAddr::V4(_)) => ChannelTransport::MetaTls(TlsMode::Hostname),
724 Err(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
725 },
726 Self::Tls(_) => ChannelTransport::Tls,
727 Self::Local(_) => ChannelTransport::Local,
728 Self::Unix(_) => ChannelTransport::Unix,
729 Self::Alias { bind_to, .. } => bind_to.transport(),
732 }
733 }
734}
735
736#[cfg(fbcode_build)]
737fn get_host_ipv6_address() -> anyhow::Result<String> {
738 crate::meta::host_ip::host_ipv6_address()
739}
740
741#[cfg(not(fbcode_build))]
742fn get_host_ipv6_address() -> anyhow::Result<String> {
743 Ok(local_ip_address::local_ipv6()?.to_string())
744}
745
746impl fmt::Display for ChannelAddr {
747 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
748 match self {
749 Self::Tcp(addr) => write!(f, "tcp:{}", addr),
750 Self::MetaTls(addr) => write!(f, "metatls:{}", addr),
751 Self::Tls(addr) => write!(f, "tls:{}", addr),
752 Self::Local(index) => write!(f, "local:{}", index),
753 Self::Unix(addr) => write!(f, "unix:{}", addr),
754 Self::Alias { dial_to, bind_to } => {
755 write!(f, "alias:dial_to={};bind_to={}", dial_to, bind_to)
756 }
757 }
758 }
759}
760
761impl FromStr for ChannelAddr {
762 type Err = anyhow::Error;
763
764 fn from_str(addr: &str) -> Result<Self, Self::Err> {
765 match addr.split_once('!').or_else(|| addr.split_once(':')) {
766 Some(("local", rest)) => rest
767 .parse::<u64>()
768 .map(Self::Local)
769 .map_err(anyhow::Error::from),
770 Some(("tcp", rest)) => rest
771 .parse::<SocketAddr>()
772 .map(Self::Tcp)
773 .map_err(anyhow::Error::from),
774 Some(("metatls", rest)) => net::meta::parse(rest).map_err(|e| e.into()),
775 Some(("tls", rest)) => net::tls::parse(rest).map_err(|e| e.into()),
776 Some(("unix", rest)) => Ok(Self::Unix(net::unix::SocketAddr::from_str(rest)?)),
777 Some(("alias", _)) => Err(anyhow::anyhow!(
778 "detect possible alias address, but we currently do not support \
779 parsing alias' string representation since we only want to \
780 support parsing its zmq url format."
781 )),
782 Some((r#type, _)) => Err(anyhow::anyhow!("no such channel type: {type}")),
783 None => Err(anyhow::anyhow!("no channel type specified")),
784 }
785 }
786}
787
788pub(crate) fn normalize_host(host: &str) -> String {
791 let host_clean = host
794 .strip_prefix('[')
795 .and_then(|h| h.strip_suffix(']'))
796 .unwrap_or(host);
797
798 if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
799 ip_addr.to_string()
800 } else {
801 host.to_string()
802 }
803}
804
805impl ChannelAddr {
806 pub fn from_zmq_url(address: &str) -> Result<Self, anyhow::Error> {
815 if let Some(at_pos) = address.find('@') {
818 let dial_to_str = &address[..at_pos];
819 let bind_to_str = &address[at_pos + 1..];
820
821 if !dial_to_str.starts_with("tcp://") {
823 return Err(anyhow::anyhow!(
824 "alias format is only supported for TCP addresses, got dial_to: {}",
825 dial_to_str
826 ));
827 }
828 if !bind_to_str.starts_with("tcp://") {
829 return Err(anyhow::anyhow!(
830 "alias format is only supported for TCP addresses, got bind_to: {}",
831 bind_to_str
832 ));
833 }
834
835 let dial_to = Self::from_zmq_url(dial_to_str)?;
836 let bind_to = Self::from_zmq_url(bind_to_str)?;
837
838 return Ok(Self::Alias {
839 dial_to: Box::new(dial_to),
840 bind_to: Box::new(bind_to),
841 });
842 }
843
844 let (scheme, address) = address.split_once("://").ok_or_else(|| {
846 anyhow::anyhow!("address must be in url form scheme://endppoint {}", address)
847 })?;
848
849 match scheme {
850 "tcp" => {
851 let (host, port) = Self::split_host_port(address)?;
852
853 if host == "*" {
854 Ok(Self::Tcp(SocketAddr::new("::".parse().unwrap(), port)))
856 } else {
857 let socket_addr = Self::resolve_hostname_to_socket_addr(host, port)?;
859 Ok(Self::Tcp(socket_addr))
860 }
861 }
862 "inproc" => {
863 let port = address.parse::<u64>().map_err(|_| {
866 anyhow::anyhow!("inproc endpoint must be a valid port number: {}", address)
867 })?;
868 Ok(Self::Local(port))
869 }
870 "ipc" => {
871 Ok(Self::Unix(net::unix::SocketAddr::from_str(address)?))
873 }
874 "metatls" => {
875 let (host, port) = Self::split_host_port(address)?;
876
877 if host == "*" {
878 Ok(Self::MetaTls(TlsAddr::new(
880 std::net::Ipv6Addr::UNSPECIFIED.to_string(),
881 port,
882 )))
883 } else {
884 Ok(Self::MetaTls(TlsAddr::new(host, port)))
885 }
886 }
887 "tls" => {
888 let (host, port) = Self::split_host_port(address)?;
889
890 if host == "*" {
891 Ok(Self::Tls(TlsAddr::new(
893 std::net::Ipv6Addr::UNSPECIFIED.to_string(),
894 port,
895 )))
896 } else {
897 Ok(Self::Tls(TlsAddr::new(host, port)))
898 }
899 }
900 scheme => Err(anyhow::anyhow!("unsupported ZMQ scheme: {}", scheme)),
901 }
902 }
903
904 fn split_host_port(address: &str) -> Result<(&str, u16), anyhow::Error> {
906 if let Some((host, port_str)) = address.rsplit_once(':') {
907 let port: u16 = port_str
908 .parse()
909 .map_err(|_| anyhow::anyhow!("invalid port: {}", port_str))?;
910 Ok((host, port))
911 } else {
912 Err(anyhow::anyhow!("invalid address format: {}", address))
913 }
914 }
915
916 pub fn to_zmq_url(&self) -> String {
918 match self {
919 Self::Tcp(addr) => format!("tcp://{}", addr),
920 Self::MetaTls(addr) => format!("metatls://{}:{}", addr.hostname, addr.port),
921 Self::Tls(addr) => format!("tls://{}:{}", addr.hostname, addr.port),
922 Self::Local(index) => format!("inproc://{}", index),
923 Self::Unix(addr) => format!("ipc://{}", addr),
924 Self::Alias { dial_to, bind_to } => {
925 format!("{}@{}", dial_to.to_zmq_url(), bind_to.to_zmq_url())
926 }
927 }
928 }
929
930 fn resolve_hostname_to_socket_addr(host: &str, port: u16) -> Result<SocketAddr, anyhow::Error> {
932 let host_clean = if host.starts_with('[') && host.ends_with(']') {
934 &host[1..host.len() - 1]
935 } else {
936 host
937 };
938
939 if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
941 return Ok(SocketAddr::new(ip_addr, port));
942 }
943
944 use std::net::ToSocketAddrs;
946 let mut addrs = (host_clean, port)
947 .to_socket_addrs()
948 .map_err(|e| anyhow::anyhow!("failed to resolve hostname '{}': {}", host_clean, e))?;
949
950 addrs
951 .next()
952 .ok_or_else(|| anyhow::anyhow!("no addresses found for hostname '{}'", host_clean))
953 }
954}
955
956pub struct ChannelTx<M: RemoteMessage> {
958 inner: ChannelTxKind<M>,
959}
960
961impl<M: RemoteMessage> fmt::Debug for ChannelTx<M> {
962 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
963 f.debug_struct("ChannelTx")
964 .field("addr", &self.addr())
965 .finish()
966 }
967}
968
969enum ChannelTxKind<M: RemoteMessage> {
971 Local(local::LocalTx<M>),
972 Net(net::NetTx<M>),
973}
974
975#[async_trait]
976impl<M: RemoteMessage> Tx<M> for ChannelTx<M> {
977 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
978 match &self.inner {
979 ChannelTxKind::Local(tx) => tx.do_post(message, return_channel),
980 ChannelTxKind::Net(tx) => tx.do_post(message, return_channel),
981 }
982 }
983
984 fn addr(&self) -> ChannelAddr {
985 match &self.inner {
986 ChannelTxKind::Local(tx) => tx.addr(),
987 ChannelTxKind::Net(tx) => Tx::<M>::addr(tx),
988 }
989 }
990
991 fn status(&self) -> &watch::Receiver<TxStatus> {
992 match &self.inner {
993 ChannelTxKind::Local(tx) => tx.status(),
994 ChannelTxKind::Net(tx) => tx.status(),
995 }
996 }
997}
998
999pub struct ChannelRx<M: RemoteMessage> {
1001 inner: ChannelRxKind<M>,
1002}
1003
1004impl<M: RemoteMessage> fmt::Debug for ChannelRx<M> {
1005 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1006 f.debug_struct("ChannelRx")
1007 .field("addr", &self.addr())
1008 .finish()
1009 }
1010}
1011
1012enum ChannelRxKind<M: RemoteMessage> {
1014 Local(local::LocalRx<M>),
1015 Net(net::NetRx<M>),
1016}
1017
1018#[async_trait]
1019impl<M: RemoteMessage> Rx<M> for ChannelRx<M> {
1020 #[tracing::instrument(level = "debug", skip_all)]
1021 async fn recv(&mut self) -> Result<M, ChannelError> {
1022 match &mut self.inner {
1023 ChannelRxKind::Local(rx) => rx.recv().await,
1024 ChannelRxKind::Net(rx) => rx.recv().await,
1025 }
1026 }
1027
1028 fn addr(&self) -> ChannelAddr {
1029 match &self.inner {
1030 ChannelRxKind::Local(rx) => rx.addr(),
1031 ChannelRxKind::Net(rx) => rx.addr(),
1032 }
1033 }
1034
1035 async fn join(self) {
1036 match self.inner {
1037 ChannelRxKind::Local(rx) => rx.join().await,
1038 ChannelRxKind::Net(rx) => rx.join().await,
1039 }
1040 }
1041}
1042
1043#[allow(clippy::result_large_err)] #[track_caller]
1048pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, ChannelError> {
1049 tracing::debug!(name = "dial", caller = %Location::caller(), %addr, "dialing channel {}", addr);
1050 let inner = match addr {
1051 ChannelAddr::Local(port) => ChannelTxKind::Local(local::dial(port)?),
1052 ChannelAddr::Tcp(_)
1053 | ChannelAddr::Unix(_)
1054 | ChannelAddr::Tls(_)
1055 | ChannelAddr::MetaTls(_) => ChannelTxKind::Net(net::spawn(net::link(addr)?)),
1056 ChannelAddr::Alias { dial_to, .. } => dial(*dial_to)?.inner,
1057 };
1058 Ok(ChannelTx { inner })
1059}
1060
1061#[track_caller]
1064pub fn serve<M: RemoteMessage>(
1065 addr: ChannelAddr,
1066) -> Result<(ChannelAddr, ChannelRx<M>), ChannelError> {
1067 let caller = Location::caller();
1068 serve_inner(addr).map(|(addr, inner)| {
1069 tracing::debug!(
1070 name = "serve",
1071 %addr,
1072 %caller,
1073 );
1074 (addr, ChannelRx { inner })
1075 })
1076}
1077
1078fn serve_inner<M: RemoteMessage>(
1079 addr: ChannelAddr,
1080) -> Result<(ChannelAddr, ChannelRxKind<M>), ChannelError> {
1081 match addr {
1082 ChannelAddr::Tcp(_)
1083 | ChannelAddr::Unix(_)
1084 | ChannelAddr::Tls(_)
1085 | ChannelAddr::MetaTls(_) => {
1086 let (addr, rx) = net::server::serve::<M>(addr)?;
1087 Ok((addr, ChannelRxKind::Net(rx)))
1088 }
1089 ChannelAddr::Local(0) => {
1090 let (port, rx) = local::serve::<M>();
1091 Ok((ChannelAddr::Local(port), ChannelRxKind::Local(rx)))
1092 }
1093 ChannelAddr::Local(a) => Err(ChannelError::InvalidAddress(format!(
1094 "invalid local addr: {}",
1095 a
1096 ))),
1097 ChannelAddr::Alias { dial_to, bind_to } => {
1098 let (bound_addr, rx) = serve_inner::<M>(*bind_to)?;
1099 let alias_addr = ChannelAddr::Alias {
1100 dial_to,
1101 bind_to: Box::new(bound_addr),
1102 };
1103 Ok((alias_addr, rx))
1104 }
1105 }
1106}
1107
1108pub fn serve_local<M: RemoteMessage>() -> (ChannelAddr, ChannelRx<M>) {
1111 let (port, rx) = local::serve::<M>();
1112 (
1113 ChannelAddr::Local(port),
1114 ChannelRx {
1115 inner: ChannelRxKind::Local(rx),
1116 },
1117 )
1118}
1119
1120#[cfg(test)]
1121mod tests {
1122 use std::assert_matches::assert_matches;
1123 use std::collections::HashSet;
1124 use std::net::IpAddr;
1125 use std::net::Ipv4Addr;
1126 use std::net::Ipv6Addr;
1127 use std::time::Duration;
1128
1129 use tokio::task::JoinSet;
1130
1131 use super::net::*;
1132 use super::*;
1133 #[test]
1134 fn test_channel_addr() {
1135 let cases_ok = vec![
1136 (
1137 "tcp<DELIM>[::1]:1234",
1138 ChannelAddr::Tcp(SocketAddr::new(
1139 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
1140 1234,
1141 )),
1142 ),
1143 (
1144 "tcp<DELIM>127.0.0.1:8080",
1145 ChannelAddr::Tcp(SocketAddr::new(
1146 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
1147 8080,
1148 )),
1149 ),
1150 #[cfg(target_os = "linux")]
1151 ("local<DELIM>123", ChannelAddr::Local(123)),
1152 (
1153 "unix<DELIM>@yolo",
1154 ChannelAddr::Unix(
1155 unix::SocketAddr::from_abstract_name("yolo")
1156 .expect("can't make socket from abstract name"),
1157 ),
1158 ),
1159 (
1160 "unix<DELIM>/cool/socket-path",
1161 ChannelAddr::Unix(
1162 unix::SocketAddr::from_pathname("/cool/socket-path")
1163 .expect("can't make socket from path"),
1164 ),
1165 ),
1166 ];
1167
1168 for (raw, parsed) in cases_ok {
1169 for delim in ["!", ":"] {
1170 let raw = raw.replace("<DELIM>", delim);
1171 assert_eq!(raw.parse::<ChannelAddr>().unwrap(), parsed);
1172 }
1173 }
1174
1175 let cases_err = vec![
1176 ("tcp:abcdef..123124", "invalid socket address syntax"),
1177 ("xxx:foo", "no such channel type: xxx"),
1178 ("127.0.0.1", "no channel type specified"),
1179 ("local:abc", "invalid digit found in string"),
1180 ];
1181
1182 for (raw, error) in cases_err {
1183 let Err(err) = raw.parse::<ChannelAddr>() else {
1184 panic!("expected error parsing: {}", &raw)
1185 };
1186 assert_eq!(format!("{}", err), error);
1187 }
1188 }
1189
1190 #[test]
1191 fn test_zmq_style_channel_addr() {
1192 assert_eq!(
1194 ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080").unwrap(),
1195 ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap())
1196 );
1197
1198 assert_eq!(
1200 ChannelAddr::from_zmq_url("tcp://*:5555").unwrap(),
1201 ChannelAddr::Tcp("[::]:5555".parse().unwrap())
1202 );
1203
1204 assert_eq!(
1206 ChannelAddr::from_zmq_url("inproc://12345").unwrap(),
1207 ChannelAddr::Local(12345)
1208 );
1209
1210 assert_eq!(
1212 ChannelAddr::from_zmq_url("ipc:///tmp/my-socket").unwrap(),
1213 ChannelAddr::Unix(unix::SocketAddr::from_pathname("/tmp/my-socket").unwrap())
1214 );
1215
1216 assert_eq!(
1218 ChannelAddr::from_zmq_url("metatls://example.com:443").unwrap(),
1219 ChannelAddr::MetaTls(TlsAddr::new("example.com", 443))
1220 );
1221
1222 assert_eq!(
1224 ChannelAddr::from_zmq_url("metatls://192.168.1.1:443").unwrap(),
1225 ChannelAddr::MetaTls(TlsAddr::new("192.168.1.1", 443))
1226 );
1227
1228 assert_eq!(
1230 ChannelAddr::from_zmq_url("metatls://*:8443").unwrap(),
1231 ChannelAddr::MetaTls(TlsAddr::new("::", 8443))
1232 );
1233
1234 let tcp_hostname_result = ChannelAddr::from_zmq_url("tcp://localhost:8080");
1238 assert!(tcp_hostname_result.is_ok());
1239
1240 assert_eq!(
1242 ChannelAddr::from_zmq_url("tcp://[::1]:1234").unwrap(),
1243 ChannelAddr::Tcp("[::1]:1234".parse().unwrap())
1244 );
1245
1246 assert!(ChannelAddr::from_zmq_url("invalid://scheme").is_err());
1248 assert!(ChannelAddr::from_zmq_url("tcp://invalid-port").is_err());
1249 assert!(ChannelAddr::from_zmq_url("metatls://no-port").is_err());
1250 assert!(ChannelAddr::from_zmq_url("inproc://not-a-number").is_err());
1251
1252 assert_eq!(
1254 ChannelAddr::from_zmq_url("metatls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443")
1255 .unwrap(),
1256 ChannelAddr::MetaTls(TlsAddr::new("2a03:83e4:5000:c000:56d7:cf:75ce:144a", 443))
1257 );
1258
1259 assert_eq!(
1261 ChannelAddr::from_zmq_url("metatls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443")
1262 .unwrap(),
1263 ChannelAddr::from_zmq_url("metatls://2a03:83e4:5000:c000:56d7:cf:75ce:144a:443")
1264 .unwrap(),
1265 );
1266
1267 assert_eq!(
1269 ChannelAddr::from_zmq_url("metatls://[::1]:443").unwrap(),
1270 ChannelAddr::MetaTls(TlsAddr::new("::1", 443))
1271 );
1272
1273 assert_eq!(
1275 ChannelAddr::from_zmq_url("tls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443").unwrap(),
1276 ChannelAddr::Tls(TlsAddr::new("2a03:83e4:5000:c000:56d7:cf:75ce:144a", 443))
1277 );
1278 assert_eq!(
1279 ChannelAddr::from_zmq_url("tls://2a03:83e4:5000:c000:56d7:00cf:75ce:144a:443").unwrap(),
1280 ChannelAddr::from_zmq_url("tls://2a03:83e4:5000:c000:56d7:cf:75ce:144a:443").unwrap(),
1281 );
1282 assert_eq!(
1283 ChannelAddr::from_zmq_url("tls://[::1]:443").unwrap(),
1284 ChannelAddr::Tls(TlsAddr::new("::1", 443))
1285 );
1286 }
1287
1288 #[test]
1289 fn test_normalize_host() {
1290 assert_eq!(normalize_host("192.168.1.1"), "192.168.1.1");
1292
1293 assert_eq!(normalize_host("example.com"), "example.com");
1295
1296 assert_eq!(
1298 normalize_host("2a03:83e4:5000:c000:56d7:00cf:75ce:144a"),
1299 "2a03:83e4:5000:c000:56d7:cf:75ce:144a"
1300 );
1301
1302 assert_eq!(normalize_host("[::1]"), "::1");
1304
1305 assert!("[::1]".parse::<IpAddr>().is_err());
1309 }
1310
1311 #[test]
1312 fn test_zmq_style_alias_channel_addr() {
1313 let alias_addr = ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::]:8800").unwrap();
1319 match alias_addr {
1320 ChannelAddr::Alias { dial_to, bind_to } => {
1321 assert_eq!(
1322 *dial_to,
1323 ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap())
1324 );
1325 assert_eq!(*bind_to, ChannelAddr::Tcp("[::]:8800".parse().unwrap()));
1326 }
1327 _ => panic!("Expected Alias"),
1328 }
1329
1330 assert!(
1332 ChannelAddr::from_zmq_url("metatls://example.com:443@tcp://127.0.0.1:8080").is_err()
1333 );
1334
1335 assert!(
1337 ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@metatls://example.com:443").is_err()
1338 );
1339
1340 assert!(ChannelAddr::from_zmq_url("invalid://scheme@tcp://127.0.0.1:8080").is_err());
1342
1343 assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@invalid://scheme").is_err());
1345
1346 assert!(ChannelAddr::from_zmq_url("tcp://host@tcp://127.0.0.1:8080").is_err());
1348
1349 assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@tcp://example.com").is_err());
1351 }
1352
1353 #[tokio::test]
1354 async fn test_multiple_connections() {
1355 for addr in ChannelTransport::all().map(ChannelAddr::any) {
1356 let (listen_addr, mut rx) = crate::channel::serve::<u64>(addr).unwrap();
1357
1358 let mut sends: JoinSet<()> = JoinSet::new();
1359 for message in 0u64..100u64 {
1360 let addr = listen_addr.clone();
1361 sends.spawn(async move {
1362 let tx = dial::<u64>(addr).unwrap();
1363 tx.post(message);
1364 });
1365 }
1366
1367 let mut received: HashSet<u64> = HashSet::new();
1368 while received.len() < 100 {
1369 received.insert(rx.recv().await.unwrap());
1370 }
1371
1372 for message in 0u64..100u64 {
1373 assert!(received.contains(&message));
1374 }
1375
1376 loop {
1377 match sends.join_next().await {
1378 Some(Ok(())) => (),
1379 Some(Err(err)) => panic!("{}", err),
1380 None => break,
1381 }
1382 }
1383 }
1384 }
1385
1386 #[tokio::test]
1387 async fn test_server_close() {
1388 for addr in ChannelTransport::all().map(ChannelAddr::any) {
1389 if net::is_net_addr(&addr) {
1390 continue;
1393 }
1394
1395 let (listen_addr, rx) = crate::channel::serve::<u64>(addr).unwrap();
1396
1397 let tx = dial::<u64>(listen_addr).unwrap();
1398 tx.post(123);
1399 drop(rx);
1400
1401 let start = tokio::time::Instant::now();
1406
1407 let result = loop {
1408 let (return_tx, return_rx) = oneshot::channel();
1409 tx.try_post(123, return_tx);
1410 let result = return_rx.await;
1411
1412 if result.is_ok() || start.elapsed() > Duration::from_secs(10) {
1413 break result;
1414 }
1415 };
1416 assert_matches!(
1417 result,
1418 Ok(SendError {
1419 error: ChannelError::Closed,
1420 message: 123,
1421 reason: None
1422 })
1423 );
1424 }
1425 }
1426
1427 fn addrs() -> Vec<ChannelAddr> {
1428 use rand::Rng;
1429 use rand::distr::Uniform;
1430
1431 let rng = rand::rng();
1432 let uniform = Uniform::new_inclusive('a', 'z').unwrap();
1433 vec![
1434 "tcp:[::1]:0".parse().unwrap(),
1435 "local:0".parse().unwrap(),
1436 #[cfg(target_os = "linux")]
1437 "unix:".parse().unwrap(),
1438 #[cfg(target_os = "linux")]
1439 format!(
1440 "unix:@{}",
1441 rng.sample_iter(uniform).take(10).collect::<String>()
1442 )
1443 .parse()
1444 .unwrap(),
1445 ]
1446 }
1447
1448 #[test]
1449 fn test_bind_spec_from_str() {
1450 assert_eq!(
1452 BindSpec::from_str("tcp").unwrap(),
1453 BindSpec::Any(ChannelTransport::Tcp(TcpMode::Hostname))
1454 );
1455 assert_eq!(
1456 BindSpec::from_str("metatls(Hostname)").unwrap(),
1457 BindSpec::Any(ChannelTransport::MetaTls(TlsMode::Hostname))
1458 );
1459
1460 assert_eq!(
1462 BindSpec::from_str("tcp:127.0.0.1:8080").unwrap(),
1463 BindSpec::Addr(ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap()))
1464 );
1465
1466 assert_eq!(
1468 BindSpec::from_str("tcp://127.0.0.1:9000").unwrap(),
1469 BindSpec::Addr(ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap()))
1470 );
1471 assert_eq!(
1472 BindSpec::from_str("tcp://127.0.0.1:9000@tcp://[::1]:7200").unwrap(),
1473 BindSpec::Addr(
1474 ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::1]:7200").unwrap()
1475 )
1476 );
1477
1478 assert!(BindSpec::from_str("invalid_spec").is_err());
1480 assert!(BindSpec::from_str("unknown://scheme").is_err());
1481 assert!(BindSpec::from_str("").is_err());
1482 }
1483
1484 #[tokio::test]
1485 #[cfg_attr(not(fbcode_build), ignore)]
1487 async fn test_dial_serve() {
1488 for addr in addrs() {
1489 let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1490 let tx = crate::channel::dial(listen_addr).unwrap();
1491 tx.post(123);
1492 assert_eq!(rx.recv().await.unwrap(), 123);
1493 }
1494 }
1495
1496 #[tokio::test]
1497 #[cfg_attr(not(fbcode_build), ignore)]
1499 async fn test_send() {
1500 let config = hyperactor_config::global::lock();
1501
1502 let _guard1 = config.override_key(
1504 crate::config::MESSAGE_DELIVERY_TIMEOUT,
1505 Duration::from_secs(1),
1506 );
1507 let _guard2 = config.override_key(crate::config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
1508 for addr in addrs() {
1509 let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1510 let tx = crate::channel::dial(listen_addr).unwrap();
1511 tx.send(123).await.unwrap();
1512 assert_eq!(rx.recv().await.unwrap(), 123);
1513
1514 drop(rx);
1515 assert_matches!(
1516 tx.send(123).await.unwrap_err(),
1517 SendError {
1518 error: ChannelError::Closed,
1519 message: 123,
1520 ..
1521 }
1522 );
1523 }
1524 }
1525
1526 #[test]
1527 fn test_find_routable_address_skips_link_local_ipv6() {
1528 let link_local_v6: IpAddr = "fe80::1".parse().unwrap();
1529 let routable_v6: IpAddr = "2001:db8::1".parse().unwrap();
1530 let addrs = vec![link_local_v6, routable_v6];
1531 assert_eq!(find_routable_address(&addrs), Some(routable_v6));
1532 }
1533
1534 #[test]
1535 fn test_find_routable_address_skips_link_local_ipv4() {
1536 let link_local_v4: IpAddr = "169.254.1.1".parse().unwrap();
1537 let routable_v4: IpAddr = "192.168.1.1".parse().unwrap();
1538 let addrs = vec![link_local_v4, routable_v4];
1539 assert_eq!(find_routable_address(&addrs), Some(routable_v4));
1540 }
1541
1542 #[test]
1543 fn test_find_routable_address_returns_none_when_all_link_local() {
1544 let link_local_v6: IpAddr = "fe80::1".parse().unwrap();
1545 let link_local_v4: IpAddr = "169.254.1.1".parse().unwrap();
1546 let addrs = vec![link_local_v6, link_local_v4];
1547 assert_eq!(find_routable_address(&addrs), None);
1548 }
1549
1550 #[test]
1551 fn test_find_routable_address_mixed() {
1552 let link_local_v6: IpAddr = "fe80::1".parse().unwrap();
1553 let link_local_v4: IpAddr = "169.254.0.1".parse().unwrap();
1554 let routable_v4: IpAddr = "10.0.0.1".parse().unwrap();
1555 let routable_v6: IpAddr = "2001:db8::2".parse().unwrap();
1556
1557 let addrs = vec![link_local_v6, link_local_v4, routable_v4, routable_v6];
1559 assert_eq!(find_routable_address(&addrs), Some(routable_v4));
1560 }
1561}