1#![allow(dead_code)] use core::net::SocketAddr;
15use std::fmt;
16use std::net::IpAddr;
17#[cfg(target_os = "linux")]
18use std::os::linux::net::SocketAddrExt;
19use std::str::FromStr;
20
21use async_trait::async_trait;
22use enum_as_inner::EnumAsInner;
23use lazy_static::lazy_static;
24use local_ip_address::local_ipv6;
25use serde::Deserialize;
26use serde::Serialize;
27use tokio::sync::mpsc;
28use tokio::sync::oneshot;
29use tokio::sync::watch;
30
31use crate as hyperactor;
32use crate::Named;
33use crate::RemoteMessage;
34use crate::attrs::AttrValue;
35use crate::channel::sim::SimAddr;
36use crate::simnet::SimNetError;
37
38pub(crate) mod local;
39pub(crate) mod net;
40pub mod sim;
41
42#[derive(thiserror::Error, Debug)]
44pub enum ChannelError {
45 #[error("channel closed")]
47 Closed,
48
49 #[error("send: {0}")]
51 Send(#[source] anyhow::Error),
52
53 #[error(transparent)]
55 Client(#[from] net::ClientError),
56
57 #[error("invalid address {0:?}")]
59 InvalidAddress(String),
60
61 #[error(transparent)]
63 Server(#[from] net::ServerError),
64
65 #[error(transparent)]
67 Bincode(#[from] Box<bincode::ErrorKind>),
68
69 #[error(transparent)]
71 Data(#[from] crate::data::Error),
72
73 #[error(transparent)]
75 Other(#[from] anyhow::Error),
76
77 #[error("operation timed out after {0:?}")]
79 Timeout(std::time::Duration),
80
81 #[error(transparent)]
83 SimNetError(#[from] SimNetError),
84}
85
86#[derive(thiserror::Error, Debug)]
88#[error("{0}")]
89pub struct SendError<M: RemoteMessage>(#[source] pub ChannelError, pub M);
90
91impl<M: RemoteMessage> From<SendError<M>> for ChannelError {
92 fn from(error: SendError<M>) -> Self {
93 error.0
94 }
95}
96
97#[derive(Debug, Copy, Clone, PartialEq)]
99pub enum TxStatus {
100 Active,
102 Closed,
104}
105
106#[async_trait]
108pub trait Tx<M: RemoteMessage>: std::fmt::Debug {
109 #[allow(clippy::result_large_err)] fn try_post(&self, message: M, return_channel: oneshot::Sender<M>) -> Result<(), SendError<M>>;
116
117 fn post(&self, message: M) {
120 let _ignore = self.try_post(message, oneshot::channel().0);
123 }
124
125 async fn send(&self, message: M) -> Result<(), SendError<M>> {
128 let (tx, rx) = oneshot::channel();
129 self.try_post(message, tx)?;
130 match rx.await {
131 Ok(m) => Err(SendError(ChannelError::Closed, m)),
133
134 Err(_) => Ok(()),
137 }
138 }
139
140 fn addr(&self) -> ChannelAddr;
142
143 fn status(&self) -> &watch::Receiver<TxStatus>;
145}
146
147#[async_trait]
149pub trait Rx<M: RemoteMessage>: std::fmt::Debug {
150 async fn recv(&mut self) -> Result<M, ChannelError>;
153
154 fn addr(&self) -> ChannelAddr;
156}
157
158#[derive(Debug)]
159struct MpscTx<M: RemoteMessage> {
160 tx: mpsc::UnboundedSender<M>,
161 addr: ChannelAddr,
162 status: watch::Receiver<TxStatus>,
163}
164
165impl<M: RemoteMessage> MpscTx<M> {
166 pub fn new(tx: mpsc::UnboundedSender<M>, addr: ChannelAddr) -> (Self, watch::Sender<TxStatus>) {
167 let (sender, receiver) = watch::channel(TxStatus::Active);
168 (
169 Self {
170 tx,
171 addr,
172 status: receiver,
173 },
174 sender,
175 )
176 }
177}
178
179#[async_trait]
180impl<M: RemoteMessage> Tx<M> for MpscTx<M> {
181 fn try_post(
182 &self,
183 message: M,
184 _return_channel: oneshot::Sender<M>,
185 ) -> Result<(), SendError<M>> {
186 self.tx
187 .send(message)
188 .map_err(|mpsc::error::SendError(message)| SendError(ChannelError::Closed, message))
189 }
190
191 fn addr(&self) -> ChannelAddr {
192 self.addr.clone()
193 }
194
195 fn status(&self) -> &watch::Receiver<TxStatus> {
196 &self.status
197 }
198}
199
200#[derive(Debug)]
201struct MpscRx<M: RemoteMessage> {
202 rx: mpsc::UnboundedReceiver<M>,
203 addr: ChannelAddr,
204 status_sender: watch::Sender<TxStatus>,
206}
207
208impl<M: RemoteMessage> MpscRx<M> {
209 pub fn new(
210 rx: mpsc::UnboundedReceiver<M>,
211 addr: ChannelAddr,
212 status_sender: watch::Sender<TxStatus>,
213 ) -> Self {
214 Self {
215 rx,
216 addr,
217 status_sender,
218 }
219 }
220}
221
222impl<M: RemoteMessage> Drop for MpscRx<M> {
223 fn drop(&mut self) {
224 let _ = self.status_sender.send(TxStatus::Closed);
225 }
226}
227
228#[async_trait]
229impl<M: RemoteMessage> Rx<M> for MpscRx<M> {
230 async fn recv(&mut self) -> Result<M, ChannelError> {
231 self.rx.recv().await.ok_or(ChannelError::Closed)
232 }
233
234 fn addr(&self) -> ChannelAddr {
235 self.addr.clone()
236 }
237}
238
239#[derive(
241 Clone,
242 Debug,
243 PartialEq,
244 Eq,
245 Hash,
246 Serialize,
247 Deserialize,
248 strum::EnumIter,
249 strum::Display,
250 strum::EnumString
251)]
252pub enum TlsMode {
253 IpV6,
255 Hostname,
257 }
259
260#[derive(
264 Clone,
265 Debug,
266 PartialEq,
267 Eq,
268 Hash,
269 Serialize,
270 Deserialize,
271 Ord,
272 PartialOrd,
273 EnumAsInner
274)]
275pub enum MetaTlsAddr {
276 Host {
278 hostname: Hostname,
280 port: Port,
282 },
283 Socket(SocketAddr),
285}
286
287impl MetaTlsAddr {
288 pub fn port(&self) -> Port {
290 match self {
291 Self::Host { port, .. } => *port,
292 Self::Socket(addr) => addr.port(),
293 }
294 }
295
296 pub fn hostname(&self) -> Option<&str> {
298 match self {
299 Self::Host { hostname, .. } => Some(hostname),
300 Self::Socket(_) => None,
301 }
302 }
303}
304
305impl fmt::Display for MetaTlsAddr {
306 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307 match self {
308 Self::Host { hostname, port } => write!(f, "{}:{}", hostname, port),
309 Self::Socket(addr) => write!(f, "{}", addr),
310 }
311 }
312}
313
314#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Named)]
316pub enum ChannelTransport {
317 Tcp,
319
320 MetaTls(TlsMode),
322
323 Local,
325
326 Sim(Box<ChannelTransport>),
328
329 Unix,
331}
332
333impl fmt::Display for ChannelTransport {
334 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
335 match self {
336 Self::Tcp => write!(f, "tcp"),
337 Self::MetaTls(mode) => write!(f, "metatls({:?})", mode),
338 Self::Local => write!(f, "local"),
339 Self::Sim(transport) => write!(f, "sim({})", transport),
340 Self::Unix => write!(f, "unix"),
341 }
342 }
343}
344
345impl FromStr for ChannelTransport {
346 type Err = anyhow::Error;
347
348 fn from_str(s: &str) -> Result<Self, Self::Err> {
349 if let Some(rest) = s.strip_prefix("sim(") {
351 if let Some(end) = rest.rfind(')') {
352 let inner = &rest[..end];
353 let inner_transport = ChannelTransport::from_str(inner)?;
354 return Ok(ChannelTransport::Sim(Box::new(inner_transport)));
355 } else {
356 return Err(anyhow::anyhow!("invalid sim transport"));
357 }
358 }
359
360 match s {
361 "tcp" => Ok(ChannelTransport::Tcp),
362 "local" => Ok(ChannelTransport::Local),
363 "unix" => Ok(ChannelTransport::Unix),
364 s if s.starts_with("metatls(") && s.ends_with(")") => {
365 let inner = &s["metatls(".len()..s.len() - 1];
366 let mode = inner.parse()?;
367 Ok(ChannelTransport::MetaTls(mode))
368 }
369 unknown => Err(anyhow::anyhow!("unknown channel transport: {}", unknown)),
370 }
371 }
372}
373
374impl ChannelTransport {
375 pub fn all() -> [ChannelTransport; 3] {
377 [
378 ChannelTransport::Tcp,
379 ChannelTransport::Local,
380 ChannelTransport::Unix,
381 ]
385 }
386
387 pub fn any(&self) -> ChannelAddr {
389 ChannelAddr::any(self.clone())
390 }
391
392 pub fn is_remote(&self) -> bool {
394 match self {
395 ChannelTransport::Tcp => true,
396 ChannelTransport::MetaTls(_) => true,
397 ChannelTransport::Local => false,
398 ChannelTransport::Sim(_) => false,
399 ChannelTransport::Unix => false,
400 }
401 }
402}
403
404impl AttrValue for ChannelTransport {
405 fn display(&self) -> String {
406 self.to_string()
407 }
408
409 fn parse(s: &str) -> Result<Self, anyhow::Error> {
410 s.parse()
411 }
412}
413
414pub type Hostname = String;
416
417pub type Port = u16;
419
420#[derive(
443 Clone,
444 Debug,
445 PartialEq,
446 Eq,
447 Ord,
448 PartialOrd,
449 Serialize,
450 Deserialize,
451 Hash,
452 Named
453)]
454pub enum ChannelAddr {
455 Tcp(SocketAddr),
458
459 MetaTls(MetaTlsAddr),
463
464 Local(u64),
467
468 Sim(SimAddr),
470
471 Unix(net::unix::SocketAddr),
474}
475
476impl From<SocketAddr> for ChannelAddr {
477 fn from(value: SocketAddr) -> Self {
478 Self::Tcp(value)
479 }
480}
481
482impl From<net::unix::SocketAddr> for ChannelAddr {
483 fn from(value: net::unix::SocketAddr) -> Self {
484 Self::Unix(value)
485 }
486}
487
488impl From<std::os::unix::net::SocketAddr> for ChannelAddr {
489 fn from(value: std::os::unix::net::SocketAddr) -> Self {
490 Self::Unix(net::unix::SocketAddr::new(value))
491 }
492}
493
494impl From<tokio::net::unix::SocketAddr> for ChannelAddr {
495 fn from(value: tokio::net::unix::SocketAddr) -> Self {
496 std::os::unix::net::SocketAddr::from(value).into()
497 }
498}
499
500impl ChannelAddr {
501 pub fn any(transport: ChannelTransport) -> Self {
504 match transport {
505 ChannelTransport::Tcp => {
506 let ip = hostname::get()
507 .ok()
508 .and_then(|hostname| {
509 hostname.to_str().and_then(|hostname_str| {
511 dns_lookup::lookup_host(hostname_str)
512 .ok()
513 .and_then(|addresses| addresses.first().cloned())
514 })
515 })
516 .unwrap_or_else(|| IpAddr::from_str("::1").unwrap());
517 Self::Tcp(SocketAddr::new(ip, 0))
518 }
519 ChannelTransport::MetaTls(mode) => {
520 let host_address = match mode {
521 TlsMode::Hostname => hostname::get()
522 .ok()
523 .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
524 .unwrap_or("unknown_host".to_string()),
525 TlsMode::IpV6 => local_ipv6()
526 .ok()
527 .and_then(|addr| addr.to_string().parse().ok())
528 .expect("failed to retrieve ipv6 address"),
529 };
530 Self::MetaTls(MetaTlsAddr::Host {
531 hostname: host_address,
532 port: 0,
533 })
534 }
535 ChannelTransport::Local => Self::Local(0),
536 ChannelTransport::Sim(transport) => sim::any(*transport),
537 ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()),
539 }
540 }
541
542 pub fn transport(&self) -> ChannelTransport {
544 match self {
545 Self::Tcp(_) => ChannelTransport::Tcp,
546 Self::MetaTls(addr) => match addr {
547 MetaTlsAddr::Host { hostname, .. } => match hostname.parse::<IpAddr>() {
548 Ok(IpAddr::V6(_)) => ChannelTransport::MetaTls(TlsMode::IpV6),
549 Ok(IpAddr::V4(_)) => ChannelTransport::MetaTls(TlsMode::Hostname),
550 Err(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
551 },
552 MetaTlsAddr::Socket(socket_addr) => match socket_addr.ip() {
553 IpAddr::V6(_) => ChannelTransport::MetaTls(TlsMode::IpV6),
554 IpAddr::V4(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
555 },
556 },
557 Self::Local(_) => ChannelTransport::Local,
558 Self::Sim(addr) => ChannelTransport::Sim(Box::new(addr.transport())),
559 Self::Unix(_) => ChannelTransport::Unix,
560 }
561 }
562}
563
564impl fmt::Display for ChannelAddr {
565 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
566 match self {
567 Self::Tcp(addr) => write!(f, "tcp:{}", addr),
568 Self::MetaTls(addr) => write!(f, "metatls:{}", addr),
569 Self::Local(index) => write!(f, "local:{}", index),
570 Self::Sim(sim_addr) => write!(f, "sim:{}", sim_addr),
571 Self::Unix(addr) => write!(f, "unix:{}", addr),
572 }
573 }
574}
575
576impl FromStr for ChannelAddr {
577 type Err = anyhow::Error;
578
579 fn from_str(addr: &str) -> Result<Self, Self::Err> {
580 match addr.split_once('!').or_else(|| addr.split_once(':')) {
581 Some(("local", rest)) => rest
582 .parse::<u64>()
583 .map(Self::Local)
584 .map_err(anyhow::Error::from),
585 Some(("tcp", rest)) => rest
586 .parse::<SocketAddr>()
587 .map(Self::Tcp)
588 .map_err(anyhow::Error::from),
589 Some(("metatls", rest)) => net::meta::parse(rest).map_err(|e| e.into()),
590 Some(("sim", rest)) => sim::parse(rest).map_err(|e| e.into()),
591 Some(("unix", rest)) => Ok(Self::Unix(net::unix::SocketAddr::from_str(rest)?)),
592 Some((r#type, _)) => Err(anyhow::anyhow!("no such channel type: {type}")),
593 None => Err(anyhow::anyhow!("no channel type specified")),
594 }
595 }
596}
597
598impl ChannelAddr {
599 pub fn from_zmq_url(address: &str) -> Result<Self, anyhow::Error> {
606 let (scheme, address) = address.split_once("://").ok_or_else(|| {
608 anyhow::anyhow!("address must be in url form scheme://endppoint {}", address)
609 })?;
610
611 match scheme {
612 "tcp" => {
613 let (host, port) = Self::split_host_port(address)?;
614
615 if host == "*" {
616 Ok(Self::Tcp(SocketAddr::new("::".parse().unwrap(), port)))
618 } else {
619 let socket_addr = Self::resolve_hostname_to_socket_addr(host, port)?;
621 Ok(Self::Tcp(socket_addr))
622 }
623 }
624 "inproc" => {
625 let port = address.parse::<u64>().map_err(|_| {
628 anyhow::anyhow!("inproc endpoint must be a valid port number: {}", address)
629 })?;
630 Ok(Self::Local(port))
631 }
632 "ipc" => {
633 Ok(Self::Unix(net::unix::SocketAddr::from_str(address)?))
635 }
636 "metatls" => {
637 let (host, port) = Self::split_host_port(address)?;
638
639 if host == "*" {
640 Ok(Self::MetaTls(MetaTlsAddr::Host {
642 hostname: std::net::Ipv6Addr::UNSPECIFIED.to_string(),
643 port,
644 }))
645 } else {
646 Ok(Self::MetaTls(MetaTlsAddr::Host {
647 hostname: host.to_string(),
648 port,
649 }))
650 }
651 }
652 scheme => Err(anyhow::anyhow!("unsupported ZMQ scheme: {}", scheme)),
653 }
654 }
655
656 fn split_host_port(address: &str) -> Result<(&str, u16), anyhow::Error> {
658 if let Some((host, port_str)) = address.rsplit_once(':') {
659 let port: u16 = port_str
660 .parse()
661 .map_err(|_| anyhow::anyhow!("invalid port: {}", port_str))?;
662 Ok((host, port))
663 } else {
664 Err(anyhow::anyhow!("invalid address format: {}", address))
665 }
666 }
667
668 fn resolve_hostname_to_socket_addr(host: &str, port: u16) -> Result<SocketAddr, anyhow::Error> {
670 let host_clean = if host.starts_with('[') && host.ends_with(']') {
672 &host[1..host.len() - 1]
673 } else {
674 host
675 };
676
677 if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
679 return Ok(SocketAddr::new(ip_addr, port));
680 }
681
682 use std::net::ToSocketAddrs;
684 let mut addrs = (host_clean, port)
685 .to_socket_addrs()
686 .map_err(|e| anyhow::anyhow!("failed to resolve hostname '{}': {}", host_clean, e))?;
687
688 addrs
689 .next()
690 .ok_or_else(|| anyhow::anyhow!("no addresses found for hostname '{}'", host_clean))
691 }
692}
693
694#[derive(Debug)]
696pub struct ChannelTx<M: RemoteMessage> {
697 inner: ChannelTxKind<M>,
698}
699
700#[derive(Debug)]
702enum ChannelTxKind<M: RemoteMessage> {
703 Local(local::LocalTx<M>),
704 Tcp(net::NetTx<M>),
705 MetaTls(net::NetTx<M>),
706 Unix(net::NetTx<M>),
707 Sim(sim::SimTx<M>),
708}
709
710#[async_trait]
711impl<M: RemoteMessage> Tx<M> for ChannelTx<M> {
712 fn try_post(&self, message: M, return_channel: oneshot::Sender<M>) -> Result<(), SendError<M>> {
713 match &self.inner {
714 ChannelTxKind::Local(tx) => tx.try_post(message, return_channel),
715 ChannelTxKind::Tcp(tx) => tx.try_post(message, return_channel),
716 ChannelTxKind::MetaTls(tx) => tx.try_post(message, return_channel),
717 ChannelTxKind::Sim(tx) => tx.try_post(message, return_channel),
718 ChannelTxKind::Unix(tx) => tx.try_post(message, return_channel),
719 }
720 }
721
722 fn addr(&self) -> ChannelAddr {
723 match &self.inner {
724 ChannelTxKind::Local(tx) => tx.addr(),
725 ChannelTxKind::Tcp(tx) => Tx::<M>::addr(tx),
726 ChannelTxKind::MetaTls(tx) => Tx::<M>::addr(tx),
727 ChannelTxKind::Sim(tx) => tx.addr(),
728 ChannelTxKind::Unix(tx) => Tx::<M>::addr(tx),
729 }
730 }
731
732 fn status(&self) -> &watch::Receiver<TxStatus> {
733 match &self.inner {
734 ChannelTxKind::Local(tx) => tx.status(),
735 ChannelTxKind::Tcp(tx) => tx.status(),
736 ChannelTxKind::MetaTls(tx) => tx.status(),
737 ChannelTxKind::Sim(tx) => tx.status(),
738 ChannelTxKind::Unix(tx) => tx.status(),
739 }
740 }
741}
742
743#[derive(Debug)]
745pub struct ChannelRx<M: RemoteMessage> {
746 inner: ChannelRxKind<M>,
747}
748
749#[derive(Debug)]
751enum ChannelRxKind<M: RemoteMessage> {
752 Local(local::LocalRx<M>),
753 Tcp(net::NetRx<M>),
754 MetaTls(net::NetRx<M>),
755 Unix(net::NetRx<M>),
756 Sim(sim::SimRx<M>),
757}
758
759#[async_trait]
760impl<M: RemoteMessage> Rx<M> for ChannelRx<M> {
761 async fn recv(&mut self) -> Result<M, ChannelError> {
762 match &mut self.inner {
763 ChannelRxKind::Local(rx) => rx.recv().await,
764 ChannelRxKind::Tcp(rx) => rx.recv().await,
765 ChannelRxKind::MetaTls(rx) => rx.recv().await,
766 ChannelRxKind::Sim(rx) => rx.recv().await,
767 ChannelRxKind::Unix(rx) => rx.recv().await,
768 }
769 }
770
771 fn addr(&self) -> ChannelAddr {
772 match &self.inner {
773 ChannelRxKind::Local(rx) => rx.addr(),
774 ChannelRxKind::Tcp(rx) => rx.addr(),
775 ChannelRxKind::MetaTls(rx) => rx.addr(),
776 ChannelRxKind::Sim(rx) => rx.addr(),
777 ChannelRxKind::Unix(rx) => rx.addr(),
778 }
779 }
780}
781
782#[allow(clippy::result_large_err)] pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, ChannelError> {
787 tracing::debug!(name = "dial", "dialing channel {}", addr);
788 let inner = match addr {
789 ChannelAddr::Local(port) => ChannelTxKind::Local(local::dial(port)?),
790 ChannelAddr::Tcp(addr) => ChannelTxKind::Tcp(net::tcp::dial(addr)),
791 ChannelAddr::MetaTls(meta_addr) => ChannelTxKind::MetaTls(net::meta::dial(meta_addr)?),
792 ChannelAddr::Sim(sim_addr) => ChannelTxKind::Sim(sim::dial::<M>(sim_addr)?),
793 ChannelAddr::Unix(path) => ChannelTxKind::Unix(net::unix::dial(path)),
794 };
795 Ok(ChannelTx { inner })
796}
797
798#[crate::instrument]
801pub fn serve<M: RemoteMessage>(
802 addr: ChannelAddr,
803) -> Result<(ChannelAddr, ChannelRx<M>), ChannelError> {
804 tracing::debug!(name = "serve", "serving channel address {}", addr);
805 match addr {
806 ChannelAddr::Tcp(addr) => {
807 let (addr, rx) = net::tcp::serve::<M>(addr)?;
808 Ok((addr, ChannelRxKind::Tcp(rx)))
809 }
810 ChannelAddr::MetaTls(meta_addr) => {
811 let (addr, rx) = net::meta::serve::<M>(meta_addr)?;
812 Ok((addr, ChannelRxKind::MetaTls(rx)))
813 }
814 ChannelAddr::Unix(path) => {
815 let (addr, rx) = net::unix::serve::<M>(path)?;
816 Ok((addr, ChannelRxKind::Unix(rx)))
817 }
818 ChannelAddr::Local(0) => {
819 let (port, rx) = local::serve::<M>();
820 Ok((ChannelAddr::Local(port), ChannelRxKind::Local(rx)))
821 }
822 ChannelAddr::Sim(sim_addr) => {
823 let (addr, rx) = sim::serve::<M>(sim_addr)?;
824 Ok((addr, ChannelRxKind::Sim(rx)))
825 }
826 ChannelAddr::Local(a) => Err(ChannelError::InvalidAddress(format!(
827 "invalid local addr: {}",
828 a
829 ))),
830 }
831 .map(|(addr, inner)| (addr, ChannelRx { inner }))
832}
833
834pub fn serve_local<M: RemoteMessage>() -> (ChannelAddr, ChannelRx<M>) {
837 let (port, rx) = local::serve::<M>();
838 (
839 ChannelAddr::Local(port),
840 ChannelRx {
841 inner: ChannelRxKind::Local(rx),
842 },
843 )
844}
845
846#[cfg(test)]
847mod tests {
848 use std::assert_matches::assert_matches;
849 use std::collections::HashSet;
850 use std::net::IpAddr;
851 use std::net::Ipv4Addr;
852 use std::net::Ipv6Addr;
853 use std::time::Duration;
854
855 use tokio::task::JoinSet;
856
857 use super::net::*;
858 use super::*;
859 use crate::clock::Clock;
860 use crate::clock::RealClock;
861
862 #[test]
863 fn test_channel_addr() {
864 let cases_ok = vec![
865 (
866 "tcp<DELIM>[::1]:1234",
867 ChannelAddr::Tcp(SocketAddr::new(
868 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
869 1234,
870 )),
871 ),
872 (
873 "tcp<DELIM>127.0.0.1:8080",
874 ChannelAddr::Tcp(SocketAddr::new(
875 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
876 8080,
877 )),
878 ),
879 #[cfg(target_os = "linux")]
880 ("local<DELIM>123", ChannelAddr::Local(123)),
881 (
882 "unix<DELIM>@yolo",
883 ChannelAddr::Unix(
884 unix::SocketAddr::from_abstract_name("yolo")
885 .expect("can't make socket from abstract name"),
886 ),
887 ),
888 (
889 "unix<DELIM>/cool/socket-path",
890 ChannelAddr::Unix(
891 unix::SocketAddr::from_pathname("/cool/socket-path")
892 .expect("can't make socket from path"),
893 ),
894 ),
895 ];
896
897 for (raw, parsed) in cases_ok.clone() {
898 for delim in ["!", ":"] {
899 let raw = raw.replace("<DELIM>", delim);
900 assert_eq!(raw.parse::<ChannelAddr>().unwrap(), parsed);
901 }
902 }
903
904 for (raw, parsed) in cases_ok {
905 for delim in ["!", ":"] {
906 let raw = format!("sim{}{}", delim, raw.replace("<DELIM>", delim));
908 assert_eq!(
909 raw.parse::<ChannelAddr>().unwrap(),
910 ChannelAddr::Sim(SimAddr::new(parsed.clone()).unwrap())
911 );
912 }
913 }
914
915 let cases_err = vec![
916 ("tcp:abcdef..123124", "invalid socket address syntax"),
917 ("xxx:foo", "no such channel type: xxx"),
918 ("127.0.0.1", "no channel type specified"),
919 ("local:abc", "invalid digit found in string"),
920 ];
921
922 for (raw, error) in cases_err {
923 let Err(err) = raw.parse::<ChannelAddr>() else {
924 panic!("expected error parsing: {}", &raw)
925 };
926 assert_eq!(format!("{}", err), error);
927 }
928 }
929
930 #[test]
931 fn test_zmq_style_channel_addr() {
932 assert_eq!(
934 ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080").unwrap(),
935 ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap())
936 );
937
938 assert_eq!(
940 ChannelAddr::from_zmq_url("tcp://*:5555").unwrap(),
941 ChannelAddr::Tcp("[::]:5555".parse().unwrap())
942 );
943
944 assert_eq!(
946 ChannelAddr::from_zmq_url("inproc://12345").unwrap(),
947 ChannelAddr::Local(12345)
948 );
949
950 assert_eq!(
952 ChannelAddr::from_zmq_url("ipc:///tmp/my-socket").unwrap(),
953 ChannelAddr::Unix(unix::SocketAddr::from_pathname("/tmp/my-socket").unwrap())
954 );
955
956 assert_eq!(
958 ChannelAddr::from_zmq_url("metatls://example.com:443").unwrap(),
959 ChannelAddr::MetaTls(MetaTlsAddr::Host {
960 hostname: "example.com".to_string(),
961 port: 443
962 })
963 );
964
965 assert_eq!(
967 ChannelAddr::from_zmq_url("metatls://192.168.1.1:443").unwrap(),
968 ChannelAddr::MetaTls(MetaTlsAddr::Host {
969 hostname: "192.168.1.1".to_string(),
970 port: 443
971 })
972 );
973
974 assert_eq!(
976 ChannelAddr::from_zmq_url("metatls://*:8443").unwrap(),
977 ChannelAddr::MetaTls(MetaTlsAddr::Host {
978 hostname: "::".to_string(),
979 port: 8443
980 })
981 );
982
983 let tcp_hostname_result = ChannelAddr::from_zmq_url("tcp://localhost:8080");
987 assert!(tcp_hostname_result.is_ok());
988
989 assert_eq!(
991 ChannelAddr::from_zmq_url("tcp://[::1]:1234").unwrap(),
992 ChannelAddr::Tcp("[::1]:1234".parse().unwrap())
993 );
994
995 assert!(ChannelAddr::from_zmq_url("invalid://scheme").is_err());
997 assert!(ChannelAddr::from_zmq_url("tcp://invalid-port").is_err());
998 assert!(ChannelAddr::from_zmq_url("metatls://no-port").is_err());
999 assert!(ChannelAddr::from_zmq_url("inproc://not-a-number").is_err());
1000 }
1001
1002 #[tokio::test]
1003 async fn test_multiple_connections() {
1004 for addr in ChannelTransport::all().map(ChannelAddr::any) {
1005 let (listen_addr, mut rx) = crate::channel::serve::<u64>(addr).unwrap();
1006
1007 let mut sends: JoinSet<()> = JoinSet::new();
1008 for message in 0u64..100u64 {
1009 let addr = listen_addr.clone();
1010 sends.spawn(async move {
1011 let tx = dial::<u64>(addr).unwrap();
1012 tx.try_post(message, oneshot::channel().0).unwrap();
1013 });
1014 }
1015
1016 let mut received: HashSet<u64> = HashSet::new();
1017 while received.len() < 100 {
1018 received.insert(rx.recv().await.unwrap());
1019 }
1020
1021 for message in 0u64..100u64 {
1022 assert!(received.contains(&message));
1023 }
1024
1025 loop {
1026 match sends.join_next().await {
1027 Some(Ok(())) => (),
1028 Some(Err(err)) => panic!("{}", err),
1029 None => break,
1030 }
1031 }
1032 }
1033 }
1034
1035 #[tokio::test]
1036 async fn test_server_close() {
1037 for addr in ChannelTransport::all().map(ChannelAddr::any) {
1038 if net::is_net_addr(&addr) {
1039 continue;
1042 }
1043
1044 let (listen_addr, rx) = crate::channel::serve::<u64>(addr).unwrap();
1045
1046 let tx = dial::<u64>(listen_addr).unwrap();
1047 tx.try_post(123, oneshot::channel().0).unwrap();
1048 drop(rx);
1049
1050 let start = RealClock.now();
1055
1056 let result = loop {
1057 let result = tx.try_post(123, oneshot::channel().0);
1058 if result.is_err() || start.elapsed() > Duration::from_secs(10) {
1059 break result;
1060 }
1061 };
1062 assert_matches!(result, Err(SendError(ChannelError::Closed, 123)));
1063 }
1064 }
1065
1066 fn addrs() -> Vec<ChannelAddr> {
1067 use rand::Rng;
1068 use rand::distributions::Uniform;
1069
1070 let rng = rand::thread_rng();
1071 vec![
1072 "tcp:[::1]:0".parse().unwrap(),
1073 "local:0".parse().unwrap(),
1074 #[cfg(target_os = "linux")]
1075 "unix:".parse().unwrap(),
1076 #[cfg(target_os = "linux")]
1077 format!(
1078 "unix:@{}",
1079 rng.sample_iter(Uniform::new_inclusive('a', 'z'))
1080 .take(10)
1081 .collect::<String>()
1082 )
1083 .parse()
1084 .unwrap(),
1085 ]
1086 }
1087
1088 #[tokio::test]
1089 #[cfg_attr(not(feature = "fb"), ignore)]
1091 async fn test_dial_serve() {
1092 for addr in addrs() {
1093 let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1094 let tx = crate::channel::dial(listen_addr).unwrap();
1095 tx.try_post(123, oneshot::channel().0).unwrap();
1096 assert_eq!(rx.recv().await.unwrap(), 123);
1097 }
1098 }
1099
1100 #[tokio::test]
1101 #[cfg_attr(not(feature = "fb"), ignore)]
1103 async fn test_send() {
1104 let config = crate::config::global::lock();
1105
1106 let _guard1 = config.override_key(
1108 crate::config::MESSAGE_DELIVERY_TIMEOUT,
1109 Duration::from_secs(1),
1110 );
1111 let _guard2 = config.override_key(crate::config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
1112 for addr in addrs() {
1113 let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1114 let tx = crate::channel::dial(listen_addr).unwrap();
1115 tx.send(123).await.unwrap();
1116 assert_eq!(rx.recv().await.unwrap(), 123);
1117
1118 drop(rx);
1119 assert_matches!(
1120 tx.send(123).await.unwrap_err(),
1121 SendError(ChannelError::Closed, 123)
1122 );
1123 }
1124 }
1125}