1use std::fmt;
51use std::fmt::Debug;
52use std::net::ToSocketAddrs;
53use std::time::Duration;
54
55use backoff::ExponentialBackoffBuilder;
56use backoff::backoff::Backoff;
57use bytes::Bytes;
58use enum_as_inner::EnumAsInner;
59use serde::Deserialize;
60use serde::Serialize;
61use serde::de::Error;
62use tokio::io::AsyncRead;
63use tokio::io::AsyncReadExt;
64use tokio::io::AsyncWrite;
65use tokio::io::AsyncWriteExt;
66use tokio::sync::watch;
67use tokio::time::Instant;
68
69use super::*;
70use crate::RemoteMessage;
71
72pub mod duplex;
73mod framed;
74pub(super) mod server;
75pub(super) mod session;
76pub use server::ServerHandle;
77
78pub(crate) trait Stream:
79 AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug + 'static
80{
81}
82impl<S: AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug + 'static> Stream for S {}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
88pub(crate) struct SessionId(u64);
89
90impl SessionId {
91 pub fn random() -> Self {
93 Self(rand::random())
94 }
95}
96
97impl fmt::Display for SessionId {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 write!(f, "{:016x}", self.0)
100 }
101}
102
103pub(crate) const INITIATOR_TO_ACCEPTOR: u8 = 0;
105
106pub(crate) const ACCEPTOR_TO_INITIATOR: u8 = 1;
108
109const LINK_INIT_MAGIC: [u8; 4] = *b"LNK\0";
118const LINK_INIT_SIZE: usize = 4 + 8;
119
120async fn write_link_init<S: AsyncWrite + Unpin>(
122 stream: &mut S,
123 session_id: SessionId,
124) -> Result<(), std::io::Error> {
125 let mut buf = [0u8; LINK_INIT_SIZE];
126 buf[0..4].copy_from_slice(&LINK_INIT_MAGIC);
127 buf[4..12].copy_from_slice(&session_id.0.to_be_bytes());
128 stream.write_all(&buf).await
129}
130
131async fn read_link_init<S: AsyncRead + Unpin>(stream: &mut S) -> Result<SessionId, std::io::Error> {
133 let mut buf = [0u8; LINK_INIT_SIZE];
134 stream.read_exact(&mut buf).await?;
135 if buf[0..4] != LINK_INIT_MAGIC {
136 return Err(std::io::Error::new(
137 std::io::ErrorKind::InvalidData,
138 format!(
139 "invalid LinkInit magic: expected {:?}, got {:?}",
140 LINK_INIT_MAGIC,
141 &buf[0..4]
142 ),
143 ));
144 }
145 let session_id = SessionId(u64::from_be_bytes(buf[4..12].try_into().unwrap()));
146 Ok(session_id)
147}
148
149#[async_trait]
153pub(crate) trait Link: Send + Sync + Debug + 'static {
154 type Stream: Stream;
156
157 fn dest(&self) -> ChannelAddr;
159
160 fn link_id(&self) -> SessionId;
162
163 async fn next(&self) -> Result<Self::Stream, ClientError>;
166}
167
168use session::Session;
169
170use crate::config;
171use crate::metrics;
172
173pub(crate) enum LinkStatus {
174 NeverConnected,
175 Connected(tokio::time::Instant),
176 Disconnected {
177 last_connected: tokio::time::Instant,
178 since: tokio::time::Instant,
179 },
180}
181
182impl LinkStatus {
183 fn connected(&mut self) {
184 *self = LinkStatus::Connected(tokio::time::Instant::now());
185 }
186
187 fn disconnected(&mut self) {
188 match *self {
189 LinkStatus::Connected(at) => {
190 *self = LinkStatus::Disconnected {
191 last_connected: at,
192 since: tokio::time::Instant::now(),
193 };
194 }
195 _ => {}
197 }
198 }
199}
200
201impl std::fmt::Display for LinkStatus {
202 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203 match self {
204 LinkStatus::NeverConnected => write!(f, "never connected"),
205 LinkStatus::Connected(at) => {
206 write!(f, "connected for {:.1}s", at.elapsed().as_secs_f64())
207 }
208 LinkStatus::Disconnected {
209 last_connected,
210 since,
211 } => {
212 write!(
213 f,
214 "last connected {:.1}s ago, disconnected for {:.1}s",
215 last_connected.elapsed().as_secs_f64(),
216 since.elapsed().as_secs_f64(),
217 )
218 }
219 }
220 }
221}
222
223fn log_send_error(
226 error: &session::SendLoopError,
227 dest: &ChannelAddr,
228 session_id: u64,
229 mode: &str,
230 link_status: &LinkStatus,
231) -> bool {
232 match error {
233 session::SendLoopError::Io(err) => {
234 tracing::info!(dest = %dest, session_id, error = %err, mode, "send error; {link_status}");
235 metrics::CHANNEL_ERRORS.add(
236 1,
237 hyperactor_telemetry::kv_pairs!(
238 "dest" => dest.to_string(),
239 "session_id" => session_id.to_string(),
240 "error_type" => metrics::ChannelErrorType::SendError.as_str(),
241 "mode" => mode.to_string(),
242 ),
243 );
244 false
245 }
246 session::SendLoopError::AppClosed => true,
247 session::SendLoopError::Rejected(reason) => {
248 tracing::error!(dest = %dest, session_id, mode, "server rejected connection: {reason}; {link_status}");
249 true
250 }
251 session::SendLoopError::ServerClosed => {
252 tracing::info!(dest = %dest, session_id, mode, "server closed the channel; {link_status}");
253 true
254 }
255 session::SendLoopError::DeliveryTimeout => {
256 let timeout = hyperactor_config::global::get(config::MESSAGE_DELIVERY_TIMEOUT);
257 tracing::error!(
258 dest = %dest, session_id, mode,
259 "failed to receive ack within timeout {timeout:?}; link is currently connected; {link_status}"
260 );
261 true
262 }
263 session::SendLoopError::OversizedFrame(reason) => {
264 tracing::error!(dest = %dest, session_id, mode, "oversized frame: {reason}; {link_status}");
265 true
266 }
267 }
268}
269
270pub(crate) fn spawn<M: RemoteMessage>(link: impl Link) -> NetTx<M> {
272 let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
273 let dest = link.dest();
274 let session_id = link.link_id();
275 let (notify, status) = watch::channel(TxStatus::Active);
276 let tx = NetTx {
277 sender,
278 dest: dest.clone(),
279 status,
280 };
281 crate::init::get_runtime().spawn(async move {
282 let mut session = Session::new(link);
283 let log_id = format!("session {}.{:016x}", dest, session_id.0);
284 let mut deliveries = session::Deliveries {
285 outbox: session::Outbox::new(log_id.clone(), dest.clone(), session_id.0),
286 unacked: session::Unacked::new(None, log_id.clone()),
287 };
288 let mut receiver = receiver;
289
290 match receiver.recv().await {
292 Some(msg) => {
293 if let Err(err) = deliveries.outbox.push_back(msg) {
294 tracing::error!(
295 dest = %dest,
296 session_id = session_id.0,
297 error = %err,
298 "failed to push message to outbox"
299 );
300 let _ = notify.send(TxStatus::Closed("failed to push to outbox".into()));
301 return;
302 }
303 }
304 None => {
305 let _ = notify.send(TxStatus::Closed("sender dropped".into()));
306 return;
307 }
308 }
309
310 let mut reconnect_backoff = ExponentialBackoffBuilder::new()
311 .with_initial_interval(Duration::from_millis(10))
312 .with_multiplier(2.0)
313 .with_randomization_factor(0.1)
314 .with_max_interval(Duration::from_secs(5))
315 .with_max_elapsed_time(None)
316 .build();
317
318 let mut link_status = LinkStatus::NeverConnected;
319
320 let reason: String = 'outer: loop {
321 let connected = match deliveries.expiry_time() {
322 Some(deadline) => match session.connect_by(deadline).await {
323 Ok(s) => s,
324 Err(_) => {
325 let timeout =
326 hyperactor_config::global::get(config::MESSAGE_DELIVERY_TIMEOUT);
327 let error_msg = if deliveries.outbox.is_expired(timeout) {
328 format!("failed to deliver message within timeout {timeout:?}; {link_status}")
329 } else {
330 format!(
331 "failed to receive ack within timeout {timeout:?}; \
332 link is currently broken; {link_status}",
333 )
334 };
335 tracing::error!(
336 dest = %dest, session_id = session_id.0, "{}", error_msg
337 );
338 break 'outer format!("{log_id}: {error_msg}");
339 }
340 },
341 None => match session.connect().await {
342 Ok(s) => s,
343 Err(_) => break 'outer "session shut down".into(),
344 },
345 };
346
347 metrics::CHANNEL_CONNECTIONS.add(
348 1,
349 hyperactor_telemetry::kv_pairs!(
350 "transport" => dest.transport().to_string(),
351 "mode" => "simplex",
352 "reason" => "link connected",
353 ),
354 );
355
356 if !deliveries.unacked.is_empty() {
357 metrics::CHANNEL_RECONNECTIONS.add(
358 1,
359 hyperactor_telemetry::kv_pairs!(
360 "dest" => dest.to_string(),
361 "transport" => dest.transport().to_string(),
362 "mode" => "simplex",
363 "reason" => "reconnect_with_unacked",
364 ),
365 );
366 }
367 deliveries.requeue_unacked();
368
369 link_status.connected();
370 let connected_at = tokio::time::Instant::now();
371
372 let result = {
373 let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
374 session::send_connected(&stream, &mut deliveries, &mut receiver).await
375 };
376 session = connected.release();
377
378 link_status.disconnected();
379
380 if connected_at.elapsed() > Duration::from_secs(1) {
383 reconnect_backoff.reset();
384 }
385
386 match result {
387 Ok(()) => {
388 if let Some(delay) = reconnect_backoff.next_backoff() {
390 tracing::info!(
391 dest = %dest,
392 session_id = session_id.0,
393 delay_ms = delay.as_millis() as u64,
394 "send_connected returned EOF, reconnecting after backoff; {link_status}"
395 );
396 tokio::time::sleep(delay).await;
397 }
398 }
399 Err(ref e) => {
400 if log_send_error(e, &dest, session_id.0, "simplex", &link_status) {
401 break 'outer format!("{log_id}: {e}");
402 }
403 if let Some(delay) = reconnect_backoff.next_backoff() {
405 tracing::info!(
406 dest = %dest,
407 session_id = session_id.0,
408 delay_ms = delay.as_millis() as u64,
409 error = %e,
410 "send_connected returned recoverable error, reconnecting after backoff; {link_status}"
411 );
412 tokio::time::sleep(delay).await;
413 }
414 }
415 }
416 };
417
418 tracing::info!(
419 dest = %dest, session_id = session_id.0, "NetTx closing: {reason}"
420 );
421
422 receiver.close();
423 deliveries
424 .unacked
425 .deque
426 .drain(..)
427 .chain(deliveries.outbox.deque.drain(..))
428 .for_each(|queued| queued.try_return(Some(reason.clone())));
429 while let Ok((msg, return_channel, _)) = receiver.try_recv() {
430 let _ = return_channel.send(SendError {
431 error: ChannelError::Closed,
432 message: msg,
433 reason: Some(reason.clone()),
434 });
435 }
436
437 let _ = notify.send(TxStatus::Closed(reason.into()));
438 });
439 tx
440}
441
442#[derive(Debug)]
445pub(crate) enum NetLink {
446 Tcp(tcp::TcpLink),
447 Unix(unix::UnixLink),
448 Tls(tls::TlsLink),
449}
450
451pub(crate) fn link(addr: ChannelAddr) -> Result<NetLink, ClientError> {
453 match addr {
454 ChannelAddr::Tcp(socket_addr) => Ok(NetLink::Tcp(tcp::link(socket_addr))),
455 ChannelAddr::Unix(unix_addr) => Ok(NetLink::Unix(unix::link(unix_addr))),
456 ChannelAddr::Tls(tls_addr) => Ok(NetLink::Tls(tls::link(tls_addr)?)),
457 ChannelAddr::MetaTls(meta_addr) => Ok(NetLink::Tls(meta::link(meta_addr)?)),
458 other => Err(ClientError::Connect(
459 other,
460 std::io::Error::other("unsupported transport"),
461 "unsupported transport".into(),
462 )),
463 }
464}
465
466#[async_trait]
467impl Link for NetLink {
468 type Stream = Box<dyn Stream>;
469
470 fn dest(&self) -> ChannelAddr {
471 match self {
472 Self::Tcp(l) => l.dest(),
473 Self::Unix(l) => l.dest(),
474 Self::Tls(l) => l.dest(),
475 }
476 }
477
478 fn link_id(&self) -> SessionId {
479 match self {
480 Self::Tcp(l) => l.link_id(),
481 Self::Unix(l) => l.link_id(),
482 Self::Tls(l) => l.link_id(),
483 }
484 }
485
486 async fn next(&self) -> Result<Box<dyn Stream>, ClientError> {
487 match self {
488 Self::Tcp(l) => Ok(Box::new(l.next().await?)),
489 Self::Unix(l) => Ok(Box::new(l.next().await?)),
490 Self::Tls(l) => Ok(Box::new(l.next().await?)),
491 }
492 }
493}
494
495#[async_trait]
500pub(crate) trait Listener: Send + Unpin + 'static {
501 type Stream: Stream;
503
504 async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError>;
506}
507
508#[derive(Debug)]
513pub(crate) enum NetListener {
514 Tcp(tcp::TcpSocketListener),
515 Unix(unix::UnixSocketListener),
516}
517
518#[async_trait]
519impl Listener for NetListener {
520 type Stream = Box<dyn Stream>;
521
522 async fn accept(&mut self) -> Result<(Box<dyn Stream>, ChannelAddr), ServerError> {
523 match self {
524 Self::Tcp(l) => {
525 let (stream, addr) = l.accept().await?;
526 Ok((Box::new(stream), addr))
527 }
528 Self::Unix(l) => {
529 let (stream, addr) = l.accept().await?;
530 Ok((Box::new(stream), addr))
531 }
532 }
533 }
534}
535
536pub(crate) fn listen(addr: ChannelAddr) -> Result<(NetListener, ChannelAddr), ServerError> {
540 match addr {
541 ChannelAddr::Tcp(socket_addr) => {
542 let std_listener = std::net::TcpListener::bind(socket_addr)
543 .map_err(|err| ServerError::Listen(ChannelAddr::Tcp(socket_addr), err))?;
544 std_listener
545 .set_nonblocking(true)
546 .map_err(|e| ServerError::Listen(ChannelAddr::Tcp(socket_addr), e))?;
547 let tokio_listener = tokio::net::TcpListener::from_std(std_listener)
548 .map_err(|e| ServerError::Listen(ChannelAddr::Tcp(socket_addr), e))?;
549 let local_addr = tokio_listener
550 .local_addr()
551 .map_err(|err| ServerError::Resolve(ChannelAddr::Tcp(socket_addr), err))?;
552 let listener = tcp::TcpSocketListener {
553 inner: tokio_listener,
554 addr: local_addr,
555 };
556 Ok((NetListener::Tcp(listener), ChannelAddr::Tcp(local_addr)))
557 }
558 ChannelAddr::Unix(ref unix_addr) => {
559 use std::os::unix::net::UnixDatagram as StdUnixDatagram;
560 use std::os::unix::net::UnixListener as StdUnixListener;
561
562 let caddr = addr.clone();
563 let maybe_listener = match unix_addr {
564 unix::SocketAddr::Bound(sock_addr) => StdUnixListener::bind_addr(sock_addr),
565 unix::SocketAddr::Unbound => StdUnixDatagram::unbound()
566 .and_then(|u| u.local_addr())
567 .and_then(|uaddr| StdUnixListener::bind_addr(&uaddr)),
568 };
569 let std_listener =
570 maybe_listener.map_err(|err| ServerError::Listen(caddr.clone(), err))?;
571 std_listener
572 .set_nonblocking(true)
573 .map_err(|err| ServerError::Listen(caddr.clone(), err))?;
574 let local_addr = std_listener
575 .local_addr()
576 .map_err(|err| ServerError::Resolve(caddr.clone(), err))?;
577 let tokio_listener = tokio::net::UnixListener::from_std(std_listener)
578 .map_err(|err| ServerError::Io(caddr, err))?;
579 let bound_addr = unix::SocketAddr::new(local_addr);
580 let listener = unix::UnixSocketListener {
581 inner: tokio_listener,
582 addr: bound_addr.clone(),
583 };
584 Ok((NetListener::Unix(listener), ChannelAddr::Unix(bound_addr)))
585 }
586 addr @ (ChannelAddr::Tls(_) | ChannelAddr::MetaTls(_)) => {
587 let is_meta = matches!(addr, ChannelAddr::MetaTls(_));
588 let tls_addr = match addr {
589 ChannelAddr::Tls(a) | ChannelAddr::MetaTls(a) => a,
590 _ => unreachable!(),
591 };
592 let TlsAddr { hostname, port } = tls_addr;
593 let make_channel_addr = |h: &str, p: Port| {
594 if is_meta {
595 ChannelAddr::MetaTls(TlsAddr::new(h, p))
596 } else {
597 ChannelAddr::Tls(TlsAddr::new(h, p))
598 }
599 };
600
601 let addrs: Vec<core::net::SocketAddr> = (hostname.as_ref(), port)
602 .to_socket_addrs()
603 .map_err(|err| ServerError::Resolve(make_channel_addr(&hostname, port), err))?
604 .collect();
605
606 if addrs.is_empty() {
607 return Err(ServerError::Resolve(
608 make_channel_addr(&hostname, port),
609 std::io::Error::other("no available socket addr"),
610 ));
611 }
612
613 let channel_addr = make_channel_addr(&hostname, port);
614 let std_listener = std::net::TcpListener::bind(&addrs[..])
615 .map_err(|err| ServerError::Listen(channel_addr.clone(), err))?;
616 std_listener
617 .set_nonblocking(true)
618 .map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
619 let tokio_listener = tokio::net::TcpListener::from_std(std_listener)
620 .map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
621 let local_addr = tokio_listener
622 .local_addr()
623 .map_err(|err| ServerError::Resolve(channel_addr, err))?;
624 let listener = tcp::TcpSocketListener {
625 inner: tokio_listener,
626 addr: local_addr,
627 };
628 Ok((
629 NetListener::Tcp(listener),
630 make_channel_addr(&hostname, local_addr.port()),
631 ))
632 }
633 other => Err(ServerError::Listen(
634 other.clone(),
635 std::io::Error::other(format!("unsupported transport: {}", other)),
636 )),
637 }
638}
639
640#[derive(Debug, Serialize, Deserialize, EnumAsInner, PartialEq)]
642pub(super) enum Frame<M> {
643 Message(u64, M),
645}
646
647#[derive(Debug, Serialize, Deserialize, EnumAsInner)]
648pub(super) enum NetRxResponse {
649 Ack(u64),
650 Reject(String),
652 Closed,
654}
655
656pub(super) fn serialize_response(response: NetRxResponse) -> Result<Bytes, bincode::Error> {
657 bincode::serialize(&response).map(|bytes| bytes.into())
658}
659
660pub(super) fn deserialize_response(data: Bytes) -> Result<NetRxResponse, bincode::Error> {
661 bincode::deserialize(&data)
662}
663
664pub(crate) struct NetTx<M: RemoteMessage> {
667 sender: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
668 dest: ChannelAddr,
669 status: watch::Receiver<TxStatus>,
670}
671
672#[async_trait]
673impl<M: RemoteMessage> Tx<M> for NetTx<M> {
674 fn addr(&self) -> ChannelAddr {
675 self.dest.clone()
676 }
677
678 fn status(&self) -> &watch::Receiver<TxStatus> {
679 &self.status
680 }
681
682 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
683 tracing::trace!(
684 name = "post",
685 dest = %self.dest,
686 "sending message"
687 );
688
689 let return_channel = return_channel.unwrap_or_else(|| oneshot::channel().0);
690 if let Err(mpsc::error::SendError((message, return_channel, _))) =
691 self.sender
692 .send((message, return_channel, tokio::time::Instant::now()))
693 {
694 let reason = self.status.borrow().as_closed().map(|r| r.to_string());
695 let _ = return_channel.send(SendError {
696 error: ChannelError::Closed,
697 message,
698 reason,
699 });
700 }
701 }
702}
703
704pub struct NetRx<M: RemoteMessage>(mpsc::Receiver<M>, ChannelAddr, ServerHandle);
705
706#[async_trait]
707impl<M: RemoteMessage> Rx<M> for NetRx<M> {
708 async fn recv(&mut self) -> Result<M, ChannelError> {
709 tracing::trace!(
710 name = "recv",
711 dest = %self.1,
712 "receiving message"
713 );
714 self.0.recv().await.ok_or(ChannelError::Closed)
715 }
716
717 fn addr(&self) -> ChannelAddr {
718 self.1.clone()
719 }
720
721 async fn join(mut self) {
724 self.2
725 .stop(&format!("NetRx joined; channel address: {}", self.1));
726 let _ = (&mut self.2).await;
727 }
729}
730
731impl<M: RemoteMessage> Drop for NetRx<M> {
732 fn drop(&mut self) {
733 self.2
734 .stop(&format!("NetRx dropped; channel address: {}", self.1));
735 }
736}
737
738#[derive(Debug, thiserror::Error)]
740pub enum ServerError {
741 #[error("io: {1}")]
742 Io(ChannelAddr, #[source] std::io::Error),
743 #[error("listen: {0} {1}")]
744 Listen(ChannelAddr, #[source] std::io::Error),
745 #[error("resolve: {0} {1}")]
746 Resolve(ChannelAddr, #[source] std::io::Error),
747 #[error("internal: {0} {1}")]
748 Internal(ChannelAddr, #[source] anyhow::Error),
749}
750
751#[derive(thiserror::Error, Debug)]
752pub enum ClientError {
753 #[error("connection to {0} failed: {1}: {2}")]
754 Connect(ChannelAddr, std::io::Error, String),
755 #[error("unable to resolve address: {0}")]
756 Resolve(ChannelAddr),
757 #[error("io: {0} {1}")]
758 Io(ChannelAddr, std::io::Error),
759 #[error("send {0}: serialize: {1}")]
760 Serialize(ChannelAddr, bincode::ErrorKind),
761 #[error("invalid address: {0}")]
762 InvalidAddress(String),
763}
764
765#[cfg(test)]
768pub(super) fn is_net_addr(addr: &ChannelAddr) -> bool {
769 match addr.transport() {
770 ChannelTransport::Tcp(_) => true,
771 ChannelTransport::MetaTls(_) => true,
772 ChannelTransport::Tls => true,
773 ChannelTransport::Unix => true,
774 _ => false,
775 }
776}
777
778pub(crate) mod unix {
779
780 use core::str;
781 use std::os::unix::net::SocketAddr as StdSocketAddr;
782 use std::os::unix::net::UnixStream as StdUnixStream;
783
784 use rand::Rng;
785 use rand::distr::Alphanumeric;
786 use tokio::net::UnixListener;
787 use tokio::net::UnixStream;
788
789 use super::*;
790
791 #[derive(Debug)]
792 pub(crate) struct UnixLink {
793 pub(super) addr: SocketAddr,
794 pub(super) session_id: SessionId,
795 }
796
797 #[async_trait]
798 impl Link for UnixLink {
799 type Stream = UnixStream;
800
801 fn dest(&self) -> ChannelAddr {
802 ChannelAddr::Unix(self.addr.clone())
803 }
804
805 fn link_id(&self) -> SessionId {
806 self.session_id
807 }
808
809 async fn next(&self) -> Result<Self::Stream, ClientError> {
810 let session_id = self.session_id;
811 let sock_addr = match &self.addr {
812 SocketAddr::Bound(a) => a,
813 SocketAddr::Unbound => return Err(ClientError::Resolve(self.dest())),
814 };
815 let mut backoff = ExponentialBackoffBuilder::new()
816 .with_initial_interval(Duration::from_millis(1))
817 .with_multiplier(2.0)
818 .with_randomization_factor(0.1)
819 .with_max_interval(Duration::from_millis(1000))
820 .with_max_elapsed_time(None)
821 .build();
822 loop {
823 match StdUnixStream::connect_addr(sock_addr) {
824 Ok(std_stream) => {
825 std_stream
826 .set_nonblocking(true)
827 .map_err(|err| ClientError::Io(self.dest(), err))?;
828 let mut stream = UnixStream::from_std(std_stream)
829 .map_err(|err| ClientError::Io(self.dest(), err))?;
830 write_link_init(&mut stream, session_id)
831 .await
832 .map_err(|err| ClientError::Io(self.dest(), err))?;
833 return Ok(stream);
834 }
835 Err(err) => {
836 tracing::debug!(error = %err, "unix connect failed, backing off");
837 if let Some(delay) = backoff.next_backoff() {
838 tokio::time::sleep(delay).await;
839 }
840 }
841 }
842 }
843 }
844 }
845
846 #[derive(Debug)]
848 pub(crate) struct UnixSocketListener {
849 pub(super) inner: UnixListener,
850 pub(super) addr: SocketAddr,
851 }
852
853 #[async_trait]
854 impl super::Listener for UnixSocketListener {
855 type Stream = UnixStream;
856
857 async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
858 let (stream, peer_addr) = self
859 .inner
860 .accept()
861 .await
862 .map_err(|err| ServerError::Io(ChannelAddr::Unix(self.addr.clone()), err))?;
863 let std_addr: StdSocketAddr = peer_addr.into();
865 Ok((stream, ChannelAddr::Unix(SocketAddr::new(std_addr))))
866 }
867 }
868
869 pub(crate) fn link(addr: SocketAddr) -> UnixLink {
871 UnixLink {
872 addr,
873 session_id: SessionId::random(),
874 }
875 }
876
877 #[derive(Clone, Debug)]
879 pub enum SocketAddr {
880 Bound(Box<StdSocketAddr>),
881 Unbound,
882 }
883
884 impl PartialOrd for SocketAddr {
885 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
886 Some(self.cmp(other))
887 }
888 }
889
890 impl Ord for SocketAddr {
891 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
892 self.to_string().cmp(&other.to_string())
893 }
894 }
895
896 impl<'de> Deserialize<'de> for SocketAddr {
897 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
898 where
899 D: serde::Deserializer<'de>,
900 {
901 let s = String::deserialize(deserializer)?;
902 Self::from_str(&s).map_err(D::Error::custom)
903 }
904 }
905
906 impl Serialize for SocketAddr {
907 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
908 where
909 S: serde::Serializer,
910 {
911 serializer.serialize_str(String::from(self).as_str())
912 }
913 }
914
915 impl From<&SocketAddr> for String {
916 fn from(value: &SocketAddr) -> Self {
917 match value {
918 SocketAddr::Bound(addr) => match addr.as_pathname() {
919 Some(path) => path
920 .to_str()
921 .expect("unable to get str for path")
922 .to_string(),
923 #[cfg(target_os = "linux")]
924 _ => match addr.as_abstract_name() {
925 Some(name) => format!("@{}", String::from_utf8_lossy(name)),
926 _ => String::from("(unnamed)"),
927 },
928 #[cfg(not(target_os = "linux"))]
929 _ => String::from("(unnamed)"),
930 },
931 SocketAddr::Unbound => String::from("(unbound)"),
932 }
933 }
934 }
935
936 impl FromStr for SocketAddr {
937 type Err = anyhow::Error;
938
939 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
940 match s {
941 "" => {
942 let random_string = rand::rng()
945 .sample_iter(&Alphanumeric)
946 .take(24)
947 .map(char::from)
948 .collect::<String>();
949 SocketAddr::from_abstract_name(&random_string)
950 }
951 name if name.starts_with("@") => {
953 SocketAddr::from_abstract_name(name.strip_prefix("@").unwrap())
954 }
955 path => SocketAddr::from_pathname(path),
956 }
957 }
958 }
959
960 impl Eq for SocketAddr {}
961 impl std::hash::Hash for SocketAddr {
962 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
963 String::from(self).hash(state);
964 }
965 }
966 impl PartialEq for SocketAddr {
967 fn eq(&self, other: &Self) -> bool {
968 match (self, other) {
969 (Self::Bound(saddr), Self::Bound(oaddr)) => {
970 if saddr.is_unnamed() || oaddr.is_unnamed() {
971 return false;
972 }
973
974 #[cfg(target_os = "linux")]
975 {
976 saddr.as_pathname() == oaddr.as_pathname()
977 && saddr.as_abstract_name() == oaddr.as_abstract_name()
978 }
979 #[cfg(not(target_os = "linux"))]
980 {
981 saddr.as_pathname() == oaddr.as_pathname()
983 }
984 }
985 (Self::Unbound, _) | (_, Self::Unbound) => false,
986 }
987 }
988 }
989
990 impl fmt::Display for SocketAddr {
991 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
992 match self {
993 Self::Bound(addr) => match addr.as_pathname() {
994 Some(path) => {
995 write!(f, "{}", path.to_string_lossy())
996 }
997 #[cfg(target_os = "linux")]
998 _ => match addr.as_abstract_name() {
999 Some(name) => {
1000 if name.starts_with(b"@") {
1001 return write!(f, "{}", String::from_utf8_lossy(name));
1002 }
1003 write!(f, "@{}", String::from_utf8_lossy(name))
1004 }
1005 _ => write!(f, "(unnamed)"),
1006 },
1007 #[cfg(not(target_os = "linux"))]
1008 _ => write!(f, "(unnamed)"),
1009 },
1010 Self::Unbound => write!(f, "(unbound)"),
1011 }
1012 }
1013 }
1014
1015 impl SocketAddr {
1016 pub fn new(addr: StdSocketAddr) -> Self {
1018 Self::Bound(Box::new(addr))
1019 }
1020
1021 #[cfg(target_os = "linux")]
1024 pub fn from_abstract_name(name: &str) -> anyhow::Result<Self> {
1025 Ok(Self::new(StdSocketAddr::from_abstract_name(
1026 name.strip_prefix("@").unwrap_or(name),
1027 )?))
1028 }
1029
1030 #[cfg(not(target_os = "linux"))]
1031 pub fn from_abstract_name(name: &str) -> anyhow::Result<Self> {
1032 let name = name.strip_prefix("@").unwrap_or(name);
1034 let path = Self::abstract_to_filesystem_path(name);
1035 Self::from_pathname(&path.to_string_lossy())
1036 }
1037
1038 #[cfg(not(target_os = "linux"))]
1039 fn abstract_to_filesystem_path(abstract_name: &str) -> std::path::PathBuf {
1040 use std::collections::hash_map::DefaultHasher;
1041 use std::hash::Hash;
1042 use std::hash::Hasher;
1043
1044 let mut hasher = DefaultHasher::new();
1046 abstract_name.hash(&mut hasher);
1047 let hash = hasher.finish();
1048
1049 let process_id = std::process::id();
1051
1052 std::path::PathBuf::from(format!("/tmp/hyperactor_{}_{:x}", process_id, hash))
1054 }
1055
1056 pub fn from_pathname(name: &str) -> anyhow::Result<Self> {
1058 Ok(Self::new(StdSocketAddr::from_pathname(name)?))
1059 }
1060 }
1061
1062 impl TryFrom<SocketAddr> for StdSocketAddr {
1063 type Error = anyhow::Error;
1064
1065 fn try_from(value: SocketAddr) -> Result<Self, Self::Error> {
1066 match value {
1067 SocketAddr::Bound(addr) => Ok(*addr),
1068 SocketAddr::Unbound => Err(anyhow::anyhow!(
1069 "std::os::unix::SocketAddr must be a bound address"
1070 )),
1071 }
1072 }
1073 }
1074}
1075
1076pub(crate) mod tcp {
1077 use tokio::net::TcpListener;
1078 use tokio::net::TcpStream;
1079
1080 use super::*;
1081
1082 #[derive(Debug)]
1083 pub(crate) struct TcpLink {
1084 pub(super) addr: SocketAddr,
1085 pub(super) session_id: SessionId,
1086 }
1087
1088 #[async_trait]
1089 impl Link for TcpLink {
1090 type Stream = TcpStream;
1091
1092 fn dest(&self) -> ChannelAddr {
1093 ChannelAddr::Tcp(self.addr)
1094 }
1095
1096 fn link_id(&self) -> SessionId {
1097 self.session_id
1098 }
1099
1100 async fn next(&self) -> Result<Self::Stream, ClientError> {
1101 let session_id = self.session_id;
1102 let mut backoff = ExponentialBackoffBuilder::new()
1103 .with_initial_interval(Duration::from_millis(1))
1104 .with_multiplier(2.0)
1105 .with_randomization_factor(0.1)
1106 .with_max_interval(Duration::from_millis(1000))
1107 .with_max_elapsed_time(None)
1108 .build();
1109 loop {
1110 match TcpStream::connect(&self.addr).await {
1111 Ok(mut stream) => {
1112 stream.set_nodelay(true).map_err(|err| {
1113 ClientError::Connect(
1114 self.dest(),
1115 err,
1116 "cannot disable Nagle algorithm".to_string(),
1117 )
1118 })?;
1119 write_link_init(&mut stream, session_id)
1120 .await
1121 .map_err(|err| ClientError::Io(self.dest(), err))?;
1122 return Ok(stream);
1123 }
1124 Err(err) => {
1125 tracing::debug!(error = %err, "tcp connect failed, backing off");
1126 if let Some(delay) = backoff.next_backoff() {
1127 tokio::time::sleep(delay).await;
1128 }
1129 }
1130 }
1131 }
1132 }
1133 }
1134
1135 #[derive(Debug)]
1137 pub(crate) struct TcpSocketListener {
1138 pub(super) inner: TcpListener,
1139 pub(super) addr: SocketAddr,
1140 }
1141
1142 #[async_trait]
1143 impl super::Listener for TcpSocketListener {
1144 type Stream = TcpStream;
1145
1146 async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
1147 let (stream, peer_addr) = self
1148 .inner
1149 .accept()
1150 .await
1151 .map_err(|err| ServerError::Io(ChannelAddr::Tcp(self.addr), err))?;
1152 stream
1153 .set_nodelay(true)
1154 .map_err(|err| ServerError::Io(ChannelAddr::Tcp(self.addr), err))?;
1155 Ok((stream, ChannelAddr::Tcp(peer_addr)))
1156 }
1157 }
1158
1159 pub(crate) fn link(addr: SocketAddr) -> TcpLink {
1161 TcpLink {
1162 addr,
1163 session_id: SessionId::random(),
1164 }
1165 }
1166}
1167
1168pub(crate) mod meta {
1170 use std::io;
1171 use std::path::PathBuf;
1172 use std::sync::Arc;
1173
1174 use anyhow::Result;
1175 use tokio_rustls::TlsAcceptor;
1176 use tokio_rustls::TlsConnector;
1177
1178 use super::*;
1179 use crate::config::Pem;
1180 use crate::config::PemBundle;
1181
1182 const THRIFT_TLS_SRV_CA_PATH_ENV: &str = "THRIFT_TLS_SRV_CA_PATH";
1183 const DEFAULT_SRV_CA_PATH: &str = "/var/facebook/rootcanal/ca.pem";
1184 const THRIFT_TLS_CL_CERT_PATH_ENV: &str = "THRIFT_TLS_CL_CERT_PATH";
1185 const THRIFT_TLS_CL_KEY_PATH_ENV: &str = "THRIFT_TLS_CL_KEY_PATH";
1186 const DEFAULT_SERVER_PEM_PATH: &str = "/var/facebook/x509_identities/server.pem";
1187
1188 #[allow(clippy::result_large_err)] pub(crate) fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
1190 let parts = addr_string.rsplit_once(":");
1192 match parts {
1193 Some((hostname, port_str)) => {
1194 let Ok(port) = port_str.parse() else {
1195 return Err(ChannelError::InvalidAddress(addr_string.to_string()));
1196 };
1197 Ok(ChannelAddr::MetaTls(TlsAddr::new(hostname, port)))
1198 }
1199 _ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
1200 }
1201 }
1202
1203 pub(super) fn get_server_pem_bundle() -> PemBundle {
1206 let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1207 .map(PathBuf::from)
1208 .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1209 let server_pem_path = PathBuf::from(DEFAULT_SERVER_PEM_PATH);
1210 PemBundle {
1211 ca: Pem::File(ca_path),
1212 cert: Pem::File(server_pem_path.clone()),
1213 key: Pem::File(server_pem_path),
1214 }
1215 }
1216
1217 fn get_client_pem_bundle() -> Option<PemBundle> {
1220 let cert_path = std::env::var_os(THRIFT_TLS_CL_CERT_PATH_ENV)?;
1221 let key_path = std::env::var_os(THRIFT_TLS_CL_KEY_PATH_ENV)?;
1222 let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1223 .map(PathBuf::from)
1224 .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1225 Some(PemBundle {
1226 ca: Pem::File(ca_path),
1227 cert: Pem::File(PathBuf::from(cert_path)),
1228 key: Pem::File(PathBuf::from(key_path)),
1229 })
1230 }
1231
1232 pub(crate) fn tls_acceptor(enforce_client_tls: bool) -> Result<TlsAcceptor> {
1234 let bundle = get_server_pem_bundle();
1235 tls::tls_acceptor_from_bundle(&bundle, enforce_client_tls)
1236 }
1237
1238 pub(super) fn try_tls_connector() -> Result<TlsConnector> {
1244 tls_connector()
1245 }
1246
1247 fn tls_connector() -> Result<TlsConnector> {
1250 let _ = rustls::crypto::ring::default_provider().install_default();
1253
1254 let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1255 .map(PathBuf::from)
1256 .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1257 let ca_pem = Pem::File(ca_path);
1258 let root_store = tls::build_root_store(&ca_pem)?;
1259
1260 let config = rustls::ClientConfig::builder().with_root_certificates(Arc::new(root_store));
1262
1263 let config = if let Some(bundle) = get_client_pem_bundle() {
1264 let certs = tls::load_certs(&bundle.cert)?;
1265 let key = tls::load_key(&bundle.key)?;
1266 config
1267 .with_client_auth_cert(certs, key)
1268 .map_err(|e| anyhow::anyhow!("load client certs: {}", e))?
1269 } else {
1270 config.with_no_client_auth()
1271 };
1272
1273 Ok(TlsConnector::from(Arc::new(config)))
1274 }
1275
1276 pub fn link(addr: TlsAddr) -> Result<tls::TlsLink, ClientError> {
1278 let connector = tls_connector().map_err(|e| {
1279 ClientError::Connect(
1280 ChannelAddr::MetaTls(addr.clone()),
1281 io::Error::other(e.to_string()),
1282 "failed to create TLS connector".to_string(),
1283 )
1284 })?;
1285 let TlsAddr { hostname, port } = addr;
1286 Ok(tls::TlsLink {
1287 hostname,
1288 port,
1289 connector,
1290 addr_type: tls::TlsAddrType::MetaTls,
1291 session_id: SessionId::random(),
1292 })
1293 }
1294}
1295
1296pub(crate) mod tls {
1298 use std::io;
1299 use std::io::BufReader;
1300 use std::sync::Arc;
1301
1302 use anyhow::Context;
1303 use anyhow::Result;
1304 use rustls::RootCertStore;
1305 use rustls::pki_types::CertificateDer;
1306 use rustls::pki_types::PrivateKeyDer;
1307 use rustls::pki_types::ServerName;
1308 use tokio::net::TcpStream;
1309 use tokio_rustls::TlsAcceptor;
1310 use tokio_rustls::TlsConnector;
1311 use tokio_rustls::client::TlsStream;
1312
1313 use super::*;
1314 use crate::channel::TlsAddr;
1315 use crate::config::Pem;
1316 use crate::config::PemBundle;
1317 use crate::config::TLS_CA;
1318 use crate::config::TLS_CERT;
1319 use crate::config::TLS_KEY;
1320
1321 #[derive(Debug, Clone, Copy)]
1323 pub(crate) enum TlsAddrType {
1324 Tls,
1325 MetaTls,
1326 }
1327
1328 #[allow(clippy::result_large_err)]
1330 pub(crate) fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
1331 let parts = addr_string.rsplit_once(":");
1333 match parts {
1334 Some((hostname, port_str)) => {
1335 let Ok(port) = port_str.parse() else {
1336 return Err(ChannelError::InvalidAddress(addr_string.to_string()));
1337 };
1338 Ok(ChannelAddr::Tls(TlsAddr::new(hostname, port)))
1339 }
1340 _ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
1341 }
1342 }
1343
1344 pub(super) fn load_certs(pem: &Pem) -> Result<Vec<CertificateDer<'static>>> {
1346 let mut reader = BufReader::new(pem.reader()?);
1347 let certs = rustls_pemfile::certs(&mut reader)
1348 .filter_map(Result::ok)
1349 .collect();
1350 Ok(certs)
1351 }
1352
1353 pub(super) fn load_key(pem: &Pem) -> Result<PrivateKeyDer<'static>> {
1355 let mut reader = BufReader::new(pem.reader()?);
1356 loop {
1357 break match rustls_pemfile::read_one(&mut reader)? {
1358 Some(rustls_pemfile::Item::Pkcs1Key(key)) => Ok(PrivateKeyDer::Pkcs1(key)),
1359 Some(rustls_pemfile::Item::Pkcs8Key(key)) => Ok(PrivateKeyDer::Pkcs8(key)),
1360 Some(rustls_pemfile::Item::Sec1Key(key)) => Ok(PrivateKeyDer::Sec1(key)),
1361 Some(_) => continue,
1362 None => anyhow::bail!("no private key found in TLS key file"),
1363 };
1364 }
1365 }
1366
1367 pub(super) fn build_root_store(ca_pem: &Pem) -> Result<RootCertStore> {
1369 let mut root_store = RootCertStore::empty();
1370 let certs = load_certs(ca_pem)?;
1371 root_store.add_parsable_certificates(certs);
1372 Ok(root_store)
1373 }
1374
1375 fn get_pem_bundle() -> PemBundle {
1377 PemBundle {
1378 ca: hyperactor_config::global::get_cloned(TLS_CA),
1379 cert: hyperactor_config::global::get_cloned(TLS_CERT),
1380 key: hyperactor_config::global::get_cloned(TLS_KEY),
1381 }
1382 }
1383
1384 pub(super) fn tls_acceptor_from_bundle(
1387 bundle: &PemBundle,
1388 enforce_client_tls: bool,
1389 ) -> Result<TlsAcceptor> {
1390 let _ = rustls::crypto::ring::default_provider().install_default();
1393
1394 let certs = load_certs(&bundle.cert).context("load TLS certificate")?;
1395 let key = load_key(&bundle.key).context("load TLS key")?;
1396 let root_store = build_root_store(&bundle.ca).context("build root cert store")?;
1397
1398 let config = rustls::ServerConfig::builder();
1399 let config = if enforce_client_tls {
1400 let client_verifier =
1402 rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
1403 .build()
1404 .map_err(|e| anyhow::anyhow!("failed to build client verifier: {}", e))?;
1405 config.with_client_cert_verifier(client_verifier)
1406 } else {
1407 config.with_no_client_auth()
1408 }
1409 .with_single_cert(certs, key)?;
1410
1411 Ok(TlsAcceptor::from(Arc::new(config)))
1412 }
1413
1414 pub(crate) fn tls_acceptor() -> Result<TlsAcceptor> {
1416 tls_acceptor_from_bundle(&get_pem_bundle(), true)
1417 }
1418
1419 pub(super) fn tls_connector_from_bundle(bundle: &PemBundle) -> Result<TlsConnector> {
1421 let _ = rustls::crypto::ring::default_provider().install_default();
1424
1425 let certs = load_certs(&bundle.cert).context("load TLS certificate")?;
1426 let key = load_key(&bundle.key).context("load TLS key")?;
1427 let root_store = build_root_store(&bundle.ca).context("build root cert store")?;
1428
1429 let config = rustls::ClientConfig::builder()
1430 .with_root_certificates(Arc::new(root_store))
1431 .with_client_auth_cert(certs, key)
1432 .context("configure client auth")?;
1433
1434 Ok(TlsConnector::from(Arc::new(config)))
1435 }
1436
1437 fn tls_connector() -> Result<TlsConnector> {
1439 tls_connector_from_bundle(&get_pem_bundle())
1440 }
1441
1442 pub(crate) struct TlsLink {
1444 pub(crate) hostname: Hostname,
1445 pub(crate) port: Port,
1446 pub(crate) connector: TlsConnector,
1447 pub(crate) addr_type: TlsAddrType,
1448 pub(crate) session_id: SessionId,
1449 }
1450
1451 impl std::fmt::Debug for TlsLink {
1452 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1453 f.debug_struct("TlsLink")
1454 .field("hostname", &self.hostname)
1455 .field("port", &self.port)
1456 .field("addr_type", &self.addr_type)
1457 .finish()
1458 }
1459 }
1460
1461 #[async_trait]
1462 impl Link for TlsLink {
1463 type Stream = TlsStream<TcpStream>;
1464
1465 fn dest(&self) -> ChannelAddr {
1466 let addr = TlsAddr::new(self.hostname.clone(), self.port);
1467 match self.addr_type {
1468 TlsAddrType::Tls => ChannelAddr::Tls(addr),
1469 TlsAddrType::MetaTls => ChannelAddr::MetaTls(addr),
1470 }
1471 }
1472
1473 fn link_id(&self) -> SessionId {
1474 self.session_id
1475 }
1476
1477 async fn next(&self) -> Result<Self::Stream, ClientError> {
1478 let session_id = self.session_id;
1479 let server_name = ServerName::try_from(self.hostname.clone()).map_err(|e| {
1480 ClientError::Connect(
1481 self.dest(),
1482 io::Error::other(e.to_string()),
1483 "invalid server name".to_string(),
1484 )
1485 })?;
1486 let mut backoff = ExponentialBackoffBuilder::new()
1487 .with_initial_interval(Duration::from_millis(1))
1488 .with_multiplier(2.0)
1489 .with_randomization_factor(0.1)
1490 .with_max_interval(Duration::from_millis(1000))
1491 .with_max_elapsed_time(None)
1492 .build();
1493 loop {
1494 let mut addrs = (self.hostname.as_ref(), self.port)
1495 .to_socket_addrs()
1496 .map_err(|_| ClientError::Resolve(self.dest()))?;
1497 let addr = addrs.next().ok_or(ClientError::Resolve(self.dest()))?;
1498 match TcpStream::connect(&addr).await {
1499 Ok(stream) => {
1500 stream.set_nodelay(true).map_err(|err| {
1501 ClientError::Connect(
1502 self.dest(),
1503 err,
1504 "cannot disable Nagle algorithm".to_string(),
1505 )
1506 })?;
1507 let mut tls_stream = self
1508 .connector
1509 .connect(server_name.clone(), stream)
1510 .await
1511 .map_err(|err| {
1512 tracing::info!(
1513 dest = %self.dest(),
1514 error = %err,
1515 "TLS handshake failed"
1516 );
1517 ClientError::Connect(
1518 self.dest(),
1519 err,
1520 format!("cannot establish TLS connection to {:?}", server_name),
1521 )
1522 })?;
1523 write_link_init(&mut tls_stream, session_id)
1524 .await
1525 .map_err(|err| ClientError::Io(self.dest(), err))?;
1526 return Ok(tls_stream);
1527 }
1528 Err(err) => {
1529 tracing::debug!(error = %err, "tls connect failed, backing off");
1530 if let Some(delay) = backoff.next_backoff() {
1531 tokio::time::sleep(delay).await;
1532 }
1533 }
1534 }
1535 }
1536 }
1537 }
1538
1539 pub fn link(addr: TlsAddr) -> Result<TlsLink, ClientError> {
1541 let connector = tls_connector().map_err(|e| {
1542 ClientError::Connect(
1543 ChannelAddr::Tls(addr.clone()),
1544 io::Error::other(e.to_string()),
1545 "failed to create TLS connector".to_string(),
1546 )
1547 })?;
1548 let TlsAddr { hostname, port } = addr;
1549 Ok(TlsLink {
1550 hostname,
1551 port,
1552 connector,
1553 addr_type: TlsAddrType::Tls,
1554 session_id: SessionId::random(),
1555 })
1556 }
1557
1558 #[cfg(test)]
1559 mod tests {
1560 use timed_test::async_timed_test;
1561
1562 use super::*;
1563 use crate::channel::Rx;
1564 use crate::channel::net::server;
1565 use crate::config::Pem;
1566 use crate::config::TLS_CA;
1567 use crate::config::TLS_CERT;
1568 use crate::config::TLS_KEY;
1569
1570 const TEST_CA_CERT: &str = r#"-----BEGIN CERTIFICATE-----
1574MIIDBTCCAe2gAwIBAgIUaGNmboiIosG+8Up0vgDr/+cg+2IwDQYJKoZIhvcNAQEL
1575BQAwEjEQMA4GA1UEAwwHVGVzdCBDQTAeFw0yNjAxMjgxNzA4MzlaFw0yNzAxMjgx
1576NzA4MzlaMBIxEDAOBgNVBAMMB1Rlc3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IB
1577DwAwggEKAoIBAQC9RBoMYXCajklswt8Vi1JI1lEYzic0WNOmz45vG/7H6jTWkgL3
1578K5Ri+Seg3MobDNc48YHWXYm4hP9wCzkx8ih3ntT5XiY1My/G3jLUuoIEE9pF/BoJ
1579YQwZVoPNFhA9WhXNRsINf1cXFf8NzRfXpxBfKWtQJxYXU4JiDBQ6rLnQQABo8JmQ
1580vYFhJbBaYip5jTSiVNn7mB1zNr5jsVxuoSF53Pb7xQ76bwBdOq4zd6PSxL5/lr4G
1581cHSoxwZQdZMG7PL6hbxDQ2S2YI2lYVET1zwc2WPKCfjbEXBC/jzx828CInQtuksk
158218gJt6xHkTFEA8CSA29GM3lejnwYWf51xyyBAgMBAAGjUzBRMB0GA1UdDgQWBBRX
1583cbxSZ70NsUkAS3Hhy6irugywJDAfBgNVHSMEGDAWgBRXcbxSZ70NsUkAS3Hhy6ir
1584ugywJDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQA7aAFfyW67
1585Z+uGSVYhpsT/uH/3Z3nr7X1smTz5CGEfq2czEcTC7gbYI2l8GZ47GPfnAvHTBZVm
1586V/XncBCsj7/thOh2jYEHFyCbPckoaSCRyCOnK7LPUlr4HN5uP9EFe45qBLCJDEoY
1587GTTw7MtzwdovfjchNfKQCTtkBJCXQ95WLCf6UOh02Sn28UTlgfXzF0X0FrcWqWa3
1588uJZd4XOo4O6hKKlHaBaQPiEr++1xc3SWPV7jZHbckI/vKBnDdEZ9JQX5fFZuypUI
1589sgomYHxvxrU2hWx+7k53CRdjfaIvT9Ie44z9sSdsU/+blw2S8f/ZTmuECoIAAXYO
15900qpzlxZMdr7T
1591-----END CERTIFICATE-----"#;
1592
1593 const TEST_SERVER_CERT: &str = r#"-----BEGIN CERTIFICATE-----
1595MIIDJDCCAgygAwIBAgIUaz66DsWaH5ZXM4hCFnbVbMsyN1cwDQYJKoZIhvcNAQEL
1596BQAwEjEQMA4GA1UEAwwHVGVzdCBDQTAeFw0yNjAxMjgxNzA4MzlaFw0yNzAxMjgx
1597NzA4MzlaMBQxEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQAD
1598ggEPADCCAQoCggEBAKCbp++qNyTn5LOsV0h9gLKJALBcjg2A14I3804N9UyDhPW2
1599QKQ2W424u2P1MfKrw/2C+CErGlrADlnco2RQVDAarAIuGdFvBOt5UezqOS7Mk4OS
16009MlS7NZnMbc37KuM9UIG5ScJjXR/Z5z9dxeR0I9y3n0Ix6khbV7tOSHobiweI0FI
16018LftBS+CQnXr6vbWPcHcW6Z0FHUv7IWhqMWmv9PlZRGe9Y6VzXrRp0PBnZMOnAYf
1602aMQUwYRswWdm9j9Z1sMdTJ14G+KVmO3Vj6XI6Sm9uIcYhlwG/kORwogJFWlVuP9o
1603rloFRCjyHJ1d7GZqqnRyHHDDCBms8ed+3YfEYQECAwEAAaNwMG4wLAYDVR0RBCUw
1604I4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMB0GA1UdDgQWBBQl
1605J4vxUoCzqqeTwQAiLqE8wYezKzAfBgNVHSMEGDAWgBRXcbxSZ70NsUkAS3Hhy6ir
1606ugywJDANBgkqhkiG9w0BAQsFAAOCAQEAnXHIBDQ4AHAMV71piTOuI41ShASQed6L
1607bi7XUMZgZDslLkfU1vnP3BlwpliraBsAytSYQC6kbytOuz1uQ4K7yLb2tAAmUgEO
1608EdIVt9SXr5tCcIPeLmInF0pysPqjZO8n7vtJyd9gryKqdhm1uzA7WQWq/Az8a9Sk
1609uW2J6Oc5p6P7Mf3/ixqXzvGRo8rzu0CUJOJ67UTE/HhbJuplQ5dep5CEEOAIsAtH
1610zn9O4rW92ueBkoBJM++YILS1vQ7jKc2N3RNrnHm7FeootBrtR9mBi0TH97K73ZPZ
16112Cdhnym0CsCJggrllFGH32cYo7+K2PO7/4oj5XbBCSWcssicvd8ovg==
1612-----END CERTIFICATE-----"#;
1613
1614 const TEST_SERVER_KEY: &str = r#"-----BEGIN PRIVATE KEY-----
1618MIIEugIBADANBgkqhkiG9w0BAQEFAASCBKQwggSgAgEAAoIBAQCgm6fvqjck5+Sz
1619rFdIfYCyiQCwXI4NgNeCN/NODfVMg4T1tkCkNluNuLtj9THyq8P9gvghKxpawA5Z
16203KNkUFQwGqwCLhnRbwTreVHs6jkuzJODkvTJUuzWZzG3N+yrjPVCBuUnCY10f2ec
1621/XcXkdCPct59CMepIW1e7Tkh6G4sHiNBSPC37QUvgkJ16+r21j3B3FumdBR1L+yF
1622oajFpr/T5WURnvWOlc160adDwZ2TDpwGH2jEFMGEbMFnZvY/WdbDHUydeBvilZjt
16231Y+lyOkpvbiHGIZcBv5DkcKICRVpVbj/aK5aBUQo8hydXexmaqp0chxwwwgZrPHn
1624ft2HxGEBAgMBAAECgf8G5qlQov+7ljs9fSpC8yGUik59RXzVF7Qq5DyQHglsQDp2
1625VF5yr+M/M7DZmq+KvdauDfKbej6np5j2Q4TByrHTX1IExfZWCW8srwnWJDpQyHmO
1626LcJW5DlI/SYluUFyHZxsOd+ezcpGNzM8i6eSW7GaeFUXCkmJ+uW4LnlF+7bALnnd
1627D6sak/58EsII+IJyd4lFn+voszlPn3CZGR0jkp21rvpaKgrMIsKVWWQO/sLDU5pr
1628VbpBThcLU5gRcnQouQX12e2VTCIlFu75WTsJ8V/KnEaOZUVlU/B/Bs+WQF3U+/Jo
1629eX4N+D6OsEcNQjERAFyWujxsl1WpD4uSsbFMN0ECgYEA2b7AdL+oKPQHku2KcBhr
1630Zw8K4tMDlr2VPPNwZcBTLo+O71vv/xXjMcXrXmowzkgEQckUmt1VB46riyydhwdP
1631/n9ciWcz0Va/nwHR6Y9F9unBiyUBP7PRhRyjQyRZZRGDSJvP+Xmc5UJFpRr07VLU
1632nfgMXDj37vXzKDpfhdEB2nkCgYEAvNMfA8P8w3+6246x5YHflvTkPdw+2oyge+LD
1633mphB/w7SF8mlyNGloj3+KBZmd9SkvT57wCvO96Y9/n+mBAVisRggc0hK4ymOVYhb
1634+im/JvqGQMbVeg6iCOHnWdaZf9tL8uVsegQy3kVTN7vAa+CMFgX1dt65cGBX6XkB
163544pYmMkCgYALhbiRdQLlB+TOtZs5y1EDpxwgXKI3+9hF3Wv5NnAwapBZwje0++eF
16363r9Rw7TJda4j/QwGFehF+hrBxp6fYpetE/hFnRx0225Qb7w368j8A+ql/lNOl6li
1637rd1F1EqWupKD6RrcTL8sspEU55RGaretlE6zIqCcGI/BdTVQ03qRoQKBgHDC3zWf
1638d7XD9HGjQGdfbIe4jQjIGxzmd/wjik4q+NZ5IkukVwWa9P/zZ3DHF8Ad05dT1hEH
16392FwaAdGWpyyljq9VSiOuG1KXAXHgsZSuE4ISf9P1KYzvaiJFzaPfvOEWs79E9MfU
16409A+6dJzG2X1SpjWMr26iSTlrv3QkmFUqzAfJAoGASBkn4wls+oC5rv/Mch43pBv5
1641UmKru4ltnEHJZdbSi2DJ+AnDLD222JCasb1VT1tm2XgW6DBqrdVRPPP6GOlB0MHU
1642+3ULtZxAczt7I+ST2bo0/DV2Hse89Cm63w4wLOiVZs7+1wrAzJZLokWF7Q5gesra
1643u19txmtkiMEH+aNmekk=
1644-----END PRIVATE KEY-----"#;
1645
1646 #[async_timed_test(timeout_secs = 30)]
1647 async fn test_tls_basic() {
1648 let _ = rustls::crypto::ring::default_provider().install_default();
1651
1652 let config = hyperactor_config::global::lock();
1654 let _guard_cert =
1655 config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1656 let _guard_key =
1657 config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1658 let _guard_ca =
1659 config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1660
1661 let addr = TlsAddr::new("localhost", 0);
1663
1664 let (local_addr, mut rx) =
1665 server::serve::<u64>(ChannelAddr::Tls(addr)).expect("failed to serve");
1666
1667 let tx: super::NetTx<u64> = super::spawn(
1669 link(match &local_addr {
1670 ChannelAddr::Tls(addr) => addr.clone(),
1671 _ => panic!("unexpected address type"),
1672 })
1673 .expect("failed to create link"),
1674 );
1675
1676 tx.post(42u64);
1678
1679 let received = rx.recv().await.expect("failed to receive");
1681 assert_eq!(received, 42u64);
1682 }
1683
1684 #[async_timed_test(timeout_secs = 30)]
1685 async fn test_tls_multiple_messages() {
1686 let _ = rustls::crypto::ring::default_provider().install_default();
1687
1688 let config = hyperactor_config::global::lock();
1690 let _guard_cert =
1691 config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1692 let _guard_key =
1693 config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1694 let _guard_ca =
1695 config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1696
1697 let addr = TlsAddr::new("localhost", 0);
1698
1699 let (local_addr, mut rx) =
1700 server::serve::<String>(ChannelAddr::Tls(addr)).expect("failed to serve");
1701 let tx: super::NetTx<String> = super::spawn(
1702 link(match &local_addr {
1703 ChannelAddr::Tls(addr) => addr.clone(),
1704 _ => panic!("unexpected address type"),
1705 })
1706 .expect("failed to create link"),
1707 );
1708
1709 for i in 0..10 {
1711 tx.post(format!("message {}", i));
1712 }
1713
1714 for i in 0..10 {
1716 let received = rx.recv().await.expect("failed to receive");
1717 assert_eq!(received, format!("message {}", i));
1718 }
1719 }
1720
1721 #[test]
1722 fn test_tls_parse_hostname_port() {
1723 let addr = parse("localhost:8080").expect("failed to parse");
1724 assert!(matches!(
1725 addr,
1726 ChannelAddr::Tls(TlsAddr { hostname, port })
1727 if hostname == "localhost" && port == 8080
1728 ));
1729 }
1730
1731 #[test]
1732 fn test_tls_parse_socket_addr() {
1733 let addr = parse("127.0.0.1:8080").expect("failed to parse");
1734 assert!(matches!(
1735 addr,
1736 ChannelAddr::Tls(TlsAddr { hostname, port })
1737 if hostname == "127.0.0.1" && port == 8080
1738 ));
1739 }
1740
1741 #[test]
1742 fn test_tls_certs_parsing() {
1743 let cert_pem = Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec());
1745 let key_pem = Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec());
1746 let ca_pem = Pem::Value(TEST_CA_CERT.as_bytes().to_vec());
1747
1748 let certs = super::load_certs(&cert_pem).expect("failed to load certs");
1749 assert!(!certs.is_empty(), "expected at least one certificate");
1750
1751 let _key = super::load_key(&key_pem).expect("failed to load key");
1752
1753 let root_store = super::build_root_store(&ca_pem).expect("failed to build root store");
1754 assert!(!root_store.is_empty(), "expected at least one CA cert");
1755 }
1756
1757 #[test]
1758 fn test_tls_acceptor_creation() {
1759 let _ = rustls::crypto::ring::default_provider().install_default();
1762
1763 let config = hyperactor_config::global::lock();
1765 let _guard_cert =
1766 config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1767 let _guard_key =
1768 config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1769 let _guard_ca =
1770 config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1771
1772 let _acceptor = super::tls_acceptor().expect("failed to create TLS acceptor");
1774 }
1775
1776 #[test]
1777 fn test_tls_connector_creation() {
1778 let _ = rustls::crypto::ring::default_provider().install_default();
1781
1782 let config = hyperactor_config::global::lock();
1784 let _guard_cert =
1785 config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1786 let _guard_key =
1787 config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1788 let _guard_ca =
1789 config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1790
1791 let _connector = super::tls_connector().expect("failed to create TLS connector");
1793 }
1794 }
1795}
1796
1797fn oss_pem_bundle() -> crate::config::PemBundle {
1799 crate::config::PemBundle {
1800 ca: hyperactor_config::global::get_cloned(crate::config::TLS_CA),
1801 cert: hyperactor_config::global::get_cloned(crate::config::TLS_CERT),
1802 key: hyperactor_config::global::get_cloned(crate::config::TLS_KEY),
1803 }
1804}
1805
1806pub fn try_tls_pem_bundle() -> Option<crate::config::PemBundle> {
1816 let oss_bundle = oss_pem_bundle();
1817 if oss_bundle.ca.reader().is_ok() {
1818 return Some(oss_bundle);
1819 }
1820 tracing::debug!("OSS TLS bundle: CA not readable, trying Meta paths");
1821
1822 let meta_bundle = meta::get_server_pem_bundle();
1823 if meta_bundle.ca.reader().is_ok() {
1824 return Some(meta_bundle);
1825 }
1826 tracing::debug!("Meta TLS bundle: CA not readable, no TLS available");
1827
1828 None
1829}
1830
1831pub fn try_tls_acceptor(enforce_client_tls: bool) -> Option<tokio_rustls::TlsAcceptor> {
1851 let oss_bundle = oss_pem_bundle();
1852 if let Ok(acceptor) = tls::tls_acceptor_from_bundle(&oss_bundle, enforce_client_tls) {
1853 return Some(acceptor);
1854 }
1855 tracing::debug!("OSS TLS acceptor failed, trying Meta paths");
1856
1857 let meta_bundle = meta::get_server_pem_bundle();
1858 if let Ok(acceptor) = tls::tls_acceptor_from_bundle(&meta_bundle, enforce_client_tls) {
1859 return Some(acceptor);
1860 }
1861 tracing::debug!("Meta TLS acceptor failed, no TLS available");
1862
1863 None
1864}
1865
1866pub fn try_tls_connector() -> Option<tokio_rustls::TlsConnector> {
1878 let oss_bundle = oss_pem_bundle();
1879 if let Ok(connector) = tls::tls_connector_from_bundle(&oss_bundle) {
1880 return Some(connector);
1881 }
1882 tracing::debug!("OSS TLS connector failed, trying Meta paths");
1883
1884 if let Ok(connector) = meta::try_tls_connector() {
1885 return Some(connector);
1886 }
1887 tracing::debug!("Meta TLS connector failed, no TLS available");
1888
1889 None
1890}
1891
1892#[cfg(test)]
1893mod tests {
1894 use std::assert_matches::assert_matches;
1895 use std::collections::VecDeque;
1896 use std::marker::PhantomData;
1897 use std::sync::Arc;
1898 use std::sync::RwLock;
1899 use std::sync::atomic::AtomicBool;
1900 use std::sync::atomic::AtomicU64;
1901 use std::sync::atomic::Ordering;
1902 use std::time::Duration;
1903 #[cfg(target_os = "linux")] use std::time::UNIX_EPOCH;
1905
1906 #[cfg(target_os = "linux")] use anyhow::Result;
1908 use bytes::Bytes;
1909 use rand::Rng;
1910 use rand::SeedableRng;
1911 use rand::distr::Alphanumeric;
1912 use timed_test::async_timed_test;
1913 use tokio::io::AsyncWrite;
1914 use tokio::io::DuplexStream;
1915 use tokio::io::ReadHalf;
1916 use tokio::io::WriteHalf;
1917 use tokio::task::JoinHandle;
1918 use tokio_util::sync::CancellationToken;
1919
1920 use super::server;
1921 use super::*;
1922 use crate::channel;
1923 use crate::channel::net::framed::FrameReader;
1924 use crate::channel::net::framed::FrameWrite;
1925 use crate::channel::net::server::AcceptorLink;
1926 use crate::config;
1927 use crate::metrics;
1928 use crate::sync::mvar::MVar;
1929
1930 fn logs_assert_unscoped(f: impl Fn(&[&str]) -> Result<(), String>) {
1934 let buf = tracing_test::internal::global_buf().lock().unwrap();
1935 let logs_str = std::str::from_utf8(&buf).expect("Logs contain invalid UTF8");
1936 let lines: Vec<&str> = logs_str.lines().collect();
1937 match f(&lines) {
1938 Ok(()) => {}
1939 Err(msg) => panic!("{}", msg),
1940 }
1941 }
1942
1943 #[cfg(target_os = "linux")] #[tracing_test::traced_test]
1945 #[tokio::test]
1946 async fn test_unix_basic() -> Result<()> {
1947 let timestamp = std::time::SystemTime::now()
1948 .duration_since(UNIX_EPOCH)
1949 .unwrap()
1950 .as_nanos();
1951 let unique_address = format!("test_unix_basic_{}", timestamp);
1952
1953 let (addr, mut rx) = server::serve::<u64>(ChannelAddr::Unix(
1954 unix::SocketAddr::from_abstract_name(&unique_address)?,
1955 ))
1956 .unwrap();
1957
1958 {
1965 let tx: ChannelTx<u64> = channel::dial::<u64>(addr.clone()).unwrap();
1966 tx.post(123);
1967 assert_eq!(rx.recv().await.unwrap(), 123);
1968 }
1969
1970 {
1971 let tx = channel::dial::<u64>(addr.clone()).unwrap();
1972 tx.post(321);
1973 tx.post(111);
1974 tx.post(444);
1975
1976 assert_eq!(rx.recv().await.unwrap(), 321);
1977 assert_eq!(rx.recv().await.unwrap(), 111);
1978 assert_eq!(rx.recv().await.unwrap(), 444);
1979 }
1980
1981 {
1982 let tx = channel::dial::<u64>(addr).unwrap();
1983 drop(rx);
1984
1985 let (return_tx, return_rx) = oneshot::channel();
1986 tx.try_post(123, return_tx);
1987 assert_matches!(
1988 return_rx.await,
1989 Ok(SendError {
1990 error: ChannelError::Closed,
1991 message: 123,
1992 ..
1993 })
1994 );
1995 }
1996
1997 Ok(())
1998 }
1999
2000 #[cfg(target_os = "linux")] #[tracing_test::traced_test]
2002 #[tokio::test]
2003 async fn test_unix_basic_client_before_server() -> Result<()> {
2004 let timestamp = std::time::SystemTime::now()
2006 .duration_since(UNIX_EPOCH)
2007 .unwrap()
2008 .as_nanos();
2009 let socket_addr =
2010 unix::SocketAddr::from_abstract_name(&format!("test_unix_basic_{}", timestamp))
2011 .unwrap();
2012
2013 let addr = ChannelAddr::Unix(socket_addr.clone());
2015 let tx = crate::channel::dial::<u64>(addr.clone()).unwrap();
2016 tx.post(123);
2017
2018 let (_, mut rx) = server::serve::<u64>(ChannelAddr::Unix(socket_addr)).unwrap();
2019 assert_eq!(rx.recv().await.unwrap(), 123);
2020
2021 tx.post(321);
2022 tx.post(111);
2023 tx.post(444);
2024
2025 assert_eq!(rx.recv().await.unwrap(), 321);
2026 assert_eq!(rx.recv().await.unwrap(), 111);
2027 assert_eq!(rx.recv().await.unwrap(), 444);
2028
2029 Ok(())
2030 }
2031
2032 #[tracing_test::traced_test]
2033 #[async_timed_test(timeout_secs = 60)]
2034 #[cfg_attr(not(fbcode_build), ignore)]
2036 async fn test_tcp_basic() {
2037 let (addr, mut rx) =
2038 server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
2039 {
2040 let tx = channel::dial::<u64>(addr.clone()).unwrap();
2041 tx.post(123);
2042 assert_eq!(rx.recv().await.unwrap(), 123);
2043 }
2044
2045 {
2046 let tx = channel::dial::<u64>(addr.clone()).unwrap();
2047 tx.post(321);
2048 tx.post(111);
2049 tx.post(444);
2050
2051 assert_eq!(rx.recv().await.unwrap(), 321);
2052 assert_eq!(rx.recv().await.unwrap(), 111);
2053 assert_eq!(rx.recv().await.unwrap(), 444);
2054 }
2055
2056 {
2057 let tx = channel::dial::<u64>(addr).unwrap();
2058 drop(rx);
2059
2060 let (return_tx, return_rx) = oneshot::channel();
2061 tx.try_post(123, return_tx);
2062 assert_matches!(
2063 return_rx.await,
2064 Ok(SendError {
2065 error: ChannelError::Closed,
2066 message: 123,
2067 ..
2068 })
2069 );
2070 }
2071 }
2072
2073 #[async_timed_test(timeout_secs = 5)]
2075 #[cfg_attr(not(fbcode_build), ignore)]
2077 async fn test_tcp_message_size() {
2078 let default_size_in_bytes = 100 * 1024 * 1024;
2079 let config = hyperactor_config::global::lock();
2081 let _guard1 = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1));
2082 let _guard2 = config.override_key(config::CODEC_MAX_FRAME_LENGTH, default_size_in_bytes);
2083
2084 let (addr, mut rx) =
2085 server::serve::<String>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
2086
2087 let tx = channel::dial::<String>(addr.clone()).unwrap();
2088 {
2090 let message = "a".repeat(default_size_in_bytes - 1024);
2092 tx.post(message.clone());
2093 assert_eq!(rx.recv().await.unwrap(), message);
2094 }
2095 {
2097 let (return_channel, return_receiver) = oneshot::channel();
2098 let message = "a".repeat(default_size_in_bytes + 1024);
2099 tx.try_post(message.clone(), return_channel);
2100 let returned = return_receiver.await.unwrap();
2101 assert_eq!(message, returned.message);
2102 }
2103 }
2104
2105 #[async_timed_test(timeout_secs = 30)]
2106 #[cfg_attr(not(fbcode_build), ignore)]
2108 async fn test_ack_flush() {
2109 let config = hyperactor_config::global::lock();
2110 let _guard_message_ack =
2113 config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 100000000);
2114 let _guard_delivery_timeout =
2115 config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(5));
2116
2117 let (addr, mut net_rx) =
2118 server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
2119 let net_tx = channel::dial::<u64>(addr.clone()).unwrap();
2120 let (tx, rx) = oneshot::channel();
2121 net_tx.try_post(1, tx);
2122 assert_eq!(net_rx.recv().await.unwrap(), 1);
2123 drop(net_rx);
2124 assert!(rx.await.is_err());
2127 }
2128
2129 #[async_timed_test(timeout_secs = 60)]
2130 #[cfg_attr(not(fbcode_build), ignore)]
2132 async fn test_meta_tls_basic() {
2133 hyperactor_telemetry::initialize_logging_for_test();
2134
2135 let addr = ChannelAddr::any(ChannelTransport::MetaTls(TlsMode::IpV6));
2136 let meta_addr = match addr {
2137 ChannelAddr::MetaTls(meta_addr) => meta_addr,
2138 _ => panic!("expected MetaTls address"),
2139 };
2140 let (local_addr, mut rx) = server::serve::<u64>(ChannelAddr::MetaTls(meta_addr)).unwrap();
2141 {
2142 let tx = channel::dial::<u64>(local_addr.clone()).unwrap();
2143 tx.post(123);
2144 }
2145 assert_eq!(rx.recv().await.unwrap(), 123);
2146
2147 {
2148 let tx = channel::dial::<u64>(local_addr.clone()).unwrap();
2149 tx.post(321);
2150 tx.post(111);
2151 tx.post(444);
2152 assert_eq!(rx.recv().await.unwrap(), 321);
2153 assert_eq!(rx.recv().await.unwrap(), 111);
2154 assert_eq!(rx.recv().await.unwrap(), 444);
2155 }
2156
2157 {
2158 let tx = channel::dial::<u64>(local_addr).unwrap();
2159 drop(rx);
2160
2161 let (return_tx, return_rx) = oneshot::channel();
2162 tx.try_post(123, return_tx);
2163 assert_matches!(
2164 return_rx.await,
2165 Ok(SendError {
2166 error: ChannelError::Closed,
2167 message: 123,
2168 ..
2169 })
2170 );
2171 }
2172 }
2173
2174 #[derive(Clone, Debug, Default)]
2175 struct NetworkFlakiness {
2176 disconnect_params: Option<(f64, u64, Duration)>,
2184 latency_range: Option<(Duration, Duration)>,
2187 }
2188
2189 impl NetworkFlakiness {
2190 async fn should_disconnect(
2192 &self,
2193 rng: &mut impl rand::Rng,
2194 disconnected_count: u64,
2195 prev_disconnected_at: &RwLock<Instant>,
2196 ) -> bool {
2197 let Some((prob, max_disconnects, duration)) = &self.disconnect_params else {
2198 return false;
2199 };
2200
2201 let disconnected_at = prev_disconnected_at.read().unwrap();
2202 if disconnected_at.elapsed() > *duration && disconnected_count < *max_disconnects {
2203 rng.random_bool(*prob)
2204 } else {
2205 false
2206 }
2207 }
2208 }
2209
2210 struct MockLink<M> {
2211 buffer_size: usize,
2212 session_id: SessionId,
2213 receiver_storage: Arc<MVar<DuplexStream>>,
2214 fail_connects: Arc<AtomicBool>,
2216 disconnect_signal: watch::Sender<()>,
2219 network_flakiness: NetworkFlakiness,
2220 disconnected_count: Arc<AtomicU64>,
2221 prev_disconnected_at: Arc<RwLock<Instant>>,
2222 debug_log_sampling_rate: Option<u64>,
2225 _message_type: PhantomData<M>,
2226 }
2227
2228 impl<M> fmt::Debug for MockLink<M> {
2229 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2230 f.debug_struct("MockLink")
2231 .field("buffer_size", &self.buffer_size)
2232 .field("receiver_storage", &"<MVar<DuplexStream>>")
2233 .field("fail_connects", &self.fail_connects)
2234 .field("disconnect_signal", &"<watch::Sender>")
2235 .field("network_flakiness", &self.network_flakiness)
2236 .field("disconnected_count", &self.disconnected_count)
2237 .field("prev_disconnected_at", &"<RwLock<Instant>>")
2238 .field("debug_log_sampling_rate", &self.debug_log_sampling_rate)
2239 .finish()
2240 }
2241 }
2242
2243 impl<M: RemoteMessage> MockLink<M> {
2244 fn new() -> Self {
2245 let (sender, _) = watch::channel(());
2246 Self {
2247 buffer_size: 64,
2248 session_id: SessionId::random(),
2249 receiver_storage: Arc::new(MVar::empty()),
2250 fail_connects: Arc::new(AtomicBool::new(false)),
2251 disconnect_signal: sender,
2252 network_flakiness: NetworkFlakiness::default(),
2253 disconnected_count: Arc::new(AtomicU64::new(0)),
2254 prev_disconnected_at: Arc::new(RwLock::new(tokio::time::Instant::now())),
2255 debug_log_sampling_rate: None,
2256 _message_type: PhantomData,
2257 }
2258 }
2259
2260 fn fail_connects() -> Self {
2263 Self {
2264 fail_connects: Arc::new(AtomicBool::new(true)),
2265 ..Self::new()
2266 }
2267 }
2268
2269 fn with_network_flakiness(network_flakiness: NetworkFlakiness) -> Self {
2270 if let Some((min, max)) = network_flakiness.latency_range {
2271 assert!(min < max);
2272 }
2273
2274 Self {
2275 network_flakiness,
2276 ..Self::new()
2277 }
2278 }
2279
2280 fn receiver_storage(&self) -> Arc<MVar<DuplexStream>> {
2281 self.receiver_storage.clone()
2282 }
2283
2284 fn disconnected_count(&self) -> Arc<AtomicU64> {
2285 self.disconnected_count.clone()
2286 }
2287
2288 fn disconnect_signal(&self) -> &watch::Sender<()> {
2289 &self.disconnect_signal
2290 }
2291
2292 fn fail_connects_switch(&self) -> Arc<AtomicBool> {
2293 self.fail_connects.clone()
2294 }
2295
2296 fn set_buffer_size(&mut self, size: usize) {
2297 self.buffer_size = size;
2298 }
2299
2300 fn set_sampling_rate(&mut self, sampling_rate: u64) {
2301 self.debug_log_sampling_rate = Some(sampling_rate);
2302 }
2303 }
2304
2305 #[async_trait]
2306 impl<M: RemoteMessage> Link for MockLink<M> {
2307 type Stream = DuplexStream;
2308
2309 fn dest(&self) -> ChannelAddr {
2310 ChannelAddr::Local(u64::MAX)
2311 }
2312
2313 fn link_id(&self) -> SessionId {
2314 self.session_id
2315 }
2316
2317 async fn next(&self) -> Result<Self::Stream, ClientError> {
2318 let session_id = self.session_id;
2319 tracing::debug!("MockLink starts to connect.");
2320 if self.fail_connects.load(Ordering::Acquire) {
2321 return Err(ClientError::Connect(
2322 self.dest(),
2323 std::io::Error::other("intentional error"),
2324 "expected failure injected by the mock".to_string(),
2325 ));
2326 }
2327
2328 async fn relay_message<M: RemoteMessage>(
2334 mut disconnect_signal: watch::Receiver<()>,
2335 network_flakiness: NetworkFlakiness,
2336 disconnected_count: Arc<AtomicU64>,
2337 prev_disconnected_at: Arc<RwLock<Instant>>,
2338 mut reader: FrameReader<ReadHalf<DuplexStream>>,
2339 mut writer: WriteHalf<DuplexStream>,
2340 task_coordination_token: CancellationToken,
2343 debug_log_sampling_rate: Option<u64>,
2344 is_from_client: bool,
2347 ) {
2348 async fn wait_for_latency_elapse(
2352 queue: &VecDeque<(Bytes, Instant)>,
2353 network_flakiness: &NetworkFlakiness,
2354 rng: &mut impl rand::Rng,
2355 ) {
2356 if let Some((min, max)) = network_flakiness.latency_range {
2357 let diff = max.abs_diff(min);
2358 let factor = rng.random_range(0.0..=1.0);
2359 let latency = min + diff.mul_f64(factor);
2360 tokio::time::sleep_until(queue.front().unwrap().1 + latency).await;
2361 }
2362 }
2363
2364 let mut rng = rand::rngs::SmallRng::from_os_rng();
2365 let mut queue: VecDeque<(Bytes, Instant)> = VecDeque::new();
2366 let mut send_count = 0u64;
2367
2368 loop {
2369 tokio::select! {
2370 read_res = reader.next() => {
2371 match read_res {
2372 Ok(Some((_, data))) => {
2373 queue.push_back((data, tokio::time::Instant::now()));
2374 }
2375 Ok(None) | Err(_) => {
2376 tracing::debug!("The upstream is closed or dropped. MockLink disconnects");
2377 break;
2378 }
2379 }
2380 }
2381 _ = wait_for_latency_elapse(&queue, &network_flakiness, &mut rng), if !queue.is_empty() => {
2382 let count = disconnected_count.load(Ordering::Relaxed);
2383 if network_flakiness.should_disconnect(&mut rng, count, &prev_disconnected_at).await {
2384 tracing::debug!("MockLink disconnects");
2385 disconnected_count.fetch_add(1, Ordering::Relaxed);
2386
2387 metrics::CHANNEL_RECONNECTIONS.add(
2388 1,
2389 hyperactor_telemetry::kv_pairs!(
2390 "transport" => "mock",
2391 "reason" => "network_flakiness",
2392 ),
2393 );
2394
2395 let mut w = prev_disconnected_at.write().unwrap();
2396 *w = tokio::time::Instant::now();
2397 break;
2398 }
2399 let data = queue.pop_front().unwrap().0;
2400 let is_sampled = debug_log_sampling_rate.is_some_and(|sample_rate| send_count % sample_rate == 1);
2401 if is_sampled {
2402 if is_from_client {
2403 if let Ok(Frame::Message(_seq, _msg)) = bincode::deserialize::<Frame<M>>(&data) {
2404 tracing::debug!("MockLink relays a msg from client. msg type: {}", std::any::type_name::<M>());
2405 }
2406 } else {
2407 let result = deserialize_response(data.clone());
2408 if let Ok(NetRxResponse::Ack(seq)) = result {
2409 tracing::debug!("MockLink relays an ack from server. seq: {}", seq);
2410 }
2411 }
2412 }
2413 let mut fw = FrameWrite::new(writer, data, hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH), 0).unwrap();
2414 if fw.send().await.is_err() {
2415 break;
2416 }
2417 writer = fw.complete();
2418 send_count += 1;
2419 }
2420 _ = task_coordination_token.cancelled() => break,
2421
2422 changed = disconnect_signal.changed() => {
2423 tracing::debug!("MockLink disconnects per disconnect_signal {:?}", changed);
2424 break;
2425 }
2426 }
2427 }
2428
2429 task_coordination_token.cancel();
2430 }
2431
2432 let (server, mut server_relay) = tokio::io::duplex(self.buffer_size);
2433 let (client, client_relay) = tokio::io::duplex(self.buffer_size);
2434
2435 write_link_init(&mut server_relay, session_id)
2439 .await
2440 .map_err(|err| ClientError::Io(self.dest(), err))?;
2441
2442 let (server_r, server_writer) = tokio::io::split(server_relay);
2443 let (client_r, client_writer) = tokio::io::split(client_relay);
2444
2445 let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
2446 let server_reader = FrameReader::new(server_r, max_len);
2447 let client_reader = FrameReader::new(client_r, max_len);
2448
2449 let task_coordination_token = CancellationToken::new();
2450 let _server_relay_task_handle = tokio::spawn(relay_message::<M>(
2451 self.disconnect_signal.subscribe(),
2452 self.network_flakiness.clone(),
2453 self.disconnected_count.clone(),
2454 self.prev_disconnected_at.clone(),
2455 server_reader,
2456 client_writer,
2457 task_coordination_token.clone(),
2458 self.debug_log_sampling_rate.clone(),
2459 false,
2460 ));
2461 let _client_relay_task_handle = tokio::spawn(relay_message::<M>(
2462 self.disconnect_signal.subscribe(),
2463 self.network_flakiness.clone(),
2464 self.disconnected_count.clone(),
2465 self.prev_disconnected_at.clone(),
2466 client_reader,
2467 server_writer,
2468 task_coordination_token,
2469 self.debug_log_sampling_rate.clone(),
2470 true,
2471 ));
2472
2473 self.receiver_storage.put(server).await;
2474 Ok(client)
2475 }
2476 }
2477
2478 struct MockLinkListener {
2479 receiver_storage: Arc<MVar<DuplexStream>>,
2480 channel_addr: ChannelAddr,
2481 }
2482
2483 impl MockLinkListener {
2484 fn new(receiver_storage: Arc<MVar<DuplexStream>>, channel_addr: ChannelAddr) -> Self {
2485 Self {
2486 receiver_storage,
2487 channel_addr,
2488 }
2489 }
2490 }
2491
2492 impl fmt::Debug for MockLinkListener {
2493 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2494 f.debug_struct("MockLinkListener")
2495 .field("channel_addr", &self.channel_addr)
2496 .finish()
2497 }
2498 }
2499
2500 #[async_trait]
2501 impl super::Listener for MockLinkListener {
2502 type Stream = DuplexStream;
2503
2504 async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
2505 let stream = self.receiver_storage.take().await;
2506 Ok((stream, self.channel_addr.clone()))
2507 }
2508 }
2509
2510 fn serve_acceptor_test<M: RemoteMessage>(
2514 session_id: SessionId,
2515 ) -> (
2516 JoinHandle<()>,
2517 crate::sync::mvar::MVar<DuplexStream>,
2518 mpsc::Receiver<M>,
2519 CancellationToken,
2520 ) {
2521 let mvar = crate::sync::mvar::MVar::empty();
2522 let cancel_token = CancellationToken::new();
2523 let link = AcceptorLink {
2524 dest: ChannelAddr::Local(u64::MAX),
2525 session_id,
2526 stream: mvar.clone(),
2527 cancel: cancel_token.clone(),
2528 };
2529 let (tx, rx) = mpsc::channel::<M>(1024);
2530 let ct = cancel_token.clone();
2531 let handle = tokio::spawn(async move {
2532 let mut session = Session::new(link);
2533 let mut next = session::Next { seq: 0, ack: 0 };
2534
2535 loop {
2536 let connected = match session.connect().await {
2537 Ok(s) => s,
2538 Err(_) => break,
2539 };
2540
2541 let result = {
2542 let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2543 tokio::select! {
2544 r = session::recv_connected::<M, _, _>(&stream, &tx, &mut next) => r,
2545 _ = ct.cancelled() => Err(session::RecvLoopError::Cancelled),
2546 }
2547 };
2548
2549 if next.ack < next.seq {
2551 let ack = serialize_response(NetRxResponse::Ack(next.seq - 1)).unwrap();
2552 let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2553 let mut completion = stream.write(ack);
2554 match completion.drive().await {
2555 Ok(()) => {
2556 next.ack = next.seq;
2557 }
2558 Err(e) => {
2559 tracing::debug!(
2560 error = %e,
2561 "failed to flush acks during cleanup"
2562 );
2563 }
2564 }
2565 }
2566
2567 let terminal_response = match &result {
2569 Err(session::RecvLoopError::SequenceError(reason)) => {
2570 Some(NetRxResponse::Reject(reason.clone()))
2571 }
2572 Err(session::RecvLoopError::Cancelled) => Some(NetRxResponse::Closed),
2573 _ => None,
2574 };
2575 if let Some(rsp) = terminal_response {
2576 let data = serialize_response(rsp).unwrap();
2577 let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2578 let mut completion = stream.write(data);
2579 let _ = completion.drive().await;
2580 }
2581
2582 let recoverable = matches!(&result, Ok(()) | Err(session::RecvLoopError::Io(_)));
2583 session = connected.release();
2584 if recoverable {
2585 continue;
2586 }
2587 break;
2588 }
2589 });
2590 (handle, mvar, rx, cancel_token)
2591 }
2592
2593 async fn write_stream<M, W>(
2594 mut writer: W,
2595 _session_id: u64,
2596 messages: &[(u64, M)],
2597 _init: bool,
2598 ) -> W
2599 where
2600 M: RemoteMessage + PartialEq + Clone,
2601 W: AsyncWrite + Unpin,
2602 {
2603 for (seq, message) in messages {
2604 let message =
2605 serde_multipart::serialize_bincode(&Frame::<M>::Message(*seq, message.clone()))
2606 .unwrap();
2607 let mut fw = FrameWrite::new(
2608 writer,
2609 message.framed(),
2610 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2611 0,
2612 )
2613 .map_err(|(_w, e)| e)
2614 .unwrap();
2615 fw.send().await.unwrap();
2616 writer = fw.complete();
2617 }
2618
2619 writer
2620 }
2621
2622 #[async_timed_test(timeout_secs = 60)]
2623 async fn test_persistent_server_session() {
2624 let config = hyperactor_config::global::lock();
2625 let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
2626
2627 async fn verify_ack(reader: &mut FrameReader<ReadHalf<DuplexStream>>, expected_last: u64) {
2628 let mut last_acked: i128 = -1;
2629 loop {
2630 let (_, bytes) = reader.next().await.unwrap().unwrap();
2631 let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
2632 assert!(
2633 acked as i128 > last_acked,
2634 "acks should be delivered in ascending order"
2635 );
2636 last_acked = acked as i128;
2637 assert!(acked <= expected_last);
2638 if acked == expected_last {
2639 break;
2640 }
2641 }
2642 }
2643
2644 let session_id = SessionId(123);
2645 let (_handle, mvar, mut rx, cancel_token) = serve_acceptor_test::<u64>(session_id);
2646
2647 {
2649 let (sender, receiver) = tokio::io::duplex(5000);
2650 mvar.put(receiver).await;
2651
2652 let (r, writer) = tokio::io::split(sender);
2653 let mut reader = FrameReader::new(
2654 r,
2655 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2656 );
2657
2658 let _writer = write_stream(
2659 writer,
2660 123,
2661 &[
2662 (0u64, 100u64),
2663 (1u64, 101u64),
2664 (2u64, 102u64),
2665 (3u64, 103u64),
2666 ],
2667 true,
2668 )
2669 .await;
2670
2671 assert_eq!(rx.recv().await, Some(100));
2672 assert_eq!(rx.recv().await, Some(101));
2673 assert_eq!(rx.recv().await, Some(102));
2674 assert_eq!(rx.recv().await, Some(103));
2675
2676 verify_ack(&mut reader, 3).await;
2677 }
2679
2680 {
2682 let (sender2, receiver2) = tokio::io::duplex(5000);
2683 mvar.put(receiver2).await;
2684
2685 let (r2, writer2) = tokio::io::split(sender2);
2686 let mut reader2 = FrameReader::new(
2687 r2,
2688 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2689 );
2690
2691 let _ = write_stream(
2692 writer2,
2693 123,
2694 &[
2695 (2u64, 102u64),
2696 (3u64, 103u64),
2697 (4u64, 104u64),
2698 (5u64, 105u64),
2699 ],
2700 true,
2701 )
2702 .await;
2703
2704 assert_eq!(rx.recv().await, Some(104));
2706 assert_eq!(rx.recv().await, Some(105));
2707
2708 verify_ack(&mut reader2, 5).await;
2709
2710 cancel_token.cancel();
2711 }
2712 }
2713
2714 #[async_timed_test(timeout_secs = 60)]
2715 async fn test_ack_from_server_session() {
2716 let config = hyperactor_config::global::lock();
2717 let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
2718 let session_id = SessionId(123);
2719 let (_handle, mvar, mut rx, cancel_token) = serve_acceptor_test::<u64>(session_id);
2720
2721 let (sender, receiver) = tokio::io::duplex(5000);
2722 mvar.put(receiver).await;
2723 let (r, mut writer) = tokio::io::split(sender);
2724 let mut reader = FrameReader::new(
2725 r,
2726 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2727 );
2728
2729 for i in 0u64..100u64 {
2730 writer = write_stream(writer, 123, &[(i, 100u64 + i)], i == 0u64).await;
2731 assert_eq!(rx.recv().await, Some(100u64 + i));
2732 let (_, bytes) = reader.next().await.unwrap().unwrap();
2733 let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
2734 assert_eq!(acked, i);
2735 }
2736
2737 tokio::time::sleep(Duration::from_secs(5)).await;
2739
2740 cancel_token.cancel();
2741
2742 let (_, bytes) = reader.next().await.unwrap().unwrap();
2744 assert!(deserialize_response(bytes).unwrap().is_closed());
2745 }
2746
2747 #[tracing_test::traced_test]
2748 async fn verify_tx_closed(tx_status: &mut watch::Receiver<TxStatus>, expected_log: &str) {
2749 match tokio::time::timeout(Duration::from_secs(5), tx_status.changed()).await {
2750 Ok(Ok(())) => {
2751 let current_status = tx_status.borrow().clone();
2752 assert!(current_status.is_closed());
2753 logs_assert_unscoped(|logs| {
2754 if logs.iter().any(|log| log.contains(expected_log)) {
2755 Ok(())
2756 } else {
2757 Err("expected log not found".to_string())
2758 }
2759 });
2760 }
2761 Ok(Err(_)) => panic!("watch::Receiver::changed() failed because sender is dropped."),
2762 Err(_) => panic!("timeout before tx_status changed"),
2763 }
2764 }
2765
2766 #[tracing_test::traced_test]
2767 #[tokio::test]
2768 #[cfg_attr(not(fbcode_build), ignore)]
2770 async fn test_tcp_tx_delivery_timeout() {
2771 let link = MockLink::<u64>::fail_connects();
2773 let tx = spawn::<u64>(link);
2774 let config = hyperactor_config::global::lock();
2776 let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1));
2777 let mut tx_receiver = tx.status().clone();
2778 let (return_channel, _return_receiver) = oneshot::channel();
2779 tx.try_post(123, return_channel);
2780 verify_tx_closed(&mut tx_receiver, "failed to deliver message within timeout").await;
2781 }
2782
2783 async fn take_receiver(
2784 receiver_storage: &MVar<DuplexStream>,
2785 ) -> (FrameReader<ReadHalf<DuplexStream>>, WriteHalf<DuplexStream>) {
2786 let mut receiver = receiver_storage.take().await;
2787 let _session_id = read_link_init(&mut receiver).await.expect("read LinkInit");
2789 let (r, writer) = tokio::io::split(receiver);
2790 let reader = FrameReader::new(
2791 r,
2792 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2793 );
2794 (reader, writer)
2795 }
2796
2797 async fn verify_message<M: RemoteMessage + PartialEq + std::fmt::Debug>(
2798 reader: &mut FrameReader<ReadHalf<DuplexStream>>,
2799 expect: (u64, M),
2800 loc: u32,
2801 ) {
2802 let expected = Frame::Message(expect.0, expect.1);
2803 let (_, bytes) = reader.next().await.unwrap().expect("unexpected EOF");
2804 let message = serde_multipart::Message::from_framed(bytes).unwrap();
2805 let frame: Frame<M> = serde_multipart::deserialize_bincode(message).unwrap();
2806
2807 assert_eq!(frame, expected, "from ln={loc}");
2808 }
2809
2810 async fn verify_stream<M: RemoteMessage + PartialEq + std::fmt::Debug + Clone>(
2811 reader: &mut FrameReader<ReadHalf<DuplexStream>>,
2812 expects: &[(u64, M)],
2813 _expect_session_id: Option<u64>,
2814 loc: u32,
2815 ) {
2816 for expect in expects {
2817 verify_message(reader, expect.clone(), loc).await;
2818 }
2819 }
2820
2821 async fn net_tx_send(tx: &NetTx<u64>, msgs: &[u64]) {
2822 for msg in msgs {
2823 tx.post(*msg);
2824 }
2825 }
2826
2827 #[async_timed_test(timeout_secs = 30)]
2829 async fn test_ack_in_net_tx_basic() {
2830 let link = MockLink::<u64>::new();
2831 let receiver_storage = link.receiver_storage();
2832 let tx = spawn::<u64>(link);
2833
2834 net_tx_send(&tx, &[100, 101, 102, 103, 104]).await;
2836 {
2837 let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2838 verify_stream(
2839 &mut reader,
2840 &[
2841 (0u64, 100u64),
2842 (1u64, 101u64),
2843 (2u64, 102u64),
2844 (3u64, 103u64),
2845 (4u64, 104u64),
2846 ],
2847 None,
2848 line!(),
2849 )
2850 .await;
2851
2852 for i in 0u64..5u64 {
2853 writer = FrameWrite::write_frame(
2854 writer,
2855 serialize_response(NetRxResponse::Ack(i)).unwrap(),
2856 1024,
2857 0,
2858 )
2859 .await
2860 .map_err(|(_, e)| e)
2861 .unwrap();
2862 }
2863 tokio::time::sleep(Duration::from_secs(3)).await;
2865 drop(reader);
2867 drop(writer);
2868 };
2869
2870 net_tx_send(&tx, &[105u64]).await;
2872 {
2873 let (mut reader, _writer) = take_receiver(&receiver_storage).await;
2874 verify_stream(&mut reader, &[(5u64, 105u64)], None, line!()).await;
2875 };
2877 }
2878
2879 #[async_timed_test(timeout_secs = 60)]
2881 async fn test_persistent_net_tx() {
2882 let link = MockLink::<u64>::new();
2883 let receiver_storage = link.receiver_storage();
2884
2885 let tx = spawn::<u64>(link);
2886
2887 net_tx_send(&tx, &[100, 101, 102, 103, 104]).await;
2889
2890 let n = 3;
2894
2895 for i in 0..n {
2898 {
2899 let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2900 verify_stream(
2901 &mut reader,
2902 &[
2903 (0u64, 100u64),
2904 (1u64, 101u64),
2905 (2u64, 102u64),
2906 (3u64, 103u64),
2907 (4u64, 104u64),
2908 ],
2909 None,
2910 line!(),
2911 )
2912 .await;
2913
2914 if i == n - 1 {
2917 writer = FrameWrite::write_frame(
2918 writer,
2919 serialize_response(NetRxResponse::Ack(1)).unwrap(),
2920 1024,
2921 0,
2922 )
2923 .await
2924 .map_err(|(_, e)| e)
2925 .unwrap();
2926 tokio::time::sleep(Duration::from_secs(3)).await;
2928 }
2929 drop(reader);
2931 drop(writer);
2932 };
2933 }
2934
2935 for _ in 0..n {
2937 {
2938 let (mut reader, mut _writer) = take_receiver(&receiver_storage).await;
2939 verify_stream(
2940 &mut reader,
2941 &[(2u64, 102u64), (3u64, 103u64), (4u64, 104u64)],
2942 None,
2943 line!(),
2944 )
2945 .await;
2946 };
2948 }
2949
2950 net_tx_send(&tx, &[105u64, 106u64, 107u64, 108u64, 109u64]).await;
2952 for i in 0..n {
2955 {
2956 let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2957 verify_stream(
2958 &mut reader,
2959 &[
2960 (2u64, 102u64),
2962 (3u64, 103u64),
2963 (4u64, 104u64),
2964 (5u64, 105u64),
2966 (6u64, 106u64),
2967 (7u64, 107u64),
2968 (8u64, 108u64),
2969 (9u64, 109u64),
2970 ],
2971 None,
2972 line!(),
2973 )
2974 .await;
2975
2976 if i == n - 1 {
2979 writer = FrameWrite::write_frame(
2982 writer,
2983 serialize_response(NetRxResponse::Ack(1)).unwrap(),
2984 1024,
2985 0,
2986 )
2987 .await
2988 .map_err(|(_, e)| e)
2989 .unwrap();
2990 writer = FrameWrite::write_frame(
2991 writer,
2992 serialize_response(NetRxResponse::Ack(2)).unwrap(),
2993 1024,
2994 0,
2995 )
2996 .await
2997 .map_err(|(_, e)| e)
2998 .unwrap();
2999 writer = FrameWrite::write_frame(
3000 writer,
3001 serialize_response(NetRxResponse::Ack(3)).unwrap(),
3002 1024,
3003 0,
3004 )
3005 .await
3006 .map_err(|(_, e)| e)
3007 .unwrap();
3008 tokio::time::sleep(Duration::from_secs(3)).await;
3010 }
3011 drop(reader);
3013 drop(writer);
3014 };
3015 }
3016
3017 for i in 0..n {
3018 {
3019 let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
3020 verify_stream(
3021 &mut reader,
3022 &[
3023 (4u64, 104),
3025 (5u64, 105u64),
3027 (6u64, 106u64),
3028 (7u64, 107u64),
3029 (8u64, 108u64),
3030 (9u64, 109u64),
3031 ],
3032 None,
3033 line!(),
3034 )
3035 .await;
3036
3037 if i == n - 1 {
3039 writer = FrameWrite::write_frame(
3040 writer,
3041 serialize_response(NetRxResponse::Ack(7)).unwrap(),
3042 1024,
3043 0,
3044 )
3045 .await
3046 .map_err(|(_, e)| e)
3047 .unwrap();
3048 tokio::time::sleep(Duration::from_secs(3)).await;
3050 }
3051 drop(reader);
3053 drop(writer);
3054 };
3055 }
3056
3057 for _ in 0..n {
3058 {
3059 let (mut reader, writer) = take_receiver(&receiver_storage).await;
3060 verify_stream(
3061 &mut reader,
3062 &[
3063 (8u64, 108u64),
3065 (9u64, 109u64),
3066 ],
3067 None,
3068 line!(),
3069 )
3070 .await;
3071 drop(reader);
3073 drop(writer);
3074 };
3075 }
3076 }
3077
3078 #[async_timed_test(timeout_secs = 15)]
3079 async fn test_ack_before_redelivery_in_net_tx() {
3080 let link = MockLink::<u64>::new();
3081 let receiver_storage = link.receiver_storage();
3082 let net_tx = spawn::<u64>(link);
3083
3084 let (return_channel_tx, return_channel_rx) = oneshot::channel();
3087 net_tx.try_post(100, return_channel_tx);
3088 let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
3089 verify_stream(&mut reader, &[(0u64, 100u64)], None, line!()).await;
3090 writer = FrameWrite::write_frame(
3092 writer,
3093 serialize_response(NetRxResponse::Ack(0)).unwrap(),
3094 1024,
3095 0,
3096 )
3097 .await
3098 .map_err(|(_, e)| e)
3099 .unwrap();
3100 assert!(return_channel_rx.await.is_err());
3105
3106 let _ = FrameWrite::write_frame(
3111 writer,
3112 serialize_response(NetRxResponse::Ack(1)).unwrap(),
3113 1024,
3114 0,
3115 )
3116 .await
3117 .map_err(|(_, e)| e)
3118 .unwrap();
3119
3120 let (return_channel_tx, return_channel_rx) = oneshot::channel();
3121 net_tx.try_post(101, return_channel_tx);
3122 verify_message(&mut reader, (1u64, 101u64), line!()).await;
3124 assert!(return_channel_rx.await.is_err());
3131 }
3132
3133 async fn verify_ack_exceeded_limit(disconnect_before_ack: bool) {
3134 let config = hyperactor_config::global::lock();
3136 let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(2));
3137
3138 let link: MockLink<u64> = MockLink::<u64>::new();
3139 let disconnect_signal = link.disconnect_signal().clone();
3140 let fail_connect_switch = link.fail_connects_switch();
3141 let receiver_storage = link.receiver_storage();
3142 let tx = spawn::<u64>(link);
3143 let mut tx_status = tx.status().clone();
3144 tx.post(100);
3146 let (mut reader, writer) = take_receiver(&receiver_storage).await;
3147 verify_stream(&mut reader, &[(0u64, 100u64)], None, line!()).await;
3149 let _ = FrameWrite::write_frame(
3151 writer,
3152 serialize_response(NetRxResponse::Ack(0)).unwrap(),
3153 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3154 0,
3155 )
3156 .await
3157 .map_err(|(_, e)| e)
3158 .unwrap();
3159 tokio::time::sleep(Duration::from_secs(3)).await;
3160 assert!(!tx_status.has_changed().unwrap());
3162 assert_eq!(*tx_status.borrow(), TxStatus::Active);
3163
3164 tx.post(101);
3165 verify_message(&mut reader, (1u64, 101u64), line!()).await;
3167
3168 if disconnect_before_ack {
3169 fail_connect_switch.store(true, Ordering::Release);
3171 disconnect_signal.send(()).unwrap();
3173 }
3174
3175 let expected_log: &str = if disconnect_before_ack {
3177 "failed to receive ack within timeout 2s; link is currently broken"
3178 } else {
3179 "failed to receive ack within timeout 2s; link is currently connected"
3180 };
3181
3182 verify_tx_closed(&mut tx_status, expected_log).await;
3183 }
3184
3185 #[tracing_test::traced_test]
3186 #[async_timed_test(timeout_secs = 30)]
3187 #[cfg_attr(not(fbcode_build), ignore)]
3189 async fn test_ack_exceeded_limit_with_connected_link() {
3190 verify_ack_exceeded_limit(false).await;
3191 }
3192
3193 #[tracing_test::traced_test]
3194 #[async_timed_test(timeout_secs = 30)]
3195 #[cfg_attr(not(fbcode_build), ignore)]
3197 async fn test_ack_exceeded_limit_with_broken_link() {
3198 verify_ack_exceeded_limit(true).await;
3199 }
3200
3201 #[async_timed_test(timeout_secs = 60)]
3204 async fn test_network_flakiness_in_channel() {
3205 hyperactor_telemetry::initialize_logging_for_test();
3206
3207 let sampling_rate = 100;
3208 let mut link = MockLink::<u64>::with_network_flakiness(NetworkFlakiness {
3209 disconnect_params: Some((0.001, 15, Duration::from_millis(400))),
3210 latency_range: Some((Duration::from_millis(100), Duration::from_millis(200))),
3211 });
3212 link.set_sampling_rate(sampling_rate);
3213 link.set_buffer_size(1024000);
3215 let disconnected_count = link.disconnected_count();
3216 let receiver_storage = link.receiver_storage();
3217 let listener = MockLinkListener::new(receiver_storage.clone(), link.dest());
3218 let local_addr = listener.channel_addr.clone();
3219 let (_, mut nx): (ChannelAddr, NetRx<u64>) =
3220 super::server::serve_with_listener(listener, local_addr).unwrap();
3221 let tx = spawn::<u64>(link);
3222 let messages: Vec<_> = (0..10001).collect();
3223 let messages_clone = messages.clone();
3224 let send_task_handle = tokio::spawn(async move {
3227 for message in messages_clone {
3228 tokio::time::sleep(Duration::from_micros(rand::random::<u64>() % 100)).await;
3240 tx.post(message);
3241 }
3242 tracing::debug!("NetTx sent all messages");
3243 tx
3246 });
3247
3248 for message in &messages {
3249 if message % sampling_rate == 0 {
3250 tracing::debug!("NetRx received a message: {message}");
3251 }
3252 assert_eq!(nx.recv().await.unwrap(), *message);
3253 }
3254 tracing::debug!("NetRx received all messages");
3255
3256 let send_result = send_task_handle.await;
3257 assert!(send_result.is_ok());
3258
3259 tracing::debug!(
3260 "MockLink disconnected {} times.",
3261 disconnected_count.load(Ordering::SeqCst)
3262 );
3263 }
3266
3267 #[async_timed_test(timeout_secs = 60)]
3268 async fn test_ack_every_n_messages() {
3269 let config = hyperactor_config::global::lock();
3270 let _guard_message_ack = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 600);
3271 let _guard_time_interval =
3272 config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(1000));
3273 sparse_ack().await;
3274 }
3275
3276 #[async_timed_test(timeout_secs = 60)]
3277 async fn test_ack_every_time_interval() {
3278 let config = hyperactor_config::global::lock();
3279 let _guard_message_ack =
3280 config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 100000000);
3281 let _guard_time_interval = config.override_key(
3282 config::MESSAGE_ACK_TIME_INTERVAL,
3283 Duration::from_millis(500),
3284 );
3285 sparse_ack().await;
3286 }
3287
3288 async fn sparse_ack() {
3289 let mut link = MockLink::<u64>::new();
3290 link.set_buffer_size(1024000);
3292 let disconnected_count = link.disconnected_count();
3293 let receiver_storage = link.receiver_storage();
3294 let listener = MockLinkListener::new(receiver_storage.clone(), link.dest());
3295 let local_addr = listener.channel_addr.clone();
3296 let (_, mut nx): (ChannelAddr, NetRx<u64>) =
3297 super::server::serve_with_listener(listener, local_addr).unwrap();
3298 let tx = spawn::<u64>(link);
3299 let messages: Vec<_> = (0..20001).collect();
3300 let messages_clone = messages.clone();
3301 let send_task_handle = tokio::spawn(async move {
3304 for message in messages_clone {
3305 tokio::time::sleep(Duration::from_micros(rand::random::<u64>() % 100)).await;
3306 tx.post(message);
3307 }
3308 tokio::time::sleep(Duration::from_secs(5)).await;
3309 tracing::debug!("NetTx sent all messages");
3310 tx
3311 });
3312
3313 for message in &messages {
3314 assert_eq!(nx.recv().await.unwrap(), *message);
3315 }
3316 tracing::debug!("NetRx received all messages");
3317
3318 let send_result = send_task_handle.await;
3319 assert!(send_result.is_ok());
3320
3321 tracing::debug!(
3322 "MockLink disconnected {} times.",
3323 disconnected_count.load(Ordering::SeqCst)
3324 );
3325 }
3326
3327 #[test]
3328 fn test_metatls_parsing() {
3329 let channel: ChannelAddr = "metatls!localhost:1234".parse().unwrap();
3331 assert_eq!(
3332 channel,
3333 ChannelAddr::MetaTls(TlsAddr::new("localhost", 1234))
3334 );
3335 let channel: ChannelAddr = "metatls!1.2.3.4:1234".parse().unwrap();
3337 assert_eq!(channel, ChannelAddr::MetaTls(TlsAddr::new("1.2.3.4", 1234)));
3338 let channel: ChannelAddr = "metatls!2401:db00:33c:6902:face:0:2a2:0:1234"
3340 .parse()
3341 .unwrap();
3342 assert_eq!(
3343 channel,
3344 ChannelAddr::MetaTls(TlsAddr::new("2401:db00:33c:6902:face:0:2a2:0", 1234))
3345 );
3346
3347 let channel: ChannelAddr = "metatls![::]:1234".parse().unwrap();
3348 assert_eq!(channel, ChannelAddr::MetaTls(TlsAddr::new("::", 1234)));
3349 }
3350
3351 #[async_timed_test(timeout_secs = 300)]
3352 #[cfg_attr(not(fbcode_build), ignore)]
3354 async fn test_tcp_throughput() {
3355 let config = hyperactor_config::global::lock();
3356 let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_mins(5));
3357
3358 let socket_addr: SocketAddr = "[::1]:0".parse().unwrap();
3359 let (local_addr, mut rx) = server::serve::<String>(ChannelAddr::Tcp(socket_addr)).unwrap();
3360
3361 let total_num_msgs = 500000;
3363
3364 let receive_handle = tokio::spawn(async move {
3365 let mut num = 0;
3366 for _ in 0..10 * total_num_msgs {
3367 rx.recv().await.unwrap();
3368 num += 1;
3369
3370 if num % 100000 == 0 {
3371 tracing::info!("total number of received messages: {}", num);
3372 }
3373 }
3374 });
3375
3376 let mut tx_handles = vec![];
3377 let mut txs = vec![];
3378 for _ in 0..10 {
3379 let server_addr = local_addr.clone();
3380 let tx = Arc::new(channel::dial::<String>(server_addr).unwrap());
3381 let tx2 = Arc::clone(&tx);
3382 txs.push(tx);
3383 tx_handles.push(tokio::spawn(async move {
3384 let random_string = rand::rng()
3385 .sample_iter(&Alphanumeric)
3386 .take(2048)
3387 .map(char::from)
3388 .collect::<String>();
3389 for _ in 0..total_num_msgs {
3390 tx2.post(random_string.clone());
3391 }
3392 }));
3393 }
3394
3395 receive_handle.await.unwrap();
3396 for handle in tx_handles {
3397 handle.await.unwrap();
3398 }
3399 }
3400
3401 #[tracing_test::traced_test]
3402 #[async_timed_test(timeout_secs = 60)]
3403 #[cfg_attr(not(fbcode_build), ignore)]
3405 async fn test_net_tx_closed_on_server_reject() {
3406 let link = MockLink::<u64>::new();
3407 let receiver_storage = link.receiver_storage();
3408 let mut tx = spawn::<u64>(link);
3409 net_tx_send(&tx, &[100]).await;
3410
3411 {
3412 let (_reader, writer) = take_receiver(&receiver_storage).await;
3413 let _ = FrameWrite::write_frame(
3414 writer,
3415 serialize_response(NetRxResponse::Reject("testing".to_string())).unwrap(),
3416 1024,
3417 0,
3418 )
3419 .await
3420 .map_err(|(_, e)| e);
3421
3422 tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
3425 }
3426
3427 verify_tx_closed(&mut tx.status, "server rejected connection").await;
3428 }
3429
3430 #[async_timed_test(timeout_secs = 60)]
3431 async fn test_server_rejects_conn_on_out_of_sequence_message() {
3432 let config = hyperactor_config::global::lock();
3433 let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
3434 let session_id = SessionId(123);
3435 let (_handle, mvar, mut rx, _cancel_token) = serve_acceptor_test::<u64>(session_id);
3436
3437 let (sender, receiver) = tokio::io::duplex(5000);
3438 mvar.put(receiver).await;
3439 let (r, writer) = tokio::io::split(sender);
3440 let mut reader = FrameReader::new(
3441 r,
3442 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3443 );
3444
3445 let _ = write_stream(writer, 123, &[(0, 100u64), (1, 101u64), (3, 103u64)], true).await;
3446 assert_eq!(rx.recv().await, Some(100u64));
3447 assert_eq!(rx.recv().await, Some(101u64));
3448 let (_, bytes) = reader.next().await.unwrap().unwrap();
3449 let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
3450 assert_eq!(acked, 0);
3451 let (_, bytes) = reader.next().await.unwrap().unwrap();
3452 let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
3453 assert_eq!(acked, 1);
3454 let (_, bytes) = reader.next().await.unwrap().unwrap();
3455 assert!(deserialize_response(bytes).unwrap().is_reject());
3456 }
3457
3458 #[async_timed_test(timeout_secs = 60)]
3459 #[cfg_attr(not(fbcode_build), ignore)]
3461 async fn test_stop_net_tx_after_stopping_net_rx() {
3462 hyperactor_telemetry::initialize_logging_for_test();
3463
3464 let config = hyperactor_config::global::lock();
3465 let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_mins(5));
3466 let (addr, mut rx) =
3467 server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
3468 let socket_addr = match addr {
3469 ChannelAddr::Tcp(a) => a,
3470 _ => panic!("unexpected channel type"),
3471 };
3472 let tx: NetTx<u64> = spawn(tcp::link(socket_addr));
3473 tx.send(100).await.unwrap();
3478 assert_eq!(rx.recv().await.unwrap(), 100);
3479 rx.2.stop("testing");
3481 assert!(rx.recv().await.is_err());
3482
3483 tx.post(101);
3486 let mut watcher = tx.status().clone();
3487 let _ = watcher.wait_for(|val| val.is_closed()).await;
3489 assert!(watcher.borrow().is_closed());
3493 }
3494}