1use std::fmt;
51use std::fmt::Debug;
52use std::net::ToSocketAddrs;
53use std::time::Duration;
54
55use backoff::ExponentialBackoffBuilder;
56use backoff::backoff::Backoff;
57use bytes::Bytes;
58use enum_as_inner::EnumAsInner;
59use serde::Deserialize;
60use serde::Serialize;
61use serde::de::Error;
62use tokio::io::AsyncRead;
63use tokio::io::AsyncReadExt;
64use tokio::io::AsyncWrite;
65use tokio::io::AsyncWriteExt;
66use tokio::sync::watch;
67use tokio::time::Instant;
68
69use super::*;
70use crate::RemoteMessage;
71
72pub mod duplex;
73mod framed;
74pub(super) mod server;
75pub(super) mod session;
76pub use server::ServerHandle;
77
78pub(crate) trait Stream:
79 AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug + 'static
80{
81}
82impl<S: AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug + 'static> Stream for S {}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
88pub(crate) struct SessionId(u64);
89
90impl SessionId {
91 pub fn random() -> Self {
93 Self(rand::random())
94 }
95}
96
97impl fmt::Display for SessionId {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 write!(f, "{:016x}", self.0)
100 }
101}
102
103pub(crate) const INITIATOR_TO_ACCEPTOR: u8 = 0;
105
106pub(crate) const ACCEPTOR_TO_INITIATOR: u8 = 1;
108
109const LINK_INIT_MAGIC: [u8; 4] = *b"LNK\0";
118const LINK_INIT_SIZE: usize = 4 + 8;
119
120async fn write_link_init<S: AsyncWrite + Unpin>(
122 stream: &mut S,
123 session_id: SessionId,
124) -> Result<(), std::io::Error> {
125 let mut buf = [0u8; LINK_INIT_SIZE];
126 buf[0..4].copy_from_slice(&LINK_INIT_MAGIC);
127 buf[4..12].copy_from_slice(&session_id.0.to_be_bytes());
128 stream.write_all(&buf).await
129}
130
131async fn read_link_init<S: AsyncRead + Unpin>(stream: &mut S) -> Result<SessionId, std::io::Error> {
133 let mut buf = [0u8; LINK_INIT_SIZE];
134 stream.read_exact(&mut buf).await?;
135 if buf[0..4] != LINK_INIT_MAGIC {
136 return Err(std::io::Error::new(
137 std::io::ErrorKind::InvalidData,
138 format!(
139 "invalid LinkInit magic: expected {:?}, got {:?}",
140 LINK_INIT_MAGIC,
141 &buf[0..4]
142 ),
143 ));
144 }
145 let session_id = SessionId(u64::from_be_bytes(buf[4..12].try_into().unwrap()));
146 Ok(session_id)
147}
148
149#[async_trait]
153pub(crate) trait Link: Send + Sync + Debug + 'static {
154 type Stream: Stream;
156
157 fn dest(&self) -> ChannelAddr;
159
160 fn link_id(&self) -> SessionId;
162
163 async fn next(&self) -> Result<Self::Stream, ClientError>;
166}
167
168use session::Session;
169
170use crate::config;
171use crate::metrics;
172
173fn log_send_error(
176 error: &session::SendLoopError,
177 dest: &ChannelAddr,
178 session_id: u64,
179 mode: &str,
180) -> bool {
181 match error {
182 session::SendLoopError::Io(err) => {
183 tracing::info!(dest = %dest, session_id, error = %err, mode, "send error");
184 metrics::CHANNEL_ERRORS.add(
185 1,
186 hyperactor_telemetry::kv_pairs!(
187 "dest" => dest.to_string(),
188 "session_id" => session_id.to_string(),
189 "error_type" => metrics::ChannelErrorType::SendError.as_str(),
190 "mode" => mode.to_string(),
191 ),
192 );
193 false
194 }
195 session::SendLoopError::AppClosed => true,
196 session::SendLoopError::Rejected(reason) => {
197 tracing::error!(dest = %dest, session_id, mode, "server rejected connection: {reason}");
198 true
199 }
200 session::SendLoopError::ServerClosed => {
201 tracing::info!(dest = %dest, session_id, mode, "server closed the channel");
202 true
203 }
204 session::SendLoopError::DeliveryTimeout => {
205 let timeout = hyperactor_config::global::get(config::MESSAGE_DELIVERY_TIMEOUT);
206 tracing::error!(
207 dest = %dest, session_id, mode,
208 "failed to receive ack within timeout {timeout:?}; link is currently connected"
209 );
210 true
211 }
212 session::SendLoopError::OversizedFrame(reason) => {
213 tracing::error!(dest = %dest, session_id, mode, "oversized frame: {reason}");
214 true
215 }
216 }
217}
218
219pub(crate) fn spawn<M: RemoteMessage>(link: impl Link) -> NetTx<M> {
221 let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
222 let dest = link.dest();
223 let session_id = link.link_id();
224 let (notify, status) = watch::channel(TxStatus::Active);
225 let tx = NetTx {
226 sender,
227 dest: dest.clone(),
228 status,
229 };
230 crate::init::get_runtime().spawn(async move {
231 let mut session = Session::new(link);
232 let log_id = format!("session {}.{:016x}", dest, session_id.0);
233 let mut deliveries = session::Deliveries {
234 outbox: session::Outbox::new(log_id.clone(), dest.clone(), session_id.0),
235 unacked: session::Unacked::new(None, log_id.clone()),
236 };
237 let mut receiver = receiver;
238
239 match receiver.recv().await {
241 Some(msg) => {
242 if let Err(err) = deliveries.outbox.push_back(msg) {
243 tracing::error!(
244 dest = %dest,
245 session_id = session_id.0,
246 error = %err,
247 "failed to push message to outbox"
248 );
249 let _ = notify.send(TxStatus::Closed);
250 return;
251 }
252 }
253 None => {
254 let _ = notify.send(TxStatus::Closed);
255 return;
256 }
257 }
258
259 let reason: String = 'outer: loop {
260 let connected = match deliveries.expiry_time() {
261 Some(deadline) => match session.connect_by(deadline).await {
262 Ok(s) => s,
263 Err(_) => {
264 let timeout =
265 hyperactor_config::global::get(config::MESSAGE_DELIVERY_TIMEOUT);
266 let error_msg = if deliveries.outbox.is_expired(timeout) {
267 format!("failed to deliver message within timeout {timeout:?}",)
268 } else {
269 format!(
270 "failed to receive ack within timeout {timeout:?}; \
271 link is currently broken",
272 )
273 };
274 tracing::error!(
275 dest = %dest, session_id = session_id.0, "{}", error_msg
276 );
277 break 'outer format!("{log_id}: {error_msg}");
278 }
279 },
280 None => match session.connect().await {
281 Ok(s) => s,
282 Err(_) => break 'outer "session shut down".into(),
283 },
284 };
285
286 metrics::CHANNEL_CONNECTIONS.add(
287 1,
288 hyperactor_telemetry::kv_pairs!(
289 "transport" => dest.transport().to_string(),
290 "mode" => "simplex",
291 "reason" => "link connected",
292 ),
293 );
294
295 if !deliveries.unacked.is_empty() {
296 metrics::CHANNEL_RECONNECTIONS.add(
297 1,
298 hyperactor_telemetry::kv_pairs!(
299 "dest" => dest.to_string(),
300 "transport" => dest.transport().to_string(),
301 "mode" => "simplex",
302 "reason" => "reconnect_with_unacked",
303 ),
304 );
305 }
306 deliveries.requeue_unacked();
307
308 let result = {
309 let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
310 session::send_connected(&stream, &mut deliveries, &mut receiver).await
311 };
312 session = connected.release();
313
314 match result {
315 Ok(()) => {
316 }
318 Err(ref e) => {
319 if log_send_error(e, &dest, session_id.0, "simplex") {
320 break 'outer format!("{log_id}: {e}");
321 }
322 }
323 }
324 };
325
326 tracing::info!(
327 dest = %dest, session_id = session_id.0, "NetTx closing: {reason}"
328 );
329
330 receiver.close();
331 deliveries
332 .unacked
333 .deque
334 .drain(..)
335 .chain(deliveries.outbox.deque.drain(..))
336 .for_each(|queued| queued.try_return(Some(reason.clone())));
337 while let Ok((msg, return_channel, _)) = receiver.try_recv() {
338 let _ = return_channel.send(SendError {
339 error: ChannelError::Closed,
340 message: msg,
341 reason: Some(reason.clone()),
342 });
343 }
344
345 let _ = notify.send(TxStatus::Closed);
346 });
347 tx
348}
349
350#[derive(Debug)]
353pub(crate) enum NetLink {
354 Tcp(tcp::TcpLink),
355 Unix(unix::UnixLink),
356 Tls(tls::TlsLink),
357}
358
359pub(crate) fn link(addr: ChannelAddr) -> Result<NetLink, ClientError> {
361 match addr {
362 ChannelAddr::Tcp(socket_addr) => Ok(NetLink::Tcp(tcp::link(socket_addr))),
363 ChannelAddr::Unix(unix_addr) => Ok(NetLink::Unix(unix::link(unix_addr))),
364 ChannelAddr::Tls(tls_addr) => Ok(NetLink::Tls(tls::link(tls_addr)?)),
365 ChannelAddr::MetaTls(meta_addr) => Ok(NetLink::Tls(meta::link(meta_addr)?)),
366 other => Err(ClientError::Connect(
367 other,
368 std::io::Error::other("unsupported transport"),
369 "unsupported transport".into(),
370 )),
371 }
372}
373
374#[async_trait]
375impl Link for NetLink {
376 type Stream = Box<dyn Stream>;
377
378 fn dest(&self) -> ChannelAddr {
379 match self {
380 Self::Tcp(l) => l.dest(),
381 Self::Unix(l) => l.dest(),
382 Self::Tls(l) => l.dest(),
383 }
384 }
385
386 fn link_id(&self) -> SessionId {
387 match self {
388 Self::Tcp(l) => l.link_id(),
389 Self::Unix(l) => l.link_id(),
390 Self::Tls(l) => l.link_id(),
391 }
392 }
393
394 async fn next(&self) -> Result<Box<dyn Stream>, ClientError> {
395 match self {
396 Self::Tcp(l) => Ok(Box::new(l.next().await?)),
397 Self::Unix(l) => Ok(Box::new(l.next().await?)),
398 Self::Tls(l) => Ok(Box::new(l.next().await?)),
399 }
400 }
401}
402
403#[async_trait]
408pub(crate) trait Listener: Send + Unpin + 'static {
409 type Stream: Stream;
411
412 async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError>;
414}
415
416#[derive(Debug)]
421pub(crate) enum NetListener {
422 Tcp(tcp::TcpSocketListener),
423 Unix(unix::UnixSocketListener),
424}
425
426#[async_trait]
427impl Listener for NetListener {
428 type Stream = Box<dyn Stream>;
429
430 async fn accept(&mut self) -> Result<(Box<dyn Stream>, ChannelAddr), ServerError> {
431 match self {
432 Self::Tcp(l) => {
433 let (stream, addr) = l.accept().await?;
434 Ok((Box::new(stream), addr))
435 }
436 Self::Unix(l) => {
437 let (stream, addr) = l.accept().await?;
438 Ok((Box::new(stream), addr))
439 }
440 }
441 }
442}
443
444pub(crate) fn listen(addr: ChannelAddr) -> Result<(NetListener, ChannelAddr), ServerError> {
448 match addr {
449 ChannelAddr::Tcp(socket_addr) => {
450 let std_listener = std::net::TcpListener::bind(socket_addr)
451 .map_err(|err| ServerError::Listen(ChannelAddr::Tcp(socket_addr), err))?;
452 std_listener
453 .set_nonblocking(true)
454 .map_err(|e| ServerError::Listen(ChannelAddr::Tcp(socket_addr), e))?;
455 let tokio_listener = tokio::net::TcpListener::from_std(std_listener)
456 .map_err(|e| ServerError::Listen(ChannelAddr::Tcp(socket_addr), e))?;
457 let local_addr = tokio_listener
458 .local_addr()
459 .map_err(|err| ServerError::Resolve(ChannelAddr::Tcp(socket_addr), err))?;
460 let listener = tcp::TcpSocketListener {
461 inner: tokio_listener,
462 addr: local_addr,
463 };
464 Ok((NetListener::Tcp(listener), ChannelAddr::Tcp(local_addr)))
465 }
466 ChannelAddr::Unix(ref unix_addr) => {
467 use std::os::unix::net::UnixDatagram as StdUnixDatagram;
468 use std::os::unix::net::UnixListener as StdUnixListener;
469
470 let caddr = addr.clone();
471 let maybe_listener = match unix_addr {
472 unix::SocketAddr::Bound(sock_addr) => StdUnixListener::bind_addr(sock_addr),
473 unix::SocketAddr::Unbound => StdUnixDatagram::unbound()
474 .and_then(|u| u.local_addr())
475 .and_then(|uaddr| StdUnixListener::bind_addr(&uaddr)),
476 };
477 let std_listener =
478 maybe_listener.map_err(|err| ServerError::Listen(caddr.clone(), err))?;
479 std_listener
480 .set_nonblocking(true)
481 .map_err(|err| ServerError::Listen(caddr.clone(), err))?;
482 let local_addr = std_listener
483 .local_addr()
484 .map_err(|err| ServerError::Resolve(caddr.clone(), err))?;
485 let tokio_listener = tokio::net::UnixListener::from_std(std_listener)
486 .map_err(|err| ServerError::Io(caddr, err))?;
487 let bound_addr = unix::SocketAddr::new(local_addr);
488 let listener = unix::UnixSocketListener {
489 inner: tokio_listener,
490 addr: bound_addr.clone(),
491 };
492 Ok((NetListener::Unix(listener), ChannelAddr::Unix(bound_addr)))
493 }
494 addr @ (ChannelAddr::Tls(_) | ChannelAddr::MetaTls(_)) => {
495 let is_meta = matches!(addr, ChannelAddr::MetaTls(_));
496 let tls_addr = match addr {
497 ChannelAddr::Tls(a) | ChannelAddr::MetaTls(a) => a,
498 _ => unreachable!(),
499 };
500 let TlsAddr { hostname, port } = tls_addr;
501 let make_channel_addr = |h: &str, p: Port| {
502 if is_meta {
503 ChannelAddr::MetaTls(TlsAddr::new(h, p))
504 } else {
505 ChannelAddr::Tls(TlsAddr::new(h, p))
506 }
507 };
508
509 let addrs: Vec<core::net::SocketAddr> = (hostname.as_ref(), port)
510 .to_socket_addrs()
511 .map_err(|err| ServerError::Resolve(make_channel_addr(&hostname, port), err))?
512 .collect();
513
514 if addrs.is_empty() {
515 return Err(ServerError::Resolve(
516 make_channel_addr(&hostname, port),
517 std::io::Error::other("no available socket addr"),
518 ));
519 }
520
521 let channel_addr = make_channel_addr(&hostname, port);
522 let std_listener = std::net::TcpListener::bind(&addrs[..])
523 .map_err(|err| ServerError::Listen(channel_addr.clone(), err))?;
524 std_listener
525 .set_nonblocking(true)
526 .map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
527 let tokio_listener = tokio::net::TcpListener::from_std(std_listener)
528 .map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
529 let local_addr = tokio_listener
530 .local_addr()
531 .map_err(|err| ServerError::Resolve(channel_addr, err))?;
532 let listener = tcp::TcpSocketListener {
533 inner: tokio_listener,
534 addr: local_addr,
535 };
536 Ok((
537 NetListener::Tcp(listener),
538 make_channel_addr(&hostname, local_addr.port()),
539 ))
540 }
541 other => Err(ServerError::Listen(
542 other.clone(),
543 std::io::Error::other(format!("unsupported transport: {}", other)),
544 )),
545 }
546}
547
548#[derive(Debug, Serialize, Deserialize, EnumAsInner, PartialEq)]
550pub(super) enum Frame<M> {
551 Message(u64, M),
553}
554
555#[derive(Debug, Serialize, Deserialize, EnumAsInner)]
556pub(super) enum NetRxResponse {
557 Ack(u64),
558 Reject(String),
560 Closed,
562}
563
564pub(super) fn serialize_response(response: NetRxResponse) -> Result<Bytes, bincode::Error> {
565 bincode::serialize(&response).map(|bytes| bytes.into())
566}
567
568pub(super) fn deserialize_response(data: Bytes) -> Result<NetRxResponse, bincode::Error> {
569 bincode::deserialize(&data)
570}
571
572pub(crate) struct NetTx<M: RemoteMessage> {
575 sender: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
576 dest: ChannelAddr,
577 status: watch::Receiver<TxStatus>,
578}
579
580#[async_trait]
581impl<M: RemoteMessage> Tx<M> for NetTx<M> {
582 fn addr(&self) -> ChannelAddr {
583 self.dest.clone()
584 }
585
586 fn status(&self) -> &watch::Receiver<TxStatus> {
587 &self.status
588 }
589
590 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
591 tracing::trace!(
592 name = "post",
593 dest = %self.dest,
594 "sending message"
595 );
596
597 let return_channel = return_channel.unwrap_or_else(|| oneshot::channel().0);
598 if let Err(mpsc::error::SendError((message, return_channel, _))) =
599 self.sender
600 .send((message, return_channel, tokio::time::Instant::now()))
601 {
602 let _ = return_channel.send(SendError {
603 error: ChannelError::Closed,
604 message,
605 reason: None,
606 });
607 }
608 }
609}
610
611pub struct NetRx<M: RemoteMessage>(mpsc::Receiver<M>, ChannelAddr, ServerHandle);
612
613#[async_trait]
614impl<M: RemoteMessage> Rx<M> for NetRx<M> {
615 async fn recv(&mut self) -> Result<M, ChannelError> {
616 tracing::trace!(
617 name = "recv",
618 dest = %self.1,
619 "receiving message"
620 );
621 self.0.recv().await.ok_or(ChannelError::Closed)
622 }
623
624 fn addr(&self) -> ChannelAddr {
625 self.1.clone()
626 }
627
628 async fn join(mut self) {
631 self.2
632 .stop(&format!("NetRx joined; channel address: {}", self.1));
633 let _ = (&mut self.2).await;
634 }
636}
637
638impl<M: RemoteMessage> Drop for NetRx<M> {
639 fn drop(&mut self) {
640 self.2
641 .stop(&format!("NetRx dropped; channel address: {}", self.1));
642 }
643}
644
645#[derive(Debug, thiserror::Error)]
647pub enum ServerError {
648 #[error("io: {1}")]
649 Io(ChannelAddr, #[source] std::io::Error),
650 #[error("listen: {0} {1}")]
651 Listen(ChannelAddr, #[source] std::io::Error),
652 #[error("resolve: {0} {1}")]
653 Resolve(ChannelAddr, #[source] std::io::Error),
654 #[error("internal: {0} {1}")]
655 Internal(ChannelAddr, #[source] anyhow::Error),
656}
657
658#[derive(thiserror::Error, Debug)]
659pub enum ClientError {
660 #[error("connection to {0} failed: {1}: {2}")]
661 Connect(ChannelAddr, std::io::Error, String),
662 #[error("unable to resolve address: {0}")]
663 Resolve(ChannelAddr),
664 #[error("io: {0} {1}")]
665 Io(ChannelAddr, std::io::Error),
666 #[error("send {0}: serialize: {1}")]
667 Serialize(ChannelAddr, bincode::ErrorKind),
668 #[error("invalid address: {0}")]
669 InvalidAddress(String),
670}
671
672#[cfg(test)]
675pub(super) fn is_net_addr(addr: &ChannelAddr) -> bool {
676 match addr.transport() {
677 ChannelTransport::Tcp(_) => true,
678 ChannelTransport::MetaTls(_) => true,
679 ChannelTransport::Tls => true,
680 ChannelTransport::Unix => true,
681 _ => false,
682 }
683}
684
685pub(crate) mod unix {
686
687 use core::str;
688 use std::os::unix::net::SocketAddr as StdSocketAddr;
689 use std::os::unix::net::UnixStream as StdUnixStream;
690
691 use rand::Rng;
692 use rand::distributions::Alphanumeric;
693 use tokio::net::UnixListener;
694 use tokio::net::UnixStream;
695
696 use super::*;
697
698 #[derive(Debug)]
699 pub(crate) struct UnixLink {
700 pub(super) addr: SocketAddr,
701 pub(super) session_id: SessionId,
702 }
703
704 #[async_trait]
705 impl Link for UnixLink {
706 type Stream = UnixStream;
707
708 fn dest(&self) -> ChannelAddr {
709 ChannelAddr::Unix(self.addr.clone())
710 }
711
712 fn link_id(&self) -> SessionId {
713 self.session_id
714 }
715
716 async fn next(&self) -> Result<Self::Stream, ClientError> {
717 let session_id = self.session_id;
718 let sock_addr = match &self.addr {
719 SocketAddr::Bound(a) => a,
720 SocketAddr::Unbound => return Err(ClientError::Resolve(self.dest())),
721 };
722 let mut backoff = ExponentialBackoffBuilder::new()
723 .with_initial_interval(Duration::from_millis(1))
724 .with_multiplier(2.0)
725 .with_randomization_factor(0.1)
726 .with_max_interval(Duration::from_millis(1000))
727 .with_max_elapsed_time(None)
728 .build();
729 loop {
730 match StdUnixStream::connect_addr(sock_addr) {
731 Ok(std_stream) => {
732 std_stream
733 .set_nonblocking(true)
734 .map_err(|err| ClientError::Io(self.dest(), err))?;
735 let mut stream = UnixStream::from_std(std_stream)
736 .map_err(|err| ClientError::Io(self.dest(), err))?;
737 write_link_init(&mut stream, session_id)
738 .await
739 .map_err(|err| ClientError::Io(self.dest(), err))?;
740 return Ok(stream);
741 }
742 Err(err) => {
743 tracing::debug!(error = %err, "unix connect failed, backing off");
744 if let Some(delay) = backoff.next_backoff() {
745 tokio::time::sleep(delay).await;
746 }
747 }
748 }
749 }
750 }
751 }
752
753 #[derive(Debug)]
755 pub(crate) struct UnixSocketListener {
756 pub(super) inner: UnixListener,
757 pub(super) addr: SocketAddr,
758 }
759
760 #[async_trait]
761 impl super::Listener for UnixSocketListener {
762 type Stream = UnixStream;
763
764 async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
765 let (stream, peer_addr) = self
766 .inner
767 .accept()
768 .await
769 .map_err(|err| ServerError::Io(ChannelAddr::Unix(self.addr.clone()), err))?;
770 let std_addr: StdSocketAddr = peer_addr.into();
772 Ok((stream, ChannelAddr::Unix(SocketAddr::new(std_addr))))
773 }
774 }
775
776 pub(crate) fn link(addr: SocketAddr) -> UnixLink {
778 UnixLink {
779 addr,
780 session_id: SessionId::random(),
781 }
782 }
783
784 #[derive(Clone, Debug)]
786 pub enum SocketAddr {
787 Bound(Box<StdSocketAddr>),
788 Unbound,
789 }
790
791 impl PartialOrd for SocketAddr {
792 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
793 Some(self.cmp(other))
794 }
795 }
796
797 impl Ord for SocketAddr {
798 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
799 self.to_string().cmp(&other.to_string())
800 }
801 }
802
803 impl<'de> Deserialize<'de> for SocketAddr {
804 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
805 where
806 D: serde::Deserializer<'de>,
807 {
808 let s = String::deserialize(deserializer)?;
809 Self::from_str(&s).map_err(D::Error::custom)
810 }
811 }
812
813 impl Serialize for SocketAddr {
814 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
815 where
816 S: serde::Serializer,
817 {
818 serializer.serialize_str(String::from(self).as_str())
819 }
820 }
821
822 impl From<&SocketAddr> for String {
823 fn from(value: &SocketAddr) -> Self {
824 match value {
825 SocketAddr::Bound(addr) => match addr.as_pathname() {
826 Some(path) => path
827 .to_str()
828 .expect("unable to get str for path")
829 .to_string(),
830 #[cfg(target_os = "linux")]
831 _ => match addr.as_abstract_name() {
832 Some(name) => format!("@{}", String::from_utf8_lossy(name)),
833 _ => String::from("(unnamed)"),
834 },
835 #[cfg(not(target_os = "linux"))]
836 _ => String::from("(unnamed)"),
837 },
838 SocketAddr::Unbound => String::from("(unbound)"),
839 }
840 }
841 }
842
843 impl FromStr for SocketAddr {
844 type Err = anyhow::Error;
845
846 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
847 match s {
848 "" => {
849 let random_string = rand::thread_rng()
852 .sample_iter(&Alphanumeric)
853 .take(24)
854 .map(char::from)
855 .collect::<String>();
856 SocketAddr::from_abstract_name(&random_string)
857 }
858 name if name.starts_with("@") => {
860 SocketAddr::from_abstract_name(name.strip_prefix("@").unwrap())
861 }
862 path => SocketAddr::from_pathname(path),
863 }
864 }
865 }
866
867 impl Eq for SocketAddr {}
868 impl std::hash::Hash for SocketAddr {
869 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
870 String::from(self).hash(state);
871 }
872 }
873 impl PartialEq for SocketAddr {
874 fn eq(&self, other: &Self) -> bool {
875 match (self, other) {
876 (Self::Bound(saddr), Self::Bound(oaddr)) => {
877 if saddr.is_unnamed() || oaddr.is_unnamed() {
878 return false;
879 }
880
881 #[cfg(target_os = "linux")]
882 {
883 saddr.as_pathname() == oaddr.as_pathname()
884 && saddr.as_abstract_name() == oaddr.as_abstract_name()
885 }
886 #[cfg(not(target_os = "linux"))]
887 {
888 saddr.as_pathname() == oaddr.as_pathname()
890 }
891 }
892 (Self::Unbound, _) | (_, Self::Unbound) => false,
893 }
894 }
895 }
896
897 impl fmt::Display for SocketAddr {
898 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
899 match self {
900 Self::Bound(addr) => match addr.as_pathname() {
901 Some(path) => {
902 write!(f, "{}", path.to_string_lossy())
903 }
904 #[cfg(target_os = "linux")]
905 _ => match addr.as_abstract_name() {
906 Some(name) => {
907 if name.starts_with(b"@") {
908 return write!(f, "{}", String::from_utf8_lossy(name));
909 }
910 write!(f, "@{}", String::from_utf8_lossy(name))
911 }
912 _ => write!(f, "(unnamed)"),
913 },
914 #[cfg(not(target_os = "linux"))]
915 _ => write!(f, "(unnamed)"),
916 },
917 Self::Unbound => write!(f, "(unbound)"),
918 }
919 }
920 }
921
922 impl SocketAddr {
923 pub fn new(addr: StdSocketAddr) -> Self {
925 Self::Bound(Box::new(addr))
926 }
927
928 #[cfg(target_os = "linux")]
931 pub fn from_abstract_name(name: &str) -> anyhow::Result<Self> {
932 Ok(Self::new(StdSocketAddr::from_abstract_name(
933 name.strip_prefix("@").unwrap_or(name),
934 )?))
935 }
936
937 #[cfg(not(target_os = "linux"))]
938 pub fn from_abstract_name(name: &str) -> anyhow::Result<Self> {
939 let name = name.strip_prefix("@").unwrap_or(name);
941 let path = Self::abstract_to_filesystem_path(name);
942 Self::from_pathname(&path.to_string_lossy())
943 }
944
945 #[cfg(not(target_os = "linux"))]
946 fn abstract_to_filesystem_path(abstract_name: &str) -> std::path::PathBuf {
947 use std::collections::hash_map::DefaultHasher;
948 use std::hash::Hash;
949 use std::hash::Hasher;
950
951 let mut hasher = DefaultHasher::new();
953 abstract_name.hash(&mut hasher);
954 let hash = hasher.finish();
955
956 let process_id = std::process::id();
958
959 std::path::PathBuf::from(format!("/tmp/hyperactor_{}_{:x}", process_id, hash))
961 }
962
963 pub fn from_pathname(name: &str) -> anyhow::Result<Self> {
965 Ok(Self::new(StdSocketAddr::from_pathname(name)?))
966 }
967 }
968
969 impl TryFrom<SocketAddr> for StdSocketAddr {
970 type Error = anyhow::Error;
971
972 fn try_from(value: SocketAddr) -> Result<Self, Self::Error> {
973 match value {
974 SocketAddr::Bound(addr) => Ok(*addr),
975 SocketAddr::Unbound => Err(anyhow::anyhow!(
976 "std::os::unix::SocketAddr must be a bound address"
977 )),
978 }
979 }
980 }
981}
982
983pub(crate) mod tcp {
984 use tokio::net::TcpListener;
985 use tokio::net::TcpStream;
986
987 use super::*;
988
989 #[derive(Debug)]
990 pub(crate) struct TcpLink {
991 pub(super) addr: SocketAddr,
992 pub(super) session_id: SessionId,
993 }
994
995 #[async_trait]
996 impl Link for TcpLink {
997 type Stream = TcpStream;
998
999 fn dest(&self) -> ChannelAddr {
1000 ChannelAddr::Tcp(self.addr)
1001 }
1002
1003 fn link_id(&self) -> SessionId {
1004 self.session_id
1005 }
1006
1007 async fn next(&self) -> Result<Self::Stream, ClientError> {
1008 let session_id = self.session_id;
1009 let mut backoff = ExponentialBackoffBuilder::new()
1010 .with_initial_interval(Duration::from_millis(1))
1011 .with_multiplier(2.0)
1012 .with_randomization_factor(0.1)
1013 .with_max_interval(Duration::from_millis(1000))
1014 .with_max_elapsed_time(None)
1015 .build();
1016 loop {
1017 match TcpStream::connect(&self.addr).await {
1018 Ok(mut stream) => {
1019 stream.set_nodelay(true).map_err(|err| {
1020 ClientError::Connect(
1021 self.dest(),
1022 err,
1023 "cannot disable Nagle algorithm".to_string(),
1024 )
1025 })?;
1026 write_link_init(&mut stream, session_id)
1027 .await
1028 .map_err(|err| ClientError::Io(self.dest(), err))?;
1029 return Ok(stream);
1030 }
1031 Err(err) => {
1032 tracing::debug!(error = %err, "tcp connect failed, backing off");
1033 if let Some(delay) = backoff.next_backoff() {
1034 tokio::time::sleep(delay).await;
1035 }
1036 }
1037 }
1038 }
1039 }
1040 }
1041
1042 #[derive(Debug)]
1044 pub(crate) struct TcpSocketListener {
1045 pub(super) inner: TcpListener,
1046 pub(super) addr: SocketAddr,
1047 }
1048
1049 #[async_trait]
1050 impl super::Listener for TcpSocketListener {
1051 type Stream = TcpStream;
1052
1053 async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
1054 let (stream, peer_addr) = self
1055 .inner
1056 .accept()
1057 .await
1058 .map_err(|err| ServerError::Io(ChannelAddr::Tcp(self.addr), err))?;
1059 stream
1060 .set_nodelay(true)
1061 .map_err(|err| ServerError::Io(ChannelAddr::Tcp(self.addr), err))?;
1062 Ok((stream, ChannelAddr::Tcp(peer_addr)))
1063 }
1064 }
1065
1066 pub(crate) fn link(addr: SocketAddr) -> TcpLink {
1068 TcpLink {
1069 addr,
1070 session_id: SessionId::random(),
1071 }
1072 }
1073}
1074
1075pub(crate) mod meta {
1077 use std::io;
1078 use std::path::PathBuf;
1079 use std::sync::Arc;
1080
1081 use anyhow::Result;
1082 use tokio_rustls::TlsAcceptor;
1083 use tokio_rustls::TlsConnector;
1084
1085 use super::*;
1086 use crate::config::Pem;
1087 use crate::config::PemBundle;
1088
1089 const THRIFT_TLS_SRV_CA_PATH_ENV: &str = "THRIFT_TLS_SRV_CA_PATH";
1090 const DEFAULT_SRV_CA_PATH: &str = "/var/facebook/rootcanal/ca.pem";
1091 const THRIFT_TLS_CL_CERT_PATH_ENV: &str = "THRIFT_TLS_CL_CERT_PATH";
1092 const THRIFT_TLS_CL_KEY_PATH_ENV: &str = "THRIFT_TLS_CL_KEY_PATH";
1093 const DEFAULT_SERVER_PEM_PATH: &str = "/var/facebook/x509_identities/server.pem";
1094
1095 #[allow(clippy::result_large_err)] pub(crate) fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
1097 let parts = addr_string.rsplit_once(":");
1099 match parts {
1100 Some((hostname, port_str)) => {
1101 let Ok(port) = port_str.parse() else {
1102 return Err(ChannelError::InvalidAddress(addr_string.to_string()));
1103 };
1104 Ok(ChannelAddr::MetaTls(TlsAddr::new(hostname, port)))
1105 }
1106 _ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
1107 }
1108 }
1109
1110 pub(super) fn get_server_pem_bundle() -> PemBundle {
1113 let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1114 .map(PathBuf::from)
1115 .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1116 let server_pem_path = PathBuf::from(DEFAULT_SERVER_PEM_PATH);
1117 PemBundle {
1118 ca: Pem::File(ca_path),
1119 cert: Pem::File(server_pem_path.clone()),
1120 key: Pem::File(server_pem_path),
1121 }
1122 }
1123
1124 fn get_client_pem_bundle() -> Option<PemBundle> {
1127 let cert_path = std::env::var_os(THRIFT_TLS_CL_CERT_PATH_ENV)?;
1128 let key_path = std::env::var_os(THRIFT_TLS_CL_KEY_PATH_ENV)?;
1129 let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1130 .map(PathBuf::from)
1131 .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1132 Some(PemBundle {
1133 ca: Pem::File(ca_path),
1134 cert: Pem::File(PathBuf::from(cert_path)),
1135 key: Pem::File(PathBuf::from(key_path)),
1136 })
1137 }
1138
1139 pub(crate) fn tls_acceptor(enforce_client_tls: bool) -> Result<TlsAcceptor> {
1141 let bundle = get_server_pem_bundle();
1142 tls::tls_acceptor_from_bundle(&bundle, enforce_client_tls)
1143 }
1144
1145 pub(super) fn try_tls_connector() -> Result<TlsConnector> {
1151 tls_connector()
1152 }
1153
1154 fn tls_connector() -> Result<TlsConnector> {
1157 let _ = rustls::crypto::ring::default_provider().install_default();
1160
1161 let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1162 .map(PathBuf::from)
1163 .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1164 let ca_pem = Pem::File(ca_path);
1165 let root_store = tls::build_root_store(&ca_pem)?;
1166
1167 let config = rustls::ClientConfig::builder().with_root_certificates(Arc::new(root_store));
1169
1170 let config = if let Some(bundle) = get_client_pem_bundle() {
1171 let certs = tls::load_certs(&bundle.cert)?;
1172 let key = tls::load_key(&bundle.key)?;
1173 config
1174 .with_client_auth_cert(certs, key)
1175 .map_err(|e| anyhow::anyhow!("load client certs: {}", e))?
1176 } else {
1177 config.with_no_client_auth()
1178 };
1179
1180 Ok(TlsConnector::from(Arc::new(config)))
1181 }
1182
1183 pub fn link(addr: TlsAddr) -> Result<tls::TlsLink, ClientError> {
1185 let connector = tls_connector().map_err(|e| {
1186 ClientError::Connect(
1187 ChannelAddr::MetaTls(addr.clone()),
1188 io::Error::other(e.to_string()),
1189 "failed to create TLS connector".to_string(),
1190 )
1191 })?;
1192 let TlsAddr { hostname, port } = addr;
1193 Ok(tls::TlsLink {
1194 hostname,
1195 port,
1196 connector,
1197 addr_type: tls::TlsAddrType::MetaTls,
1198 session_id: SessionId::random(),
1199 })
1200 }
1201}
1202
1203pub(crate) mod tls {
1205 use std::io;
1206 use std::io::BufReader;
1207 use std::sync::Arc;
1208
1209 use anyhow::Context;
1210 use anyhow::Result;
1211 use rustls::RootCertStore;
1212 use rustls::pki_types::CertificateDer;
1213 use rustls::pki_types::PrivateKeyDer;
1214 use rustls::pki_types::ServerName;
1215 use tokio::net::TcpStream;
1216 use tokio_rustls::TlsAcceptor;
1217 use tokio_rustls::TlsConnector;
1218 use tokio_rustls::client::TlsStream;
1219
1220 use super::*;
1221 use crate::channel::TlsAddr;
1222 use crate::config::Pem;
1223 use crate::config::PemBundle;
1224 use crate::config::TLS_CA;
1225 use crate::config::TLS_CERT;
1226 use crate::config::TLS_KEY;
1227
1228 #[derive(Debug, Clone, Copy)]
1230 pub(crate) enum TlsAddrType {
1231 Tls,
1232 MetaTls,
1233 }
1234
1235 #[allow(clippy::result_large_err)]
1237 pub(crate) fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
1238 let parts = addr_string.rsplit_once(":");
1240 match parts {
1241 Some((hostname, port_str)) => {
1242 let Ok(port) = port_str.parse() else {
1243 return Err(ChannelError::InvalidAddress(addr_string.to_string()));
1244 };
1245 Ok(ChannelAddr::Tls(TlsAddr::new(hostname, port)))
1246 }
1247 _ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
1248 }
1249 }
1250
1251 pub(super) fn load_certs(pem: &Pem) -> Result<Vec<CertificateDer<'static>>> {
1253 let mut reader = BufReader::new(pem.reader()?);
1254 let certs = rustls_pemfile::certs(&mut reader)
1255 .filter_map(Result::ok)
1256 .collect();
1257 Ok(certs)
1258 }
1259
1260 pub(super) fn load_key(pem: &Pem) -> Result<PrivateKeyDer<'static>> {
1262 let mut reader = BufReader::new(pem.reader()?);
1263 loop {
1264 break match rustls_pemfile::read_one(&mut reader)? {
1265 Some(rustls_pemfile::Item::Pkcs1Key(key)) => Ok(PrivateKeyDer::Pkcs1(key)),
1266 Some(rustls_pemfile::Item::Pkcs8Key(key)) => Ok(PrivateKeyDer::Pkcs8(key)),
1267 Some(rustls_pemfile::Item::Sec1Key(key)) => Ok(PrivateKeyDer::Sec1(key)),
1268 Some(_) => continue,
1269 None => anyhow::bail!("no private key found in TLS key file"),
1270 };
1271 }
1272 }
1273
1274 pub(super) fn build_root_store(ca_pem: &Pem) -> Result<RootCertStore> {
1276 let mut root_store = RootCertStore::empty();
1277 let certs = load_certs(ca_pem)?;
1278 root_store.add_parsable_certificates(certs);
1279 Ok(root_store)
1280 }
1281
1282 fn get_pem_bundle() -> PemBundle {
1284 PemBundle {
1285 ca: hyperactor_config::global::get_cloned(TLS_CA),
1286 cert: hyperactor_config::global::get_cloned(TLS_CERT),
1287 key: hyperactor_config::global::get_cloned(TLS_KEY),
1288 }
1289 }
1290
1291 pub(super) fn tls_acceptor_from_bundle(
1294 bundle: &PemBundle,
1295 enforce_client_tls: bool,
1296 ) -> Result<TlsAcceptor> {
1297 let _ = rustls::crypto::ring::default_provider().install_default();
1300
1301 let certs = load_certs(&bundle.cert).context("load TLS certificate")?;
1302 let key = load_key(&bundle.key).context("load TLS key")?;
1303 let root_store = build_root_store(&bundle.ca).context("build root cert store")?;
1304
1305 let config = rustls::ServerConfig::builder();
1306 let config = if enforce_client_tls {
1307 let client_verifier =
1309 rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
1310 .build()
1311 .map_err(|e| anyhow::anyhow!("failed to build client verifier: {}", e))?;
1312 config.with_client_cert_verifier(client_verifier)
1313 } else {
1314 config.with_no_client_auth()
1315 }
1316 .with_single_cert(certs, key)?;
1317
1318 Ok(TlsAcceptor::from(Arc::new(config)))
1319 }
1320
1321 pub(crate) fn tls_acceptor() -> Result<TlsAcceptor> {
1323 tls_acceptor_from_bundle(&get_pem_bundle(), true)
1324 }
1325
1326 pub(super) fn tls_connector_from_bundle(bundle: &PemBundle) -> Result<TlsConnector> {
1328 let _ = rustls::crypto::ring::default_provider().install_default();
1331
1332 let certs = load_certs(&bundle.cert).context("load TLS certificate")?;
1333 let key = load_key(&bundle.key).context("load TLS key")?;
1334 let root_store = build_root_store(&bundle.ca).context("build root cert store")?;
1335
1336 let config = rustls::ClientConfig::builder()
1337 .with_root_certificates(Arc::new(root_store))
1338 .with_client_auth_cert(certs, key)
1339 .context("configure client auth")?;
1340
1341 Ok(TlsConnector::from(Arc::new(config)))
1342 }
1343
1344 fn tls_connector() -> Result<TlsConnector> {
1346 tls_connector_from_bundle(&get_pem_bundle())
1347 }
1348
1349 pub(crate) struct TlsLink {
1351 pub(crate) hostname: Hostname,
1352 pub(crate) port: Port,
1353 pub(crate) connector: TlsConnector,
1354 pub(crate) addr_type: TlsAddrType,
1355 pub(crate) session_id: SessionId,
1356 }
1357
1358 impl std::fmt::Debug for TlsLink {
1359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1360 f.debug_struct("TlsLink")
1361 .field("hostname", &self.hostname)
1362 .field("port", &self.port)
1363 .field("addr_type", &self.addr_type)
1364 .finish()
1365 }
1366 }
1367
1368 #[async_trait]
1369 impl Link for TlsLink {
1370 type Stream = TlsStream<TcpStream>;
1371
1372 fn dest(&self) -> ChannelAddr {
1373 let addr = TlsAddr::new(self.hostname.clone(), self.port);
1374 match self.addr_type {
1375 TlsAddrType::Tls => ChannelAddr::Tls(addr),
1376 TlsAddrType::MetaTls => ChannelAddr::MetaTls(addr),
1377 }
1378 }
1379
1380 fn link_id(&self) -> SessionId {
1381 self.session_id
1382 }
1383
1384 async fn next(&self) -> Result<Self::Stream, ClientError> {
1385 let session_id = self.session_id;
1386 let server_name = ServerName::try_from(self.hostname.clone()).map_err(|e| {
1387 ClientError::Connect(
1388 self.dest(),
1389 io::Error::other(e.to_string()),
1390 "invalid server name".to_string(),
1391 )
1392 })?;
1393 let mut backoff = ExponentialBackoffBuilder::new()
1394 .with_initial_interval(Duration::from_millis(1))
1395 .with_multiplier(2.0)
1396 .with_randomization_factor(0.1)
1397 .with_max_interval(Duration::from_millis(1000))
1398 .with_max_elapsed_time(None)
1399 .build();
1400 loop {
1401 let mut addrs = (self.hostname.as_ref(), self.port)
1402 .to_socket_addrs()
1403 .map_err(|_| ClientError::Resolve(self.dest()))?;
1404 let addr = addrs.next().ok_or(ClientError::Resolve(self.dest()))?;
1405 match TcpStream::connect(&addr).await {
1406 Ok(stream) => {
1407 stream.set_nodelay(true).map_err(|err| {
1408 ClientError::Connect(
1409 self.dest(),
1410 err,
1411 "cannot disable Nagle algorithm".to_string(),
1412 )
1413 })?;
1414 let mut tls_stream = self
1415 .connector
1416 .connect(server_name.clone(), stream)
1417 .await
1418 .map_err(|err| {
1419 ClientError::Connect(
1420 self.dest(),
1421 err,
1422 format!("cannot establish TLS connection to {:?}", server_name),
1423 )
1424 })?;
1425 write_link_init(&mut tls_stream, session_id)
1426 .await
1427 .map_err(|err| ClientError::Io(self.dest(), err))?;
1428 return Ok(tls_stream);
1429 }
1430 Err(err) => {
1431 tracing::debug!(error = %err, "tls connect failed, backing off");
1432 if let Some(delay) = backoff.next_backoff() {
1433 tokio::time::sleep(delay).await;
1434 }
1435 }
1436 }
1437 }
1438 }
1439 }
1440
1441 pub fn link(addr: TlsAddr) -> Result<TlsLink, ClientError> {
1443 let connector = tls_connector().map_err(|e| {
1444 ClientError::Connect(
1445 ChannelAddr::Tls(addr.clone()),
1446 io::Error::other(e.to_string()),
1447 "failed to create TLS connector".to_string(),
1448 )
1449 })?;
1450 let TlsAddr { hostname, port } = addr;
1451 Ok(TlsLink {
1452 hostname,
1453 port,
1454 connector,
1455 addr_type: TlsAddrType::Tls,
1456 session_id: SessionId::random(),
1457 })
1458 }
1459
1460 #[cfg(test)]
1461 mod tests {
1462 use timed_test::async_timed_test;
1463
1464 use super::*;
1465 use crate::channel::Rx;
1466 use crate::channel::net::server;
1467 use crate::config::Pem;
1468 use crate::config::TLS_CA;
1469 use crate::config::TLS_CERT;
1470 use crate::config::TLS_KEY;
1471
1472 const TEST_CA_CERT: &str = r#"-----BEGIN CERTIFICATE-----
1476MIIDBTCCAe2gAwIBAgIUaGNmboiIosG+8Up0vgDr/+cg+2IwDQYJKoZIhvcNAQEL
1477BQAwEjEQMA4GA1UEAwwHVGVzdCBDQTAeFw0yNjAxMjgxNzA4MzlaFw0yNzAxMjgx
1478NzA4MzlaMBIxEDAOBgNVBAMMB1Rlc3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IB
1479DwAwggEKAoIBAQC9RBoMYXCajklswt8Vi1JI1lEYzic0WNOmz45vG/7H6jTWkgL3
1480K5Ri+Seg3MobDNc48YHWXYm4hP9wCzkx8ih3ntT5XiY1My/G3jLUuoIEE9pF/BoJ
1481YQwZVoPNFhA9WhXNRsINf1cXFf8NzRfXpxBfKWtQJxYXU4JiDBQ6rLnQQABo8JmQ
1482vYFhJbBaYip5jTSiVNn7mB1zNr5jsVxuoSF53Pb7xQ76bwBdOq4zd6PSxL5/lr4G
1483cHSoxwZQdZMG7PL6hbxDQ2S2YI2lYVET1zwc2WPKCfjbEXBC/jzx828CInQtuksk
148418gJt6xHkTFEA8CSA29GM3lejnwYWf51xyyBAgMBAAGjUzBRMB0GA1UdDgQWBBRX
1485cbxSZ70NsUkAS3Hhy6irugywJDAfBgNVHSMEGDAWgBRXcbxSZ70NsUkAS3Hhy6ir
1486ugywJDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQA7aAFfyW67
1487Z+uGSVYhpsT/uH/3Z3nr7X1smTz5CGEfq2czEcTC7gbYI2l8GZ47GPfnAvHTBZVm
1488V/XncBCsj7/thOh2jYEHFyCbPckoaSCRyCOnK7LPUlr4HN5uP9EFe45qBLCJDEoY
1489GTTw7MtzwdovfjchNfKQCTtkBJCXQ95WLCf6UOh02Sn28UTlgfXzF0X0FrcWqWa3
1490uJZd4XOo4O6hKKlHaBaQPiEr++1xc3SWPV7jZHbckI/vKBnDdEZ9JQX5fFZuypUI
1491sgomYHxvxrU2hWx+7k53CRdjfaIvT9Ie44z9sSdsU/+blw2S8f/ZTmuECoIAAXYO
14920qpzlxZMdr7T
1493-----END CERTIFICATE-----"#;
1494
1495 const TEST_SERVER_CERT: &str = r#"-----BEGIN CERTIFICATE-----
1497MIIDJDCCAgygAwIBAgIUaz66DsWaH5ZXM4hCFnbVbMsyN1cwDQYJKoZIhvcNAQEL
1498BQAwEjEQMA4GA1UEAwwHVGVzdCBDQTAeFw0yNjAxMjgxNzA4MzlaFw0yNzAxMjgx
1499NzA4MzlaMBQxEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQAD
1500ggEPADCCAQoCggEBAKCbp++qNyTn5LOsV0h9gLKJALBcjg2A14I3804N9UyDhPW2
1501QKQ2W424u2P1MfKrw/2C+CErGlrADlnco2RQVDAarAIuGdFvBOt5UezqOS7Mk4OS
15029MlS7NZnMbc37KuM9UIG5ScJjXR/Z5z9dxeR0I9y3n0Ix6khbV7tOSHobiweI0FI
15038LftBS+CQnXr6vbWPcHcW6Z0FHUv7IWhqMWmv9PlZRGe9Y6VzXrRp0PBnZMOnAYf
1504aMQUwYRswWdm9j9Z1sMdTJ14G+KVmO3Vj6XI6Sm9uIcYhlwG/kORwogJFWlVuP9o
1505rloFRCjyHJ1d7GZqqnRyHHDDCBms8ed+3YfEYQECAwEAAaNwMG4wLAYDVR0RBCUw
1506I4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMB0GA1UdDgQWBBQl
1507J4vxUoCzqqeTwQAiLqE8wYezKzAfBgNVHSMEGDAWgBRXcbxSZ70NsUkAS3Hhy6ir
1508ugywJDANBgkqhkiG9w0BAQsFAAOCAQEAnXHIBDQ4AHAMV71piTOuI41ShASQed6L
1509bi7XUMZgZDslLkfU1vnP3BlwpliraBsAytSYQC6kbytOuz1uQ4K7yLb2tAAmUgEO
1510EdIVt9SXr5tCcIPeLmInF0pysPqjZO8n7vtJyd9gryKqdhm1uzA7WQWq/Az8a9Sk
1511uW2J6Oc5p6P7Mf3/ixqXzvGRo8rzu0CUJOJ67UTE/HhbJuplQ5dep5CEEOAIsAtH
1512zn9O4rW92ueBkoBJM++YILS1vQ7jKc2N3RNrnHm7FeootBrtR9mBi0TH97K73ZPZ
15132Cdhnym0CsCJggrllFGH32cYo7+K2PO7/4oj5XbBCSWcssicvd8ovg==
1514-----END CERTIFICATE-----"#;
1515
1516 const TEST_SERVER_KEY: &str = r#"-----BEGIN PRIVATE KEY-----
1520MIIEugIBADANBgkqhkiG9w0BAQEFAASCBKQwggSgAgEAAoIBAQCgm6fvqjck5+Sz
1521rFdIfYCyiQCwXI4NgNeCN/NODfVMg4T1tkCkNluNuLtj9THyq8P9gvghKxpawA5Z
15223KNkUFQwGqwCLhnRbwTreVHs6jkuzJODkvTJUuzWZzG3N+yrjPVCBuUnCY10f2ec
1523/XcXkdCPct59CMepIW1e7Tkh6G4sHiNBSPC37QUvgkJ16+r21j3B3FumdBR1L+yF
1524oajFpr/T5WURnvWOlc160adDwZ2TDpwGH2jEFMGEbMFnZvY/WdbDHUydeBvilZjt
15251Y+lyOkpvbiHGIZcBv5DkcKICRVpVbj/aK5aBUQo8hydXexmaqp0chxwwwgZrPHn
1526ft2HxGEBAgMBAAECgf8G5qlQov+7ljs9fSpC8yGUik59RXzVF7Qq5DyQHglsQDp2
1527VF5yr+M/M7DZmq+KvdauDfKbej6np5j2Q4TByrHTX1IExfZWCW8srwnWJDpQyHmO
1528LcJW5DlI/SYluUFyHZxsOd+ezcpGNzM8i6eSW7GaeFUXCkmJ+uW4LnlF+7bALnnd
1529D6sak/58EsII+IJyd4lFn+voszlPn3CZGR0jkp21rvpaKgrMIsKVWWQO/sLDU5pr
1530VbpBThcLU5gRcnQouQX12e2VTCIlFu75WTsJ8V/KnEaOZUVlU/B/Bs+WQF3U+/Jo
1531eX4N+D6OsEcNQjERAFyWujxsl1WpD4uSsbFMN0ECgYEA2b7AdL+oKPQHku2KcBhr
1532Zw8K4tMDlr2VPPNwZcBTLo+O71vv/xXjMcXrXmowzkgEQckUmt1VB46riyydhwdP
1533/n9ciWcz0Va/nwHR6Y9F9unBiyUBP7PRhRyjQyRZZRGDSJvP+Xmc5UJFpRr07VLU
1534nfgMXDj37vXzKDpfhdEB2nkCgYEAvNMfA8P8w3+6246x5YHflvTkPdw+2oyge+LD
1535mphB/w7SF8mlyNGloj3+KBZmd9SkvT57wCvO96Y9/n+mBAVisRggc0hK4ymOVYhb
1536+im/JvqGQMbVeg6iCOHnWdaZf9tL8uVsegQy3kVTN7vAa+CMFgX1dt65cGBX6XkB
153744pYmMkCgYALhbiRdQLlB+TOtZs5y1EDpxwgXKI3+9hF3Wv5NnAwapBZwje0++eF
15383r9Rw7TJda4j/QwGFehF+hrBxp6fYpetE/hFnRx0225Qb7w368j8A+ql/lNOl6li
1539rd1F1EqWupKD6RrcTL8sspEU55RGaretlE6zIqCcGI/BdTVQ03qRoQKBgHDC3zWf
1540d7XD9HGjQGdfbIe4jQjIGxzmd/wjik4q+NZ5IkukVwWa9P/zZ3DHF8Ad05dT1hEH
15412FwaAdGWpyyljq9VSiOuG1KXAXHgsZSuE4ISf9P1KYzvaiJFzaPfvOEWs79E9MfU
15429A+6dJzG2X1SpjWMr26iSTlrv3QkmFUqzAfJAoGASBkn4wls+oC5rv/Mch43pBv5
1543UmKru4ltnEHJZdbSi2DJ+AnDLD222JCasb1VT1tm2XgW6DBqrdVRPPP6GOlB0MHU
1544+3ULtZxAczt7I+ST2bo0/DV2Hse89Cm63w4wLOiVZs7+1wrAzJZLokWF7Q5gesra
1545u19txmtkiMEH+aNmekk=
1546-----END PRIVATE KEY-----"#;
1547
1548 #[async_timed_test(timeout_secs = 30)]
1549 async fn test_tls_basic() {
1550 let _ = rustls::crypto::ring::default_provider().install_default();
1553
1554 let config = hyperactor_config::global::lock();
1556 let _guard_cert =
1557 config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1558 let _guard_key =
1559 config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1560 let _guard_ca =
1561 config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1562
1563 let addr = TlsAddr::new("localhost", 0);
1565
1566 let (local_addr, mut rx) =
1567 server::serve::<u64>(ChannelAddr::Tls(addr)).expect("failed to serve");
1568
1569 let tx: super::NetTx<u64> = super::spawn(
1571 link(match &local_addr {
1572 ChannelAddr::Tls(addr) => addr.clone(),
1573 _ => panic!("unexpected address type"),
1574 })
1575 .expect("failed to create link"),
1576 );
1577
1578 tx.post(42u64);
1580
1581 let received = rx.recv().await.expect("failed to receive");
1583 assert_eq!(received, 42u64);
1584 }
1585
1586 #[async_timed_test(timeout_secs = 30)]
1587 async fn test_tls_multiple_messages() {
1588 let _ = rustls::crypto::ring::default_provider().install_default();
1589
1590 let config = hyperactor_config::global::lock();
1592 let _guard_cert =
1593 config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1594 let _guard_key =
1595 config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1596 let _guard_ca =
1597 config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1598
1599 let addr = TlsAddr::new("localhost", 0);
1600
1601 let (local_addr, mut rx) =
1602 server::serve::<String>(ChannelAddr::Tls(addr)).expect("failed to serve");
1603 let tx: super::NetTx<String> = super::spawn(
1604 link(match &local_addr {
1605 ChannelAddr::Tls(addr) => addr.clone(),
1606 _ => panic!("unexpected address type"),
1607 })
1608 .expect("failed to create link"),
1609 );
1610
1611 for i in 0..10 {
1613 tx.post(format!("message {}", i));
1614 }
1615
1616 for i in 0..10 {
1618 let received = rx.recv().await.expect("failed to receive");
1619 assert_eq!(received, format!("message {}", i));
1620 }
1621 }
1622
1623 #[test]
1624 fn test_tls_parse_hostname_port() {
1625 let addr = parse("localhost:8080").expect("failed to parse");
1626 assert!(matches!(
1627 addr,
1628 ChannelAddr::Tls(TlsAddr { hostname, port })
1629 if hostname == "localhost" && port == 8080
1630 ));
1631 }
1632
1633 #[test]
1634 fn test_tls_parse_socket_addr() {
1635 let addr = parse("127.0.0.1:8080").expect("failed to parse");
1636 assert!(matches!(
1637 addr,
1638 ChannelAddr::Tls(TlsAddr { hostname, port })
1639 if hostname == "127.0.0.1" && port == 8080
1640 ));
1641 }
1642
1643 #[test]
1644 fn test_tls_certs_parsing() {
1645 let cert_pem = Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec());
1647 let key_pem = Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec());
1648 let ca_pem = Pem::Value(TEST_CA_CERT.as_bytes().to_vec());
1649
1650 let certs = super::load_certs(&cert_pem).expect("failed to load certs");
1651 assert!(!certs.is_empty(), "expected at least one certificate");
1652
1653 let _key = super::load_key(&key_pem).expect("failed to load key");
1654
1655 let root_store = super::build_root_store(&ca_pem).expect("failed to build root store");
1656 assert!(!root_store.is_empty(), "expected at least one CA cert");
1657 }
1658
1659 #[test]
1660 fn test_tls_acceptor_creation() {
1661 let _ = rustls::crypto::ring::default_provider().install_default();
1664
1665 let config = hyperactor_config::global::lock();
1667 let _guard_cert =
1668 config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1669 let _guard_key =
1670 config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1671 let _guard_ca =
1672 config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1673
1674 let _acceptor = super::tls_acceptor().expect("failed to create TLS acceptor");
1676 }
1677
1678 #[test]
1679 fn test_tls_connector_creation() {
1680 let _ = rustls::crypto::ring::default_provider().install_default();
1683
1684 let config = hyperactor_config::global::lock();
1686 let _guard_cert =
1687 config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1688 let _guard_key =
1689 config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1690 let _guard_ca =
1691 config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1692
1693 let _connector = super::tls_connector().expect("failed to create TLS connector");
1695 }
1696 }
1697}
1698
1699fn oss_pem_bundle() -> crate::config::PemBundle {
1701 crate::config::PemBundle {
1702 ca: hyperactor_config::global::get_cloned(crate::config::TLS_CA),
1703 cert: hyperactor_config::global::get_cloned(crate::config::TLS_CERT),
1704 key: hyperactor_config::global::get_cloned(crate::config::TLS_KEY),
1705 }
1706}
1707
1708pub fn try_tls_pem_bundle() -> Option<crate::config::PemBundle> {
1718 let oss_bundle = oss_pem_bundle();
1719 if oss_bundle.ca.reader().is_ok() {
1720 return Some(oss_bundle);
1721 }
1722 tracing::debug!("OSS TLS bundle: CA not readable, trying Meta paths");
1723
1724 let meta_bundle = meta::get_server_pem_bundle();
1725 if meta_bundle.ca.reader().is_ok() {
1726 return Some(meta_bundle);
1727 }
1728 tracing::debug!("Meta TLS bundle: CA not readable, no TLS available");
1729
1730 None
1731}
1732
1733pub fn try_tls_acceptor(enforce_client_tls: bool) -> Option<tokio_rustls::TlsAcceptor> {
1753 let oss_bundle = oss_pem_bundle();
1754 if let Ok(acceptor) = tls::tls_acceptor_from_bundle(&oss_bundle, enforce_client_tls) {
1755 return Some(acceptor);
1756 }
1757 tracing::debug!("OSS TLS acceptor failed, trying Meta paths");
1758
1759 let meta_bundle = meta::get_server_pem_bundle();
1760 if let Ok(acceptor) = tls::tls_acceptor_from_bundle(&meta_bundle, enforce_client_tls) {
1761 return Some(acceptor);
1762 }
1763 tracing::debug!("Meta TLS acceptor failed, no TLS available");
1764
1765 None
1766}
1767
1768pub fn try_tls_connector() -> Option<tokio_rustls::TlsConnector> {
1780 let oss_bundle = oss_pem_bundle();
1781 if let Ok(connector) = tls::tls_connector_from_bundle(&oss_bundle) {
1782 return Some(connector);
1783 }
1784 tracing::debug!("OSS TLS connector failed, trying Meta paths");
1785
1786 if let Ok(connector) = meta::try_tls_connector() {
1787 return Some(connector);
1788 }
1789 tracing::debug!("Meta TLS connector failed, no TLS available");
1790
1791 None
1792}
1793
1794#[cfg(test)]
1795mod tests {
1796 use std::assert_matches::assert_matches;
1797 use std::collections::VecDeque;
1798 use std::marker::PhantomData;
1799 use std::sync::Arc;
1800 use std::sync::RwLock;
1801 use std::sync::atomic::AtomicBool;
1802 use std::sync::atomic::AtomicU64;
1803 use std::sync::atomic::Ordering;
1804 use std::time::Duration;
1805 #[cfg(target_os = "linux")] use std::time::UNIX_EPOCH;
1807
1808 #[cfg(target_os = "linux")] use anyhow::Result;
1810 use bytes::Bytes;
1811 use rand::Rng;
1812 use rand::SeedableRng;
1813 use rand::distributions::Alphanumeric;
1814 use timed_test::async_timed_test;
1815 use tokio::io::AsyncWrite;
1816 use tokio::io::DuplexStream;
1817 use tokio::io::ReadHalf;
1818 use tokio::io::WriteHalf;
1819 use tokio::task::JoinHandle;
1820 use tokio_util::sync::CancellationToken;
1821
1822 use super::server;
1823 use super::*;
1824 use crate::channel;
1825 use crate::channel::net::framed::FrameReader;
1826 use crate::channel::net::framed::FrameWrite;
1827 use crate::channel::net::server::AcceptorLink;
1828 use crate::config;
1829 use crate::metrics;
1830 use crate::sync::mvar::MVar;
1831
1832 fn logs_assert_unscoped(f: impl Fn(&[&str]) -> Result<(), String>) {
1836 let buf = tracing_test::internal::global_buf().lock().unwrap();
1837 let logs_str = std::str::from_utf8(&buf).expect("Logs contain invalid UTF8");
1838 let lines: Vec<&str> = logs_str.lines().collect();
1839 match f(&lines) {
1840 Ok(()) => {}
1841 Err(msg) => panic!("{}", msg),
1842 }
1843 }
1844
1845 #[cfg(target_os = "linux")] #[tracing_test::traced_test]
1847 #[tokio::test]
1848 async fn test_unix_basic() -> Result<()> {
1849 let timestamp = std::time::SystemTime::now()
1850 .duration_since(UNIX_EPOCH)
1851 .unwrap()
1852 .as_nanos();
1853 let unique_address = format!("test_unix_basic_{}", timestamp);
1854
1855 let (addr, mut rx) = server::serve::<u64>(ChannelAddr::Unix(
1856 unix::SocketAddr::from_abstract_name(&unique_address)?,
1857 ))
1858 .unwrap();
1859
1860 {
1867 let tx: ChannelTx<u64> = channel::dial::<u64>(addr.clone()).unwrap();
1868 tx.post(123);
1869 assert_eq!(rx.recv().await.unwrap(), 123);
1870 }
1871
1872 {
1873 let tx = channel::dial::<u64>(addr.clone()).unwrap();
1874 tx.post(321);
1875 tx.post(111);
1876 tx.post(444);
1877
1878 assert_eq!(rx.recv().await.unwrap(), 321);
1879 assert_eq!(rx.recv().await.unwrap(), 111);
1880 assert_eq!(rx.recv().await.unwrap(), 444);
1881 }
1882
1883 {
1884 let tx = channel::dial::<u64>(addr).unwrap();
1885 drop(rx);
1886
1887 let (return_tx, return_rx) = oneshot::channel();
1888 tx.try_post(123, return_tx);
1889 assert_matches!(
1890 return_rx.await,
1891 Ok(SendError {
1892 error: ChannelError::Closed,
1893 message: 123,
1894 ..
1895 })
1896 );
1897 }
1898
1899 Ok(())
1900 }
1901
1902 #[cfg(target_os = "linux")] #[tracing_test::traced_test]
1904 #[tokio::test]
1905 async fn test_unix_basic_client_before_server() -> Result<()> {
1906 let timestamp = std::time::SystemTime::now()
1908 .duration_since(UNIX_EPOCH)
1909 .unwrap()
1910 .as_nanos();
1911 let socket_addr =
1912 unix::SocketAddr::from_abstract_name(&format!("test_unix_basic_{}", timestamp))
1913 .unwrap();
1914
1915 let addr = ChannelAddr::Unix(socket_addr.clone());
1917 let tx = crate::channel::dial::<u64>(addr.clone()).unwrap();
1918 tx.post(123);
1919
1920 let (_, mut rx) = server::serve::<u64>(ChannelAddr::Unix(socket_addr)).unwrap();
1921 assert_eq!(rx.recv().await.unwrap(), 123);
1922
1923 tx.post(321);
1924 tx.post(111);
1925 tx.post(444);
1926
1927 assert_eq!(rx.recv().await.unwrap(), 321);
1928 assert_eq!(rx.recv().await.unwrap(), 111);
1929 assert_eq!(rx.recv().await.unwrap(), 444);
1930
1931 Ok(())
1932 }
1933
1934 #[tracing_test::traced_test]
1935 #[async_timed_test(timeout_secs = 60)]
1936 #[cfg_attr(not(fbcode_build), ignore)]
1938 async fn test_tcp_basic() {
1939 let (addr, mut rx) =
1940 server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
1941 {
1942 let tx = channel::dial::<u64>(addr.clone()).unwrap();
1943 tx.post(123);
1944 assert_eq!(rx.recv().await.unwrap(), 123);
1945 }
1946
1947 {
1948 let tx = channel::dial::<u64>(addr.clone()).unwrap();
1949 tx.post(321);
1950 tx.post(111);
1951 tx.post(444);
1952
1953 assert_eq!(rx.recv().await.unwrap(), 321);
1954 assert_eq!(rx.recv().await.unwrap(), 111);
1955 assert_eq!(rx.recv().await.unwrap(), 444);
1956 }
1957
1958 {
1959 let tx = channel::dial::<u64>(addr).unwrap();
1960 drop(rx);
1961
1962 let (return_tx, return_rx) = oneshot::channel();
1963 tx.try_post(123, return_tx);
1964 assert_matches!(
1965 return_rx.await,
1966 Ok(SendError {
1967 error: ChannelError::Closed,
1968 message: 123,
1969 ..
1970 })
1971 );
1972 }
1973 }
1974
1975 #[async_timed_test(timeout_secs = 5)]
1977 #[cfg_attr(not(fbcode_build), ignore)]
1979 async fn test_tcp_message_size() {
1980 let default_size_in_bytes = 100 * 1024 * 1024;
1981 let config = hyperactor_config::global::lock();
1983 let _guard1 = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1));
1984 let _guard2 = config.override_key(config::CODEC_MAX_FRAME_LENGTH, default_size_in_bytes);
1985
1986 let (addr, mut rx) =
1987 server::serve::<String>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
1988
1989 let tx = channel::dial::<String>(addr.clone()).unwrap();
1990 {
1992 let message = "a".repeat(default_size_in_bytes - 1024);
1994 tx.post(message.clone());
1995 assert_eq!(rx.recv().await.unwrap(), message);
1996 }
1997 {
1999 let (return_channel, return_receiver) = oneshot::channel();
2000 let message = "a".repeat(default_size_in_bytes + 1024);
2001 tx.try_post(message.clone(), return_channel);
2002 let returned = return_receiver.await.unwrap();
2003 assert_eq!(message, returned.message);
2004 }
2005 }
2006
2007 #[async_timed_test(timeout_secs = 30)]
2008 #[cfg_attr(not(fbcode_build), ignore)]
2010 async fn test_ack_flush() {
2011 let config = hyperactor_config::global::lock();
2012 let _guard_message_ack =
2015 config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 100000000);
2016 let _guard_delivery_timeout =
2017 config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(5));
2018
2019 let (addr, mut net_rx) =
2020 server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
2021 let net_tx = channel::dial::<u64>(addr.clone()).unwrap();
2022 let (tx, rx) = oneshot::channel();
2023 net_tx.try_post(1, tx);
2024 assert_eq!(net_rx.recv().await.unwrap(), 1);
2025 drop(net_rx);
2026 assert!(rx.await.is_err());
2029 }
2030
2031 #[async_timed_test(timeout_secs = 60)]
2032 #[cfg_attr(not(fbcode_build), ignore)]
2034 async fn test_meta_tls_basic() {
2035 hyperactor_telemetry::initialize_logging_for_test();
2036
2037 let addr = ChannelAddr::any(ChannelTransport::MetaTls(TlsMode::IpV6));
2038 let meta_addr = match addr {
2039 ChannelAddr::MetaTls(meta_addr) => meta_addr,
2040 _ => panic!("expected MetaTls address"),
2041 };
2042 let (local_addr, mut rx) = server::serve::<u64>(ChannelAddr::MetaTls(meta_addr)).unwrap();
2043 {
2044 let tx = channel::dial::<u64>(local_addr.clone()).unwrap();
2045 tx.post(123);
2046 }
2047 assert_eq!(rx.recv().await.unwrap(), 123);
2048
2049 {
2050 let tx = channel::dial::<u64>(local_addr.clone()).unwrap();
2051 tx.post(321);
2052 tx.post(111);
2053 tx.post(444);
2054 assert_eq!(rx.recv().await.unwrap(), 321);
2055 assert_eq!(rx.recv().await.unwrap(), 111);
2056 assert_eq!(rx.recv().await.unwrap(), 444);
2057 }
2058
2059 {
2060 let tx = channel::dial::<u64>(local_addr).unwrap();
2061 drop(rx);
2062
2063 let (return_tx, return_rx) = oneshot::channel();
2064 tx.try_post(123, return_tx);
2065 assert_matches!(
2066 return_rx.await,
2067 Ok(SendError {
2068 error: ChannelError::Closed,
2069 message: 123,
2070 ..
2071 })
2072 );
2073 }
2074 }
2075
2076 #[derive(Clone, Debug, Default)]
2077 struct NetworkFlakiness {
2078 disconnect_params: Option<(f64, u64, Duration)>,
2086 latency_range: Option<(Duration, Duration)>,
2089 }
2090
2091 impl NetworkFlakiness {
2092 async fn should_disconnect(
2094 &self,
2095 rng: &mut impl rand::Rng,
2096 disconnected_count: u64,
2097 prev_disconnected_at: &RwLock<Instant>,
2098 ) -> bool {
2099 let Some((prob, max_disconnects, duration)) = &self.disconnect_params else {
2100 return false;
2101 };
2102
2103 let disconnected_at = prev_disconnected_at.read().unwrap();
2104 if disconnected_at.elapsed() > *duration && disconnected_count < *max_disconnects {
2105 rng.gen_bool(*prob)
2106 } else {
2107 false
2108 }
2109 }
2110 }
2111
2112 struct MockLink<M> {
2113 buffer_size: usize,
2114 session_id: SessionId,
2115 receiver_storage: Arc<MVar<DuplexStream>>,
2116 fail_connects: Arc<AtomicBool>,
2118 disconnect_signal: watch::Sender<()>,
2121 network_flakiness: NetworkFlakiness,
2122 disconnected_count: Arc<AtomicU64>,
2123 prev_disconnected_at: Arc<RwLock<Instant>>,
2124 debug_log_sampling_rate: Option<u64>,
2127 _message_type: PhantomData<M>,
2128 }
2129
2130 impl<M> fmt::Debug for MockLink<M> {
2131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2132 f.debug_struct("MockLink")
2133 .field("buffer_size", &self.buffer_size)
2134 .field("receiver_storage", &"<MVar<DuplexStream>>")
2135 .field("fail_connects", &self.fail_connects)
2136 .field("disconnect_signal", &"<watch::Sender>")
2137 .field("network_flakiness", &self.network_flakiness)
2138 .field("disconnected_count", &self.disconnected_count)
2139 .field("prev_disconnected_at", &"<RwLock<Instant>>")
2140 .field("debug_log_sampling_rate", &self.debug_log_sampling_rate)
2141 .finish()
2142 }
2143 }
2144
2145 impl<M: RemoteMessage> MockLink<M> {
2146 fn new() -> Self {
2147 let (sender, _) = watch::channel(());
2148 Self {
2149 buffer_size: 64,
2150 session_id: SessionId::random(),
2151 receiver_storage: Arc::new(MVar::empty()),
2152 fail_connects: Arc::new(AtomicBool::new(false)),
2153 disconnect_signal: sender,
2154 network_flakiness: NetworkFlakiness::default(),
2155 disconnected_count: Arc::new(AtomicU64::new(0)),
2156 prev_disconnected_at: Arc::new(RwLock::new(tokio::time::Instant::now())),
2157 debug_log_sampling_rate: None,
2158 _message_type: PhantomData,
2159 }
2160 }
2161
2162 fn fail_connects() -> Self {
2165 Self {
2166 fail_connects: Arc::new(AtomicBool::new(true)),
2167 ..Self::new()
2168 }
2169 }
2170
2171 fn with_network_flakiness(network_flakiness: NetworkFlakiness) -> Self {
2172 if let Some((min, max)) = network_flakiness.latency_range {
2173 assert!(min < max);
2174 }
2175
2176 Self {
2177 network_flakiness,
2178 ..Self::new()
2179 }
2180 }
2181
2182 fn receiver_storage(&self) -> Arc<MVar<DuplexStream>> {
2183 self.receiver_storage.clone()
2184 }
2185
2186 fn disconnected_count(&self) -> Arc<AtomicU64> {
2187 self.disconnected_count.clone()
2188 }
2189
2190 fn disconnect_signal(&self) -> &watch::Sender<()> {
2191 &self.disconnect_signal
2192 }
2193
2194 fn fail_connects_switch(&self) -> Arc<AtomicBool> {
2195 self.fail_connects.clone()
2196 }
2197
2198 fn set_buffer_size(&mut self, size: usize) {
2199 self.buffer_size = size;
2200 }
2201
2202 fn set_sampling_rate(&mut self, sampling_rate: u64) {
2203 self.debug_log_sampling_rate = Some(sampling_rate);
2204 }
2205 }
2206
2207 #[async_trait]
2208 impl<M: RemoteMessage> Link for MockLink<M> {
2209 type Stream = DuplexStream;
2210
2211 fn dest(&self) -> ChannelAddr {
2212 ChannelAddr::Local(u64::MAX)
2213 }
2214
2215 fn link_id(&self) -> SessionId {
2216 self.session_id
2217 }
2218
2219 async fn next(&self) -> Result<Self::Stream, ClientError> {
2220 let session_id = self.session_id;
2221 tracing::debug!("MockLink starts to connect.");
2222 if self.fail_connects.load(Ordering::Acquire) {
2223 return Err(ClientError::Connect(
2224 self.dest(),
2225 std::io::Error::other("intentional error"),
2226 "expected failure injected by the mock".to_string(),
2227 ));
2228 }
2229
2230 async fn relay_message<M: RemoteMessage>(
2236 mut disconnect_signal: watch::Receiver<()>,
2237 network_flakiness: NetworkFlakiness,
2238 disconnected_count: Arc<AtomicU64>,
2239 prev_disconnected_at: Arc<RwLock<Instant>>,
2240 mut reader: FrameReader<ReadHalf<DuplexStream>>,
2241 mut writer: WriteHalf<DuplexStream>,
2242 task_coordination_token: CancellationToken,
2245 debug_log_sampling_rate: Option<u64>,
2246 is_from_client: bool,
2249 ) {
2250 async fn wait_for_latency_elapse(
2254 queue: &VecDeque<(Bytes, Instant)>,
2255 network_flakiness: &NetworkFlakiness,
2256 rng: &mut impl rand::Rng,
2257 ) {
2258 if let Some((min, max)) = network_flakiness.latency_range {
2259 let diff = max.abs_diff(min);
2260 let factor = rng.gen_range(0.0..=1.0);
2261 let latency = min + diff.mul_f64(factor);
2262 tokio::time::sleep_until(queue.front().unwrap().1 + latency).await;
2263 }
2264 }
2265
2266 let mut rng = rand::rngs::SmallRng::from_entropy();
2267 let mut queue: VecDeque<(Bytes, Instant)> = VecDeque::new();
2268 let mut send_count = 0u64;
2269
2270 loop {
2271 tokio::select! {
2272 read_res = reader.next() => {
2273 match read_res {
2274 Ok(Some((_, data))) => {
2275 queue.push_back((data, tokio::time::Instant::now()));
2276 }
2277 Ok(None) | Err(_) => {
2278 tracing::debug!("The upstream is closed or dropped. MockLink disconnects");
2279 break;
2280 }
2281 }
2282 }
2283 _ = wait_for_latency_elapse(&queue, &network_flakiness, &mut rng), if !queue.is_empty() => {
2284 let count = disconnected_count.load(Ordering::Relaxed);
2285 if network_flakiness.should_disconnect(&mut rng, count, &prev_disconnected_at).await {
2286 tracing::debug!("MockLink disconnects");
2287 disconnected_count.fetch_add(1, Ordering::Relaxed);
2288
2289 metrics::CHANNEL_RECONNECTIONS.add(
2290 1,
2291 hyperactor_telemetry::kv_pairs!(
2292 "transport" => "mock",
2293 "reason" => "network_flakiness",
2294 ),
2295 );
2296
2297 let mut w = prev_disconnected_at.write().unwrap();
2298 *w = tokio::time::Instant::now();
2299 break;
2300 }
2301 let data = queue.pop_front().unwrap().0;
2302 let is_sampled = debug_log_sampling_rate.is_some_and(|sample_rate| send_count % sample_rate == 1);
2303 if is_sampled {
2304 if is_from_client {
2305 if let Ok(Frame::Message(_seq, _msg)) = bincode::deserialize::<Frame<M>>(&data) {
2306 tracing::debug!("MockLink relays a msg from client. msg type: {}", std::any::type_name::<M>());
2307 }
2308 } else {
2309 let result = deserialize_response(data.clone());
2310 if let Ok(NetRxResponse::Ack(seq)) = result {
2311 tracing::debug!("MockLink relays an ack from server. seq: {}", seq);
2312 }
2313 }
2314 }
2315 let mut fw = FrameWrite::new(writer, data, hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH), 0).unwrap();
2316 if fw.send().await.is_err() {
2317 break;
2318 }
2319 writer = fw.complete();
2320 send_count += 1;
2321 }
2322 _ = task_coordination_token.cancelled() => break,
2323
2324 changed = disconnect_signal.changed() => {
2325 tracing::debug!("MockLink disconnects per disconnect_signal {:?}", changed);
2326 break;
2327 }
2328 }
2329 }
2330
2331 task_coordination_token.cancel();
2332 }
2333
2334 let (server, mut server_relay) = tokio::io::duplex(self.buffer_size);
2335 let (client, client_relay) = tokio::io::duplex(self.buffer_size);
2336
2337 write_link_init(&mut server_relay, session_id)
2341 .await
2342 .map_err(|err| ClientError::Io(self.dest(), err))?;
2343
2344 let (server_r, server_writer) = tokio::io::split(server_relay);
2345 let (client_r, client_writer) = tokio::io::split(client_relay);
2346
2347 let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
2348 let server_reader = FrameReader::new(server_r, max_len);
2349 let client_reader = FrameReader::new(client_r, max_len);
2350
2351 let task_coordination_token = CancellationToken::new();
2352 let _server_relay_task_handle = tokio::spawn(relay_message::<M>(
2353 self.disconnect_signal.subscribe(),
2354 self.network_flakiness.clone(),
2355 self.disconnected_count.clone(),
2356 self.prev_disconnected_at.clone(),
2357 server_reader,
2358 client_writer,
2359 task_coordination_token.clone(),
2360 self.debug_log_sampling_rate.clone(),
2361 false,
2362 ));
2363 let _client_relay_task_handle = tokio::spawn(relay_message::<M>(
2364 self.disconnect_signal.subscribe(),
2365 self.network_flakiness.clone(),
2366 self.disconnected_count.clone(),
2367 self.prev_disconnected_at.clone(),
2368 client_reader,
2369 server_writer,
2370 task_coordination_token,
2371 self.debug_log_sampling_rate.clone(),
2372 true,
2373 ));
2374
2375 self.receiver_storage.put(server).await;
2376 Ok(client)
2377 }
2378 }
2379
2380 struct MockLinkListener {
2381 receiver_storage: Arc<MVar<DuplexStream>>,
2382 channel_addr: ChannelAddr,
2383 }
2384
2385 impl MockLinkListener {
2386 fn new(receiver_storage: Arc<MVar<DuplexStream>>, channel_addr: ChannelAddr) -> Self {
2387 Self {
2388 receiver_storage,
2389 channel_addr,
2390 }
2391 }
2392 }
2393
2394 impl fmt::Debug for MockLinkListener {
2395 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2396 f.debug_struct("MockLinkListener")
2397 .field("channel_addr", &self.channel_addr)
2398 .finish()
2399 }
2400 }
2401
2402 #[async_trait]
2403 impl super::Listener for MockLinkListener {
2404 type Stream = DuplexStream;
2405
2406 async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
2407 let stream = self.receiver_storage.take().await;
2408 Ok((stream, self.channel_addr.clone()))
2409 }
2410 }
2411
2412 fn serve_acceptor_test<M: RemoteMessage>(
2416 session_id: SessionId,
2417 ) -> (
2418 JoinHandle<()>,
2419 crate::sync::mvar::MVar<DuplexStream>,
2420 mpsc::Receiver<M>,
2421 CancellationToken,
2422 ) {
2423 let mvar = crate::sync::mvar::MVar::empty();
2424 let cancel_token = CancellationToken::new();
2425 let link = AcceptorLink {
2426 dest: ChannelAddr::Local(u64::MAX),
2427 session_id,
2428 stream: mvar.clone(),
2429 cancel: cancel_token.clone(),
2430 };
2431 let (tx, rx) = mpsc::channel::<M>(1024);
2432 let ct = cancel_token.clone();
2433 let handle = tokio::spawn(async move {
2434 let mut session = Session::new(link);
2435 let mut next = session::Next { seq: 0, ack: 0 };
2436
2437 loop {
2438 let connected = match session.connect().await {
2439 Ok(s) => s,
2440 Err(_) => break,
2441 };
2442
2443 let result = {
2444 let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2445 tokio::select! {
2446 r = session::recv_connected::<M, _, _>(&stream, &tx, &mut next) => r,
2447 _ = ct.cancelled() => Err(session::RecvLoopError::Cancelled),
2448 }
2449 };
2450
2451 if next.ack < next.seq {
2453 let ack = serialize_response(NetRxResponse::Ack(next.seq - 1)).unwrap();
2454 let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2455 let mut completion = stream.write(ack);
2456 match completion.drive().await {
2457 Ok(()) => {
2458 next.ack = next.seq;
2459 }
2460 Err(e) => {
2461 tracing::debug!(
2462 error = %e,
2463 "failed to flush acks during cleanup"
2464 );
2465 }
2466 }
2467 }
2468
2469 let terminal_response = match &result {
2471 Err(session::RecvLoopError::SequenceError(reason)) => {
2472 Some(NetRxResponse::Reject(reason.clone()))
2473 }
2474 Err(session::RecvLoopError::Cancelled) => Some(NetRxResponse::Closed),
2475 _ => None,
2476 };
2477 if let Some(rsp) = terminal_response {
2478 let data = serialize_response(rsp).unwrap();
2479 let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2480 let mut completion = stream.write(data);
2481 let _ = completion.drive().await;
2482 }
2483
2484 let recoverable = matches!(&result, Ok(()) | Err(session::RecvLoopError::Io(_)));
2485 session = connected.release();
2486 if recoverable {
2487 continue;
2488 }
2489 break;
2490 }
2491 });
2492 (handle, mvar, rx, cancel_token)
2493 }
2494
2495 async fn write_stream<M, W>(
2496 mut writer: W,
2497 _session_id: u64,
2498 messages: &[(u64, M)],
2499 _init: bool,
2500 ) -> W
2501 where
2502 M: RemoteMessage + PartialEq + Clone,
2503 W: AsyncWrite + Unpin,
2504 {
2505 for (seq, message) in messages {
2506 let message =
2507 serde_multipart::serialize_bincode(&Frame::<M>::Message(*seq, message.clone()))
2508 .unwrap();
2509 let mut fw = FrameWrite::new(
2510 writer,
2511 message.framed(),
2512 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2513 0,
2514 )
2515 .map_err(|(_w, e)| e)
2516 .unwrap();
2517 fw.send().await.unwrap();
2518 writer = fw.complete();
2519 }
2520
2521 writer
2522 }
2523
2524 #[async_timed_test(timeout_secs = 60)]
2525 async fn test_persistent_server_session() {
2526 let config = hyperactor_config::global::lock();
2527 let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
2528
2529 async fn verify_ack(reader: &mut FrameReader<ReadHalf<DuplexStream>>, expected_last: u64) {
2530 let mut last_acked: i128 = -1;
2531 loop {
2532 let (_, bytes) = reader.next().await.unwrap().unwrap();
2533 let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
2534 assert!(
2535 acked as i128 > last_acked,
2536 "acks should be delivered in ascending order"
2537 );
2538 last_acked = acked as i128;
2539 assert!(acked <= expected_last);
2540 if acked == expected_last {
2541 break;
2542 }
2543 }
2544 }
2545
2546 let session_id = SessionId(123);
2547 let (_handle, mvar, mut rx, cancel_token) = serve_acceptor_test::<u64>(session_id);
2548
2549 {
2551 let (sender, receiver) = tokio::io::duplex(5000);
2552 mvar.put(receiver).await;
2553
2554 let (r, writer) = tokio::io::split(sender);
2555 let mut reader = FrameReader::new(
2556 r,
2557 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2558 );
2559
2560 let _writer = write_stream(
2561 writer,
2562 123,
2563 &[
2564 (0u64, 100u64),
2565 (1u64, 101u64),
2566 (2u64, 102u64),
2567 (3u64, 103u64),
2568 ],
2569 true,
2570 )
2571 .await;
2572
2573 assert_eq!(rx.recv().await, Some(100));
2574 assert_eq!(rx.recv().await, Some(101));
2575 assert_eq!(rx.recv().await, Some(102));
2576 assert_eq!(rx.recv().await, Some(103));
2577
2578 verify_ack(&mut reader, 3).await;
2579 }
2581
2582 {
2584 let (sender2, receiver2) = tokio::io::duplex(5000);
2585 mvar.put(receiver2).await;
2586
2587 let (r2, writer2) = tokio::io::split(sender2);
2588 let mut reader2 = FrameReader::new(
2589 r2,
2590 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2591 );
2592
2593 let _ = write_stream(
2594 writer2,
2595 123,
2596 &[
2597 (2u64, 102u64),
2598 (3u64, 103u64),
2599 (4u64, 104u64),
2600 (5u64, 105u64),
2601 ],
2602 true,
2603 )
2604 .await;
2605
2606 assert_eq!(rx.recv().await, Some(104));
2608 assert_eq!(rx.recv().await, Some(105));
2609
2610 verify_ack(&mut reader2, 5).await;
2611
2612 cancel_token.cancel();
2613 }
2614 }
2615
2616 #[async_timed_test(timeout_secs = 60)]
2617 async fn test_ack_from_server_session() {
2618 let config = hyperactor_config::global::lock();
2619 let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
2620 let session_id = SessionId(123);
2621 let (_handle, mvar, mut rx, cancel_token) = serve_acceptor_test::<u64>(session_id);
2622
2623 let (sender, receiver) = tokio::io::duplex(5000);
2624 mvar.put(receiver).await;
2625 let (r, mut writer) = tokio::io::split(sender);
2626 let mut reader = FrameReader::new(
2627 r,
2628 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2629 );
2630
2631 for i in 0u64..100u64 {
2632 writer = write_stream(writer, 123, &[(i, 100u64 + i)], i == 0u64).await;
2633 assert_eq!(rx.recv().await, Some(100u64 + i));
2634 let (_, bytes) = reader.next().await.unwrap().unwrap();
2635 let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
2636 assert_eq!(acked, i);
2637 }
2638
2639 tokio::time::sleep(Duration::from_secs(5)).await;
2641
2642 cancel_token.cancel();
2643
2644 let (_, bytes) = reader.next().await.unwrap().unwrap();
2646 assert!(deserialize_response(bytes).unwrap().is_closed());
2647 }
2648
2649 #[tracing_test::traced_test]
2650 async fn verify_tx_closed(tx_status: &mut watch::Receiver<TxStatus>, expected_log: &str) {
2651 match tokio::time::timeout(Duration::from_secs(5), tx_status.changed()).await {
2652 Ok(Ok(())) => {
2653 let current_status = *tx_status.borrow();
2654 assert_eq!(current_status, TxStatus::Closed);
2655 logs_assert_unscoped(|logs| {
2656 if logs.iter().any(|log| log.contains(expected_log)) {
2657 Ok(())
2658 } else {
2659 Err("expected log not found".to_string())
2660 }
2661 });
2662 }
2663 Ok(Err(_)) => panic!("watch::Receiver::changed() failed because sender is dropped."),
2664 Err(_) => panic!("timeout before tx_status changed"),
2665 }
2666 }
2667
2668 #[tracing_test::traced_test]
2669 #[tokio::test]
2670 #[cfg_attr(not(fbcode_build), ignore)]
2672 async fn test_tcp_tx_delivery_timeout() {
2673 let link = MockLink::<u64>::fail_connects();
2675 let tx = spawn::<u64>(link);
2676 let config = hyperactor_config::global::lock();
2678 let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1));
2679 let mut tx_receiver = tx.status().clone();
2680 let (return_channel, _return_receiver) = oneshot::channel();
2681 tx.try_post(123, return_channel);
2682 verify_tx_closed(&mut tx_receiver, "failed to deliver message within timeout").await;
2683 }
2684
2685 async fn take_receiver(
2686 receiver_storage: &MVar<DuplexStream>,
2687 ) -> (FrameReader<ReadHalf<DuplexStream>>, WriteHalf<DuplexStream>) {
2688 let mut receiver = receiver_storage.take().await;
2689 let _session_id = read_link_init(&mut receiver).await.expect("read LinkInit");
2691 let (r, writer) = tokio::io::split(receiver);
2692 let reader = FrameReader::new(
2693 r,
2694 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2695 );
2696 (reader, writer)
2697 }
2698
2699 async fn verify_message<M: RemoteMessage + PartialEq + std::fmt::Debug>(
2700 reader: &mut FrameReader<ReadHalf<DuplexStream>>,
2701 expect: (u64, M),
2702 loc: u32,
2703 ) {
2704 let expected = Frame::Message(expect.0, expect.1);
2705 let (_, bytes) = reader.next().await.unwrap().expect("unexpected EOF");
2706 let message = serde_multipart::Message::from_framed(bytes).unwrap();
2707 let frame: Frame<M> = serde_multipart::deserialize_bincode(message).unwrap();
2708
2709 assert_eq!(frame, expected, "from ln={loc}");
2710 }
2711
2712 async fn verify_stream<M: RemoteMessage + PartialEq + std::fmt::Debug + Clone>(
2713 reader: &mut FrameReader<ReadHalf<DuplexStream>>,
2714 expects: &[(u64, M)],
2715 _expect_session_id: Option<u64>,
2716 loc: u32,
2717 ) {
2718 for expect in expects {
2719 verify_message(reader, expect.clone(), loc).await;
2720 }
2721 }
2722
2723 async fn net_tx_send(tx: &NetTx<u64>, msgs: &[u64]) {
2724 for msg in msgs {
2725 tx.post(*msg);
2726 }
2727 }
2728
2729 #[async_timed_test(timeout_secs = 30)]
2731 async fn test_ack_in_net_tx_basic() {
2732 let link = MockLink::<u64>::new();
2733 let receiver_storage = link.receiver_storage();
2734 let tx = spawn::<u64>(link);
2735
2736 net_tx_send(&tx, &[100, 101, 102, 103, 104]).await;
2738 {
2739 let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2740 verify_stream(
2741 &mut reader,
2742 &[
2743 (0u64, 100u64),
2744 (1u64, 101u64),
2745 (2u64, 102u64),
2746 (3u64, 103u64),
2747 (4u64, 104u64),
2748 ],
2749 None,
2750 line!(),
2751 )
2752 .await;
2753
2754 for i in 0u64..5u64 {
2755 writer = FrameWrite::write_frame(
2756 writer,
2757 serialize_response(NetRxResponse::Ack(i)).unwrap(),
2758 1024,
2759 0,
2760 )
2761 .await
2762 .map_err(|(_, e)| e)
2763 .unwrap();
2764 }
2765 tokio::time::sleep(Duration::from_secs(3)).await;
2767 drop(reader);
2769 drop(writer);
2770 };
2771
2772 net_tx_send(&tx, &[105u64]).await;
2774 {
2775 let (mut reader, _writer) = take_receiver(&receiver_storage).await;
2776 verify_stream(&mut reader, &[(5u64, 105u64)], None, line!()).await;
2777 };
2779 }
2780
2781 #[async_timed_test(timeout_secs = 60)]
2783 async fn test_persistent_net_tx() {
2784 let link = MockLink::<u64>::new();
2785 let receiver_storage = link.receiver_storage();
2786
2787 let tx = spawn::<u64>(link);
2788
2789 net_tx_send(&tx, &[100, 101, 102, 103, 104]).await;
2791
2792 let n = 10;
2794
2795 for i in 0..n {
2798 {
2799 let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2800 verify_stream(
2801 &mut reader,
2802 &[
2803 (0u64, 100u64),
2804 (1u64, 101u64),
2805 (2u64, 102u64),
2806 (3u64, 103u64),
2807 (4u64, 104u64),
2808 ],
2809 None,
2810 line!(),
2811 )
2812 .await;
2813
2814 if i == n - 1 {
2817 writer = FrameWrite::write_frame(
2818 writer,
2819 serialize_response(NetRxResponse::Ack(1)).unwrap(),
2820 1024,
2821 0,
2822 )
2823 .await
2824 .map_err(|(_, e)| e)
2825 .unwrap();
2826 tokio::time::sleep(Duration::from_secs(3)).await;
2828 }
2829 drop(reader);
2831 drop(writer);
2832 };
2833 }
2834
2835 for _ in 0..n {
2837 {
2838 let (mut reader, mut _writer) = take_receiver(&receiver_storage).await;
2839 verify_stream(
2840 &mut reader,
2841 &[(2u64, 102u64), (3u64, 103u64), (4u64, 104u64)],
2842 None,
2843 line!(),
2844 )
2845 .await;
2846 };
2848 }
2849
2850 net_tx_send(&tx, &[105u64, 106u64, 107u64, 108u64, 109u64]).await;
2852 for i in 0..n {
2855 {
2856 let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2857 verify_stream(
2858 &mut reader,
2859 &[
2860 (2u64, 102u64),
2862 (3u64, 103u64),
2863 (4u64, 104u64),
2864 (5u64, 105u64),
2866 (6u64, 106u64),
2867 (7u64, 107u64),
2868 (8u64, 108u64),
2869 (9u64, 109u64),
2870 ],
2871 None,
2872 line!(),
2873 )
2874 .await;
2875
2876 if i == n - 1 {
2879 writer = FrameWrite::write_frame(
2882 writer,
2883 serialize_response(NetRxResponse::Ack(1)).unwrap(),
2884 1024,
2885 0,
2886 )
2887 .await
2888 .map_err(|(_, e)| e)
2889 .unwrap();
2890 writer = FrameWrite::write_frame(
2891 writer,
2892 serialize_response(NetRxResponse::Ack(2)).unwrap(),
2893 1024,
2894 0,
2895 )
2896 .await
2897 .map_err(|(_, e)| e)
2898 .unwrap();
2899 writer = FrameWrite::write_frame(
2900 writer,
2901 serialize_response(NetRxResponse::Ack(3)).unwrap(),
2902 1024,
2903 0,
2904 )
2905 .await
2906 .map_err(|(_, e)| e)
2907 .unwrap();
2908 tokio::time::sleep(Duration::from_secs(3)).await;
2910 }
2911 drop(reader);
2913 drop(writer);
2914 };
2915 }
2916
2917 for i in 0..n {
2918 {
2919 let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2920 verify_stream(
2921 &mut reader,
2922 &[
2923 (4u64, 104),
2925 (5u64, 105u64),
2927 (6u64, 106u64),
2928 (7u64, 107u64),
2929 (8u64, 108u64),
2930 (9u64, 109u64),
2931 ],
2932 None,
2933 line!(),
2934 )
2935 .await;
2936
2937 if i == n - 1 {
2939 writer = FrameWrite::write_frame(
2940 writer,
2941 serialize_response(NetRxResponse::Ack(7)).unwrap(),
2942 1024,
2943 0,
2944 )
2945 .await
2946 .map_err(|(_, e)| e)
2947 .unwrap();
2948 tokio::time::sleep(Duration::from_secs(3)).await;
2950 }
2951 drop(reader);
2953 drop(writer);
2954 };
2955 }
2956
2957 for _ in 0..n {
2958 {
2959 let (mut reader, writer) = take_receiver(&receiver_storage).await;
2960 verify_stream(
2961 &mut reader,
2962 &[
2963 (8u64, 108u64),
2965 (9u64, 109u64),
2966 ],
2967 None,
2968 line!(),
2969 )
2970 .await;
2971 drop(reader);
2973 drop(writer);
2974 };
2975 }
2976 }
2977
2978 #[async_timed_test(timeout_secs = 15)]
2979 async fn test_ack_before_redelivery_in_net_tx() {
2980 let link = MockLink::<u64>::new();
2981 let receiver_storage = link.receiver_storage();
2982 let net_tx = spawn::<u64>(link);
2983
2984 let (return_channel_tx, return_channel_rx) = oneshot::channel();
2987 net_tx.try_post(100, return_channel_tx);
2988 let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
2989 verify_stream(&mut reader, &[(0u64, 100u64)], None, line!()).await;
2990 writer = FrameWrite::write_frame(
2992 writer,
2993 serialize_response(NetRxResponse::Ack(0)).unwrap(),
2994 1024,
2995 0,
2996 )
2997 .await
2998 .map_err(|(_, e)| e)
2999 .unwrap();
3000 assert!(return_channel_rx.await.is_err());
3005
3006 let _ = FrameWrite::write_frame(
3011 writer,
3012 serialize_response(NetRxResponse::Ack(1)).unwrap(),
3013 1024,
3014 0,
3015 )
3016 .await
3017 .map_err(|(_, e)| e)
3018 .unwrap();
3019
3020 let (return_channel_tx, return_channel_rx) = oneshot::channel();
3021 net_tx.try_post(101, return_channel_tx);
3022 verify_message(&mut reader, (1u64, 101u64), line!()).await;
3024 assert!(return_channel_rx.await.is_err());
3031 }
3032
3033 async fn verify_ack_exceeded_limit(disconnect_before_ack: bool) {
3034 let config = hyperactor_config::global::lock();
3036 let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(2));
3037
3038 let link: MockLink<u64> = MockLink::<u64>::new();
3039 let disconnect_signal = link.disconnect_signal().clone();
3040 let fail_connect_switch = link.fail_connects_switch();
3041 let receiver_storage = link.receiver_storage();
3042 let tx = spawn::<u64>(link);
3043 let mut tx_status = tx.status().clone();
3044 tx.post(100);
3046 let (mut reader, writer) = take_receiver(&receiver_storage).await;
3047 verify_stream(&mut reader, &[(0u64, 100u64)], None, line!()).await;
3049 let _ = FrameWrite::write_frame(
3051 writer,
3052 serialize_response(NetRxResponse::Ack(0)).unwrap(),
3053 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3054 0,
3055 )
3056 .await
3057 .map_err(|(_, e)| e)
3058 .unwrap();
3059 tokio::time::sleep(Duration::from_secs(3)).await;
3060 assert!(!tx_status.has_changed().unwrap());
3062 assert_eq!(*tx_status.borrow(), TxStatus::Active);
3063
3064 tx.post(101);
3065 verify_message(&mut reader, (1u64, 101u64), line!()).await;
3067
3068 if disconnect_before_ack {
3069 fail_connect_switch.store(true, Ordering::Release);
3071 disconnect_signal.send(()).unwrap();
3073 }
3074
3075 let expected_log: &str = if disconnect_before_ack {
3077 "failed to receive ack within timeout 2s; link is currently broken"
3078 } else {
3079 "failed to receive ack within timeout 2s; link is currently connected"
3080 };
3081
3082 verify_tx_closed(&mut tx_status, expected_log).await;
3083 }
3084
3085 #[tracing_test::traced_test]
3086 #[async_timed_test(timeout_secs = 30)]
3087 #[cfg_attr(not(fbcode_build), ignore)]
3089 async fn test_ack_exceeded_limit_with_connected_link() {
3090 verify_ack_exceeded_limit(false).await;
3091 }
3092
3093 #[tracing_test::traced_test]
3094 #[async_timed_test(timeout_secs = 30)]
3095 #[cfg_attr(not(fbcode_build), ignore)]
3097 async fn test_ack_exceeded_limit_with_broken_link() {
3098 verify_ack_exceeded_limit(true).await;
3099 }
3100
3101 #[async_timed_test(timeout_secs = 60)]
3104 async fn test_network_flakiness_in_channel() {
3105 hyperactor_telemetry::initialize_logging_for_test();
3106
3107 let sampling_rate = 100;
3108 let mut link = MockLink::<u64>::with_network_flakiness(NetworkFlakiness {
3109 disconnect_params: Some((0.001, 15, Duration::from_millis(400))),
3110 latency_range: Some((Duration::from_millis(100), Duration::from_millis(200))),
3111 });
3112 link.set_sampling_rate(sampling_rate);
3113 link.set_buffer_size(1024000);
3115 let disconnected_count = link.disconnected_count();
3116 let receiver_storage = link.receiver_storage();
3117 let listener = MockLinkListener::new(receiver_storage.clone(), link.dest());
3118 let local_addr = listener.channel_addr.clone();
3119 let (_, mut nx): (ChannelAddr, NetRx<u64>) =
3120 super::server::serve_with_listener(listener, local_addr).unwrap();
3121 let tx = spawn::<u64>(link);
3122 let messages: Vec<_> = (0..10001).collect();
3123 let messages_clone = messages.clone();
3124 let send_task_handle = tokio::spawn(async move {
3127 for message in messages_clone {
3128 tokio::time::sleep(Duration::from_micros(rand::random::<u64>() % 100)).await;
3140 tx.post(message);
3141 }
3142 tracing::debug!("NetTx sent all messages");
3143 tx
3146 });
3147
3148 for message in &messages {
3149 if message % sampling_rate == 0 {
3150 tracing::debug!("NetRx received a message: {message}");
3151 }
3152 assert_eq!(nx.recv().await.unwrap(), *message);
3153 }
3154 tracing::debug!("NetRx received all messages");
3155
3156 let send_result = send_task_handle.await;
3157 assert!(send_result.is_ok());
3158
3159 tracing::debug!(
3160 "MockLink disconnected {} times.",
3161 disconnected_count.load(Ordering::SeqCst)
3162 );
3163 }
3166
3167 #[async_timed_test(timeout_secs = 60)]
3168 async fn test_ack_every_n_messages() {
3169 let config = hyperactor_config::global::lock();
3170 let _guard_message_ack = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 600);
3171 let _guard_time_interval =
3172 config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(1000));
3173 sparse_ack().await;
3174 }
3175
3176 #[async_timed_test(timeout_secs = 60)]
3177 async fn test_ack_every_time_interval() {
3178 let config = hyperactor_config::global::lock();
3179 let _guard_message_ack =
3180 config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 100000000);
3181 let _guard_time_interval = config.override_key(
3182 config::MESSAGE_ACK_TIME_INTERVAL,
3183 Duration::from_millis(500),
3184 );
3185 sparse_ack().await;
3186 }
3187
3188 async fn sparse_ack() {
3189 let mut link = MockLink::<u64>::new();
3190 link.set_buffer_size(1024000);
3192 let disconnected_count = link.disconnected_count();
3193 let receiver_storage = link.receiver_storage();
3194 let listener = MockLinkListener::new(receiver_storage.clone(), link.dest());
3195 let local_addr = listener.channel_addr.clone();
3196 let (_, mut nx): (ChannelAddr, NetRx<u64>) =
3197 super::server::serve_with_listener(listener, local_addr).unwrap();
3198 let tx = spawn::<u64>(link);
3199 let messages: Vec<_> = (0..20001).collect();
3200 let messages_clone = messages.clone();
3201 let send_task_handle = tokio::spawn(async move {
3204 for message in messages_clone {
3205 tokio::time::sleep(Duration::from_micros(rand::random::<u64>() % 100)).await;
3206 tx.post(message);
3207 }
3208 tokio::time::sleep(Duration::from_secs(5)).await;
3209 tracing::debug!("NetTx sent all messages");
3210 tx
3211 });
3212
3213 for message in &messages {
3214 assert_eq!(nx.recv().await.unwrap(), *message);
3215 }
3216 tracing::debug!("NetRx received all messages");
3217
3218 let send_result = send_task_handle.await;
3219 assert!(send_result.is_ok());
3220
3221 tracing::debug!(
3222 "MockLink disconnected {} times.",
3223 disconnected_count.load(Ordering::SeqCst)
3224 );
3225 }
3226
3227 #[test]
3228 fn test_metatls_parsing() {
3229 let channel: ChannelAddr = "metatls!localhost:1234".parse().unwrap();
3231 assert_eq!(
3232 channel,
3233 ChannelAddr::MetaTls(TlsAddr::new("localhost", 1234))
3234 );
3235 let channel: ChannelAddr = "metatls!1.2.3.4:1234".parse().unwrap();
3237 assert_eq!(channel, ChannelAddr::MetaTls(TlsAddr::new("1.2.3.4", 1234)));
3238 let channel: ChannelAddr = "metatls!2401:db00:33c:6902:face:0:2a2:0:1234"
3240 .parse()
3241 .unwrap();
3242 assert_eq!(
3243 channel,
3244 ChannelAddr::MetaTls(TlsAddr::new("2401:db00:33c:6902:face:0:2a2:0", 1234))
3245 );
3246
3247 let channel: ChannelAddr = "metatls![::]:1234".parse().unwrap();
3248 assert_eq!(channel, ChannelAddr::MetaTls(TlsAddr::new("::", 1234)));
3249 }
3250
3251 #[async_timed_test(timeout_secs = 300)]
3252 #[cfg_attr(not(fbcode_build), ignore)]
3254 async fn test_tcp_throughput() {
3255 let config = hyperactor_config::global::lock();
3256 let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_mins(5));
3257
3258 let socket_addr: SocketAddr = "[::1]:0".parse().unwrap();
3259 let (local_addr, mut rx) = server::serve::<String>(ChannelAddr::Tcp(socket_addr)).unwrap();
3260
3261 let total_num_msgs = 500000;
3263
3264 let receive_handle = tokio::spawn(async move {
3265 let mut num = 0;
3266 for _ in 0..10 * total_num_msgs {
3267 rx.recv().await.unwrap();
3268 num += 1;
3269
3270 if num % 100000 == 0 {
3271 tracing::info!("total number of received messages: {}", num);
3272 }
3273 }
3274 });
3275
3276 let mut tx_handles = vec![];
3277 let mut txs = vec![];
3278 for _ in 0..10 {
3279 let server_addr = local_addr.clone();
3280 let tx = Arc::new(channel::dial::<String>(server_addr).unwrap());
3281 let tx2 = Arc::clone(&tx);
3282 txs.push(tx);
3283 tx_handles.push(tokio::spawn(async move {
3284 let random_string = rand::thread_rng()
3285 .sample_iter(&Alphanumeric)
3286 .take(2048)
3287 .map(char::from)
3288 .collect::<String>();
3289 for _ in 0..total_num_msgs {
3290 tx2.post(random_string.clone());
3291 }
3292 }));
3293 }
3294
3295 receive_handle.await.unwrap();
3296 for handle in tx_handles {
3297 handle.await.unwrap();
3298 }
3299 }
3300
3301 #[tracing_test::traced_test]
3302 #[async_timed_test(timeout_secs = 60)]
3303 #[cfg_attr(not(fbcode_build), ignore)]
3305 async fn test_net_tx_closed_on_server_reject() {
3306 let link = MockLink::<u64>::new();
3307 let receiver_storage = link.receiver_storage();
3308 let mut tx = spawn::<u64>(link);
3309 net_tx_send(&tx, &[100]).await;
3310
3311 {
3312 let (_reader, writer) = take_receiver(&receiver_storage).await;
3313 let _ = FrameWrite::write_frame(
3314 writer,
3315 serialize_response(NetRxResponse::Reject("testing".to_string())).unwrap(),
3316 1024,
3317 0,
3318 )
3319 .await
3320 .map_err(|(_, e)| e);
3321
3322 tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
3325 }
3326
3327 verify_tx_closed(&mut tx.status, "server rejected connection").await;
3328 }
3329
3330 #[async_timed_test(timeout_secs = 60)]
3331 async fn test_server_rejects_conn_on_out_of_sequence_message() {
3332 let config = hyperactor_config::global::lock();
3333 let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
3334 let session_id = SessionId(123);
3335 let (_handle, mvar, mut rx, _cancel_token) = serve_acceptor_test::<u64>(session_id);
3336
3337 let (sender, receiver) = tokio::io::duplex(5000);
3338 mvar.put(receiver).await;
3339 let (r, writer) = tokio::io::split(sender);
3340 let mut reader = FrameReader::new(
3341 r,
3342 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3343 );
3344
3345 let _ = write_stream(writer, 123, &[(0, 100u64), (1, 101u64), (3, 103u64)], true).await;
3346 assert_eq!(rx.recv().await, Some(100u64));
3347 assert_eq!(rx.recv().await, Some(101u64));
3348 let (_, bytes) = reader.next().await.unwrap().unwrap();
3349 let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
3350 assert_eq!(acked, 0);
3351 let (_, bytes) = reader.next().await.unwrap().unwrap();
3352 let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
3353 assert_eq!(acked, 1);
3354 let (_, bytes) = reader.next().await.unwrap().unwrap();
3355 assert!(deserialize_response(bytes).unwrap().is_reject());
3356 }
3357
3358 #[async_timed_test(timeout_secs = 60)]
3359 #[cfg_attr(not(fbcode_build), ignore)]
3361 async fn test_stop_net_tx_after_stopping_net_rx() {
3362 hyperactor_telemetry::initialize_logging_for_test();
3363
3364 let config = hyperactor_config::global::lock();
3365 let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_mins(5));
3366 let (addr, mut rx) =
3367 server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap())).unwrap();
3368 let socket_addr = match addr {
3369 ChannelAddr::Tcp(a) => a,
3370 _ => panic!("unexpected channel type"),
3371 };
3372 let tx: NetTx<u64> = spawn(tcp::link(socket_addr));
3373 tx.send(100).await.unwrap();
3378 assert_eq!(rx.recv().await.unwrap(), 100);
3379 rx.2.stop("testing");
3381 assert!(rx.recv().await.is_err());
3382
3383 tx.post(101);
3386 let mut watcher = tx.status().clone();
3387 let _ = watcher.wait_for(|val| *val == TxStatus::Closed).await;
3389 assert_eq!(*watcher.borrow(), TxStatus::Closed);
3393 }
3394}