1#![allow(dead_code)] use core::net::SocketAddr;
15use std::fmt;
16use std::net::IpAddr;
17#[cfg(target_os = "linux")]
18use std::os::linux::net::SocketAddrExt;
19use std::str::FromStr;
20
21use async_trait::async_trait;
22use lazy_static::lazy_static;
23use local_ip_address::local_ipv6;
24use serde::Deserialize;
25use serde::Serialize;
26use tokio::sync::mpsc;
27use tokio::sync::oneshot;
28use tokio::sync::watch;
29
30use crate as hyperactor;
31use crate::Named;
32use crate::RemoteMessage;
33use crate::channel::sim::SimAddr;
34use crate::simnet::SimNetError;
35
36pub(crate) mod local;
37pub(crate) mod net;
38pub mod sim;
39
40#[derive(thiserror::Error, Debug)]
42pub enum ChannelError {
43 #[error("channel closed")]
45 Closed,
46
47 #[error("send: {0}")]
49 Send(#[source] anyhow::Error),
50
51 #[error(transparent)]
53 Client(#[from] net::ClientError),
54
55 #[error("invalid address {0:?}")]
57 InvalidAddress(String),
58
59 #[error(transparent)]
61 Server(#[from] net::ServerError),
62
63 #[error(transparent)]
65 Bincode(#[from] Box<bincode::ErrorKind>),
66
67 #[error(transparent)]
69 Other(#[from] anyhow::Error),
70
71 #[error("operation timed out after {0:?}")]
73 Timeout(std::time::Duration),
74
75 #[error(transparent)]
77 SimNetError(#[from] SimNetError),
78}
79
80#[derive(thiserror::Error, Debug)]
82#[error("{0}")]
83pub struct SendError<M: RemoteMessage>(#[source] pub ChannelError, pub M);
84
85#[derive(Debug, Copy, Clone, PartialEq)]
87pub enum TxStatus {
88 Active,
90 Closed,
92}
93
94#[async_trait]
96pub trait Tx<M: RemoteMessage>: std::fmt::Debug {
97 #[allow(clippy::result_large_err)] fn try_post(&self, message: M, return_channel: oneshot::Sender<M>) -> Result<(), SendError<M>>;
104
105 fn post(&self, message: M) {
108 let _ignore = self.try_post(message, oneshot::channel().0);
111 }
112
113 async fn send(&self, message: M) -> Result<(), SendError<M>> {
116 let (tx, rx) = oneshot::channel();
117 self.try_post(message, tx)?;
118 match rx.await {
119 Ok(m) => Err(SendError(ChannelError::Closed, m)),
121
122 Err(_) => Ok(()),
125 }
126 }
127
128 fn addr(&self) -> ChannelAddr;
130
131 fn status(&self) -> &watch::Receiver<TxStatus>;
133}
134
135#[async_trait]
137pub trait Rx<M: RemoteMessage>: std::fmt::Debug {
138 async fn recv(&mut self) -> Result<M, ChannelError>;
141
142 fn addr(&self) -> ChannelAddr;
144}
145
146#[derive(Debug)]
147struct MpscTx<M: RemoteMessage> {
148 tx: mpsc::UnboundedSender<M>,
149 addr: ChannelAddr,
150 status: watch::Receiver<TxStatus>,
151}
152
153impl<M: RemoteMessage> MpscTx<M> {
154 pub fn new(tx: mpsc::UnboundedSender<M>, addr: ChannelAddr) -> (Self, watch::Sender<TxStatus>) {
155 let (sender, receiver) = watch::channel(TxStatus::Active);
156 (
157 Self {
158 tx,
159 addr,
160 status: receiver,
161 },
162 sender,
163 )
164 }
165}
166
167#[async_trait]
168impl<M: RemoteMessage> Tx<M> for MpscTx<M> {
169 fn try_post(
170 &self,
171 message: M,
172 _return_channel: oneshot::Sender<M>,
173 ) -> Result<(), SendError<M>> {
174 self.tx
175 .send(message)
176 .map_err(|mpsc::error::SendError(message)| SendError(ChannelError::Closed, message))
177 }
178
179 fn addr(&self) -> ChannelAddr {
180 self.addr.clone()
181 }
182
183 fn status(&self) -> &watch::Receiver<TxStatus> {
184 &self.status
185 }
186}
187
188#[derive(Debug)]
189struct MpscRx<M: RemoteMessage> {
190 rx: mpsc::UnboundedReceiver<M>,
191 addr: ChannelAddr,
192 status_sender: watch::Sender<TxStatus>,
194}
195
196impl<M: RemoteMessage> MpscRx<M> {
197 pub fn new(
198 rx: mpsc::UnboundedReceiver<M>,
199 addr: ChannelAddr,
200 status_sender: watch::Sender<TxStatus>,
201 ) -> Self {
202 Self {
203 rx,
204 addr,
205 status_sender,
206 }
207 }
208}
209
210impl<M: RemoteMessage> Drop for MpscRx<M> {
211 fn drop(&mut self) {
212 let _ = self.status_sender.send(TxStatus::Closed);
213 }
214}
215
216#[async_trait]
217impl<M: RemoteMessage> Rx<M> for MpscRx<M> {
218 async fn recv(&mut self) -> Result<M, ChannelError> {
219 self.rx.recv().await.ok_or(ChannelError::Closed)
220 }
221
222 fn addr(&self) -> ChannelAddr {
223 self.addr.clone()
224 }
225}
226
227#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
229pub enum TlsMode {
230 IpV6,
232 Hostname,
234 }
236
237#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
239pub enum ChannelTransport {
240 Tcp,
242
243 MetaTls(TlsMode),
245
246 Local,
248
249 Sim(Box<ChannelTransport>),
251
252 Unix,
254}
255
256impl fmt::Display for ChannelTransport {
257 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258 match self {
259 Self::Tcp => write!(f, "tcp"),
260 Self::MetaTls(mode) => write!(f, "metatls({:?})", mode),
261 Self::Local => write!(f, "local"),
262 Self::Sim(transport) => write!(f, "sim({})", transport),
263 Self::Unix => write!(f, "unix"),
264 }
265 }
266}
267
268impl ChannelTransport {
269 pub fn all() -> [ChannelTransport; 3] {
271 [
272 ChannelTransport::Tcp,
273 ChannelTransport::Local,
274 ChannelTransport::Unix,
275 ]
279 }
280}
281
282pub type Hostname = String;
284
285pub type Port = u16;
287
288#[derive(
311 Clone,
312 Debug,
313 PartialEq,
314 Eq,
315 Ord,
316 PartialOrd,
317 Serialize,
318 Deserialize,
319 Hash,
320 Named
321)]
322pub enum ChannelAddr {
323 Tcp(SocketAddr),
326
327 MetaTls(Hostname, Port),
330
331 Local(u64),
334
335 Sim(SimAddr),
337
338 Unix(net::unix::SocketAddr),
341}
342
343impl From<SocketAddr> for ChannelAddr {
344 fn from(value: SocketAddr) -> Self {
345 Self::Tcp(value)
346 }
347}
348
349impl From<net::unix::SocketAddr> for ChannelAddr {
350 fn from(value: net::unix::SocketAddr) -> Self {
351 Self::Unix(value)
352 }
353}
354
355impl From<std::os::unix::net::SocketAddr> for ChannelAddr {
356 fn from(value: std::os::unix::net::SocketAddr) -> Self {
357 Self::Unix(net::unix::SocketAddr::new(value))
358 }
359}
360
361impl From<tokio::net::unix::SocketAddr> for ChannelAddr {
362 fn from(value: tokio::net::unix::SocketAddr) -> Self {
363 std::os::unix::net::SocketAddr::from(value).into()
364 }
365}
366
367impl ChannelAddr {
368 pub fn any(transport: ChannelTransport) -> Self {
371 match transport {
372 ChannelTransport::Tcp => {
373 let ip = hostname::get()
374 .ok()
375 .and_then(|hostname| {
376 hostname.to_str().and_then(|hostname_str| {
378 dns_lookup::lookup_host(hostname_str)
379 .ok()
380 .and_then(|addresses| addresses.first().cloned())
381 })
382 })
383 .unwrap_or_else(|| IpAddr::from_str("::1").unwrap());
384 Self::Tcp(SocketAddr::new(ip, 0))
385 }
386 ChannelTransport::MetaTls(mode) => {
387 let host_address = match mode {
388 TlsMode::Hostname => hostname::get()
389 .ok()
390 .and_then(|hostname| hostname.to_str().map(|s| s.to_string()))
391 .unwrap_or("unknown_host".to_string()),
392 TlsMode::IpV6 => local_ipv6()
393 .ok()
394 .and_then(|addr| addr.to_string().parse().ok())
395 .expect("failed to retrieve ipv6 address"),
396 };
397 Self::MetaTls(host_address, 0)
398 }
399 ChannelTransport::Local => Self::Local(0),
400 ChannelTransport::Sim(transport) => sim::any(*transport),
401 ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()),
403 }
404 }
405
406 pub fn transport(&self) -> ChannelTransport {
408 match self {
409 Self::Tcp(_) => ChannelTransport::Tcp,
410 Self::MetaTls(address, _) => match address.parse::<IpAddr>() {
411 Ok(ip) => match ip {
412 IpAddr::V6(_) => ChannelTransport::MetaTls(TlsMode::IpV6),
413 IpAddr::V4(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
414 },
415 Err(_) => ChannelTransport::MetaTls(TlsMode::Hostname),
416 },
417 Self::Local(_) => ChannelTransport::Local,
418 Self::Sim(addr) => ChannelTransport::Sim(Box::new(addr.transport())),
419 Self::Unix(_) => ChannelTransport::Unix,
420 }
421 }
422}
423
424impl fmt::Display for ChannelAddr {
425 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426 match self {
427 Self::Tcp(addr) => write!(f, "tcp:{}", addr),
428 Self::MetaTls(hostname, port) => write!(f, "metatls:{}:{}", hostname, port),
429 Self::Local(index) => write!(f, "local:{}", index),
430 Self::Sim(sim_addr) => write!(f, "sim:{}", sim_addr),
431 Self::Unix(addr) => write!(f, "unix:{}", addr),
432 }
433 }
434}
435
436impl FromStr for ChannelAddr {
437 type Err = anyhow::Error;
438
439 fn from_str(addr: &str) -> Result<Self, Self::Err> {
440 match addr.split_once('!').or_else(|| addr.split_once(':')) {
442 Some(("local", rest)) => rest
443 .parse::<u64>()
444 .map(Self::Local)
445 .map_err(anyhow::Error::from),
446 Some(("tcp", rest)) => rest
447 .parse::<SocketAddr>()
448 .map(Self::Tcp)
449 .map_err(anyhow::Error::from),
450 Some(("metatls", rest)) => net::meta::parse(rest).map_err(|e| e.into()),
451 Some(("sim", rest)) => sim::parse(rest).map_err(|e| e.into()),
452 Some(("unix", rest)) => Ok(Self::Unix(net::unix::SocketAddr::from_str(rest)?)),
453 Some((r#type, _)) => Err(anyhow::anyhow!("no such channel type: {type}")),
454 None => Err(anyhow::anyhow!("no channel type specified")),
455 }
456 }
457}
458
459#[derive(Debug)]
461pub struct ChannelTx<M: RemoteMessage> {
462 inner: ChannelTxKind<M>,
463}
464
465#[derive(Debug)]
467enum ChannelTxKind<M: RemoteMessage> {
468 Local(local::LocalTx<M>),
469 Tcp(net::NetTx<M>),
470 MetaTls(net::NetTx<M>),
471 Unix(net::NetTx<M>),
472 Sim(sim::SimTx<M>),
473}
474
475#[async_trait]
476impl<M: RemoteMessage> Tx<M> for ChannelTx<M> {
477 fn try_post(&self, message: M, return_channel: oneshot::Sender<M>) -> Result<(), SendError<M>> {
478 match &self.inner {
479 ChannelTxKind::Local(tx) => tx.try_post(message, return_channel),
480 ChannelTxKind::Tcp(tx) => tx.try_post(message, return_channel),
481 ChannelTxKind::MetaTls(tx) => tx.try_post(message, return_channel),
482 ChannelTxKind::Sim(tx) => tx.try_post(message, return_channel),
483 ChannelTxKind::Unix(tx) => tx.try_post(message, return_channel),
484 }
485 }
486
487 fn addr(&self) -> ChannelAddr {
488 match &self.inner {
489 ChannelTxKind::Local(tx) => tx.addr(),
490 ChannelTxKind::Tcp(tx) => Tx::<M>::addr(tx),
491 ChannelTxKind::MetaTls(tx) => Tx::<M>::addr(tx),
492 ChannelTxKind::Sim(tx) => tx.addr(),
493 ChannelTxKind::Unix(tx) => Tx::<M>::addr(tx),
494 }
495 }
496
497 fn status(&self) -> &watch::Receiver<TxStatus> {
498 match &self.inner {
499 ChannelTxKind::Local(tx) => tx.status(),
500 ChannelTxKind::Tcp(tx) => tx.status(),
501 ChannelTxKind::MetaTls(tx) => tx.status(),
502 ChannelTxKind::Sim(tx) => tx.status(),
503 ChannelTxKind::Unix(tx) => tx.status(),
504 }
505 }
506}
507
508#[derive(Debug)]
510pub struct ChannelRx<M: RemoteMessage> {
511 inner: ChannelRxKind<M>,
512}
513
514#[derive(Debug)]
516enum ChannelRxKind<M: RemoteMessage> {
517 Local(local::LocalRx<M>),
518 Tcp(net::NetRx<M>),
519 MetaTls(net::NetRx<M>),
520 Unix(net::NetRx<M>),
521 Sim(sim::SimRx<M>),
522}
523
524#[async_trait]
525impl<M: RemoteMessage> Rx<M> for ChannelRx<M> {
526 async fn recv(&mut self) -> Result<M, ChannelError> {
527 match &mut self.inner {
528 ChannelRxKind::Local(rx) => rx.recv().await,
529 ChannelRxKind::Tcp(rx) => rx.recv().await,
530 ChannelRxKind::MetaTls(rx) => rx.recv().await,
531 ChannelRxKind::Sim(rx) => rx.recv().await,
532 ChannelRxKind::Unix(rx) => rx.recv().await,
533 }
534 }
535
536 fn addr(&self) -> ChannelAddr {
537 match &self.inner {
538 ChannelRxKind::Local(rx) => rx.addr(),
539 ChannelRxKind::Tcp(rx) => rx.addr(),
540 ChannelRxKind::MetaTls(rx) => rx.addr(),
541 ChannelRxKind::Sim(rx) => rx.addr(),
542 ChannelRxKind::Unix(rx) => rx.addr(),
543 }
544 }
545}
546
547#[allow(clippy::result_large_err)] pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, ChannelError> {
552 tracing::debug!(name = "dial", "dialing channel {}", addr);
553 let inner = match addr {
554 ChannelAddr::Local(port) => ChannelTxKind::Local(local::dial(port)?),
555 ChannelAddr::Tcp(addr) => ChannelTxKind::Tcp(net::tcp::dial(addr)),
556 ChannelAddr::MetaTls(host, port) => ChannelTxKind::MetaTls(net::meta::dial(host, port)),
557 ChannelAddr::Sim(sim_addr) => ChannelTxKind::Sim(sim::dial::<M>(sim_addr)?),
558 ChannelAddr::Unix(path) => ChannelTxKind::Unix(net::unix::dial(path)),
559 };
560 Ok(ChannelTx { inner })
561}
562
563#[crate::instrument]
566pub async fn serve<M: RemoteMessage>(
567 addr: ChannelAddr,
568) -> Result<(ChannelAddr, ChannelRx<M>), ChannelError> {
569 tracing::debug!(name = "serve", "serving channel address {}", addr);
570 match addr {
571 ChannelAddr::Tcp(addr) => {
572 let (addr, rx) = net::tcp::serve::<M>(addr).await?;
573 Ok((addr, ChannelRxKind::Tcp(rx)))
574 }
575 ChannelAddr::MetaTls(hostname, port) => {
576 let (addr, rx) = net::meta::serve::<M>(hostname, port).await?;
577 Ok((addr, ChannelRxKind::MetaTls(rx)))
578 }
579 ChannelAddr::Unix(path) => {
580 let (addr, rx) = net::unix::serve::<M>(path).await?;
581 Ok((addr, ChannelRxKind::Unix(rx)))
582 }
583 ChannelAddr::Local(0) => {
584 let (port, rx) = local::serve::<M>();
585 Ok((ChannelAddr::Local(port), ChannelRxKind::Local(rx)))
586 }
587 ChannelAddr::Sim(sim_addr) => {
588 let (addr, rx) = sim::serve::<M>(sim_addr)?;
589 Ok((addr, ChannelRxKind::Sim(rx)))
590 }
591 ChannelAddr::Local(a) => Err(ChannelError::InvalidAddress(format!(
592 "invalid local addr: {}",
593 a
594 ))),
595 }
596 .map(|(addr, inner)| (addr, ChannelRx { inner }))
597}
598
599pub fn serve_local<M: RemoteMessage>() -> (ChannelAddr, ChannelRx<M>) {
602 let (port, rx) = local::serve::<M>();
603 (
604 ChannelAddr::Local(port),
605 ChannelRx {
606 inner: ChannelRxKind::Local(rx),
607 },
608 )
609}
610
611#[cfg(test)]
612mod tests {
613 use std::assert_matches::assert_matches;
614 use std::collections::HashSet;
615 use std::net::IpAddr;
616 use std::net::Ipv4Addr;
617 use std::net::Ipv6Addr;
618 use std::time::Duration;
619
620 use tokio::task::JoinSet;
621
622 use super::net::*;
623 use super::*;
624 use crate::clock::Clock;
625 use crate::clock::RealClock;
626
627 #[test]
628 fn test_channel_addr() {
629 let cases_ok = vec![
630 (
631 "tcp<DELIM>[::1]:1234",
632 ChannelAddr::Tcp(SocketAddr::new(
633 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
634 1234,
635 )),
636 ),
637 (
638 "tcp<DELIM>127.0.0.1:8080",
639 ChannelAddr::Tcp(SocketAddr::new(
640 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
641 8080,
642 )),
643 ),
644 #[cfg(target_os = "linux")]
645 ("local<DELIM>123", ChannelAddr::Local(123)),
646 (
647 "unix<DELIM>@yolo",
648 ChannelAddr::Unix(
649 unix::SocketAddr::from_abstract_name("yolo")
650 .expect("can't make socket from abstract name"),
651 ),
652 ),
653 (
654 "unix<DELIM>/cool/socket-path",
655 ChannelAddr::Unix(
656 unix::SocketAddr::from_pathname("/cool/socket-path")
657 .expect("can't make socket from path"),
658 ),
659 ),
660 ];
661
662 for (raw, parsed) in cases_ok.clone() {
663 for delim in ["!", ":"] {
664 let raw = raw.replace("<DELIM>", delim);
665 assert_eq!(raw.parse::<ChannelAddr>().unwrap(), parsed);
666 }
667 }
668
669 for (raw, parsed) in cases_ok {
670 for delim in ["!", ":"] {
671 let raw = format!("sim{}{}", delim, raw.replace("<DELIM>", delim));
673 assert_eq!(
674 raw.parse::<ChannelAddr>().unwrap(),
675 ChannelAddr::Sim(SimAddr::new(parsed.clone()).unwrap())
676 );
677 }
678 }
679
680 let cases_err = vec![
681 ("tcp:abcdef..123124", "invalid socket address syntax"),
682 ("xxx:foo", "no such channel type: xxx"),
683 ("127.0.0.1", "no channel type specified"),
684 ("local:abc", "invalid digit found in string"),
685 ];
686
687 for (raw, error) in cases_err {
688 let Err(err) = raw.parse::<ChannelAddr>() else {
689 panic!("expected error parsing: {}", &raw)
690 };
691 assert_eq!(format!("{}", err), error);
692 }
693 }
694
695 #[tokio::test]
696 async fn test_multiple_connections() {
697 for addr in ChannelTransport::all().map(ChannelAddr::any) {
698 let (listen_addr, mut rx) = crate::channel::serve::<u64>(addr).await.unwrap();
699
700 let mut sends: JoinSet<()> = JoinSet::new();
701 for message in 0u64..100u64 {
702 let addr = listen_addr.clone();
703 sends.spawn(async move {
704 let tx = dial::<u64>(addr).unwrap();
705 tx.try_post(message, oneshot::channel().0).unwrap();
706 });
707 }
708
709 let mut received: HashSet<u64> = HashSet::new();
710 while received.len() < 100 {
711 received.insert(rx.recv().await.unwrap());
712 }
713
714 for message in 0u64..100u64 {
715 assert!(received.contains(&message));
716 }
717
718 loop {
719 match sends.join_next().await {
720 Some(Ok(())) => (),
721 Some(Err(err)) => panic!("{}", err),
722 None => break,
723 }
724 }
725 }
726 }
727
728 #[tokio::test]
729 async fn test_server_close() {
730 for addr in ChannelTransport::all().map(ChannelAddr::any) {
731 if net::is_net_addr(&addr) {
732 continue;
735 }
736
737 let (listen_addr, rx) = crate::channel::serve::<u64>(addr).await.unwrap();
738
739 let tx = dial::<u64>(listen_addr).unwrap();
740 tx.try_post(123, oneshot::channel().0).unwrap();
741 drop(rx);
742
743 let start = RealClock.now();
748
749 let result = loop {
750 let result = tx.try_post(123, oneshot::channel().0);
751 if result.is_err() || start.elapsed() > Duration::from_secs(10) {
752 break result;
753 }
754 };
755 assert_matches!(result, Err(SendError(ChannelError::Closed, 123)));
756 }
757 }
758
759 fn addrs() -> Vec<ChannelAddr> {
760 use rand::Rng;
761 use rand::distributions::Uniform;
762
763 let rng = rand::thread_rng();
764 vec![
765 "tcp:[::1]:0".parse().unwrap(),
766 "local:0".parse().unwrap(),
767 #[cfg(target_os = "linux")]
768 "unix:".parse().unwrap(),
769 #[cfg(target_os = "linux")]
770 format!(
771 "unix:@{}",
772 rng.sample_iter(Uniform::new_inclusive('a', 'z'))
773 .take(10)
774 .collect::<String>()
775 )
776 .parse()
777 .unwrap(),
778 ]
779 }
780
781 #[tokio::test]
782 async fn test_dial_serve() {
783 for addr in addrs() {
784 let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).await.unwrap();
785 let tx = crate::channel::dial(listen_addr).unwrap();
786 tx.try_post(123, oneshot::channel().0).unwrap();
787 assert_eq!(rx.recv().await.unwrap(), 123);
788 }
789 }
790
791 #[tokio::test]
792 async fn test_send() {
793 let config = crate::config::global::lock();
794
795 let _guard1 = config.override_key(
797 crate::config::MESSAGE_DELIVERY_TIMEOUT,
798 Duration::from_secs(1),
799 );
800 let _guard2 = config.override_key(crate::config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
801 for addr in addrs() {
802 let (listen_addr, mut rx) = crate::channel::serve::<i32>(addr).await.unwrap();
803 let tx = crate::channel::dial(listen_addr).unwrap();
804 tx.send(123).await.unwrap();
805 assert_eq!(rx.recv().await.unwrap(), 123);
806
807 drop(rx);
808 assert_matches!(
809 tx.send(123).await.unwrap_err(),
810 SendError(ChannelError::Closed, 123)
811 );
812 }
813 }
814}