1#![allow(dead_code)] use core::net::SocketAddr;
15use std::fmt;
16use std::net::IpAddr;
17use std::net::Ipv6Addr;
18#[cfg(target_os = "linux")]
19use std::os::linux::net::SocketAddrExt;
20use std::panic::Location;
21use std::str::FromStr;
22
23use async_trait::async_trait;
24use enum_as_inner::EnumAsInner;
25use hyperactor_config::attrs::AttrValue;
26use lazy_static::lazy_static;
27use local_ip_address::local_ipv6;
28use serde::Deserialize;
29use serde::Serialize;
30use tokio::sync::mpsc;
31use tokio::sync::oneshot;
32use tokio::sync::watch;
33
34use crate as hyperactor;
35use crate::RemoteMessage;
36use crate::channel::sim::SimAddr;
37use crate::simnet::SimNetError;
38
39pub(crate) mod local;
40pub(crate) mod net;
41pub mod sim;
42
43#[derive(thiserror::Error, Debug)]
45pub enum ChannelError {
46 #[error("channel closed")]
48 Closed,
49
50 #[error("send: {0}")]
52 Send(#[source] anyhow::Error),
53
54 #[error(transparent)]
56 Client(#[from] net::ClientError),
57
58 #[error("invalid address {0:?}")]
60 InvalidAddress(String),
61
62 #[error(transparent)]
64 Server(#[from] net::ServerError),
65
66 #[error(transparent)]
68 Bincode(#[from] Box<bincode::ErrorKind>),
69
70 #[error(transparent)]
72 Data(#[from] wirevalue::Error),
73
74 #[error(transparent)]
76 Other(#[from] anyhow::Error),
77
78 #[error("operation timed out after {0:?}")]
80 Timeout(std::time::Duration),
81
82 #[error(transparent)]
84 SimNetError(#[from] SimNetError),
85}
86
87#[derive(thiserror::Error, Debug)]
89#[error("{error} for reason {reason:?}")]
90pub struct SendError<M: RemoteMessage> {
91 #[source]
93 pub error: ChannelError,
94 pub message: M,
96 pub reason: Option<String>,
98}
99
100impl<M: RemoteMessage> From<SendError<M>> for ChannelError {
101 fn from(error: SendError<M>) -> Self {
102 error.error
103 }
104}
105
106#[derive(Debug, Copy, Clone, PartialEq)]
108pub enum TxStatus {
109 Active,
111 Closed,
113}
114
115#[async_trait]
117pub trait Tx<M: RemoteMessage> {
118 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>);
124
125 #[allow(clippy::result_large_err)] #[hyperactor::instrument_infallible]
130 fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
131 self.do_post(message, Some(return_channel));
132 }
133
134 #[hyperactor::instrument_infallible]
136 fn post(&self, message: M) {
137 self.do_post(message, None);
138 }
139
140 async fn send(&self, message: M) -> Result<(), SendError<M>> {
143 let (tx, rx) = oneshot::channel();
144 self.try_post(message, tx);
145 match rx.await {
146 Ok(err) => Err(err),
148
149 Err(_) => Ok(()),
152 }
153 }
154
155 fn addr(&self) -> ChannelAddr;
157
158 fn status(&self) -> &watch::Receiver<TxStatus>;
160}
161
162#[async_trait]
164pub trait Rx<M: RemoteMessage> {
165 async fn recv(&mut self) -> Result<M, ChannelError>;
168
169 fn addr(&self) -> ChannelAddr;
171}
172
173struct MpscTx<M: RemoteMessage> {
174 tx: mpsc::UnboundedSender<M>,
175 addr: ChannelAddr,
176 status: watch::Receiver<TxStatus>,
177}
178
179impl<M: RemoteMessage> MpscTx<M> {
180 pub fn new(tx: mpsc::UnboundedSender<M>, addr: ChannelAddr) -> (Self, watch::Sender<TxStatus>) {
181 let (sender, receiver) = watch::channel(TxStatus::Active);
182 (
183 Self {
184 tx,
185 addr,
186 status: receiver,
187 },
188 sender,
189 )
190 }
191}
192
193#[async_trait]
194impl<M: RemoteMessage> Tx<M> for MpscTx<M> {
195 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
196 if let Err(mpsc::error::SendError(message)) = self.tx.send(message) {
197 if let Some(return_channel) = return_channel {
198 return_channel
199 .send(SendError {
200 error: ChannelError::Closed,
201 message,
202 reason: None,
203 })
204 .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
205 }
206 }
207 }
208
209 fn addr(&self) -> ChannelAddr {
210 self.addr.clone()
211 }
212
213 fn status(&self) -> &watch::Receiver<TxStatus> {
214 &self.status
215 }
216}
217
218struct MpscRx<M: RemoteMessage> {
219 rx: mpsc::UnboundedReceiver<M>,
220 addr: ChannelAddr,
221 status_sender: watch::Sender<TxStatus>,
223}
224
225impl<M: RemoteMessage> MpscRx<M> {
226 pub fn new(
227 rx: mpsc::UnboundedReceiver<M>,
228 addr: ChannelAddr,
229 status_sender: watch::Sender<TxStatus>,
230 ) -> Self {
231 Self {
232 rx,
233 addr,
234 status_sender,
235 }
236 }
237}
238
239impl<M: RemoteMessage> Drop for MpscRx<M> {
240 fn drop(&mut self) {
241 let _ = self.status_sender.send(TxStatus::Closed);
242 }
243}
244
245#[async_trait]
246impl<M: RemoteMessage> Rx<M> for MpscRx<M> {
247 async fn recv(&mut self) -> Result<M, ChannelError> {
248 self.rx.recv().await.ok_or(ChannelError::Closed)
249 }
250
251 fn addr(&self) -> ChannelAddr {
252 self.addr.clone()
253 }
254}
255
256#[derive(
258 Clone,
259 Debug,
260 PartialEq,
261 Eq,
262 Hash,
263 Serialize,
264 Deserialize,
265 strum::EnumIter,
266 strum::Display,
267 strum::EnumString
268)]
269pub enum TcpMode {
270 Localhost,
272 Hostname,
274}
275
276#[derive(
278 Clone,
279 Debug,
280 PartialEq,
281 Eq,
282 Hash,
283 Serialize,
284 Deserialize,
285 strum::EnumIter,
286 strum::Display,
287 strum::EnumString
288)]
289pub enum TlsMode {
290 IpV6,
292 Hostname,
294 }
296
297#[derive(
301 Clone,
302 Debug,
303 PartialEq,
304 Eq,
305 Hash,
306 Serialize,
307 Deserialize,
308 Ord,
309 PartialOrd,
310 EnumAsInner
311)]
312pub enum MetaTlsAddr {
313 Host {
315 hostname: Hostname,
317 port: Port,
319 },
320 Socket(SocketAddr),
322}
323
324impl MetaTlsAddr {
325 pub fn port(&self) -> Port {
327 match self {
328 Self::Host { port, .. } => *port,
329 Self::Socket(addr) => addr.port(),
330 }
331 }
332
333 pub fn hostname(&self) -> Option<&str> {
335 match self {
336 Self::Host { hostname, .. } => Some(hostname),
337 Self::Socket(_) => None,
338 }
339 }
340}
341
342impl fmt::Display for MetaTlsAddr {
343 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344 match self {
345 Self::Host { hostname, port } => write!(f, "{}:{}", hostname, port),
346 Self::Socket(addr) => write!(f, "{}", addr),
347 }
348 }
349}
350
351#[derive(
353 Clone,
354 Debug,
355 PartialEq,
356 Eq,
357 Hash,
358 Serialize,
359 Deserialize,
360 typeuri::Named
361)]
362pub enum ChannelTransport {
363 Tcp(TcpMode),
365
366 MetaTls(TlsMode),
368
369 Local,
371
372 Sim(Box<ChannelTransport>),
374
375 Unix,
377}
378
379impl fmt::Display for ChannelTransport {
380 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381 match self {
382 Self::Tcp(mode) => write!(f, "tcp({:?})", mode),
383 Self::MetaTls(mode) => write!(f, "metatls({:?})", mode),
384 Self::Local => write!(f, "local"),
385 Self::Sim(transport) => write!(f, "sim({})", transport),
386 Self::Unix => write!(f, "unix"),
387 }
388 }
389}
390
391impl FromStr for ChannelTransport {
392 type Err = anyhow::Error;
393
394 fn from_str(s: &str) -> Result<Self, Self::Err> {
395 if let Some(rest) = s.strip_prefix("sim(") {
397 if let Some(end) = rest.rfind(')') {
398 let inner = &rest[..end];
399 let inner_transport = ChannelTransport::from_str(inner)?;
400 return Ok(ChannelTransport::Sim(Box::new(inner_transport)));
401 } else {
402 return Err(anyhow::anyhow!("invalid sim transport"));
403 }
404 }
405
406 match s {
407 "tcp" => Ok(ChannelTransport::Tcp(TcpMode::Hostname)),
409 s if s.starts_with("tcp(") => {
410 let inner = &s["tcp(".len()..s.len() - 1];
411 let mode = inner.parse()?;
412 Ok(ChannelTransport::Tcp(mode))
413 }
414 "local" => Ok(ChannelTransport::Local),
415 "unix" => Ok(ChannelTransport::Unix),
416 s if s.starts_with("metatls(") && s.ends_with(")") => {
417 let inner = &s["metatls(".len()..s.len() - 1];
418 let mode = inner.parse()?;
419 Ok(ChannelTransport::MetaTls(mode))
420 }
421 unknown => Err(anyhow::anyhow!("unknown channel transport: {}", unknown)),
422 }
423 }
424}
425
426impl ChannelTransport {
427 pub fn all() -> [ChannelTransport; 3] {
429 [
430 ChannelTransport::Tcp(TcpMode::Hostname),
433 ChannelTransport::Local,
434 ChannelTransport::Unix,
435 ]
439 }
440
441 pub fn any(&self) -> ChannelAddr {
443 ChannelAddr::any(self.clone())
444 }
445
446 pub fn is_remote(&self) -> bool {
448 match self {
449 ChannelTransport::Tcp(_) => true,
450 ChannelTransport::MetaTls(_) => true,
451 ChannelTransport::Local => false,
452 ChannelTransport::Sim(_) => false,
453 ChannelTransport::Unix => false,
454 }
455 }
456}
457
458impl AttrValue for ChannelTransport {
459 fn display(&self) -> String {
460 self.to_string()
461 }
462
463 fn parse(s: &str) -> Result<Self, anyhow::Error> {
464 s.parse()
465 }
466}
467
468#[derive(
470 Clone,
471 Debug,
472 PartialEq,
473 Eq,
474 Hash,
475 Serialize,
476 Deserialize,
477 typeuri::Named
478)]
479pub enum BindSpec {
480 Any(ChannelTransport),
482
483 Addr(ChannelAddr),
485}
486
487impl BindSpec {
488 pub fn binding_addr(&self) -> ChannelAddr {
490 match self {
491 BindSpec::Any(transport) => ChannelAddr::any(transport.clone()),
492 BindSpec::Addr(addr) => addr.clone(),
493 }
494 }
495}
496
497impl From<ChannelTransport> for BindSpec {
498 fn from(transport: ChannelTransport) -> Self {
499 BindSpec::Any(transport)
500 }
501}
502
503impl From<ChannelAddr> for BindSpec {
504 fn from(addr: ChannelAddr) -> Self {
505 BindSpec::Addr(addr)
506 }
507}
508
509impl fmt::Display for BindSpec {
510 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
511 match self {
512 Self::Any(transport) => write!(f, "{}", transport),
513 Self::Addr(addr) => write!(f, "{}", addr),
514 }
515 }
516}
517
518impl FromStr for BindSpec {
519 type Err = anyhow::Error;
520
521 fn from_str(s: &str) -> Result<Self, Self::Err> {
522 if let Ok(transport) = ChannelTransport::from_str(s) {
523 Ok(BindSpec::Any(transport))
524 } else if let Ok(addr) = ChannelAddr::from_zmq_url(s) {
525 Ok(BindSpec::Addr(addr))
526 } else if let Ok(addr) = ChannelAddr::from_str(s) {
527 Ok(BindSpec::Addr(addr))
528 } else {
529 Err(anyhow::anyhow!("invalid bind spec: {}", s))
530 }
531 }
532}
533
534impl AttrValue for BindSpec {
535 fn display(&self) -> String {
536 self.to_string()
537 }
538
539 fn parse(s: &str) -> Result<Self, anyhow::Error> {
540 Self::from_str(s)
541 }
542}
543
544pub type Hostname = String;
546
547pub type Port = u16;
549
550#[derive(
573 Clone,
574 Debug,
575 PartialEq,
576 Eq,
577 Ord,
578 PartialOrd,
579 Serialize,
580 Deserialize,
581 Hash,
582 typeuri::Named
583)]
584pub enum ChannelAddr {
585 Tcp(SocketAddr),
588
589 MetaTls(MetaTlsAddr),
593
594 Local(u64),
597
598 Sim(SimAddr),
600
601 Unix(net::unix::SocketAddr),
604
605 Alias {
620 dial_to: Box<ChannelAddr>,
622 bind_to: Box<ChannelAddr>,
624 },
625}
626
627impl From<SocketAddr> for ChannelAddr {
628 fn from(value: SocketAddr) -> Self {
629 Self::Tcp(value)
630 }
631}
632
633impl From<net::unix::SocketAddr> for ChannelAddr {
634 fn from(value: net::unix::SocketAddr) -> Self {
635 Self::Unix(value)
636 }
637}
638
639impl From<std::os::unix::net::SocketAddr> for ChannelAddr {
640 fn from(value: std::os::unix::net::SocketAddr) -> Self {
641 Self::Unix(net::unix::SocketAddr::new(value))
642 }
643}
644
645impl From<tokio::net::unix::SocketAddr> for ChannelAddr {
646 fn from(value: tokio::net::unix::SocketAddr) -> Self {
647 std::os::unix::net::SocketAddr::from(value).into()
648 }
649}
650
651impl ChannelAddr {
652 pub fn any(transport: ChannelTransport) -> Self {
655 match transport {
656 ChannelTransport::Tcp(mode) => {
657 let ip = match mode {
658 TcpMode::Localhost => IpAddr::V6(Ipv6Addr::LOCALHOST),
659 TcpMode::Hostname => {
660 hostname::get()
661 .ok()
662 .and_then(|hostname| {
663 hostname.to_str().and_then(|hostname_str| {
665 dns_lookup::lookup_host(hostname_str)
666 .ok()
667 .and_then(|addresses| addresses.first().cloned())
668 })
669 })
670 .expect("failed to resolve hostname to ip address")
671 }
672 };
673 Self::Tcp(SocketAddr::new(ip, 0))
674 }
675 ChannelTransport::MetaTls(mode) => {
676 let host_address = match mode {
677 TlsMode::Hostname => hostname::get()
678 .ok()
679 .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
680 .unwrap_or("unknown_host".to_string()),
681 TlsMode::IpV6 => local_ipv6()
682 .ok()
683 .and_then(|addr| addr.to_string().parse().ok())
684 .expect("failed to retrieve ipv6 address"),
685 };
686 Self::MetaTls(MetaTlsAddr::Host {
687 hostname: host_address,
688 port: 0,
689 })
690 }
691 ChannelTransport::Local => Self::Local(0),
692 ChannelTransport::Sim(transport) => sim::any(*transport),
693 ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()),
695 }
696 }
697
698 pub fn transport(&self) -> ChannelTransport {
700 match self {
701 Self::Tcp(addr) => {
702 if addr.ip().is_loopback() {
703 ChannelTransport::Tcp(TcpMode::Localhost)
704 } else {
705 ChannelTransport::Tcp(TcpMode::Hostname)
706 }
707 }
708 Self::MetaTls(addr) => match addr {
709 MetaTlsAddr::Host { hostname, .. } => match hostname.parse::<IpAddr>() {
710 Ok(IpAddr::V6(_)) => ChannelTransport::MetaTls(TlsMode::IpV6),
711 Ok(IpAddr::V4(_)) => ChannelTransport::MetaTls(TlsMode::Hostname),
712 Err(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
713 },
714 MetaTlsAddr::Socket(socket_addr) => match socket_addr.ip() {
715 IpAddr::V6(_) => ChannelTransport::MetaTls(TlsMode::IpV6),
716 IpAddr::V4(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
717 },
718 },
719 Self::Local(_) => ChannelTransport::Local,
720 Self::Sim(addr) => ChannelTransport::Sim(Box::new(addr.transport())),
721 Self::Unix(_) => ChannelTransport::Unix,
722 Self::Alias { bind_to, .. } => bind_to.transport(),
725 }
726 }
727}
728
729impl fmt::Display for ChannelAddr {
730 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
731 match self {
732 Self::Tcp(addr) => write!(f, "tcp:{}", addr),
733 Self::MetaTls(addr) => write!(f, "metatls:{}", addr),
734 Self::Local(index) => write!(f, "local:{}", index),
735 Self::Sim(sim_addr) => write!(f, "sim:{}", sim_addr),
736 Self::Unix(addr) => write!(f, "unix:{}", addr),
737 Self::Alias { dial_to, bind_to } => {
738 write!(f, "alias:dial_to={};bind_to={}", dial_to, bind_to)
739 }
740 }
741 }
742}
743
744impl FromStr for ChannelAddr {
745 type Err = anyhow::Error;
746
747 fn from_str(addr: &str) -> Result<Self, Self::Err> {
748 match addr.split_once('!').or_else(|| addr.split_once(':')) {
749 Some(("local", rest)) => rest
750 .parse::<u64>()
751 .map(Self::Local)
752 .map_err(anyhow::Error::from),
753 Some(("tcp", rest)) => rest
754 .parse::<SocketAddr>()
755 .map(Self::Tcp)
756 .map_err(anyhow::Error::from),
757 Some(("metatls", rest)) => net::meta::parse(rest).map_err(|e| e.into()),
758 Some(("sim", rest)) => sim::parse(rest).map_err(|e| e.into()),
759 Some(("unix", rest)) => Ok(Self::Unix(net::unix::SocketAddr::from_str(rest)?)),
760 Some(("alias", _)) => Err(anyhow::anyhow!(
761 "detect possible alias address, but we currently do not support \
762 parsing alias' string representation since we only want to \
763 support parsing its zmq url format."
764 )),
765 Some((r#type, _)) => Err(anyhow::anyhow!("no such channel type: {type}")),
766 None => Err(anyhow::anyhow!("no channel type specified")),
767 }
768 }
769}
770
771impl ChannelAddr {
772 pub fn from_zmq_url(address: &str) -> Result<Self, anyhow::Error> {
781 if let Some(at_pos) = address.find('@') {
784 let dial_to_str = &address[..at_pos];
785 let bind_to_str = &address[at_pos + 1..];
786
787 if !dial_to_str.starts_with("tcp://") {
789 return Err(anyhow::anyhow!(
790 "alias format is only supported for TCP addresses, got dial_to: {}",
791 dial_to_str
792 ));
793 }
794 if !bind_to_str.starts_with("tcp://") {
795 return Err(anyhow::anyhow!(
796 "alias format is only supported for TCP addresses, got bind_to: {}",
797 bind_to_str
798 ));
799 }
800
801 let dial_to = Self::from_zmq_url(dial_to_str)?;
802 let bind_to = Self::from_zmq_url(bind_to_str)?;
803
804 return Ok(Self::Alias {
805 dial_to: Box::new(dial_to),
806 bind_to: Box::new(bind_to),
807 });
808 }
809
810 let (scheme, address) = address.split_once("://").ok_or_else(|| {
812 anyhow::anyhow!("address must be in url form scheme://endppoint {}", address)
813 })?;
814
815 match scheme {
816 "tcp" => {
817 let (host, port) = Self::split_host_port(address)?;
818
819 if host == "*" {
820 Ok(Self::Tcp(SocketAddr::new("::".parse().unwrap(), port)))
822 } else {
823 let socket_addr = Self::resolve_hostname_to_socket_addr(host, port)?;
825 Ok(Self::Tcp(socket_addr))
826 }
827 }
828 "inproc" => {
829 let port = address.parse::<u64>().map_err(|_| {
832 anyhow::anyhow!("inproc endpoint must be a valid port number: {}", address)
833 })?;
834 Ok(Self::Local(port))
835 }
836 "ipc" => {
837 Ok(Self::Unix(net::unix::SocketAddr::from_str(address)?))
839 }
840 "metatls" => {
841 let (host, port) = Self::split_host_port(address)?;
842
843 if host == "*" {
844 Ok(Self::MetaTls(MetaTlsAddr::Host {
846 hostname: std::net::Ipv6Addr::UNSPECIFIED.to_string(),
847 port,
848 }))
849 } else {
850 Ok(Self::MetaTls(MetaTlsAddr::Host {
851 hostname: host.to_string(),
852 port,
853 }))
854 }
855 }
856 scheme => Err(anyhow::anyhow!("unsupported ZMQ scheme: {}", scheme)),
857 }
858 }
859
860 fn split_host_port(address: &str) -> Result<(&str, u16), anyhow::Error> {
862 if let Some((host, port_str)) = address.rsplit_once(':') {
863 let port: u16 = port_str
864 .parse()
865 .map_err(|_| anyhow::anyhow!("invalid port: {}", port_str))?;
866 Ok((host, port))
867 } else {
868 Err(anyhow::anyhow!("invalid address format: {}", address))
869 }
870 }
871
872 fn resolve_hostname_to_socket_addr(host: &str, port: u16) -> Result<SocketAddr, anyhow::Error> {
874 let host_clean = if host.starts_with('[') && host.ends_with(']') {
876 &host[1..host.len() - 1]
877 } else {
878 host
879 };
880
881 if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
883 return Ok(SocketAddr::new(ip_addr, port));
884 }
885
886 use std::net::ToSocketAddrs;
888 let mut addrs = (host_clean, port)
889 .to_socket_addrs()
890 .map_err(|e| anyhow::anyhow!("failed to resolve hostname '{}': {}", host_clean, e))?;
891
892 addrs
893 .next()
894 .ok_or_else(|| anyhow::anyhow!("no addresses found for hostname '{}'", host_clean))
895 }
896}
897
898pub struct ChannelTx<M: RemoteMessage> {
900 inner: ChannelTxKind<M>,
901}
902
903impl<M: RemoteMessage> fmt::Debug for ChannelTx<M> {
904 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
905 f.debug_struct("ChannelTx")
906 .field("addr", &self.addr())
907 .finish()
908 }
909}
910
911enum ChannelTxKind<M: RemoteMessage> {
913 Local(local::LocalTx<M>),
914 Tcp(net::NetTx<M>),
915 MetaTls(net::NetTx<M>),
916 Unix(net::NetTx<M>),
917 Sim(sim::SimTx<M>),
918}
919
920#[async_trait]
921impl<M: RemoteMessage> Tx<M> for ChannelTx<M> {
922 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
923 match &self.inner {
924 ChannelTxKind::Local(tx) => tx.do_post(message, return_channel),
925 ChannelTxKind::Tcp(tx) => tx.do_post(message, return_channel),
926 ChannelTxKind::MetaTls(tx) => tx.do_post(message, return_channel),
927 ChannelTxKind::Sim(tx) => tx.do_post(message, return_channel),
928 ChannelTxKind::Unix(tx) => tx.do_post(message, return_channel),
929 }
930 }
931
932 fn addr(&self) -> ChannelAddr {
933 match &self.inner {
934 ChannelTxKind::Local(tx) => tx.addr(),
935 ChannelTxKind::Tcp(tx) => Tx::<M>::addr(tx),
936 ChannelTxKind::MetaTls(tx) => Tx::<M>::addr(tx),
937 ChannelTxKind::Sim(tx) => tx.addr(),
938 ChannelTxKind::Unix(tx) => Tx::<M>::addr(tx),
939 }
940 }
941
942 fn status(&self) -> &watch::Receiver<TxStatus> {
943 match &self.inner {
944 ChannelTxKind::Local(tx) => tx.status(),
945 ChannelTxKind::Tcp(tx) => tx.status(),
946 ChannelTxKind::MetaTls(tx) => tx.status(),
947 ChannelTxKind::Sim(tx) => tx.status(),
948 ChannelTxKind::Unix(tx) => tx.status(),
949 }
950 }
951}
952
953pub struct ChannelRx<M: RemoteMessage> {
955 inner: ChannelRxKind<M>,
956}
957
958impl<M: RemoteMessage> fmt::Debug for ChannelRx<M> {
959 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
960 f.debug_struct("ChannelRx")
961 .field("addr", &self.addr())
962 .finish()
963 }
964}
965
966enum ChannelRxKind<M: RemoteMessage> {
968 Local(local::LocalRx<M>),
969 Tcp(net::NetRx<M>),
970 MetaTls(net::NetRx<M>),
971 Unix(net::NetRx<M>),
972 Sim(sim::SimRx<M>),
973}
974
975#[async_trait]
976impl<M: RemoteMessage> Rx<M> for ChannelRx<M> {
977 #[hyperactor::instrument]
978 async fn recv(&mut self) -> Result<M, ChannelError> {
979 match &mut self.inner {
980 ChannelRxKind::Local(rx) => rx.recv().await,
981 ChannelRxKind::Tcp(rx) => rx.recv().await,
982 ChannelRxKind::MetaTls(rx) => rx.recv().await,
983 ChannelRxKind::Sim(rx) => rx.recv().await,
984 ChannelRxKind::Unix(rx) => rx.recv().await,
985 }
986 }
987
988 fn addr(&self) -> ChannelAddr {
989 match &self.inner {
990 ChannelRxKind::Local(rx) => rx.addr(),
991 ChannelRxKind::Tcp(rx) => rx.addr(),
992 ChannelRxKind::MetaTls(rx) => rx.addr(),
993 ChannelRxKind::Sim(rx) => rx.addr(),
994 ChannelRxKind::Unix(rx) => rx.addr(),
995 }
996 }
997}
998
999#[allow(clippy::result_large_err)] #[track_caller]
1004pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, ChannelError> {
1005 tracing::debug!(name = "dial", caller = %Location::caller(), %addr, "dialing channel {}", addr);
1006 let inner = match addr {
1007 ChannelAddr::Local(port) => ChannelTxKind::Local(local::dial(port)?),
1008 ChannelAddr::Tcp(addr) => ChannelTxKind::Tcp(net::tcp::dial(addr)),
1009 ChannelAddr::MetaTls(meta_addr) => ChannelTxKind::MetaTls(net::meta::dial(meta_addr)?),
1010 ChannelAddr::Sim(sim_addr) => ChannelTxKind::Sim(sim::dial::<M>(sim_addr)?),
1011 ChannelAddr::Unix(path) => ChannelTxKind::Unix(net::unix::dial(path)),
1012 ChannelAddr::Alias { dial_to, .. } => dial(*dial_to)?.inner,
1013 };
1014 Ok(ChannelTx { inner })
1015}
1016
1017#[crate::instrument]
1020#[track_caller]
1021pub fn serve<M: RemoteMessage>(
1022 addr: ChannelAddr,
1023) -> Result<(ChannelAddr, ChannelRx<M>), ChannelError> {
1024 let caller = Location::caller();
1025 serve_inner(addr).map(|(addr, inner)| {
1026 tracing::debug!(
1027 name = "serve",
1028 %addr,
1029 %caller,
1030 );
1031 (addr, ChannelRx { inner })
1032 })
1033}
1034
1035fn serve_inner<M: RemoteMessage>(
1036 addr: ChannelAddr,
1037) -> Result<(ChannelAddr, ChannelRxKind<M>), ChannelError> {
1038 match addr {
1039 ChannelAddr::Tcp(addr) => {
1040 let (addr, rx) = net::tcp::serve::<M>(addr)?;
1041 Ok((addr, ChannelRxKind::Tcp(rx)))
1042 }
1043 ChannelAddr::MetaTls(meta_addr) => {
1044 let (addr, rx) = net::meta::serve::<M>(meta_addr)?;
1045 Ok((addr, ChannelRxKind::MetaTls(rx)))
1046 }
1047 ChannelAddr::Unix(path) => {
1048 let (addr, rx) = net::unix::serve::<M>(path)?;
1049 Ok((addr, ChannelRxKind::Unix(rx)))
1050 }
1051 ChannelAddr::Local(0) => {
1052 let (port, rx) = local::serve::<M>();
1053 Ok((ChannelAddr::Local(port), ChannelRxKind::Local(rx)))
1054 }
1055 ChannelAddr::Sim(sim_addr) => {
1056 let (addr, rx) = sim::serve::<M>(sim_addr)?;
1057 Ok((addr, ChannelRxKind::Sim(rx)))
1058 }
1059 ChannelAddr::Local(a) => Err(ChannelError::InvalidAddress(format!(
1060 "invalid local addr: {}",
1061 a
1062 ))),
1063 ChannelAddr::Alias { dial_to, bind_to } => {
1064 let (bound_addr, rx) = serve_inner::<M>(*bind_to)?;
1065 let alias_addr = ChannelAddr::Alias {
1066 dial_to,
1067 bind_to: Box::new(bound_addr),
1068 };
1069 Ok((alias_addr, rx))
1070 }
1071 }
1072}
1073
1074pub fn serve_local<M: RemoteMessage>() -> (ChannelAddr, ChannelRx<M>) {
1077 let (port, rx) = local::serve::<M>();
1078 (
1079 ChannelAddr::Local(port),
1080 ChannelRx {
1081 inner: ChannelRxKind::Local(rx),
1082 },
1083 )
1084}
1085
1086#[cfg(test)]
1087mod tests {
1088 use std::assert_matches::assert_matches;
1089 use std::collections::HashSet;
1090 use std::net::IpAddr;
1091 use std::net::Ipv4Addr;
1092 use std::net::Ipv6Addr;
1093 use std::time::Duration;
1094
1095 use tokio::task::JoinSet;
1096
1097 use super::net::*;
1098 use super::*;
1099 use crate::clock::Clock;
1100 use crate::clock::RealClock;
1101
1102 #[test]
1103 fn test_channel_addr() {
1104 let cases_ok = vec![
1105 (
1106 "tcp<DELIM>[::1]:1234",
1107 ChannelAddr::Tcp(SocketAddr::new(
1108 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
1109 1234,
1110 )),
1111 ),
1112 (
1113 "tcp<DELIM>127.0.0.1:8080",
1114 ChannelAddr::Tcp(SocketAddr::new(
1115 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
1116 8080,
1117 )),
1118 ),
1119 #[cfg(target_os = "linux")]
1120 ("local<DELIM>123", ChannelAddr::Local(123)),
1121 (
1122 "unix<DELIM>@yolo",
1123 ChannelAddr::Unix(
1124 unix::SocketAddr::from_abstract_name("yolo")
1125 .expect("can't make socket from abstract name"),
1126 ),
1127 ),
1128 (
1129 "unix<DELIM>/cool/socket-path",
1130 ChannelAddr::Unix(
1131 unix::SocketAddr::from_pathname("/cool/socket-path")
1132 .expect("can't make socket from path"),
1133 ),
1134 ),
1135 ];
1136
1137 for (raw, parsed) in cases_ok.clone() {
1138 for delim in ["!", ":"] {
1139 let raw = raw.replace("<DELIM>", delim);
1140 assert_eq!(raw.parse::<ChannelAddr>().unwrap(), parsed);
1141 }
1142 }
1143
1144 for (raw, parsed) in cases_ok {
1145 for delim in ["!", ":"] {
1146 let raw = format!("sim{}{}", delim, raw.replace("<DELIM>", delim));
1148 assert_eq!(
1149 raw.parse::<ChannelAddr>().unwrap(),
1150 ChannelAddr::Sim(SimAddr::new(parsed.clone()).unwrap())
1151 );
1152 }
1153 }
1154
1155 let cases_err = vec![
1156 ("tcp:abcdef..123124", "invalid socket address syntax"),
1157 ("xxx:foo", "no such channel type: xxx"),
1158 ("127.0.0.1", "no channel type specified"),
1159 ("local:abc", "invalid digit found in string"),
1160 ];
1161
1162 for (raw, error) in cases_err {
1163 let Err(err) = raw.parse::<ChannelAddr>() else {
1164 panic!("expected error parsing: {}", &raw)
1165 };
1166 assert_eq!(format!("{}", err), error);
1167 }
1168 }
1169
1170 #[test]
1171 fn test_zmq_style_channel_addr() {
1172 assert_eq!(
1174 ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080").unwrap(),
1175 ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap())
1176 );
1177
1178 assert_eq!(
1180 ChannelAddr::from_zmq_url("tcp://*:5555").unwrap(),
1181 ChannelAddr::Tcp("[::]:5555".parse().unwrap())
1182 );
1183
1184 assert_eq!(
1186 ChannelAddr::from_zmq_url("inproc://12345").unwrap(),
1187 ChannelAddr::Local(12345)
1188 );
1189
1190 assert_eq!(
1192 ChannelAddr::from_zmq_url("ipc:///tmp/my-socket").unwrap(),
1193 ChannelAddr::Unix(unix::SocketAddr::from_pathname("/tmp/my-socket").unwrap())
1194 );
1195
1196 assert_eq!(
1198 ChannelAddr::from_zmq_url("metatls://example.com:443").unwrap(),
1199 ChannelAddr::MetaTls(MetaTlsAddr::Host {
1200 hostname: "example.com".to_string(),
1201 port: 443
1202 })
1203 );
1204
1205 assert_eq!(
1207 ChannelAddr::from_zmq_url("metatls://192.168.1.1:443").unwrap(),
1208 ChannelAddr::MetaTls(MetaTlsAddr::Host {
1209 hostname: "192.168.1.1".to_string(),
1210 port: 443
1211 })
1212 );
1213
1214 assert_eq!(
1216 ChannelAddr::from_zmq_url("metatls://*:8443").unwrap(),
1217 ChannelAddr::MetaTls(MetaTlsAddr::Host {
1218 hostname: "::".to_string(),
1219 port: 8443
1220 })
1221 );
1222
1223 let tcp_hostname_result = ChannelAddr::from_zmq_url("tcp://localhost:8080");
1227 assert!(tcp_hostname_result.is_ok());
1228
1229 assert_eq!(
1231 ChannelAddr::from_zmq_url("tcp://[::1]:1234").unwrap(),
1232 ChannelAddr::Tcp("[::1]:1234".parse().unwrap())
1233 );
1234
1235 assert!(ChannelAddr::from_zmq_url("invalid://scheme").is_err());
1237 assert!(ChannelAddr::from_zmq_url("tcp://invalid-port").is_err());
1238 assert!(ChannelAddr::from_zmq_url("metatls://no-port").is_err());
1239 assert!(ChannelAddr::from_zmq_url("inproc://not-a-number").is_err());
1240 }
1241
1242 #[test]
1243 fn test_zmq_style_alias_channel_addr() {
1244 let alias_addr = ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::]:8800").unwrap();
1250 match alias_addr {
1251 ChannelAddr::Alias { dial_to, bind_to } => {
1252 assert_eq!(
1253 *dial_to,
1254 ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap())
1255 );
1256 assert_eq!(*bind_to, ChannelAddr::Tcp("[::]:8800".parse().unwrap()));
1257 }
1258 _ => panic!("Expected Alias"),
1259 }
1260
1261 assert!(
1263 ChannelAddr::from_zmq_url("metatls://example.com:443@tcp://127.0.0.1:8080").is_err()
1264 );
1265
1266 assert!(
1268 ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@metatls://example.com:443").is_err()
1269 );
1270
1271 assert!(ChannelAddr::from_zmq_url("invalid://scheme@tcp://127.0.0.1:8080").is_err());
1273
1274 assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@invalid://scheme").is_err());
1276
1277 assert!(ChannelAddr::from_zmq_url("tcp://host@tcp://127.0.0.1:8080").is_err());
1279
1280 assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@tcp://example.com").is_err());
1282 }
1283
1284 #[tokio::test]
1285 async fn test_multiple_connections() {
1286 for addr in ChannelTransport::all().map(ChannelAddr::any) {
1287 let (listen_addr, mut rx) = crate::channel::serve::<u64>(addr).unwrap();
1288
1289 let mut sends: JoinSet<()> = JoinSet::new();
1290 for message in 0u64..100u64 {
1291 let addr = listen_addr.clone();
1292 sends.spawn(async move {
1293 let tx = dial::<u64>(addr).unwrap();
1294 tx.post(message);
1295 });
1296 }
1297
1298 let mut received: HashSet<u64> = HashSet::new();
1299 while received.len() < 100 {
1300 received.insert(rx.recv().await.unwrap());
1301 }
1302
1303 for message in 0u64..100u64 {
1304 assert!(received.contains(&message));
1305 }
1306
1307 loop {
1308 match sends.join_next().await {
1309 Some(Ok(())) => (),
1310 Some(Err(err)) => panic!("{}", err),
1311 None => break,
1312 }
1313 }
1314 }
1315 }
1316
1317 #[tokio::test]
1318 async fn test_server_close() {
1319 for addr in ChannelTransport::all().map(ChannelAddr::any) {
1320 if net::is_net_addr(&addr) {
1321 continue;
1324 }
1325
1326 let (listen_addr, rx) = crate::channel::serve::<u64>(addr).unwrap();
1327
1328 let tx = dial::<u64>(listen_addr).unwrap();
1329 tx.post(123);
1330 drop(rx);
1331
1332 let start = RealClock.now();
1337
1338 let result = loop {
1339 let (return_tx, return_rx) = oneshot::channel();
1340 tx.try_post(123, return_tx);
1341 let result = return_rx.await;
1342
1343 if result.is_ok() || start.elapsed() > Duration::from_secs(10) {
1344 break result;
1345 }
1346 };
1347 assert_matches!(
1348 result,
1349 Ok(SendError {
1350 error: ChannelError::Closed,
1351 message: 123,
1352 reason: None
1353 })
1354 );
1355 }
1356 }
1357
1358 fn addrs() -> Vec<ChannelAddr> {
1359 use rand::Rng;
1360 use rand::distributions::Uniform;
1361
1362 let rng = rand::thread_rng();
1363 vec![
1364 "tcp:[::1]:0".parse().unwrap(),
1365 "local:0".parse().unwrap(),
1366 #[cfg(target_os = "linux")]
1367 "unix:".parse().unwrap(),
1368 #[cfg(target_os = "linux")]
1369 format!(
1370 "unix:@{}",
1371 rng.sample_iter(Uniform::new_inclusive('a', 'z'))
1372 .take(10)
1373 .collect::<String>()
1374 )
1375 .parse()
1376 .unwrap(),
1377 ]
1378 }
1379
1380 #[test]
1381 fn test_bind_spec_from_str() {
1382 assert_eq!(
1384 BindSpec::from_str("tcp").unwrap(),
1385 BindSpec::Any(ChannelTransport::Tcp(TcpMode::Hostname))
1386 );
1387 assert_eq!(
1388 BindSpec::from_str("metatls(Hostname)").unwrap(),
1389 BindSpec::Any(ChannelTransport::MetaTls(TlsMode::Hostname))
1390 );
1391
1392 assert_eq!(
1394 BindSpec::from_str("tcp:127.0.0.1:8080").unwrap(),
1395 BindSpec::Addr(ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap()))
1396 );
1397
1398 assert_eq!(
1400 BindSpec::from_str("tcp://127.0.0.1:9000").unwrap(),
1401 BindSpec::Addr(ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap()))
1402 );
1403 assert_eq!(
1404 BindSpec::from_str("tcp://127.0.0.1:9000@tcp://[::1]:7200").unwrap(),
1405 BindSpec::Addr(
1406 ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::1]:7200").unwrap()
1407 )
1408 );
1409
1410 assert!(BindSpec::from_str("invalid_spec").is_err());
1412 assert!(BindSpec::from_str("unknown://scheme").is_err());
1413 assert!(BindSpec::from_str("").is_err());
1414 }
1415
1416 #[tokio::test]
1417 #[cfg_attr(not(fbcode_build), ignore)]
1419 async fn test_dial_serve() {
1420 for addr in addrs() {
1421 let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1422 let tx = crate::channel::dial(listen_addr).unwrap();
1423 tx.post(123);
1424 assert_eq!(rx.recv().await.unwrap(), 123);
1425 }
1426 }
1427
1428 #[tokio::test]
1429 #[cfg_attr(not(fbcode_build), ignore)]
1431 async fn test_send() {
1432 let config = hyperactor_config::global::lock();
1433
1434 let _guard1 = config.override_key(
1436 crate::config::MESSAGE_DELIVERY_TIMEOUT,
1437 Duration::from_secs(1),
1438 );
1439 let _guard2 = config.override_key(crate::config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
1440 for addr in addrs() {
1441 let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1442 let tx = crate::channel::dial(listen_addr).unwrap();
1443 tx.send(123).await.unwrap();
1444 assert_eq!(rx.recv().await.unwrap(), 123);
1445
1446 drop(rx);
1447 assert_matches!(
1448 tx.send(123).await.unwrap_err(),
1449 SendError {
1450 error: ChannelError::Closed,
1451 message: 123,
1452 ..
1453 }
1454 );
1455 }
1456 }
1457}