1#![allow(dead_code)] use core::net::SocketAddr;
15use std::fmt;
16use std::net::IpAddr;
17use std::net::Ipv6Addr;
18#[cfg(target_os = "linux")]
19use std::os::linux::net::SocketAddrExt;
20use std::panic::Location;
21use std::str::FromStr;
22
23use async_trait::async_trait;
24use enum_as_inner::EnumAsInner;
25use lazy_static::lazy_static;
26use local_ip_address::local_ipv6;
27use serde::Deserialize;
28use serde::Serialize;
29use tokio::sync::mpsc;
30use tokio::sync::oneshot;
31use tokio::sync::watch;
32
33use crate as hyperactor;
34use crate::Named;
35use crate::RemoteMessage;
36use crate::attrs::AttrValue;
37use crate::channel::sim::SimAddr;
38use crate::simnet::SimNetError;
39
40pub(crate) mod local;
41pub(crate) mod net;
42pub mod sim;
43
44#[derive(thiserror::Error, Debug)]
46pub enum ChannelError {
47 #[error("channel closed")]
49 Closed,
50
51 #[error("send: {0}")]
53 Send(#[source] anyhow::Error),
54
55 #[error(transparent)]
57 Client(#[from] net::ClientError),
58
59 #[error("invalid address {0:?}")]
61 InvalidAddress(String),
62
63 #[error(transparent)]
65 Server(#[from] net::ServerError),
66
67 #[error(transparent)]
69 Bincode(#[from] Box<bincode::ErrorKind>),
70
71 #[error(transparent)]
73 Data(#[from] crate::data::Error),
74
75 #[error(transparent)]
77 Other(#[from] anyhow::Error),
78
79 #[error("operation timed out after {0:?}")]
81 Timeout(std::time::Duration),
82
83 #[error(transparent)]
85 SimNetError(#[from] SimNetError),
86}
87
88#[derive(thiserror::Error, Debug)]
90#[error("{0}")]
91pub struct SendError<M: RemoteMessage>(#[source] pub ChannelError, pub M);
92
93impl<M: RemoteMessage> From<SendError<M>> for ChannelError {
94 fn from(error: SendError<M>) -> Self {
95 error.0
96 }
97}
98
99#[derive(Debug, Copy, Clone, PartialEq)]
101pub enum TxStatus {
102 Active,
104 Closed,
106}
107
108#[async_trait]
110pub trait Tx<M: RemoteMessage>: std::fmt::Debug {
111 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>);
117
118 #[allow(clippy::result_large_err)] fn try_post(&self, message: M, return_channel: oneshot::Sender<SendError<M>>) {
123 self.do_post(message, Some(return_channel));
124 }
125
126 fn post(&self, message: M) {
128 self.do_post(message, None);
129 }
130
131 async fn send(&self, message: M) -> Result<(), SendError<M>> {
134 let (tx, rx) = oneshot::channel();
135 self.try_post(message, tx);
136 match rx.await {
137 Ok(err) => Err(err),
139
140 Err(_) => Ok(()),
143 }
144 }
145
146 fn addr(&self) -> ChannelAddr;
148
149 fn status(&self) -> &watch::Receiver<TxStatus>;
151}
152
153#[async_trait]
155pub trait Rx<M: RemoteMessage>: std::fmt::Debug {
156 async fn recv(&mut self) -> Result<M, ChannelError>;
159
160 fn addr(&self) -> ChannelAddr;
162}
163
164#[derive(Debug)]
165struct MpscTx<M: RemoteMessage> {
166 tx: mpsc::UnboundedSender<M>,
167 addr: ChannelAddr,
168 status: watch::Receiver<TxStatus>,
169}
170
171impl<M: RemoteMessage> MpscTx<M> {
172 pub fn new(tx: mpsc::UnboundedSender<M>, addr: ChannelAddr) -> (Self, watch::Sender<TxStatus>) {
173 let (sender, receiver) = watch::channel(TxStatus::Active);
174 (
175 Self {
176 tx,
177 addr,
178 status: receiver,
179 },
180 sender,
181 )
182 }
183}
184
185#[async_trait]
186impl<M: RemoteMessage> Tx<M> for MpscTx<M> {
187 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
188 if let Err(mpsc::error::SendError(message)) = self.tx.send(message) {
189 if let Some(return_channel) = return_channel {
190 return_channel
191 .send(SendError(ChannelError::Closed, message))
192 .unwrap_or_else(|m| tracing::warn!("failed to deliver SendError: {}", m));
193 }
194 }
195 }
196
197 fn addr(&self) -> ChannelAddr {
198 self.addr.clone()
199 }
200
201 fn status(&self) -> &watch::Receiver<TxStatus> {
202 &self.status
203 }
204}
205
206#[derive(Debug)]
207struct MpscRx<M: RemoteMessage> {
208 rx: mpsc::UnboundedReceiver<M>,
209 addr: ChannelAddr,
210 status_sender: watch::Sender<TxStatus>,
212}
213
214impl<M: RemoteMessage> MpscRx<M> {
215 pub fn new(
216 rx: mpsc::UnboundedReceiver<M>,
217 addr: ChannelAddr,
218 status_sender: watch::Sender<TxStatus>,
219 ) -> Self {
220 Self {
221 rx,
222 addr,
223 status_sender,
224 }
225 }
226}
227
228impl<M: RemoteMessage> Drop for MpscRx<M> {
229 fn drop(&mut self) {
230 let _ = self.status_sender.send(TxStatus::Closed);
231 }
232}
233
234#[async_trait]
235impl<M: RemoteMessage> Rx<M> for MpscRx<M> {
236 async fn recv(&mut self) -> Result<M, ChannelError> {
237 self.rx.recv().await.ok_or(ChannelError::Closed)
238 }
239
240 fn addr(&self) -> ChannelAddr {
241 self.addr.clone()
242 }
243}
244
245#[derive(
247 Clone,
248 Debug,
249 PartialEq,
250 Eq,
251 Hash,
252 Serialize,
253 Deserialize,
254 strum::EnumIter,
255 strum::Display,
256 strum::EnumString
257)]
258pub enum TcpMode {
259 Localhost,
261 Hostname,
263}
264
265#[derive(
267 Clone,
268 Debug,
269 PartialEq,
270 Eq,
271 Hash,
272 Serialize,
273 Deserialize,
274 strum::EnumIter,
275 strum::Display,
276 strum::EnumString
277)]
278pub enum TlsMode {
279 IpV6,
281 Hostname,
283 }
285
286#[derive(
290 Clone,
291 Debug,
292 PartialEq,
293 Eq,
294 Hash,
295 Serialize,
296 Deserialize,
297 Ord,
298 PartialOrd,
299 EnumAsInner
300)]
301pub enum MetaTlsAddr {
302 Host {
304 hostname: Hostname,
306 port: Port,
308 },
309 Socket(SocketAddr),
311}
312
313impl MetaTlsAddr {
314 pub fn port(&self) -> Port {
316 match self {
317 Self::Host { port, .. } => *port,
318 Self::Socket(addr) => addr.port(),
319 }
320 }
321
322 pub fn hostname(&self) -> Option<&str> {
324 match self {
325 Self::Host { hostname, .. } => Some(hostname),
326 Self::Socket(_) => None,
327 }
328 }
329}
330
331impl fmt::Display for MetaTlsAddr {
332 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
333 match self {
334 Self::Host { hostname, port } => write!(f, "{}:{}", hostname, port),
335 Self::Socket(addr) => write!(f, "{}", addr),
336 }
337 }
338}
339
340#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Named)]
342pub enum ChannelTransport {
343 Tcp(TcpMode),
345
346 MetaTls(TlsMode),
348
349 Local,
351
352 Sim(Box<ChannelTransport>),
354
355 Unix,
357}
358
359impl fmt::Display for ChannelTransport {
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 match self {
362 Self::Tcp(mode) => write!(f, "tcp({:?})", mode),
363 Self::MetaTls(mode) => write!(f, "metatls({:?})", mode),
364 Self::Local => write!(f, "local"),
365 Self::Sim(transport) => write!(f, "sim({})", transport),
366 Self::Unix => write!(f, "unix"),
367 }
368 }
369}
370
371impl FromStr for ChannelTransport {
372 type Err = anyhow::Error;
373
374 fn from_str(s: &str) -> Result<Self, Self::Err> {
375 if let Some(rest) = s.strip_prefix("sim(") {
377 if let Some(end) = rest.rfind(')') {
378 let inner = &rest[..end];
379 let inner_transport = ChannelTransport::from_str(inner)?;
380 return Ok(ChannelTransport::Sim(Box::new(inner_transport)));
381 } else {
382 return Err(anyhow::anyhow!("invalid sim transport"));
383 }
384 }
385
386 match s {
387 "tcp" => Ok(ChannelTransport::Tcp(TcpMode::Hostname)),
389 s if s.starts_with("tcp(") => {
390 let inner = &s["tcp(".len()..s.len() - 1];
391 let mode = inner.parse()?;
392 Ok(ChannelTransport::Tcp(mode))
393 }
394 "local" => Ok(ChannelTransport::Local),
395 "unix" => Ok(ChannelTransport::Unix),
396 s if s.starts_with("metatls(") && s.ends_with(")") => {
397 let inner = &s["metatls(".len()..s.len() - 1];
398 let mode = inner.parse()?;
399 Ok(ChannelTransport::MetaTls(mode))
400 }
401 unknown => Err(anyhow::anyhow!("unknown channel transport: {}", unknown)),
402 }
403 }
404}
405
406impl ChannelTransport {
407 pub fn all() -> [ChannelTransport; 3] {
409 [
410 ChannelTransport::Tcp(TcpMode::Hostname),
413 ChannelTransport::Local,
414 ChannelTransport::Unix,
415 ]
419 }
420
421 pub fn any(&self) -> ChannelAddr {
423 ChannelAddr::any(self.clone())
424 }
425
426 pub fn is_remote(&self) -> bool {
428 match self {
429 ChannelTransport::Tcp(_) => true,
430 ChannelTransport::MetaTls(_) => true,
431 ChannelTransport::Local => false,
432 ChannelTransport::Sim(_) => false,
433 ChannelTransport::Unix => false,
434 }
435 }
436}
437
438impl AttrValue for ChannelTransport {
439 fn display(&self) -> String {
440 self.to_string()
441 }
442
443 fn parse(s: &str) -> Result<Self, anyhow::Error> {
444 s.parse()
445 }
446}
447
448pub type Hostname = String;
450
451pub type Port = u16;
453
454#[derive(
477 Clone,
478 Debug,
479 PartialEq,
480 Eq,
481 Ord,
482 PartialOrd,
483 Serialize,
484 Deserialize,
485 Hash,
486 Named
487)]
488pub enum ChannelAddr {
489 Tcp(SocketAddr),
492
493 MetaTls(MetaTlsAddr),
497
498 Local(u64),
501
502 Sim(SimAddr),
504
505 Unix(net::unix::SocketAddr),
508}
509
510impl From<SocketAddr> for ChannelAddr {
511 fn from(value: SocketAddr) -> Self {
512 Self::Tcp(value)
513 }
514}
515
516impl From<net::unix::SocketAddr> for ChannelAddr {
517 fn from(value: net::unix::SocketAddr) -> Self {
518 Self::Unix(value)
519 }
520}
521
522impl From<std::os::unix::net::SocketAddr> for ChannelAddr {
523 fn from(value: std::os::unix::net::SocketAddr) -> Self {
524 Self::Unix(net::unix::SocketAddr::new(value))
525 }
526}
527
528impl From<tokio::net::unix::SocketAddr> for ChannelAddr {
529 fn from(value: tokio::net::unix::SocketAddr) -> Self {
530 std::os::unix::net::SocketAddr::from(value).into()
531 }
532}
533
534impl ChannelAddr {
535 pub fn any(transport: ChannelTransport) -> Self {
538 match transport {
539 ChannelTransport::Tcp(mode) => {
540 let ip = match mode {
541 TcpMode::Localhost => IpAddr::V6(Ipv6Addr::LOCALHOST),
542 TcpMode::Hostname => {
543 hostname::get()
544 .ok()
545 .and_then(|hostname| {
546 hostname.to_str().and_then(|hostname_str| {
548 dns_lookup::lookup_host(hostname_str)
549 .ok()
550 .and_then(|addresses| addresses.first().cloned())
551 })
552 })
553 .expect("failed to resolve hostname to ip address")
554 }
555 };
556 Self::Tcp(SocketAddr::new(ip, 0))
557 }
558 ChannelTransport::MetaTls(mode) => {
559 let host_address = match mode {
560 TlsMode::Hostname => hostname::get()
561 .ok()
562 .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
563 .unwrap_or("unknown_host".to_string()),
564 TlsMode::IpV6 => local_ipv6()
565 .ok()
566 .and_then(|addr| addr.to_string().parse().ok())
567 .expect("failed to retrieve ipv6 address"),
568 };
569 Self::MetaTls(MetaTlsAddr::Host {
570 hostname: host_address,
571 port: 0,
572 })
573 }
574 ChannelTransport::Local => Self::Local(0),
575 ChannelTransport::Sim(transport) => sim::any(*transport),
576 ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()),
578 }
579 }
580
581 pub fn transport(&self) -> ChannelTransport {
583 match self {
584 Self::Tcp(addr) => {
585 if addr.ip().is_loopback() {
586 ChannelTransport::Tcp(TcpMode::Localhost)
587 } else {
588 ChannelTransport::Tcp(TcpMode::Hostname)
589 }
590 }
591 Self::MetaTls(addr) => match addr {
592 MetaTlsAddr::Host { hostname, .. } => match hostname.parse::<IpAddr>() {
593 Ok(IpAddr::V6(_)) => ChannelTransport::MetaTls(TlsMode::IpV6),
594 Ok(IpAddr::V4(_)) => ChannelTransport::MetaTls(TlsMode::Hostname),
595 Err(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
596 },
597 MetaTlsAddr::Socket(socket_addr) => match socket_addr.ip() {
598 IpAddr::V6(_) => ChannelTransport::MetaTls(TlsMode::IpV6),
599 IpAddr::V4(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
600 },
601 },
602 Self::Local(_) => ChannelTransport::Local,
603 Self::Sim(addr) => ChannelTransport::Sim(Box::new(addr.transport())),
604 Self::Unix(_) => ChannelTransport::Unix,
605 }
606 }
607}
608
609impl fmt::Display for ChannelAddr {
610 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
611 match self {
612 Self::Tcp(addr) => write!(f, "tcp:{}", addr),
613 Self::MetaTls(addr) => write!(f, "metatls:{}", addr),
614 Self::Local(index) => write!(f, "local:{}", index),
615 Self::Sim(sim_addr) => write!(f, "sim:{}", sim_addr),
616 Self::Unix(addr) => write!(f, "unix:{}", addr),
617 }
618 }
619}
620
621impl FromStr for ChannelAddr {
622 type Err = anyhow::Error;
623
624 fn from_str(addr: &str) -> Result<Self, Self::Err> {
625 match addr.split_once('!').or_else(|| addr.split_once(':')) {
626 Some(("local", rest)) => rest
627 .parse::<u64>()
628 .map(Self::Local)
629 .map_err(anyhow::Error::from),
630 Some(("tcp", rest)) => rest
631 .parse::<SocketAddr>()
632 .map(Self::Tcp)
633 .map_err(anyhow::Error::from),
634 Some(("metatls", rest)) => net::meta::parse(rest).map_err(|e| e.into()),
635 Some(("sim", rest)) => sim::parse(rest).map_err(|e| e.into()),
636 Some(("unix", rest)) => Ok(Self::Unix(net::unix::SocketAddr::from_str(rest)?)),
637 Some((r#type, _)) => Err(anyhow::anyhow!("no such channel type: {type}")),
638 None => Err(anyhow::anyhow!("no channel type specified")),
639 }
640 }
641}
642
643impl ChannelAddr {
644 pub fn from_zmq_url(address: &str) -> Result<Self, anyhow::Error> {
651 let (scheme, address) = address.split_once("://").ok_or_else(|| {
653 anyhow::anyhow!("address must be in url form scheme://endppoint {}", address)
654 })?;
655
656 match scheme {
657 "tcp" => {
658 let (host, port) = Self::split_host_port(address)?;
659
660 if host == "*" {
661 Ok(Self::Tcp(SocketAddr::new("::".parse().unwrap(), port)))
663 } else {
664 let socket_addr = Self::resolve_hostname_to_socket_addr(host, port)?;
666 Ok(Self::Tcp(socket_addr))
667 }
668 }
669 "inproc" => {
670 let port = address.parse::<u64>().map_err(|_| {
673 anyhow::anyhow!("inproc endpoint must be a valid port number: {}", address)
674 })?;
675 Ok(Self::Local(port))
676 }
677 "ipc" => {
678 Ok(Self::Unix(net::unix::SocketAddr::from_str(address)?))
680 }
681 "metatls" => {
682 let (host, port) = Self::split_host_port(address)?;
683
684 if host == "*" {
685 Ok(Self::MetaTls(MetaTlsAddr::Host {
687 hostname: std::net::Ipv6Addr::UNSPECIFIED.to_string(),
688 port,
689 }))
690 } else {
691 Ok(Self::MetaTls(MetaTlsAddr::Host {
692 hostname: host.to_string(),
693 port,
694 }))
695 }
696 }
697 scheme => Err(anyhow::anyhow!("unsupported ZMQ scheme: {}", scheme)),
698 }
699 }
700
701 fn split_host_port(address: &str) -> Result<(&str, u16), anyhow::Error> {
703 if let Some((host, port_str)) = address.rsplit_once(':') {
704 let port: u16 = port_str
705 .parse()
706 .map_err(|_| anyhow::anyhow!("invalid port: {}", port_str))?;
707 Ok((host, port))
708 } else {
709 Err(anyhow::anyhow!("invalid address format: {}", address))
710 }
711 }
712
713 fn resolve_hostname_to_socket_addr(host: &str, port: u16) -> Result<SocketAddr, anyhow::Error> {
715 let host_clean = if host.starts_with('[') && host.ends_with(']') {
717 &host[1..host.len() - 1]
718 } else {
719 host
720 };
721
722 if let Ok(ip_addr) = host_clean.parse::<IpAddr>() {
724 return Ok(SocketAddr::new(ip_addr, port));
725 }
726
727 use std::net::ToSocketAddrs;
729 let mut addrs = (host_clean, port)
730 .to_socket_addrs()
731 .map_err(|e| anyhow::anyhow!("failed to resolve hostname '{}': {}", host_clean, e))?;
732
733 addrs
734 .next()
735 .ok_or_else(|| anyhow::anyhow!("no addresses found for hostname '{}'", host_clean))
736 }
737}
738
739#[derive(Debug)]
741pub struct ChannelTx<M: RemoteMessage> {
742 inner: ChannelTxKind<M>,
743}
744
745#[derive(Debug)]
747enum ChannelTxKind<M: RemoteMessage> {
748 Local(local::LocalTx<M>),
749 Tcp(net::NetTx<M>),
750 MetaTls(net::NetTx<M>),
751 Unix(net::NetTx<M>),
752 Sim(sim::SimTx<M>),
753}
754
755#[async_trait]
756impl<M: RemoteMessage> Tx<M> for ChannelTx<M> {
757 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
758 match &self.inner {
759 ChannelTxKind::Local(tx) => tx.do_post(message, return_channel),
760 ChannelTxKind::Tcp(tx) => tx.do_post(message, return_channel),
761 ChannelTxKind::MetaTls(tx) => tx.do_post(message, return_channel),
762 ChannelTxKind::Sim(tx) => tx.do_post(message, return_channel),
763 ChannelTxKind::Unix(tx) => tx.do_post(message, return_channel),
764 }
765 }
766
767 fn addr(&self) -> ChannelAddr {
768 match &self.inner {
769 ChannelTxKind::Local(tx) => tx.addr(),
770 ChannelTxKind::Tcp(tx) => Tx::<M>::addr(tx),
771 ChannelTxKind::MetaTls(tx) => Tx::<M>::addr(tx),
772 ChannelTxKind::Sim(tx) => tx.addr(),
773 ChannelTxKind::Unix(tx) => Tx::<M>::addr(tx),
774 }
775 }
776
777 fn status(&self) -> &watch::Receiver<TxStatus> {
778 match &self.inner {
779 ChannelTxKind::Local(tx) => tx.status(),
780 ChannelTxKind::Tcp(tx) => tx.status(),
781 ChannelTxKind::MetaTls(tx) => tx.status(),
782 ChannelTxKind::Sim(tx) => tx.status(),
783 ChannelTxKind::Unix(tx) => tx.status(),
784 }
785 }
786}
787
788#[derive(Debug)]
790pub struct ChannelRx<M: RemoteMessage> {
791 inner: ChannelRxKind<M>,
792}
793
794#[derive(Debug)]
796enum ChannelRxKind<M: RemoteMessage> {
797 Local(local::LocalRx<M>),
798 Tcp(net::NetRx<M>),
799 MetaTls(net::NetRx<M>),
800 Unix(net::NetRx<M>),
801 Sim(sim::SimRx<M>),
802}
803
804#[async_trait]
805impl<M: RemoteMessage> Rx<M> for ChannelRx<M> {
806 async fn recv(&mut self) -> Result<M, ChannelError> {
807 match &mut self.inner {
808 ChannelRxKind::Local(rx) => rx.recv().await,
809 ChannelRxKind::Tcp(rx) => rx.recv().await,
810 ChannelRxKind::MetaTls(rx) => rx.recv().await,
811 ChannelRxKind::Sim(rx) => rx.recv().await,
812 ChannelRxKind::Unix(rx) => rx.recv().await,
813 }
814 }
815
816 fn addr(&self) -> ChannelAddr {
817 match &self.inner {
818 ChannelRxKind::Local(rx) => rx.addr(),
819 ChannelRxKind::Tcp(rx) => rx.addr(),
820 ChannelRxKind::MetaTls(rx) => rx.addr(),
821 ChannelRxKind::Sim(rx) => rx.addr(),
822 ChannelRxKind::Unix(rx) => rx.addr(),
823 }
824 }
825}
826
827#[allow(clippy::result_large_err)] #[track_caller]
832pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, ChannelError> {
833 tracing::debug!(name = "dial", caller = %Location::caller(), %addr, "dialing channel {}", addr);
834 let inner = match addr {
835 ChannelAddr::Local(port) => ChannelTxKind::Local(local::dial(port)?),
836 ChannelAddr::Tcp(addr) => ChannelTxKind::Tcp(net::tcp::dial(addr)),
837 ChannelAddr::MetaTls(meta_addr) => ChannelTxKind::MetaTls(net::meta::dial(meta_addr)?),
838 ChannelAddr::Sim(sim_addr) => ChannelTxKind::Sim(sim::dial::<M>(sim_addr)?),
839 ChannelAddr::Unix(path) => ChannelTxKind::Unix(net::unix::dial(path)),
840 };
841 Ok(ChannelTx { inner })
842}
843
844#[crate::instrument]
847#[track_caller]
848pub fn serve<M: RemoteMessage>(
849 addr: ChannelAddr,
850) -> Result<(ChannelAddr, ChannelRx<M>), ChannelError> {
851 let caller = Location::caller();
852 match addr {
853 ChannelAddr::Tcp(addr) => {
854 let (addr, rx) = net::tcp::serve::<M>(addr)?;
855 Ok((addr, ChannelRxKind::Tcp(rx)))
856 }
857 ChannelAddr::MetaTls(meta_addr) => {
858 let (addr, rx) = net::meta::serve::<M>(meta_addr)?;
859 Ok((addr, ChannelRxKind::MetaTls(rx)))
860 }
861 ChannelAddr::Unix(path) => {
862 let (addr, rx) = net::unix::serve::<M>(path)?;
863 Ok((addr, ChannelRxKind::Unix(rx)))
864 }
865 ChannelAddr::Local(0) => {
866 let (port, rx) = local::serve::<M>();
867 Ok((ChannelAddr::Local(port), ChannelRxKind::Local(rx)))
868 }
869 ChannelAddr::Sim(sim_addr) => {
870 let (addr, rx) = sim::serve::<M>(sim_addr)?;
871 Ok((addr, ChannelRxKind::Sim(rx)))
872 }
873 ChannelAddr::Local(a) => Err(ChannelError::InvalidAddress(format!(
874 "invalid local addr: {}",
875 a
876 ))),
877 }
878 .map(|(addr, inner)| {
879 tracing::debug!(
880 name = "serve",
881 %addr,
882 %caller,
883 );
884 (addr, ChannelRx { inner })
885 })
886}
887
888pub fn serve_local<M: RemoteMessage>() -> (ChannelAddr, ChannelRx<M>) {
891 let (port, rx) = local::serve::<M>();
892 (
893 ChannelAddr::Local(port),
894 ChannelRx {
895 inner: ChannelRxKind::Local(rx),
896 },
897 )
898}
899
900#[cfg(test)]
901mod tests {
902 use std::assert_matches::assert_matches;
903 use std::collections::HashSet;
904 use std::net::IpAddr;
905 use std::net::Ipv4Addr;
906 use std::net::Ipv6Addr;
907 use std::time::Duration;
908
909 use tokio::task::JoinSet;
910
911 use super::net::*;
912 use super::*;
913 use crate::clock::Clock;
914 use crate::clock::RealClock;
915
916 #[test]
917 fn test_channel_addr() {
918 let cases_ok = vec![
919 (
920 "tcp<DELIM>[::1]:1234",
921 ChannelAddr::Tcp(SocketAddr::new(
922 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
923 1234,
924 )),
925 ),
926 (
927 "tcp<DELIM>127.0.0.1:8080",
928 ChannelAddr::Tcp(SocketAddr::new(
929 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
930 8080,
931 )),
932 ),
933 #[cfg(target_os = "linux")]
934 ("local<DELIM>123", ChannelAddr::Local(123)),
935 (
936 "unix<DELIM>@yolo",
937 ChannelAddr::Unix(
938 unix::SocketAddr::from_abstract_name("yolo")
939 .expect("can't make socket from abstract name"),
940 ),
941 ),
942 (
943 "unix<DELIM>/cool/socket-path",
944 ChannelAddr::Unix(
945 unix::SocketAddr::from_pathname("/cool/socket-path")
946 .expect("can't make socket from path"),
947 ),
948 ),
949 ];
950
951 for (raw, parsed) in cases_ok.clone() {
952 for delim in ["!", ":"] {
953 let raw = raw.replace("<DELIM>", delim);
954 assert_eq!(raw.parse::<ChannelAddr>().unwrap(), parsed);
955 }
956 }
957
958 for (raw, parsed) in cases_ok {
959 for delim in ["!", ":"] {
960 let raw = format!("sim{}{}", delim, raw.replace("<DELIM>", delim));
962 assert_eq!(
963 raw.parse::<ChannelAddr>().unwrap(),
964 ChannelAddr::Sim(SimAddr::new(parsed.clone()).unwrap())
965 );
966 }
967 }
968
969 let cases_err = vec![
970 ("tcp:abcdef..123124", "invalid socket address syntax"),
971 ("xxx:foo", "no such channel type: xxx"),
972 ("127.0.0.1", "no channel type specified"),
973 ("local:abc", "invalid digit found in string"),
974 ];
975
976 for (raw, error) in cases_err {
977 let Err(err) = raw.parse::<ChannelAddr>() else {
978 panic!("expected error parsing: {}", &raw)
979 };
980 assert_eq!(format!("{}", err), error);
981 }
982 }
983
984 #[test]
985 fn test_zmq_style_channel_addr() {
986 assert_eq!(
988 ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080").unwrap(),
989 ChannelAddr::Tcp("127.0.0.1:8080".parse().unwrap())
990 );
991
992 assert_eq!(
994 ChannelAddr::from_zmq_url("tcp://*:5555").unwrap(),
995 ChannelAddr::Tcp("[::]:5555".parse().unwrap())
996 );
997
998 assert_eq!(
1000 ChannelAddr::from_zmq_url("inproc://12345").unwrap(),
1001 ChannelAddr::Local(12345)
1002 );
1003
1004 assert_eq!(
1006 ChannelAddr::from_zmq_url("ipc:///tmp/my-socket").unwrap(),
1007 ChannelAddr::Unix(unix::SocketAddr::from_pathname("/tmp/my-socket").unwrap())
1008 );
1009
1010 assert_eq!(
1012 ChannelAddr::from_zmq_url("metatls://example.com:443").unwrap(),
1013 ChannelAddr::MetaTls(MetaTlsAddr::Host {
1014 hostname: "example.com".to_string(),
1015 port: 443
1016 })
1017 );
1018
1019 assert_eq!(
1021 ChannelAddr::from_zmq_url("metatls://192.168.1.1:443").unwrap(),
1022 ChannelAddr::MetaTls(MetaTlsAddr::Host {
1023 hostname: "192.168.1.1".to_string(),
1024 port: 443
1025 })
1026 );
1027
1028 assert_eq!(
1030 ChannelAddr::from_zmq_url("metatls://*:8443").unwrap(),
1031 ChannelAddr::MetaTls(MetaTlsAddr::Host {
1032 hostname: "::".to_string(),
1033 port: 8443
1034 })
1035 );
1036
1037 let tcp_hostname_result = ChannelAddr::from_zmq_url("tcp://localhost:8080");
1041 assert!(tcp_hostname_result.is_ok());
1042
1043 assert_eq!(
1045 ChannelAddr::from_zmq_url("tcp://[::1]:1234").unwrap(),
1046 ChannelAddr::Tcp("[::1]:1234".parse().unwrap())
1047 );
1048
1049 assert!(ChannelAddr::from_zmq_url("invalid://scheme").is_err());
1051 assert!(ChannelAddr::from_zmq_url("tcp://invalid-port").is_err());
1052 assert!(ChannelAddr::from_zmq_url("metatls://no-port").is_err());
1053 assert!(ChannelAddr::from_zmq_url("inproc://not-a-number").is_err());
1054 }
1055
1056 #[tokio::test]
1057 async fn test_multiple_connections() {
1058 for addr in ChannelTransport::all().map(ChannelAddr::any) {
1059 let (listen_addr, mut rx) = crate::channel::serve::<u64>(addr).unwrap();
1060
1061 let mut sends: JoinSet<()> = JoinSet::new();
1062 for message in 0u64..100u64 {
1063 let addr = listen_addr.clone();
1064 sends.spawn(async move {
1065 let tx = dial::<u64>(addr).unwrap();
1066 tx.post(message);
1067 });
1068 }
1069
1070 let mut received: HashSet<u64> = HashSet::new();
1071 while received.len() < 100 {
1072 received.insert(rx.recv().await.unwrap());
1073 }
1074
1075 for message in 0u64..100u64 {
1076 assert!(received.contains(&message));
1077 }
1078
1079 loop {
1080 match sends.join_next().await {
1081 Some(Ok(())) => (),
1082 Some(Err(err)) => panic!("{}", err),
1083 None => break,
1084 }
1085 }
1086 }
1087 }
1088
1089 #[tokio::test]
1090 async fn test_server_close() {
1091 for addr in ChannelTransport::all().map(ChannelAddr::any) {
1092 if net::is_net_addr(&addr) {
1093 continue;
1096 }
1097
1098 let (listen_addr, rx) = crate::channel::serve::<u64>(addr).unwrap();
1099
1100 let tx = dial::<u64>(listen_addr).unwrap();
1101 tx.post(123);
1102 drop(rx);
1103
1104 let start = RealClock.now();
1109
1110 let result = loop {
1111 let (return_tx, return_rx) = oneshot::channel();
1112 tx.try_post(123, return_tx);
1113 let result = return_rx.await;
1114
1115 if result.is_ok() || start.elapsed() > Duration::from_secs(10) {
1116 break result;
1117 }
1118 };
1119 assert_matches!(result, Ok(SendError(ChannelError::Closed, 123)));
1120 }
1121 }
1122
1123 fn addrs() -> Vec<ChannelAddr> {
1124 use rand::Rng;
1125 use rand::distributions::Uniform;
1126
1127 let rng = rand::thread_rng();
1128 vec![
1129 "tcp:[::1]:0".parse().unwrap(),
1130 "local:0".parse().unwrap(),
1131 #[cfg(target_os = "linux")]
1132 "unix:".parse().unwrap(),
1133 #[cfg(target_os = "linux")]
1134 format!(
1135 "unix:@{}",
1136 rng.sample_iter(Uniform::new_inclusive('a', 'z'))
1137 .take(10)
1138 .collect::<String>()
1139 )
1140 .parse()
1141 .unwrap(),
1142 ]
1143 }
1144
1145 #[tokio::test]
1146 #[cfg_attr(not(fbcode_build), ignore)]
1148 async fn test_dial_serve() {
1149 for addr in addrs() {
1150 let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1151 let tx = crate::channel::dial(listen_addr).unwrap();
1152 tx.post(123);
1153 assert_eq!(rx.recv().await.unwrap(), 123);
1154 }
1155 }
1156
1157 #[tokio::test]
1158 #[cfg_attr(not(fbcode_build), ignore)]
1160 async fn test_send() {
1161 let config = crate::config::global::lock();
1162
1163 let _guard1 = config.override_key(
1165 crate::config::MESSAGE_DELIVERY_TIMEOUT,
1166 Duration::from_secs(1),
1167 );
1168 let _guard2 = config.override_key(crate::config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
1169 for addr in addrs() {
1170 let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).unwrap();
1171 let tx = crate::channel::dial(listen_addr).unwrap();
1172 tx.send(123).await.unwrap();
1173 assert_eq!(rx.recv().await.unwrap(), 123);
1174
1175 drop(rx);
1176 assert_matches!(
1177 tx.send(123).await.unwrap_err(),
1178 SendError(ChannelError::Closed, 123)
1179 );
1180 }
1181 }
1182}