1use std::fmt;
51use std::fmt::Debug;
52use std::net::ToSocketAddrs;
53use std::time::Duration;
54
55use backoff::ExponentialBackoffBuilder;
56use backoff::backoff::Backoff;
57use bytes::Bytes;
58use enum_as_inner::EnumAsInner;
59use serde::Deserialize;
60use serde::Serialize;
61use serde::de::Error;
62use tokio::io::AsyncRead;
63use tokio::io::AsyncReadExt;
64use tokio::io::AsyncWrite;
65use tokio::io::AsyncWriteExt;
66use tokio::sync::watch;
67use tokio::time::Instant;
68
69use super::*;
70use crate::RemoteMessage;
71
72pub mod duplex;
73mod framed;
74pub(super) mod server;
75pub(super) mod session;
76pub use server::ServerHandle;
77
78pub(crate) trait Stream:
79 AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug + 'static
80{
81}
82impl<S: AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug + 'static> Stream for S {}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
88pub(crate) struct SessionId(u64);
89
90impl SessionId {
91 pub fn random() -> Self {
93 Self(rand::random())
94 }
95}
96
97impl fmt::Display for SessionId {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 write!(f, "{:016x}", self.0)
100 }
101}
102
103pub(crate) const INITIATOR_TO_ACCEPTOR: u8 = 0;
105
106pub(crate) const ACCEPTOR_TO_INITIATOR: u8 = 1;
108
109const LINK_INIT_MAGIC: [u8; 4] = *b"LNK\0";
118const LINK_INIT_SIZE: usize = 4 + 8 + 1;
119
120#[derive(Debug, Clone, Copy)]
122pub(crate) struct LinkInit {
123 pub session_id: SessionId,
124 pub stream_id: u8,
125}
126
127async fn write_link_init<S: AsyncWrite + Unpin>(
129 stream: &mut S,
130 session_id: SessionId,
131 stream_id: u8,
132) -> Result<(), std::io::Error> {
133 let mut buf = [0u8; LINK_INIT_SIZE];
134 buf[0..4].copy_from_slice(&LINK_INIT_MAGIC);
135 buf[4..12].copy_from_slice(&session_id.0.to_be_bytes());
136 buf[12] = stream_id;
137 stream.write_all(&buf).await
138}
139
140async fn read_link_init<S: AsyncRead + Unpin>(stream: &mut S) -> Result<LinkInit, std::io::Error> {
142 let mut buf = [0u8; LINK_INIT_SIZE];
143 stream.read_exact(&mut buf).await?;
144 if buf[0..4] != LINK_INIT_MAGIC {
145 return Err(std::io::Error::new(
146 std::io::ErrorKind::InvalidData,
147 format!("invalid LinkInit magic: expected LNK, got {:?}", &buf[0..4]),
148 ));
149 }
150 let session_id = SessionId(u64::from_be_bytes(buf[4..12].try_into().unwrap()));
151 let stream_id = buf[12];
152 Ok(LinkInit {
153 session_id,
154 stream_id,
155 })
156}
157
158#[async_trait]
162pub(crate) trait Link: Send + Sync + Debug + 'static {
163 type Stream: Stream;
165
166 fn dest(&self) -> ChannelAddr;
168
169 fn link_id(&self) -> SessionId;
171
172 async fn next(&mut self) -> Result<Self::Stream, ClientError>;
175}
176
177use session::Session;
178
179use crate::config;
180use crate::metrics;
181
182pub(crate) enum LinkStatus {
183 NeverConnected,
184 Connected(tokio::time::Instant),
185 Disconnected {
186 last_connected: tokio::time::Instant,
187 since: tokio::time::Instant,
188 },
189}
190
191impl LinkStatus {
192 fn connected(&mut self) {
193 *self = LinkStatus::Connected(tokio::time::Instant::now());
194 }
195
196 fn disconnected(&mut self) {
197 match *self {
198 LinkStatus::Connected(at) => {
199 *self = LinkStatus::Disconnected {
200 last_connected: at,
201 since: tokio::time::Instant::now(),
202 };
203 }
204 _ => {}
206 }
207 }
208}
209
210impl std::fmt::Display for LinkStatus {
211 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212 match self {
213 LinkStatus::NeverConnected => write!(f, "never connected"),
214 LinkStatus::Connected(at) => {
215 write!(f, "connected for {:.1}s", at.elapsed().as_secs_f64())
216 }
217 LinkStatus::Disconnected {
218 last_connected,
219 since,
220 } => {
221 write!(
222 f,
223 "last connected {:.1}s ago, disconnected for {:.1}s",
224 last_connected.elapsed().as_secs_f64(),
225 since.elapsed().as_secs_f64(),
226 )
227 }
228 }
229 }
230}
231
232fn log_send_error(
235 error: &session::SendLoopError,
236 dest: &ChannelAddr,
237 session_id: u64,
238 mode: &str,
239 link_status: &LinkStatus,
240) -> bool {
241 match error {
242 session::SendLoopError::Io(err) => {
243 tracing::info!(dest = %dest, session_id, error = %err, mode, "send error; {link_status}");
244 metrics::CHANNEL_ERRORS.add(
245 1,
246 hyperactor_telemetry::kv_pairs!(
247 "dest" => dest.to_string(),
248 "session_id" => session_id.to_string(),
249 "error_type" => metrics::ChannelErrorType::SendError.as_str(),
250 "mode" => mode.to_string(),
251 ),
252 );
253 false
254 }
255 session::SendLoopError::AppClosed => true,
256 session::SendLoopError::Rejected(reason) => {
257 tracing::error!(dest = %dest, session_id, mode, "server rejected connection: {reason}; {link_status}");
258 true
259 }
260 session::SendLoopError::ServerClosed => {
261 tracing::info!(dest = %dest, session_id, mode, "server closed the channel; {link_status}");
262 true
263 }
264 session::SendLoopError::DeliveryTimeout => {
265 let timeout = hyperactor_config::global::get(config::MESSAGE_DELIVERY_TIMEOUT);
266 tracing::error!(
267 dest = %dest, session_id, mode,
268 "failed to receive ack within timeout {timeout:?}; link is currently connected; {link_status}"
269 );
270 true
271 }
272 session::SendLoopError::OversizedFrame(reason) => {
273 tracing::error!(dest = %dest, session_id, mode, "oversized frame: {reason}; {link_status}");
274 true
275 }
276 }
277}
278
279pub(crate) fn spawn<M: RemoteMessage>(link: impl Link) -> NetTx<M> {
281 spawn_inner(link)
282}
283
284pub(crate) fn spawn_unordered<M: RemoteMessage>(links: Vec<impl Link + 'static>) -> NetTx<M> {
288 assert!(!links.is_empty());
289 if links.len() == 1 {
290 return spawn(links.into_iter().next().unwrap());
291 }
292
293 let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
294 let dest = links[0].dest();
295 let session_id = links[0].link_id();
296 let (notify, status) = watch::channel(TxStatus::Active);
297 let tx = NetTx {
298 sender,
299 dest: dest.clone(),
300 status,
301 };
302
303 let num_streams = links.len();
304
305 crate::init::get_runtime().spawn(async move {
306 let (queue_tx, queue_rx) =
311 async_channel::bounded::<session::PendingMessage<M>>(num_streams * 8);
312
313 let unacked: Arc<
317 tokio::sync::Mutex<std::collections::BTreeMap<u64, session::QueuedMessage<M>>>,
318 > = Arc::new(tokio::sync::Mutex::new(std::collections::BTreeMap::new()));
319
320 let mut writer_handles: Vec<tokio::task::JoinHandle<()>> = Vec::with_capacity(num_streams);
321 let log_id = format!("session {}.{:016x}", dest, session_id.0);
322
323 for (i, link) in links.into_iter().enumerate() {
324 let dest = dest.clone();
325 let unacked = unacked.clone();
326 let queue_rx = queue_rx.clone();
327 let log_id = log_id.clone();
328
329 writer_handles.push(tokio::spawn(async move {
330 let mut session = Session::new(link);
331 let mut reconnect_backoff = ExponentialBackoffBuilder::new()
332 .with_initial_interval(Duration::from_millis(10))
333 .with_multiplier(2.0)
334 .with_randomization_factor(0.1)
335 .with_max_interval(Duration::from_secs(5))
336 .with_max_elapsed_time(None)
337 .build();
338
339 loop {
340 let connected = match session.connect().await {
341 Ok(s) => s,
342 Err(_) => {
343 tracing::info!(
344 dest = %dest, stream = i,
345 "multi-stream writer {} connect failed", i
346 );
347 break;
348 }
349 };
350 tracing::info!(
351 dest = %dest, stream = i, "multi-stream writer {} connected", i
352 );
353
354 let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
355 let connected_at = tokio::time::Instant::now();
356
357 let result: Result<(), session::SendLoopError> = async {
359 loop {
360 tokio::select! {
361 biased;
362
363 ack_result = stream.next() => {
364 match ack_result {
365 Ok(Some(buffer)) => {
366 let response = deserialize_response(buffer)
367 .map_err(|e| session::SendLoopError::Io(e.into()))?;
368 match response {
369 NetRxResponse::Ack(ack) => {
370 let mut guard = unacked.lock().await;
371 let retain: std::collections::BTreeMap<u64, session::QueuedMessage<M>> = guard.split_off(&(ack + 1));
373 drop(std::mem::replace(&mut *guard, retain));
374 }
375 NetRxResponse::Reject(reason) => {
376 return Err(session::SendLoopError::Rejected(reason));
377 }
378 NetRxResponse::Closed => {
379 return Err(session::SendLoopError::ServerClosed);
380 }
381 }
382 }
383 Ok(None) => return Ok(()),
384 Err(e) => return Err(session::SendLoopError::Io(e.into())),
385 }
386 }
387
388 msg = queue_rx.recv() => {
389 let pending = match msg {
390 Ok(m) => m,
391 Err(_) => return Ok(()),
393 };
394 let session::PendingMessage {
395 seq,
396 message,
397 received_at,
398 return_channel,
399 } = pending;
400 let frame = Frame::Message(seq, message);
401 let serialized = match serde_multipart::serialize_bincode(&frame) {
402 Ok(m) => m,
403 Err(e) => {
404 tracing::error!(
405 "{log_id}: serialization error: {e}"
406 );
407 continue;
411 }
412 };
413 let mut queued = session::QueuedMessage {
414 seq,
415 message: serialized,
416 received_at,
417 sent_at: None,
418 return_channel,
419 };
420 let framed = queued.message.clone().framed();
421 stream.write(framed).drive().await.map_err(|e| {
422 session::SendLoopError::Io(e.into())
423 })?;
424 queued.sent_at = Some(tokio::time::Instant::now());
425 unacked.lock().await.insert(queued.seq, queued);
426 }
427 }
428 }
429 }
430 .await;
431
432 session = connected.release();
433
434 if connected_at.elapsed() > Duration::from_secs(1) {
435 reconnect_backoff.reset();
436 }
437
438 match result {
439 Ok(()) => {
440 if queue_rx.is_closed() {
441 break;
443 }
444 if let Some(delay) = reconnect_backoff.next_backoff() {
445 tokio::time::sleep(delay).await;
446 }
447 }
448 Err(ref e) => {
449 if log_send_error(e, &dest, session_id.0, "multi-stream", &LinkStatus::NeverConnected) {
450 break;
451 }
452 if let Some(delay) = reconnect_backoff.next_backoff() {
453 tokio::time::sleep(delay).await;
454 }
455 }
456 }
457 }
458
459 tracing::info!(
460 dest = %dest,
461 stream = i,
462 "multi-stream writer {} shutting down",
463 i,
464 );
465 }));
466 }
467
468 drop(queue_rx);
471
472 let mut next_seq = 0u64;
476
477 tracing::info!(
478 %dest, session = %log_id, num_streams,
479 "multi-stream dispatcher started"
480 );
481
482 while let Some((message, return_channel, received_at)) = receiver.recv().await {
483 let pending = session::PendingMessage {
484 seq: next_seq,
485 message,
486 received_at,
487 return_channel,
488 };
489 next_seq += 1;
490
491 if queue_tx.send(pending).await.is_err() {
492 break;
494 }
495 }
496
497 drop(queue_tx);
499 for handle in writer_handles {
500 let _ = handle.await;
501 }
502
503 let reason = format!("{log_id}: dispatcher closed");
504 let _ = notify.send(TxStatus::Closed(reason.into()));
505 });
506
507 tx
508}
509
510fn spawn_inner<M: RemoteMessage>(link: impl Link) -> NetTx<M> {
511 let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
512 let dest = link.dest();
513 let session_id = link.link_id();
514 let (notify, status) = watch::channel(TxStatus::Active);
515 let tx = NetTx {
516 sender,
517 dest: dest.clone(),
518 status,
519 };
520 crate::init::get_runtime().spawn(async move {
521 let mut session = Session::new(link);
522 let log_id = format!("session {}.{:016x}", dest, session_id.0);
523 let mut deliveries = session::Deliveries {
524 outbox: session::Outbox::new(log_id.clone(), dest.clone(), session_id.0),
525 unacked: session::Unacked::new(None, log_id.clone()),
526 };
527 let mut receiver = receiver;
528
529 match receiver.recv().await {
531 Some(msg) => {
532 if let Err(err) = deliveries.outbox.push_back(msg) {
533 tracing::error!(
534 dest = %dest,
535 session_id = session_id.0,
536 error = %err,
537 "failed to push message to outbox"
538 );
539 let _ = notify.send(TxStatus::Closed("failed to push to outbox".into()));
540 return;
541 }
542 }
543 None => {
544 let _ = notify.send(TxStatus::Closed("sender dropped".into()));
545 return;
546 }
547 }
548
549 let mut reconnect_backoff = ExponentialBackoffBuilder::new()
550 .with_initial_interval(Duration::from_millis(10))
551 .with_multiplier(2.0)
552 .with_randomization_factor(0.1)
553 .with_max_interval(Duration::from_secs(5))
554 .with_max_elapsed_time(None)
555 .build();
556
557 let mut link_status = LinkStatus::NeverConnected;
558
559 let reason: String = 'outer: loop {
560 let connected = match deliveries.expiry_time() {
561 Some(deadline) => match session.connect_by(deadline).await {
562 Ok(s) => s,
563 Err(_) => {
564 let timeout =
565 hyperactor_config::global::get(config::MESSAGE_DELIVERY_TIMEOUT);
566 let error_msg = if deliveries.outbox.is_expired(timeout) {
567 format!("failed to deliver message within timeout {timeout:?}; {link_status}")
568 } else {
569 format!(
570 "failed to receive ack within timeout {timeout:?}; \
571 link is currently broken; {link_status}",
572 )
573 };
574 tracing::error!(
575 dest = %dest, session_id = session_id.0, "{}", error_msg
576 );
577 break 'outer format!("{log_id}: {error_msg}");
578 }
579 },
580 None => match session.connect().await {
581 Ok(s) => s,
582 Err(_) => break 'outer "session shut down".into(),
583 },
584 };
585
586 metrics::CHANNEL_CONNECTIONS.add(
587 1,
588 hyperactor_telemetry::kv_pairs!(
589 "transport" => dest.transport().to_string(),
590 "mode" => "simplex",
591 "reason" => "link connected",
592 ),
593 );
594
595 if !deliveries.unacked.is_empty() {
596 metrics::CHANNEL_RECONNECTIONS.add(
597 1,
598 hyperactor_telemetry::kv_pairs!(
599 "dest" => dest.to_string(),
600 "transport" => dest.transport().to_string(),
601 "mode" => "simplex",
602 "reason" => "reconnect_with_unacked",
603 ),
604 );
605 }
606 deliveries.requeue_unacked();
607
608 link_status.connected();
609 let connected_at = tokio::time::Instant::now();
610
611 let result = {
612 let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
613 session::send_connected(&stream, &mut deliveries, &mut receiver).await
614 };
615 session = connected.release();
616
617 link_status.disconnected();
618
619 if connected_at.elapsed() > Duration::from_secs(1) {
622 reconnect_backoff.reset();
623 }
624
625 match result {
626 Ok(()) => {
627 if let Some(delay) = reconnect_backoff.next_backoff() {
629 tracing::info!(
630 dest = %dest,
631 session_id = session_id.0,
632 delay_ms = delay.as_millis() as u64,
633 "send_connected returned EOF, reconnecting after backoff; {link_status}"
634 );
635 tokio::time::sleep(delay).await;
636 }
637 }
638 Err(ref e) => {
639 if log_send_error(e, &dest, session_id.0, "simplex", &link_status) {
640 break 'outer format!("{log_id}: {e}");
641 }
642 if let Some(delay) = reconnect_backoff.next_backoff() {
644 tracing::info!(
645 dest = %dest,
646 session_id = session_id.0,
647 delay_ms = delay.as_millis() as u64,
648 error = %e,
649 "send_connected returned recoverable error, reconnecting after backoff; {link_status}"
650 );
651 tokio::time::sleep(delay).await;
652 }
653 }
654 }
655 };
656
657 tracing::info!(
658 dest = %dest, session_id = session_id.0, "NetTx closing: {reason}"
659 );
660
661 receiver.close();
662 deliveries
663 .unacked
664 .deque
665 .drain(..)
666 .chain(deliveries.outbox.deque.drain(..))
667 .for_each(|queued| queued.try_return(Some(reason.clone())));
668 while let Ok((msg, return_channel, _)) = receiver.try_recv() {
669 let _ = return_channel.send(SendError {
670 error: ChannelError::Closed,
671 message: msg,
672 reason: Some(reason.clone()),
673 });
674 }
675
676 let _ = notify.send(TxStatus::Closed(reason.into()));
677 });
678 tx
679}
680
681#[derive(Debug)]
684pub(crate) enum NetLink {
685 Tcp(tcp::TcpLink),
686 Unix(unix::UnixLink),
687 Tls(tls::TlsLink),
688}
689
690pub(crate) fn link(
694 addr: ChannelAddr,
695 session_id: SessionId,
696 stream_id: u8,
697) -> Result<NetLink, ClientError> {
698 match addr {
699 ChannelAddr::Tcp(socket_addr) => {
700 Ok(NetLink::Tcp(tcp::link(socket_addr, session_id, stream_id)))
701 }
702 ChannelAddr::Unix(unix_addr) => {
703 Ok(NetLink::Unix(unix::link(unix_addr, session_id, stream_id)))
704 }
705 ChannelAddr::Tls(tls_addr) => Ok(NetLink::Tls(tls::link(tls_addr, session_id, stream_id)?)),
706 ChannelAddr::MetaTls(meta_addr) => {
707 Ok(NetLink::Tls(meta::link(meta_addr, session_id, stream_id)?))
708 }
709 other => Err(ClientError::Connect(
710 other,
711 std::io::Error::other("unsupported transport"),
712 "unsupported transport".into(),
713 )),
714 }
715}
716
717#[async_trait]
718impl Link for NetLink {
719 type Stream = Box<dyn Stream>;
720
721 fn dest(&self) -> ChannelAddr {
722 match self {
723 Self::Tcp(l) => l.dest(),
724 Self::Unix(l) => l.dest(),
725 Self::Tls(l) => l.dest(),
726 }
727 }
728
729 fn link_id(&self) -> SessionId {
730 match self {
731 Self::Tcp(l) => l.link_id(),
732 Self::Unix(l) => l.link_id(),
733 Self::Tls(l) => l.link_id(),
734 }
735 }
736
737 async fn next(&mut self) -> Result<Box<dyn Stream>, ClientError> {
738 match self {
739 Self::Tcp(l) => Ok(Box::new(l.next().await?)),
740 Self::Unix(l) => Ok(Box::new(l.next().await?)),
741 Self::Tls(l) => Ok(Box::new(l.next().await?)),
742 }
743 }
744}
745
746#[async_trait]
751pub(crate) trait Listener: Send + Unpin + 'static {
752 type Stream: Stream;
754
755 async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError>;
757}
758
759#[derive(Debug)]
764pub(crate) enum NetListener {
765 Tcp(tcp::TcpSocketListener),
766 Unix(unix::UnixSocketListener),
767}
768
769#[async_trait]
770impl Listener for NetListener {
771 type Stream = Box<dyn Stream>;
772
773 async fn accept(&mut self) -> Result<(Box<dyn Stream>, ChannelAddr), ServerError> {
774 match self {
775 Self::Tcp(l) => {
776 let (stream, addr) = l.accept().await?;
777 Ok((Box::new(stream), addr))
778 }
779 Self::Unix(l) => {
780 let (stream, addr) = l.accept().await?;
781 Ok((Box::new(stream), addr))
782 }
783 }
784 }
785}
786
787pub(crate) fn listen_with_prebound(
791 addr: ChannelAddr,
792 prebound: Option<std::net::TcpListener>,
793) -> Result<(NetListener, ChannelAddr), ServerError> {
794 match addr {
795 ChannelAddr::Tcp(socket_addr) => {
796 let std_listener = match prebound {
797 Some(l) => l,
798 None => std::net::TcpListener::bind(socket_addr)
799 .map_err(|err| ServerError::Listen(ChannelAddr::Tcp(socket_addr), err))?,
800 };
801 std_listener
802 .set_nonblocking(true)
803 .map_err(|e| ServerError::Listen(ChannelAddr::Tcp(socket_addr), e))?;
804 let tokio_listener = tokio::net::TcpListener::from_std(std_listener)
805 .map_err(|e| ServerError::Listen(ChannelAddr::Tcp(socket_addr), e))?;
806 let local_addr = tokio_listener
807 .local_addr()
808 .map_err(|err| ServerError::Resolve(ChannelAddr::Tcp(socket_addr), err))?;
809 let listener = tcp::TcpSocketListener {
810 inner: tokio_listener,
811 addr: local_addr,
812 };
813 Ok((NetListener::Tcp(listener), ChannelAddr::Tcp(local_addr)))
814 }
815 ChannelAddr::Unix(ref unix_addr) => {
816 use std::os::unix::net::UnixDatagram as StdUnixDatagram;
817 use std::os::unix::net::UnixListener as StdUnixListener;
818
819 let caddr = addr.clone();
820 let maybe_listener = match unix_addr {
821 unix::SocketAddr::Bound(sock_addr) => StdUnixListener::bind_addr(sock_addr),
822 unix::SocketAddr::Unbound => StdUnixDatagram::unbound()
823 .and_then(|u| u.local_addr())
824 .and_then(|uaddr| StdUnixListener::bind_addr(&uaddr)),
825 };
826 let std_listener =
827 maybe_listener.map_err(|err| ServerError::Listen(caddr.clone(), err))?;
828 std_listener
829 .set_nonblocking(true)
830 .map_err(|err| ServerError::Listen(caddr.clone(), err))?;
831 let local_addr = std_listener
832 .local_addr()
833 .map_err(|err| ServerError::Resolve(caddr.clone(), err))?;
834 let tokio_listener = tokio::net::UnixListener::from_std(std_listener)
835 .map_err(|err| ServerError::Io(caddr, err))?;
836 let bound_addr = unix::SocketAddr::new(local_addr);
837 let listener = unix::UnixSocketListener {
838 inner: tokio_listener,
839 addr: bound_addr.clone(),
840 };
841 Ok((NetListener::Unix(listener), ChannelAddr::Unix(bound_addr)))
842 }
843 addr @ (ChannelAddr::Tls(_) | ChannelAddr::MetaTls(_)) => {
844 let is_meta = matches!(addr, ChannelAddr::MetaTls(_));
845 let tls_addr = match addr {
846 ChannelAddr::Tls(a) | ChannelAddr::MetaTls(a) => a,
847 _ => unreachable!(),
848 };
849 let TlsAddr { hostname, port } = tls_addr;
850 let make_channel_addr = |h: &str, p: Port| {
851 if is_meta {
852 ChannelAddr::MetaTls(TlsAddr::new(h, p))
853 } else {
854 ChannelAddr::Tls(TlsAddr::new(h, p))
855 }
856 };
857
858 let addrs: Vec<core::net::SocketAddr> = (hostname.as_ref(), port)
859 .to_socket_addrs()
860 .map_err(|err| ServerError::Resolve(make_channel_addr(&hostname, port), err))?
861 .collect();
862
863 if addrs.is_empty() {
864 return Err(ServerError::Resolve(
865 make_channel_addr(&hostname, port),
866 std::io::Error::other("no available socket addr"),
867 ));
868 }
869
870 let channel_addr = make_channel_addr(&hostname, port);
871 let std_listener = match prebound {
872 Some(l) => l,
873 None => std::net::TcpListener::bind(&addrs[..])
874 .map_err(|err| ServerError::Listen(channel_addr.clone(), err))?,
875 };
876 std_listener
877 .set_nonblocking(true)
878 .map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
879 let tokio_listener = tokio::net::TcpListener::from_std(std_listener)
880 .map_err(|e| ServerError::Listen(channel_addr.clone(), e))?;
881 let local_addr = tokio_listener
882 .local_addr()
883 .map_err(|err| ServerError::Resolve(channel_addr, err))?;
884 let listener = tcp::TcpSocketListener {
885 inner: tokio_listener,
886 addr: local_addr,
887 };
888 Ok((
889 NetListener::Tcp(listener),
890 make_channel_addr(&hostname, local_addr.port()),
891 ))
892 }
893 other => Err(ServerError::Listen(
894 other.clone(),
895 std::io::Error::other(format!("unsupported transport: {}", other)),
896 )),
897 }
898}
899
900#[expect(
904 dead_code,
905 reason = "canonical listen() entry point; callers currently route through listen_with_prebound"
906)]
907pub(crate) fn listen(addr: ChannelAddr) -> Result<(NetListener, ChannelAddr), ServerError> {
908 listen_with_prebound(addr, None)
909}
910
911#[derive(Debug, Serialize, Deserialize, EnumAsInner, PartialEq)]
913pub(super) enum Frame<M> {
914 Message(u64, M),
916}
917
918#[derive(Debug, Serialize, Deserialize, EnumAsInner)]
919pub(super) enum NetRxResponse {
920 Ack(u64),
921 Reject(String),
923 Closed,
925}
926
927pub(super) fn serialize_response(
928 response: NetRxResponse,
929) -> Result<Bytes, bincode::error::EncodeError> {
930 bincode::serde::encode_to_vec(&response, bincode::config::legacy()).map(|bytes| bytes.into())
931}
932
933pub(super) fn deserialize_response(
934 data: Bytes,
935) -> Result<NetRxResponse, bincode::error::DecodeError> {
936 bincode::serde::decode_from_slice(&data, bincode::config::legacy()).map(|(v, _)| v)
937}
938
939pub(crate) struct NetTx<M: RemoteMessage> {
942 sender: mpsc::UnboundedSender<(M, oneshot::Sender<SendError<M>>, Instant)>,
943 dest: ChannelAddr,
944 status: watch::Receiver<TxStatus>,
945}
946
947#[async_trait]
948impl<M: RemoteMessage> Tx<M> for NetTx<M> {
949 fn addr(&self) -> ChannelAddr {
950 self.dest.clone()
951 }
952
953 fn status(&self) -> &watch::Receiver<TxStatus> {
954 &self.status
955 }
956
957 fn do_post(&self, message: M, return_channel: Option<oneshot::Sender<SendError<M>>>) {
958 tracing::trace!(
959 name = "post",
960 dest = %self.dest,
961 "sending message"
962 );
963
964 let return_channel = return_channel.unwrap_or_else(|| oneshot::channel().0);
965 if let Err(mpsc::error::SendError((message, return_channel, _))) =
966 self.sender
967 .send((message, return_channel, tokio::time::Instant::now()))
968 {
969 let reason = self.status.borrow().as_closed().map(|r| r.to_string());
970 let _ = return_channel.send(SendError {
971 error: ChannelError::Closed,
972 message,
973 reason,
974 });
975 }
976 }
977}
978
979pub struct NetRx<M: RemoteMessage>(mpsc::Receiver<M>, ChannelAddr, ServerHandle);
980
981#[async_trait]
982impl<M: RemoteMessage> Rx<M> for NetRx<M> {
983 async fn recv(&mut self) -> Result<M, ChannelError> {
984 tracing::trace!(
985 name = "recv",
986 dest = %self.1,
987 "receiving message"
988 );
989 self.0.recv().await.ok_or(ChannelError::Closed)
990 }
991
992 fn addr(&self) -> ChannelAddr {
993 self.1.clone()
994 }
995
996 async fn join(mut self) {
999 self.2
1000 .stop(&format!("NetRx joined; channel address: {}", self.1));
1001 let _ = (&mut self.2).await;
1002 }
1004}
1005
1006impl<M: RemoteMessage> Drop for NetRx<M> {
1007 fn drop(&mut self) {
1008 self.2
1009 .stop(&format!("NetRx dropped; channel address: {}", self.1));
1010 }
1011}
1012
1013#[derive(Debug, thiserror::Error)]
1015pub enum ServerError {
1016 #[error("io: {1}")]
1018 Io(ChannelAddr, #[source] std::io::Error),
1019 #[error("listen: {0} {1}")]
1021 Listen(ChannelAddr, #[source] std::io::Error),
1022 #[error("resolve: {0} {1}")]
1024 Resolve(ChannelAddr, #[source] std::io::Error),
1025 #[error("internal: {0} {1}")]
1027 Internal(ChannelAddr, #[source] anyhow::Error),
1028}
1029
1030#[derive(thiserror::Error, Debug)]
1031pub enum ClientError {
1032 #[error("connection to {0} failed: {1}: {2}")]
1033 Connect(ChannelAddr, std::io::Error, String),
1034 #[error("unable to resolve address: {0}")]
1035 Resolve(ChannelAddr),
1036 #[error("io: {0} {1}")]
1037 Io(ChannelAddr, std::io::Error),
1038 #[error("send {0}: serialize: {1}")]
1039 Serialize(ChannelAddr, bincode::error::EncodeError),
1040 #[error("invalid address: {0}")]
1041 InvalidAddress(String),
1042}
1043
1044#[cfg(test)]
1047pub(super) fn is_net_addr(addr: &ChannelAddr) -> bool {
1048 match addr.transport() {
1049 ChannelTransport::Tcp(_) => true,
1050 ChannelTransport::MetaTls(_) => true,
1051 ChannelTransport::Tls => true,
1052 ChannelTransport::Unix => true,
1053 _ => false,
1054 }
1055}
1056
1057pub(crate) mod unix {
1058
1059 use core::str;
1060 use std::os::unix::net::SocketAddr as StdSocketAddr;
1061 use std::os::unix::net::UnixStream as StdUnixStream;
1062
1063 use rand::RngExt as _;
1064 use rand::distr::Alphanumeric;
1065 use tokio::net::UnixListener;
1066 use tokio::net::UnixStream;
1067
1068 use super::*;
1069
1070 #[derive(Debug)]
1071 pub(crate) struct UnixLink {
1072 pub(super) addr: SocketAddr,
1073 pub(super) session_id: SessionId,
1074 pub(super) stream_id: u8,
1075 }
1076
1077 #[async_trait]
1078 impl Link for UnixLink {
1079 type Stream = UnixStream;
1080
1081 fn dest(&self) -> ChannelAddr {
1082 ChannelAddr::Unix(self.addr.clone())
1083 }
1084
1085 fn link_id(&self) -> SessionId {
1086 self.session_id
1087 }
1088
1089 async fn next(&mut self) -> Result<Self::Stream, ClientError> {
1090 let session_id = self.session_id;
1091 let sock_addr = match &self.addr {
1092 SocketAddr::Bound(a) => a,
1093 SocketAddr::Unbound => return Err(ClientError::Resolve(self.dest())),
1094 };
1095 let mut backoff = ExponentialBackoffBuilder::new()
1096 .with_initial_interval(Duration::from_millis(1))
1097 .with_multiplier(2.0)
1098 .with_randomization_factor(0.1)
1099 .with_max_interval(Duration::from_millis(1000))
1100 .with_max_elapsed_time(None)
1101 .build();
1102 loop {
1103 match StdUnixStream::connect_addr(sock_addr) {
1104 Ok(std_stream) => {
1105 std_stream
1106 .set_nonblocking(true)
1107 .map_err(|err| ClientError::Io(self.dest(), err))?;
1108 let mut stream = UnixStream::from_std(std_stream)
1109 .map_err(|err| ClientError::Io(self.dest(), err))?;
1110 write_link_init(&mut stream, session_id, self.stream_id)
1111 .await
1112 .map_err(|err| ClientError::Io(self.dest(), err))?;
1113 return Ok(stream);
1114 }
1115 Err(err) => {
1116 tracing::debug!(error = %err, "unix connect failed, backing off");
1117 if let Some(delay) = backoff.next_backoff() {
1118 tokio::time::sleep(delay).await;
1119 }
1120 }
1121 }
1122 }
1123 }
1124 }
1125
1126 #[derive(Debug)]
1128 pub(crate) struct UnixSocketListener {
1129 pub(super) inner: UnixListener,
1130 pub(super) addr: SocketAddr,
1131 }
1132
1133 #[async_trait]
1134 impl super::Listener for UnixSocketListener {
1135 type Stream = UnixStream;
1136
1137 async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
1138 let (stream, peer_addr) = self
1139 .inner
1140 .accept()
1141 .await
1142 .map_err(|err| ServerError::Io(ChannelAddr::Unix(self.addr.clone()), err))?;
1143 let std_addr: StdSocketAddr = peer_addr.into();
1145 Ok((stream, ChannelAddr::Unix(SocketAddr::new(std_addr))))
1146 }
1147 }
1148
1149 pub(crate) fn link(addr: SocketAddr, session_id: SessionId, stream_id: u8) -> UnixLink {
1151 UnixLink {
1152 addr,
1153 session_id,
1154 stream_id,
1155 }
1156 }
1157
1158 #[derive(Clone, Debug)]
1160 pub enum SocketAddr {
1161 Bound(Box<StdSocketAddr>),
1162 Unbound,
1163 }
1164
1165 impl PartialOrd for SocketAddr {
1166 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1167 Some(self.cmp(other))
1168 }
1169 }
1170
1171 impl Ord for SocketAddr {
1172 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1173 self.to_string().cmp(&other.to_string())
1174 }
1175 }
1176
1177 impl<'de> Deserialize<'de> for SocketAddr {
1178 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
1179 where
1180 D: serde::Deserializer<'de>,
1181 {
1182 let s = String::deserialize(deserializer)?;
1183 Self::from_str(&s).map_err(D::Error::custom)
1184 }
1185 }
1186
1187 impl Serialize for SocketAddr {
1188 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
1189 where
1190 S: serde::Serializer,
1191 {
1192 serializer.serialize_str(String::from(self).as_str())
1193 }
1194 }
1195
1196 impl From<&SocketAddr> for String {
1197 fn from(value: &SocketAddr) -> Self {
1198 match value {
1199 SocketAddr::Bound(addr) => match addr.as_pathname() {
1200 Some(path) => path
1201 .to_str()
1202 .expect("unable to get str for path")
1203 .to_string(),
1204 #[cfg(target_os = "linux")]
1205 _ => match addr.as_abstract_name() {
1206 Some(name) => format!("@{}", String::from_utf8_lossy(name)),
1207 _ => String::from("(unnamed)"),
1208 },
1209 #[cfg(not(target_os = "linux"))]
1210 _ => String::from("(unnamed)"),
1211 },
1212 SocketAddr::Unbound => String::from("(unbound)"),
1213 }
1214 }
1215 }
1216
1217 impl FromStr for SocketAddr {
1218 type Err = anyhow::Error;
1219
1220 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
1221 match s {
1222 "" => {
1223 let random_string = rand::rng()
1226 .sample_iter(&Alphanumeric)
1227 .take(24)
1228 .map(char::from)
1229 .collect::<String>();
1230 SocketAddr::from_abstract_name(&random_string)
1231 }
1232 name if name.starts_with("@") => {
1234 SocketAddr::from_abstract_name(name.strip_prefix("@").unwrap())
1235 }
1236 path => SocketAddr::from_pathname(path),
1237 }
1238 }
1239 }
1240
1241 impl Eq for SocketAddr {}
1242 impl std::hash::Hash for SocketAddr {
1243 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
1244 String::from(self).hash(state);
1245 }
1246 }
1247 impl PartialEq for SocketAddr {
1248 fn eq(&self, other: &Self) -> bool {
1249 match (self, other) {
1250 (Self::Bound(saddr), Self::Bound(oaddr)) => {
1251 if saddr.is_unnamed() || oaddr.is_unnamed() {
1252 return false;
1253 }
1254
1255 #[cfg(target_os = "linux")]
1256 {
1257 saddr.as_pathname() == oaddr.as_pathname()
1258 && saddr.as_abstract_name() == oaddr.as_abstract_name()
1259 }
1260 #[cfg(not(target_os = "linux"))]
1261 {
1262 saddr.as_pathname() == oaddr.as_pathname()
1264 }
1265 }
1266 (Self::Unbound, _) | (_, Self::Unbound) => false,
1267 }
1268 }
1269 }
1270
1271 impl fmt::Display for SocketAddr {
1272 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1273 match self {
1274 Self::Bound(addr) => match addr.as_pathname() {
1275 Some(path) => {
1276 write!(f, "{}", path.to_string_lossy())
1277 }
1278 #[cfg(target_os = "linux")]
1279 _ => match addr.as_abstract_name() {
1280 Some(name) => {
1281 if name.starts_with(b"@") {
1282 return write!(f, "{}", String::from_utf8_lossy(name));
1283 }
1284 write!(f, "@{}", String::from_utf8_lossy(name))
1285 }
1286 _ => write!(f, "(unnamed)"),
1287 },
1288 #[cfg(not(target_os = "linux"))]
1289 _ => write!(f, "(unnamed)"),
1290 },
1291 Self::Unbound => write!(f, "(unbound)"),
1292 }
1293 }
1294 }
1295
1296 impl SocketAddr {
1297 pub fn new(addr: StdSocketAddr) -> Self {
1299 Self::Bound(Box::new(addr))
1300 }
1301
1302 #[cfg(target_os = "linux")]
1305 pub fn from_abstract_name(name: &str) -> anyhow::Result<Self> {
1306 Ok(Self::new(StdSocketAddr::from_abstract_name(
1307 name.strip_prefix("@").unwrap_or(name),
1308 )?))
1309 }
1310
1311 #[cfg(not(target_os = "linux"))]
1312 pub fn from_abstract_name(name: &str) -> anyhow::Result<Self> {
1313 let name = name.strip_prefix("@").unwrap_or(name);
1315 let path = Self::abstract_to_filesystem_path(name);
1316 Self::from_pathname(&path.to_string_lossy())
1317 }
1318
1319 #[cfg(not(target_os = "linux"))]
1320 fn abstract_to_filesystem_path(abstract_name: &str) -> std::path::PathBuf {
1321 use std::collections::hash_map::DefaultHasher;
1322 use std::hash::Hash;
1323 use std::hash::Hasher;
1324
1325 let mut hasher = DefaultHasher::new();
1327 abstract_name.hash(&mut hasher);
1328 let hash = hasher.finish();
1329
1330 let process_id = std::process::id();
1332
1333 std::path::PathBuf::from(format!("/tmp/hyperactor_{}_{:x}", process_id, hash))
1335 }
1336
1337 pub fn from_pathname(name: &str) -> anyhow::Result<Self> {
1339 Ok(Self::new(StdSocketAddr::from_pathname(name)?))
1340 }
1341 }
1342
1343 impl TryFrom<SocketAddr> for StdSocketAddr {
1344 type Error = anyhow::Error;
1345
1346 fn try_from(value: SocketAddr) -> Result<Self, Self::Error> {
1347 match value {
1348 SocketAddr::Bound(addr) => Ok(*addr),
1349 SocketAddr::Unbound => Err(anyhow::anyhow!(
1350 "std::os::unix::SocketAddr must be a bound address"
1351 )),
1352 }
1353 }
1354 }
1355}
1356
1357pub(crate) mod tcp {
1358 use tokio::net::TcpListener;
1359 use tokio::net::TcpStream;
1360
1361 use super::*;
1362
1363 #[derive(Debug)]
1364 pub(crate) struct TcpLink {
1365 pub(super) addr: SocketAddr,
1366 pub(super) session_id: SessionId,
1367 pub(super) stream_id: u8,
1368 }
1369
1370 #[async_trait]
1371 impl Link for TcpLink {
1372 type Stream = TcpStream;
1373
1374 fn dest(&self) -> ChannelAddr {
1375 ChannelAddr::Tcp(self.addr)
1376 }
1377
1378 fn link_id(&self) -> SessionId {
1379 self.session_id
1380 }
1381
1382 async fn next(&mut self) -> Result<Self::Stream, ClientError> {
1383 let session_id = self.session_id;
1384 let mut backoff = ExponentialBackoffBuilder::new()
1385 .with_initial_interval(Duration::from_millis(1))
1386 .with_multiplier(2.0)
1387 .with_randomization_factor(0.1)
1388 .with_max_interval(Duration::from_millis(1000))
1389 .with_max_elapsed_time(None)
1390 .build();
1391 loop {
1392 match TcpStream::connect(&self.addr).await {
1393 Ok(mut stream) => {
1394 stream.set_nodelay(true).map_err(|err| {
1395 ClientError::Connect(
1396 self.dest(),
1397 err,
1398 "cannot disable Nagle algorithm".to_string(),
1399 )
1400 })?;
1401 write_link_init(&mut stream, session_id, self.stream_id)
1402 .await
1403 .map_err(|err| ClientError::Io(self.dest(), err))?;
1404 return Ok(stream);
1405 }
1406 Err(err) => {
1407 tracing::debug!(error = %err, "tcp connect failed, backing off");
1408 if let Some(delay) = backoff.next_backoff() {
1409 tokio::time::sleep(delay).await;
1410 }
1411 }
1412 }
1413 }
1414 }
1415 }
1416
1417 #[derive(Debug)]
1419 pub(crate) struct TcpSocketListener {
1420 pub(super) inner: TcpListener,
1421 pub(super) addr: SocketAddr,
1422 }
1423
1424 #[async_trait]
1425 impl super::Listener for TcpSocketListener {
1426 type Stream = TcpStream;
1427
1428 async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
1429 let (stream, peer_addr) = self
1430 .inner
1431 .accept()
1432 .await
1433 .map_err(|err| ServerError::Io(ChannelAddr::Tcp(self.addr), err))?;
1434 stream
1435 .set_nodelay(true)
1436 .map_err(|err| ServerError::Io(ChannelAddr::Tcp(self.addr), err))?;
1437 Ok((stream, ChannelAddr::Tcp(peer_addr)))
1438 }
1439 }
1440
1441 pub(crate) fn link(addr: SocketAddr, session_id: SessionId, stream_id: u8) -> TcpLink {
1443 TcpLink {
1444 addr,
1445 session_id,
1446 stream_id,
1447 }
1448 }
1449}
1450
1451pub(crate) mod meta {
1453 use std::io;
1454 use std::path::PathBuf;
1455 use std::sync::Arc;
1456
1457 use anyhow::Result;
1458 use tokio_rustls::TlsAcceptor;
1459 use tokio_rustls::TlsConnector;
1460
1461 use super::*;
1462 use crate::config::Pem;
1463 use crate::config::PemBundle;
1464
1465 const THRIFT_TLS_SRV_CA_PATH_ENV: &str = "THRIFT_TLS_SRV_CA_PATH";
1466 const DEFAULT_SRV_CA_PATH: &str = "/var/facebook/rootcanal/ca.pem";
1467 const THRIFT_TLS_CL_CERT_PATH_ENV: &str = "THRIFT_TLS_CL_CERT_PATH";
1468 const THRIFT_TLS_CL_KEY_PATH_ENV: &str = "THRIFT_TLS_CL_KEY_PATH";
1469 const DEFAULT_SERVER_PEM_PATH: &str = "/var/facebook/x509_identities/server.pem";
1470
1471 #[allow(clippy::result_large_err)] pub(crate) fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
1473 let parts = addr_string.rsplit_once(":");
1475 match parts {
1476 Some((hostname, port_str)) => {
1477 let Ok(port) = port_str.parse() else {
1478 return Err(ChannelError::InvalidAddress(addr_string.to_string()));
1479 };
1480 Ok(ChannelAddr::MetaTls(TlsAddr::new(hostname, port)))
1481 }
1482 _ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
1483 }
1484 }
1485
1486 pub(super) fn get_server_pem_bundle() -> PemBundle {
1489 let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1490 .map(PathBuf::from)
1491 .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1492 let server_pem_path = PathBuf::from(DEFAULT_SERVER_PEM_PATH);
1493 PemBundle {
1494 ca: Pem::File(ca_path),
1495 cert: Pem::File(server_pem_path.clone()),
1496 key: Pem::File(server_pem_path),
1497 }
1498 }
1499
1500 fn get_client_pem_bundle() -> Option<PemBundle> {
1503 let cert_path = std::env::var_os(THRIFT_TLS_CL_CERT_PATH_ENV)?;
1504 let key_path = std::env::var_os(THRIFT_TLS_CL_KEY_PATH_ENV)?;
1505 let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1506 .map(PathBuf::from)
1507 .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1508 Some(PemBundle {
1509 ca: Pem::File(ca_path),
1510 cert: Pem::File(PathBuf::from(cert_path)),
1511 key: Pem::File(PathBuf::from(key_path)),
1512 })
1513 }
1514
1515 pub(crate) fn tls_acceptor(enforce_client_tls: bool) -> Result<TlsAcceptor> {
1517 let bundle = get_server_pem_bundle();
1518 tls::tls_acceptor_from_bundle(&bundle, enforce_client_tls)
1519 }
1520
1521 pub(super) fn try_tls_connector() -> Result<TlsConnector> {
1527 tls_connector()
1528 }
1529
1530 fn tls_connector() -> Result<TlsConnector> {
1533 let _ = rustls::crypto::ring::default_provider().install_default();
1536
1537 let ca_path = std::env::var_os(THRIFT_TLS_SRV_CA_PATH_ENV)
1538 .map(PathBuf::from)
1539 .unwrap_or_else(|| PathBuf::from(DEFAULT_SRV_CA_PATH));
1540 let ca_pem = Pem::File(ca_path);
1541 let root_store = tls::build_root_store(&ca_pem)?;
1542
1543 let config = rustls::ClientConfig::builder().with_root_certificates(Arc::new(root_store));
1545
1546 let config = if let Some(bundle) = get_client_pem_bundle() {
1547 let certs = tls::load_certs(&bundle.cert)?;
1548 let key = tls::load_key(&bundle.key)?;
1549 config
1550 .with_client_auth_cert(certs, key)
1551 .map_err(|e| anyhow::anyhow!("load client certs: {}", e))?
1552 } else {
1553 config.with_no_client_auth()
1554 };
1555
1556 Ok(TlsConnector::from(Arc::new(config)))
1557 }
1558
1559 pub fn link(
1561 addr: TlsAddr,
1562 session_id: SessionId,
1563 stream_id: u8,
1564 ) -> Result<tls::TlsLink, ClientError> {
1565 let connector = tls_connector().map_err(|e| {
1566 ClientError::Connect(
1567 ChannelAddr::MetaTls(addr.clone()),
1568 io::Error::other(e.to_string()),
1569 "failed to create TLS connector".to_string(),
1570 )
1571 })?;
1572 let TlsAddr { hostname, port } = addr;
1573 Ok(tls::TlsLink {
1574 hostname,
1575 port,
1576 connector,
1577 addr_type: tls::TlsAddrType::MetaTls,
1578 session_id,
1579 stream_id,
1580 })
1581 }
1582}
1583
1584pub(crate) mod tls {
1586 use std::io;
1587 use std::io::BufReader;
1588 use std::sync::Arc;
1589
1590 use anyhow::Context;
1591 use anyhow::Result;
1592 use rustls::RootCertStore;
1593 use rustls::pki_types::CertificateDer;
1594 use rustls::pki_types::PrivateKeyDer;
1595 use rustls::pki_types::ServerName;
1596 use tokio::net::TcpStream;
1597 use tokio_rustls::TlsAcceptor;
1598 use tokio_rustls::TlsConnector;
1599 use tokio_rustls::client::TlsStream;
1600
1601 use super::*;
1602 use crate::channel::TlsAddr;
1603 use crate::config::Pem;
1604 use crate::config::PemBundle;
1605 use crate::config::TLS_CA;
1606 use crate::config::TLS_CERT;
1607 use crate::config::TLS_KEY;
1608
1609 #[derive(Debug, Clone, Copy)]
1611 pub(crate) enum TlsAddrType {
1612 Tls,
1613 MetaTls,
1614 }
1615
1616 #[allow(clippy::result_large_err)]
1618 pub(crate) fn parse(addr_string: &str) -> Result<ChannelAddr, ChannelError> {
1619 let parts = addr_string.rsplit_once(":");
1621 match parts {
1622 Some((hostname, port_str)) => {
1623 let Ok(port) = port_str.parse() else {
1624 return Err(ChannelError::InvalidAddress(addr_string.to_string()));
1625 };
1626 Ok(ChannelAddr::Tls(TlsAddr::new(hostname, port)))
1627 }
1628 _ => Err(ChannelError::InvalidAddress(addr_string.to_string())),
1629 }
1630 }
1631
1632 pub(super) fn load_certs(pem: &Pem) -> Result<Vec<CertificateDer<'static>>> {
1634 let mut reader = BufReader::new(pem.reader()?);
1635 let certs = rustls_pemfile::certs(&mut reader)
1636 .filter_map(Result::ok)
1637 .collect();
1638 Ok(certs)
1639 }
1640
1641 pub(super) fn load_key(pem: &Pem) -> Result<PrivateKeyDer<'static>> {
1643 let mut reader = BufReader::new(pem.reader()?);
1644 loop {
1645 break match rustls_pemfile::read_one(&mut reader)? {
1646 Some(rustls_pemfile::Item::Pkcs1Key(key)) => Ok(PrivateKeyDer::Pkcs1(key)),
1647 Some(rustls_pemfile::Item::Pkcs8Key(key)) => Ok(PrivateKeyDer::Pkcs8(key)),
1648 Some(rustls_pemfile::Item::Sec1Key(key)) => Ok(PrivateKeyDer::Sec1(key)),
1649 Some(_) => continue,
1650 None => anyhow::bail!("no private key found in TLS key file"),
1651 };
1652 }
1653 }
1654
1655 pub(super) fn build_root_store(ca_pem: &Pem) -> Result<RootCertStore> {
1657 let mut root_store = RootCertStore::empty();
1658 let certs = load_certs(ca_pem)?;
1659 root_store.add_parsable_certificates(certs);
1660 Ok(root_store)
1661 }
1662
1663 fn get_pem_bundle() -> PemBundle {
1665 PemBundle {
1666 ca: hyperactor_config::global::get_cloned(TLS_CA),
1667 cert: hyperactor_config::global::get_cloned(TLS_CERT),
1668 key: hyperactor_config::global::get_cloned(TLS_KEY),
1669 }
1670 }
1671
1672 pub(super) fn tls_acceptor_from_bundle(
1675 bundle: &PemBundle,
1676 enforce_client_tls: bool,
1677 ) -> Result<TlsAcceptor> {
1678 let _ = rustls::crypto::ring::default_provider().install_default();
1681
1682 let certs = load_certs(&bundle.cert).context("load TLS certificate")?;
1683 let key = load_key(&bundle.key).context("load TLS key")?;
1684 let root_store = build_root_store(&bundle.ca).context("build root cert store")?;
1685
1686 let config = rustls::ServerConfig::builder();
1687 let config = if enforce_client_tls {
1688 let client_verifier =
1690 rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
1691 .build()
1692 .map_err(|e| anyhow::anyhow!("failed to build client verifier: {}", e))?;
1693 config.with_client_cert_verifier(client_verifier)
1694 } else {
1695 config.with_no_client_auth()
1696 }
1697 .with_single_cert(certs, key)?;
1698
1699 Ok(TlsAcceptor::from(Arc::new(config)))
1700 }
1701
1702 pub(crate) fn tls_acceptor() -> Result<TlsAcceptor> {
1704 tls_acceptor_from_bundle(&get_pem_bundle(), true)
1705 }
1706
1707 pub(super) fn tls_connector_from_bundle(bundle: &PemBundle) -> Result<TlsConnector> {
1709 let _ = rustls::crypto::ring::default_provider().install_default();
1712
1713 let certs = load_certs(&bundle.cert).context("load TLS certificate")?;
1714 let key = load_key(&bundle.key).context("load TLS key")?;
1715 let root_store = build_root_store(&bundle.ca).context("build root cert store")?;
1716
1717 let config = rustls::ClientConfig::builder()
1718 .with_root_certificates(Arc::new(root_store))
1719 .with_client_auth_cert(certs, key)
1720 .context("configure client auth")?;
1721
1722 Ok(TlsConnector::from(Arc::new(config)))
1723 }
1724
1725 fn tls_connector() -> Result<TlsConnector> {
1727 tls_connector_from_bundle(&get_pem_bundle())
1728 }
1729
1730 pub(crate) struct TlsLink {
1732 pub(crate) hostname: Hostname,
1733 pub(crate) port: Port,
1734 pub(crate) connector: TlsConnector,
1735 pub(crate) addr_type: TlsAddrType,
1736 pub(crate) session_id: SessionId,
1737 pub(crate) stream_id: u8,
1738 }
1739
1740 impl std::fmt::Debug for TlsLink {
1741 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1742 f.debug_struct("TlsLink")
1743 .field("hostname", &self.hostname)
1744 .field("port", &self.port)
1745 .field("addr_type", &self.addr_type)
1746 .finish()
1747 }
1748 }
1749
1750 #[async_trait]
1751 impl Link for TlsLink {
1752 type Stream = TlsStream<TcpStream>;
1753
1754 fn dest(&self) -> ChannelAddr {
1755 let addr = TlsAddr::new(self.hostname.clone(), self.port);
1756 match self.addr_type {
1757 TlsAddrType::Tls => ChannelAddr::Tls(addr),
1758 TlsAddrType::MetaTls => ChannelAddr::MetaTls(addr),
1759 }
1760 }
1761
1762 fn link_id(&self) -> SessionId {
1763 self.session_id
1764 }
1765
1766 async fn next(&mut self) -> Result<Self::Stream, ClientError> {
1767 let session_id = self.session_id;
1768 let server_name = ServerName::try_from(self.hostname.clone()).map_err(|e| {
1769 ClientError::Connect(
1770 self.dest(),
1771 io::Error::other(e.to_string()),
1772 "invalid server name".to_string(),
1773 )
1774 })?;
1775 let mut backoff = ExponentialBackoffBuilder::new()
1776 .with_initial_interval(Duration::from_millis(1))
1777 .with_multiplier(2.0)
1778 .with_randomization_factor(0.1)
1779 .with_max_interval(Duration::from_millis(1000))
1780 .with_max_elapsed_time(None)
1781 .build();
1782 loop {
1783 let mut addrs = (self.hostname.as_ref(), self.port)
1784 .to_socket_addrs()
1785 .map_err(|_| ClientError::Resolve(self.dest()))?;
1786 let addr = addrs.next().ok_or(ClientError::Resolve(self.dest()))?;
1787 match TcpStream::connect(&addr).await {
1788 Ok(stream) => {
1789 stream.set_nodelay(true).map_err(|err| {
1790 ClientError::Connect(
1791 self.dest(),
1792 err,
1793 "cannot disable Nagle algorithm".to_string(),
1794 )
1795 })?;
1796 let mut tls_stream = self
1797 .connector
1798 .connect(server_name.clone(), stream)
1799 .await
1800 .map_err(|err| {
1801 tracing::info!(
1802 dest = %self.dest(),
1803 error = %err,
1804 "TLS handshake failed"
1805 );
1806 ClientError::Connect(
1807 self.dest(),
1808 err,
1809 format!("cannot establish TLS connection to {:?}", server_name),
1810 )
1811 })?;
1812 write_link_init(&mut tls_stream, session_id, self.stream_id)
1813 .await
1814 .map_err(|err| ClientError::Io(self.dest(), err))?;
1815 return Ok(tls_stream);
1816 }
1817 Err(err) => {
1818 tracing::debug!(error = %err, "tls connect failed, backing off");
1819 if let Some(delay) = backoff.next_backoff() {
1820 tokio::time::sleep(delay).await;
1821 }
1822 }
1823 }
1824 }
1825 }
1826 }
1827
1828 pub fn link(
1830 addr: TlsAddr,
1831 session_id: SessionId,
1832 stream_id: u8,
1833 ) -> Result<TlsLink, ClientError> {
1834 let connector = tls_connector().map_err(|e| {
1835 ClientError::Connect(
1836 ChannelAddr::Tls(addr.clone()),
1837 io::Error::other(e.to_string()),
1838 "failed to create TLS connector".to_string(),
1839 )
1840 })?;
1841 let TlsAddr { hostname, port } = addr;
1842 Ok(TlsLink {
1843 hostname,
1844 port,
1845 connector,
1846 addr_type: TlsAddrType::Tls,
1847 session_id,
1848 stream_id,
1849 })
1850 }
1851
1852 #[cfg(test)]
1853 mod tests {
1854 use timed_test::async_timed_test;
1855
1856 use super::*;
1857 use crate::channel::Rx;
1858 use crate::channel::net::server;
1859 use crate::config::Pem;
1860 use crate::config::TLS_CA;
1861 use crate::config::TLS_CERT;
1862 use crate::config::TLS_KEY;
1863
1864 const TEST_CA_CERT: &str = r#"-----BEGIN CERTIFICATE-----
1868MIIDBTCCAe2gAwIBAgIUaGNmboiIosG+8Up0vgDr/+cg+2IwDQYJKoZIhvcNAQEL
1869BQAwEjEQMA4GA1UEAwwHVGVzdCBDQTAeFw0yNjAxMjgxNzA4MzlaFw0yNzAxMjgx
1870NzA4MzlaMBIxEDAOBgNVBAMMB1Rlc3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IB
1871DwAwggEKAoIBAQC9RBoMYXCajklswt8Vi1JI1lEYzic0WNOmz45vG/7H6jTWkgL3
1872K5Ri+Seg3MobDNc48YHWXYm4hP9wCzkx8ih3ntT5XiY1My/G3jLUuoIEE9pF/BoJ
1873YQwZVoPNFhA9WhXNRsINf1cXFf8NzRfXpxBfKWtQJxYXU4JiDBQ6rLnQQABo8JmQ
1874vYFhJbBaYip5jTSiVNn7mB1zNr5jsVxuoSF53Pb7xQ76bwBdOq4zd6PSxL5/lr4G
1875cHSoxwZQdZMG7PL6hbxDQ2S2YI2lYVET1zwc2WPKCfjbEXBC/jzx828CInQtuksk
187618gJt6xHkTFEA8CSA29GM3lejnwYWf51xyyBAgMBAAGjUzBRMB0GA1UdDgQWBBRX
1877cbxSZ70NsUkAS3Hhy6irugywJDAfBgNVHSMEGDAWgBRXcbxSZ70NsUkAS3Hhy6ir
1878ugywJDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQA7aAFfyW67
1879Z+uGSVYhpsT/uH/3Z3nr7X1smTz5CGEfq2czEcTC7gbYI2l8GZ47GPfnAvHTBZVm
1880V/XncBCsj7/thOh2jYEHFyCbPckoaSCRyCOnK7LPUlr4HN5uP9EFe45qBLCJDEoY
1881GTTw7MtzwdovfjchNfKQCTtkBJCXQ95WLCf6UOh02Sn28UTlgfXzF0X0FrcWqWa3
1882uJZd4XOo4O6hKKlHaBaQPiEr++1xc3SWPV7jZHbckI/vKBnDdEZ9JQX5fFZuypUI
1883sgomYHxvxrU2hWx+7k53CRdjfaIvT9Ie44z9sSdsU/+blw2S8f/ZTmuECoIAAXYO
18840qpzlxZMdr7T
1885-----END CERTIFICATE-----"#;
1886
1887 const TEST_SERVER_CERT: &str = r#"-----BEGIN CERTIFICATE-----
1889MIIDJDCCAgygAwIBAgIUaz66DsWaH5ZXM4hCFnbVbMsyN1cwDQYJKoZIhvcNAQEL
1890BQAwEjEQMA4GA1UEAwwHVGVzdCBDQTAeFw0yNjAxMjgxNzA4MzlaFw0yNzAxMjgx
1891NzA4MzlaMBQxEjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQAD
1892ggEPADCCAQoCggEBAKCbp++qNyTn5LOsV0h9gLKJALBcjg2A14I3804N9UyDhPW2
1893QKQ2W424u2P1MfKrw/2C+CErGlrADlnco2RQVDAarAIuGdFvBOt5UezqOS7Mk4OS
18949MlS7NZnMbc37KuM9UIG5ScJjXR/Z5z9dxeR0I9y3n0Ix6khbV7tOSHobiweI0FI
18958LftBS+CQnXr6vbWPcHcW6Z0FHUv7IWhqMWmv9PlZRGe9Y6VzXrRp0PBnZMOnAYf
1896aMQUwYRswWdm9j9Z1sMdTJ14G+KVmO3Vj6XI6Sm9uIcYhlwG/kORwogJFWlVuP9o
1897rloFRCjyHJ1d7GZqqnRyHHDDCBms8ed+3YfEYQECAwEAAaNwMG4wLAYDVR0RBCUw
1898I4IJbG9jYWxob3N0hwR/AAABhxAAAAAAAAAAAAAAAAAAAAABMB0GA1UdDgQWBBQl
1899J4vxUoCzqqeTwQAiLqE8wYezKzAfBgNVHSMEGDAWgBRXcbxSZ70NsUkAS3Hhy6ir
1900ugywJDANBgkqhkiG9w0BAQsFAAOCAQEAnXHIBDQ4AHAMV71piTOuI41ShASQed6L
1901bi7XUMZgZDslLkfU1vnP3BlwpliraBsAytSYQC6kbytOuz1uQ4K7yLb2tAAmUgEO
1902EdIVt9SXr5tCcIPeLmInF0pysPqjZO8n7vtJyd9gryKqdhm1uzA7WQWq/Az8a9Sk
1903uW2J6Oc5p6P7Mf3/ixqXzvGRo8rzu0CUJOJ67UTE/HhbJuplQ5dep5CEEOAIsAtH
1904zn9O4rW92ueBkoBJM++YILS1vQ7jKc2N3RNrnHm7FeootBrtR9mBi0TH97K73ZPZ
19052Cdhnym0CsCJggrllFGH32cYo7+K2PO7/4oj5XbBCSWcssicvd8ovg==
1906-----END CERTIFICATE-----"#;
1907
1908 const TEST_SERVER_KEY: &str = r#"-----BEGIN PRIVATE KEY-----
1912MIIEugIBADANBgkqhkiG9w0BAQEFAASCBKQwggSgAgEAAoIBAQCgm6fvqjck5+Sz
1913rFdIfYCyiQCwXI4NgNeCN/NODfVMg4T1tkCkNluNuLtj9THyq8P9gvghKxpawA5Z
19143KNkUFQwGqwCLhnRbwTreVHs6jkuzJODkvTJUuzWZzG3N+yrjPVCBuUnCY10f2ec
1915/XcXkdCPct59CMepIW1e7Tkh6G4sHiNBSPC37QUvgkJ16+r21j3B3FumdBR1L+yF
1916oajFpr/T5WURnvWOlc160adDwZ2TDpwGH2jEFMGEbMFnZvY/WdbDHUydeBvilZjt
19171Y+lyOkpvbiHGIZcBv5DkcKICRVpVbj/aK5aBUQo8hydXexmaqp0chxwwwgZrPHn
1918ft2HxGEBAgMBAAECgf8G5qlQov+7ljs9fSpC8yGUik59RXzVF7Qq5DyQHglsQDp2
1919VF5yr+M/M7DZmq+KvdauDfKbej6np5j2Q4TByrHTX1IExfZWCW8srwnWJDpQyHmO
1920LcJW5DlI/SYluUFyHZxsOd+ezcpGNzM8i6eSW7GaeFUXCkmJ+uW4LnlF+7bALnnd
1921D6sak/58EsII+IJyd4lFn+voszlPn3CZGR0jkp21rvpaKgrMIsKVWWQO/sLDU5pr
1922VbpBThcLU5gRcnQouQX12e2VTCIlFu75WTsJ8V/KnEaOZUVlU/B/Bs+WQF3U+/Jo
1923eX4N+D6OsEcNQjERAFyWujxsl1WpD4uSsbFMN0ECgYEA2b7AdL+oKPQHku2KcBhr
1924Zw8K4tMDlr2VPPNwZcBTLo+O71vv/xXjMcXrXmowzkgEQckUmt1VB46riyydhwdP
1925/n9ciWcz0Va/nwHR6Y9F9unBiyUBP7PRhRyjQyRZZRGDSJvP+Xmc5UJFpRr07VLU
1926nfgMXDj37vXzKDpfhdEB2nkCgYEAvNMfA8P8w3+6246x5YHflvTkPdw+2oyge+LD
1927mphB/w7SF8mlyNGloj3+KBZmd9SkvT57wCvO96Y9/n+mBAVisRggc0hK4ymOVYhb
1928+im/JvqGQMbVeg6iCOHnWdaZf9tL8uVsegQy3kVTN7vAa+CMFgX1dt65cGBX6XkB
192944pYmMkCgYALhbiRdQLlB+TOtZs5y1EDpxwgXKI3+9hF3Wv5NnAwapBZwje0++eF
19303r9Rw7TJda4j/QwGFehF+hrBxp6fYpetE/hFnRx0225Qb7w368j8A+ql/lNOl6li
1931rd1F1EqWupKD6RrcTL8sspEU55RGaretlE6zIqCcGI/BdTVQ03qRoQKBgHDC3zWf
1932d7XD9HGjQGdfbIe4jQjIGxzmd/wjik4q+NZ5IkukVwWa9P/zZ3DHF8Ad05dT1hEH
19332FwaAdGWpyyljq9VSiOuG1KXAXHgsZSuE4ISf9P1KYzvaiJFzaPfvOEWs79E9MfU
19349A+6dJzG2X1SpjWMr26iSTlrv3QkmFUqzAfJAoGASBkn4wls+oC5rv/Mch43pBv5
1935UmKru4ltnEHJZdbSi2DJ+AnDLD222JCasb1VT1tm2XgW6DBqrdVRPPP6GOlB0MHU
1936+3ULtZxAczt7I+ST2bo0/DV2Hse89Cm63w4wLOiVZs7+1wrAzJZLokWF7Q5gesra
1937u19txmtkiMEH+aNmekk=
1938-----END PRIVATE KEY-----"#;
1939
1940 #[async_timed_test(timeout_secs = 30)]
1941 async fn test_tls_basic() {
1942 let _ = rustls::crypto::ring::default_provider().install_default();
1945
1946 let config = hyperactor_config::global::lock();
1948 let _guard_cert =
1949 config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1950 let _guard_key =
1951 config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1952 let _guard_ca =
1953 config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1954
1955 let addr = TlsAddr::new("localhost", 0);
1957
1958 let (local_addr, mut rx) =
1959 server::serve::<u64>(ChannelAddr::Tls(addr), None).expect("failed to serve");
1960
1961 let tx: super::NetTx<u64> = super::spawn(
1963 link(
1964 match &local_addr {
1965 ChannelAddr::Tls(addr) => addr.clone(),
1966 _ => panic!("unexpected address type"),
1967 },
1968 SessionId::random(),
1969 0,
1970 )
1971 .expect("failed to create link"),
1972 );
1973
1974 tx.post(42u64);
1976
1977 let received = rx.recv().await.expect("failed to receive");
1979 assert_eq!(received, 42u64);
1980 }
1981
1982 #[async_timed_test(timeout_secs = 30)]
1983 async fn test_tls_multiple_messages() {
1984 let _ = rustls::crypto::ring::default_provider().install_default();
1985
1986 let config = hyperactor_config::global::lock();
1988 let _guard_cert =
1989 config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
1990 let _guard_key =
1991 config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
1992 let _guard_ca =
1993 config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
1994
1995 let addr = TlsAddr::new("localhost", 0);
1996
1997 let (local_addr, mut rx) =
1998 server::serve::<String>(ChannelAddr::Tls(addr), None).expect("failed to serve");
1999 let tx: super::NetTx<String> = super::spawn(
2000 link(
2001 match &local_addr {
2002 ChannelAddr::Tls(addr) => addr.clone(),
2003 _ => panic!("unexpected address type"),
2004 },
2005 SessionId::random(),
2006 0,
2007 )
2008 .expect("failed to create link"),
2009 );
2010
2011 for i in 0..10 {
2013 tx.post(format!("message {}", i));
2014 }
2015
2016 for i in 0..10 {
2018 let received = rx.recv().await.expect("failed to receive");
2019 assert_eq!(received, format!("message {}", i));
2020 }
2021 }
2022
2023 #[test]
2024 fn test_tls_parse_hostname_port() {
2025 let addr = parse("localhost:8080").expect("failed to parse");
2026 assert!(matches!(
2027 addr,
2028 ChannelAddr::Tls(TlsAddr { hostname, port })
2029 if hostname == "localhost" && port == 8080
2030 ));
2031 }
2032
2033 #[test]
2034 fn test_tls_parse_socket_addr() {
2035 let addr = parse("127.0.0.1:8080").expect("failed to parse");
2036 assert!(matches!(
2037 addr,
2038 ChannelAddr::Tls(TlsAddr { hostname, port })
2039 if hostname == "127.0.0.1" && port == 8080
2040 ));
2041 }
2042
2043 #[test]
2044 fn test_tls_certs_parsing() {
2045 let cert_pem = Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec());
2047 let key_pem = Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec());
2048 let ca_pem = Pem::Value(TEST_CA_CERT.as_bytes().to_vec());
2049
2050 let certs = super::load_certs(&cert_pem).expect("failed to load certs");
2051 assert!(!certs.is_empty(), "expected at least one certificate");
2052
2053 let _key = super::load_key(&key_pem).expect("failed to load key");
2054
2055 let root_store = super::build_root_store(&ca_pem).expect("failed to build root store");
2056 assert!(!root_store.is_empty(), "expected at least one CA cert");
2057 }
2058
2059 #[test]
2060 fn test_tls_acceptor_creation() {
2061 let _ = rustls::crypto::ring::default_provider().install_default();
2064
2065 let config = hyperactor_config::global::lock();
2067 let _guard_cert =
2068 config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
2069 let _guard_key =
2070 config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
2071 let _guard_ca =
2072 config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
2073
2074 let _acceptor = super::tls_acceptor().expect("failed to create TLS acceptor");
2076 }
2077
2078 #[test]
2079 fn test_tls_connector_creation() {
2080 let _ = rustls::crypto::ring::default_provider().install_default();
2083
2084 let config = hyperactor_config::global::lock();
2086 let _guard_cert =
2087 config.override_key(TLS_CERT, Pem::Value(TEST_SERVER_CERT.as_bytes().to_vec()));
2088 let _guard_key =
2089 config.override_key(TLS_KEY, Pem::Value(TEST_SERVER_KEY.as_bytes().to_vec()));
2090 let _guard_ca =
2091 config.override_key(TLS_CA, Pem::Value(TEST_CA_CERT.as_bytes().to_vec()));
2092
2093 let _connector = super::tls_connector().expect("failed to create TLS connector");
2095 }
2096 }
2097}
2098
2099fn oss_pem_bundle() -> crate::config::PemBundle {
2101 crate::config::PemBundle {
2102 ca: hyperactor_config::global::get_cloned(crate::config::TLS_CA),
2103 cert: hyperactor_config::global::get_cloned(crate::config::TLS_CERT),
2104 key: hyperactor_config::global::get_cloned(crate::config::TLS_KEY),
2105 }
2106}
2107
2108pub fn try_tls_pem_bundle() -> Option<crate::config::PemBundle> {
2118 let oss_bundle = oss_pem_bundle();
2119 if oss_bundle.ca.reader().is_ok() {
2120 return Some(oss_bundle);
2121 }
2122 tracing::debug!("OSS TLS bundle: CA not readable, trying Meta paths");
2123
2124 let meta_bundle = meta::get_server_pem_bundle();
2125 if meta_bundle.ca.reader().is_ok() {
2126 return Some(meta_bundle);
2127 }
2128 tracing::debug!("Meta TLS bundle: CA not readable, no TLS available");
2129
2130 None
2131}
2132
2133pub fn try_tls_acceptor(enforce_client_tls: bool) -> Option<tokio_rustls::TlsAcceptor> {
2153 let oss_bundle = oss_pem_bundle();
2154 if let Ok(acceptor) = tls::tls_acceptor_from_bundle(&oss_bundle, enforce_client_tls) {
2155 return Some(acceptor);
2156 }
2157 tracing::debug!("OSS TLS acceptor failed, trying Meta paths");
2158
2159 let meta_bundle = meta::get_server_pem_bundle();
2160 if let Ok(acceptor) = tls::tls_acceptor_from_bundle(&meta_bundle, enforce_client_tls) {
2161 return Some(acceptor);
2162 }
2163 tracing::debug!("Meta TLS acceptor failed, no TLS available");
2164
2165 None
2166}
2167
2168pub fn try_tls_connector() -> Option<tokio_rustls::TlsConnector> {
2180 let oss_bundle = oss_pem_bundle();
2181 if let Ok(connector) = tls::tls_connector_from_bundle(&oss_bundle) {
2182 return Some(connector);
2183 }
2184 tracing::debug!("OSS TLS connector failed, trying Meta paths");
2185
2186 if let Ok(connector) = meta::try_tls_connector() {
2187 return Some(connector);
2188 }
2189 tracing::debug!("Meta TLS connector failed, no TLS available");
2190
2191 None
2192}
2193
2194#[cfg(test)]
2195mod tests {
2196
2197 #![expect(
2198 clippy::await_holding_invalid_type,
2199 reason = "tracing_test::traced_test macro expansion holds tracing::span::Entered across awaits; can't be fixed in our code"
2200 )]
2201
2202 use std::assert_matches;
2203 use std::collections::VecDeque;
2204 use std::marker::PhantomData;
2205 use std::sync::Arc;
2206 use std::sync::RwLock;
2207 use std::sync::atomic::AtomicBool;
2208 use std::sync::atomic::AtomicU64;
2209 use std::sync::atomic::Ordering;
2210 use std::time::Duration;
2211 #[cfg(target_os = "linux")] use std::time::UNIX_EPOCH;
2213
2214 #[cfg(target_os = "linux")] use anyhow::Result;
2216 use bytes::Bytes;
2217 use rand::RngExt as _;
2218 use rand::SeedableRng as _;
2219 use rand::distr::Alphanumeric;
2220 use rand::rngs::SysRng;
2221 use timed_test::async_timed_test;
2222 use tokio::io::AsyncWrite;
2223 use tokio::io::DuplexStream;
2224 use tokio::io::ReadHalf;
2225 use tokio::io::WriteHalf;
2226 use tokio::task::JoinHandle;
2227 use tokio_util::sync::CancellationToken;
2228
2229 use super::server;
2230 use super::*;
2231 use crate::channel;
2232 use crate::channel::net::framed::FrameReader;
2233 use crate::channel::net::framed::FrameWrite;
2234 use crate::channel::net::server::AcceptorLink;
2235 use crate::config;
2236 use crate::metrics;
2237 use crate::sync::mvar::MVar;
2238
2239 fn logs_assert_unscoped(f: impl Fn(&[&str]) -> Result<(), String>) {
2243 let buf = tracing_test::internal::global_buf().lock().unwrap();
2244 let logs_str = std::str::from_utf8(&buf).expect("Logs contain invalid UTF8");
2245 let lines: Vec<&str> = logs_str.lines().collect();
2246 match f(&lines) {
2247 Ok(()) => {}
2248 Err(msg) => panic!("{}", msg),
2249 }
2250 }
2251
2252 #[cfg(target_os = "linux")] #[tracing_test::traced_test]
2254 #[tokio::test]
2255 async fn test_unix_basic() -> Result<()> {
2256 let timestamp = std::time::SystemTime::now()
2257 .duration_since(UNIX_EPOCH)
2258 .unwrap()
2259 .as_nanos();
2260 let unique_address = format!("test_unix_basic_{}", timestamp);
2261
2262 let (addr, mut rx) = server::serve::<u64>(
2263 ChannelAddr::Unix(unix::SocketAddr::from_abstract_name(&unique_address)?),
2264 None,
2265 )
2266 .unwrap();
2267
2268 {
2275 let tx: ChannelTx<u64> = channel::dial::<u64>(addr.clone()).unwrap();
2276 tx.post(123);
2277 assert_eq!(rx.recv().await.unwrap(), 123);
2278 }
2279
2280 {
2281 let tx = channel::dial::<u64>(addr.clone()).unwrap();
2282 tx.post(321);
2283 tx.post(111);
2284 tx.post(444);
2285
2286 assert_eq!(rx.recv().await.unwrap(), 321);
2287 assert_eq!(rx.recv().await.unwrap(), 111);
2288 assert_eq!(rx.recv().await.unwrap(), 444);
2289 }
2290
2291 {
2292 let tx = channel::dial::<u64>(addr).unwrap();
2293 drop(rx);
2294
2295 let (return_tx, return_rx) = oneshot::channel();
2296 tx.try_post(123, return_tx);
2297 assert_matches!(
2298 return_rx.await,
2299 Ok(SendError {
2300 error: ChannelError::Closed,
2301 message: 123,
2302 ..
2303 })
2304 );
2305 }
2306
2307 Ok(())
2308 }
2309
2310 #[cfg(target_os = "linux")] #[tracing_test::traced_test]
2312 #[tokio::test]
2313 async fn test_unix_basic_client_before_server() -> Result<()> {
2314 let timestamp = std::time::SystemTime::now()
2316 .duration_since(UNIX_EPOCH)
2317 .unwrap()
2318 .as_nanos();
2319 let socket_addr =
2320 unix::SocketAddr::from_abstract_name(&format!("test_unix_basic_{}", timestamp))
2321 .unwrap();
2322
2323 let addr = ChannelAddr::Unix(socket_addr.clone());
2325 let tx = crate::channel::dial::<u64>(addr.clone()).unwrap();
2326 tx.post(123);
2327
2328 let (_, mut rx) = server::serve::<u64>(ChannelAddr::Unix(socket_addr), None).unwrap();
2329 assert_eq!(rx.recv().await.unwrap(), 123);
2330
2331 tx.post(321);
2332 tx.post(111);
2333 tx.post(444);
2334
2335 assert_eq!(rx.recv().await.unwrap(), 321);
2336 assert_eq!(rx.recv().await.unwrap(), 111);
2337 assert_eq!(rx.recv().await.unwrap(), 444);
2338
2339 Ok(())
2340 }
2341
2342 #[tracing_test::traced_test]
2343 #[async_timed_test(timeout_secs = 60)]
2344 #[cfg_attr(not(fbcode_build), ignore)]
2346 async fn test_tcp_basic() {
2347 let (addr, mut rx) =
2348 server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap()), None).unwrap();
2349 {
2350 let tx = channel::dial::<u64>(addr.clone()).unwrap();
2351 tx.post(123);
2352 assert_eq!(rx.recv().await.unwrap(), 123);
2353 }
2354
2355 {
2356 let tx = channel::dial::<u64>(addr.clone()).unwrap();
2357 tx.post(321);
2358 tx.post(111);
2359 tx.post(444);
2360
2361 assert_eq!(rx.recv().await.unwrap(), 321);
2362 assert_eq!(rx.recv().await.unwrap(), 111);
2363 assert_eq!(rx.recv().await.unwrap(), 444);
2364 }
2365
2366 {
2367 let tx = channel::dial::<u64>(addr).unwrap();
2368 drop(rx);
2369
2370 let (return_tx, return_rx) = oneshot::channel();
2371 tx.try_post(123, return_tx);
2372 assert_matches!(
2373 return_rx.await,
2374 Ok(SendError {
2375 error: ChannelError::Closed,
2376 message: 123,
2377 ..
2378 })
2379 );
2380 }
2381 }
2382
2383 #[async_timed_test(timeout_secs = 5)]
2385 #[cfg_attr(not(fbcode_build), ignore)]
2387 async fn test_tcp_message_size() {
2388 let default_size_in_bytes = 100 * 1024 * 1024;
2389 let config = hyperactor_config::global::lock();
2391 let _guard1 = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1));
2392 let _guard2 = config.override_key(config::CODEC_MAX_FRAME_LENGTH, default_size_in_bytes);
2393
2394 let (addr, mut rx) =
2395 server::serve::<String>(ChannelAddr::Tcp("[::1]:0".parse().unwrap()), None).unwrap();
2396
2397 let tx = channel::dial::<String>(addr.clone()).unwrap();
2398 {
2400 let message = "a".repeat(default_size_in_bytes - 1024);
2402 tx.post(message.clone());
2403 assert_eq!(rx.recv().await.unwrap(), message);
2404 }
2405 {
2407 let (return_channel, return_receiver) = oneshot::channel();
2408 let message = "a".repeat(default_size_in_bytes + 1024);
2409 tx.try_post(message.clone(), return_channel);
2410 let returned = return_receiver.await.unwrap();
2411 assert_eq!(message, returned.message);
2412 }
2413 }
2414
2415 #[async_timed_test(timeout_secs = 30)]
2416 #[cfg_attr(not(fbcode_build), ignore)]
2418 async fn test_ack_flush() {
2419 let config = hyperactor_config::global::lock();
2420 let _guard_message_ack =
2423 config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 100000000);
2424 let _guard_delivery_timeout =
2425 config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(5));
2426
2427 let (addr, mut net_rx) =
2428 server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap()), None).unwrap();
2429 let net_tx = channel::dial::<u64>(addr.clone()).unwrap();
2430 let (tx, rx) = oneshot::channel();
2431 net_tx.try_post(1, tx);
2432 assert_eq!(net_rx.recv().await.unwrap(), 1);
2433 drop(net_rx);
2434 assert!(rx.await.is_err());
2437 }
2438
2439 #[async_timed_test(timeout_secs = 60)]
2440 #[cfg_attr(not(fbcode_build), ignore)]
2442 async fn test_meta_tls_basic() {
2443 hyperactor_telemetry::initialize_logging_for_test();
2444
2445 let addr = ChannelAddr::any(ChannelTransport::MetaTls(TlsMode::IpV6));
2446 let meta_addr = match addr {
2447 ChannelAddr::MetaTls(meta_addr) => meta_addr,
2448 _ => panic!("expected MetaTls address"),
2449 };
2450 let (local_addr, mut rx) =
2451 server::serve::<u64>(ChannelAddr::MetaTls(meta_addr), None).unwrap();
2452 {
2453 let tx = channel::dial::<u64>(local_addr.clone()).unwrap();
2454 tx.post(123);
2455 }
2456 assert_eq!(rx.recv().await.unwrap(), 123);
2457
2458 {
2459 let tx = channel::dial::<u64>(local_addr.clone()).unwrap();
2460 tx.post(321);
2461 tx.post(111);
2462 tx.post(444);
2463 assert_eq!(rx.recv().await.unwrap(), 321);
2464 assert_eq!(rx.recv().await.unwrap(), 111);
2465 assert_eq!(rx.recv().await.unwrap(), 444);
2466 }
2467
2468 {
2469 let tx = channel::dial::<u64>(local_addr).unwrap();
2470 drop(rx);
2471
2472 let (return_tx, return_rx) = oneshot::channel();
2473 tx.try_post(123, return_tx);
2474 assert_matches!(
2475 return_rx.await,
2476 Ok(SendError {
2477 error: ChannelError::Closed,
2478 message: 123,
2479 ..
2480 })
2481 );
2482 }
2483 }
2484
2485 #[derive(Clone, Debug, Default)]
2486 struct NetworkFlakiness {
2487 disconnect_params: Option<(f64, u64, Duration)>,
2495 latency_range: Option<(Duration, Duration)>,
2498 }
2499
2500 impl NetworkFlakiness {
2501 async fn should_disconnect(
2503 &self,
2504 rng: &mut impl rand::Rng,
2505 disconnected_count: u64,
2506 prev_disconnected_at: &RwLock<Instant>,
2507 ) -> bool {
2508 let Some((prob, max_disconnects, duration)) = &self.disconnect_params else {
2509 return false;
2510 };
2511
2512 let disconnected_at = prev_disconnected_at.read().unwrap();
2513 if disconnected_at.elapsed() > *duration && disconnected_count < *max_disconnects {
2514 rng.random_bool(*prob)
2515 } else {
2516 false
2517 }
2518 }
2519 }
2520
2521 struct MockLink<M> {
2522 buffer_size: usize,
2523 session_id: SessionId,
2524 receiver_storage: Arc<MVar<DuplexStream>>,
2525 fail_connects: Arc<AtomicBool>,
2527 disconnect_signal: watch::Sender<()>,
2530 network_flakiness: NetworkFlakiness,
2531 disconnected_count: Arc<AtomicU64>,
2532 prev_disconnected_at: Arc<RwLock<Instant>>,
2533 debug_log_sampling_rate: Option<u64>,
2536 _message_type: PhantomData<M>,
2537 }
2538
2539 impl<M> fmt::Debug for MockLink<M> {
2540 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2541 f.debug_struct("MockLink")
2542 .field("buffer_size", &self.buffer_size)
2543 .field("receiver_storage", &"<MVar<DuplexStream>>")
2544 .field("fail_connects", &self.fail_connects)
2545 .field("disconnect_signal", &"<watch::Sender>")
2546 .field("network_flakiness", &self.network_flakiness)
2547 .field("disconnected_count", &self.disconnected_count)
2548 .field("prev_disconnected_at", &"<RwLock<Instant>>")
2549 .field("debug_log_sampling_rate", &self.debug_log_sampling_rate)
2550 .finish()
2551 }
2552 }
2553
2554 impl<M: RemoteMessage> MockLink<M> {
2555 fn new() -> Self {
2556 let (sender, _) = watch::channel(());
2557 Self {
2558 buffer_size: 64,
2559 session_id: SessionId::random(),
2560 receiver_storage: Arc::new(MVar::empty()),
2561 fail_connects: Arc::new(AtomicBool::new(false)),
2562 disconnect_signal: sender,
2563 network_flakiness: NetworkFlakiness::default(),
2564 disconnected_count: Arc::new(AtomicU64::new(0)),
2565 prev_disconnected_at: Arc::new(RwLock::new(tokio::time::Instant::now())),
2566 debug_log_sampling_rate: None,
2567 _message_type: PhantomData,
2568 }
2569 }
2570
2571 fn fail_connects() -> Self {
2574 Self {
2575 fail_connects: Arc::new(AtomicBool::new(true)),
2576 ..Self::new()
2577 }
2578 }
2579
2580 fn with_network_flakiness(network_flakiness: NetworkFlakiness) -> Self {
2581 if let Some((min, max)) = network_flakiness.latency_range {
2582 assert!(min < max);
2583 }
2584
2585 Self {
2586 network_flakiness,
2587 ..Self::new()
2588 }
2589 }
2590
2591 fn receiver_storage(&self) -> Arc<MVar<DuplexStream>> {
2592 self.receiver_storage.clone()
2593 }
2594
2595 fn disconnected_count(&self) -> Arc<AtomicU64> {
2596 self.disconnected_count.clone()
2597 }
2598
2599 fn disconnect_signal(&self) -> &watch::Sender<()> {
2600 &self.disconnect_signal
2601 }
2602
2603 fn fail_connects_switch(&self) -> Arc<AtomicBool> {
2604 self.fail_connects.clone()
2605 }
2606
2607 fn set_buffer_size(&mut self, size: usize) {
2608 self.buffer_size = size;
2609 }
2610
2611 fn set_sampling_rate(&mut self, sampling_rate: u64) {
2612 self.debug_log_sampling_rate = Some(sampling_rate);
2613 }
2614 }
2615
2616 #[async_trait]
2617 impl<M: RemoteMessage> Link for MockLink<M> {
2618 type Stream = DuplexStream;
2619
2620 fn dest(&self) -> ChannelAddr {
2621 ChannelAddr::Local(u64::MAX)
2622 }
2623
2624 fn link_id(&self) -> SessionId {
2625 self.session_id
2626 }
2627
2628 async fn next(&mut self) -> Result<Self::Stream, ClientError> {
2629 let session_id = self.session_id;
2630 tracing::debug!("MockLink starts to connect.");
2631 if self.fail_connects.load(Ordering::Acquire) {
2632 return Err(ClientError::Connect(
2633 self.dest(),
2634 std::io::Error::other("intentional error"),
2635 "expected failure injected by the mock".to_string(),
2636 ));
2637 }
2638
2639 async fn relay_message<M: RemoteMessage>(
2645 mut disconnect_signal: watch::Receiver<()>,
2646 network_flakiness: NetworkFlakiness,
2647 disconnected_count: Arc<AtomicU64>,
2648 prev_disconnected_at: Arc<RwLock<Instant>>,
2649 mut reader: FrameReader<ReadHalf<DuplexStream>>,
2650 mut writer: WriteHalf<DuplexStream>,
2651 task_coordination_token: CancellationToken,
2654 debug_log_sampling_rate: Option<u64>,
2655 is_from_client: bool,
2658 ) {
2659 async fn wait_for_latency_elapse(
2663 queue: &VecDeque<(Bytes, Instant)>,
2664 network_flakiness: &NetworkFlakiness,
2665 rng: &mut impl rand::Rng,
2666 ) {
2667 if let Some((min, max)) = network_flakiness.latency_range {
2668 let diff = max.abs_diff(min);
2669 let factor = rng.random_range(0.0..=1.0);
2670 let latency = min + diff.mul_f64(factor);
2671 tokio::time::sleep_until(queue.front().unwrap().1 + latency).await;
2672 }
2673 }
2674
2675 let mut rng = rand::rngs::SmallRng::try_from_rng(&mut SysRng).unwrap();
2676 let mut queue: VecDeque<(Bytes, Instant)> = VecDeque::new();
2677 let mut send_count = 0u64;
2678
2679 loop {
2680 tokio::select! {
2681 read_res = reader.next() => {
2682 match read_res {
2683 Ok(Some((_, data))) => {
2684 queue.push_back((data, tokio::time::Instant::now()));
2685 }
2686 Ok(None) | Err(_) => {
2687 tracing::debug!("The upstream is closed or dropped. MockLink disconnects");
2688 break;
2689 }
2690 }
2691 }
2692 _ = wait_for_latency_elapse(&queue, &network_flakiness, &mut rng), if !queue.is_empty() => {
2693 let count = disconnected_count.load(Ordering::Relaxed);
2694 if network_flakiness.should_disconnect(&mut rng, count, &prev_disconnected_at).await {
2695 tracing::debug!("MockLink disconnects");
2696 disconnected_count.fetch_add(1, Ordering::Relaxed);
2697
2698 metrics::CHANNEL_RECONNECTIONS.add(
2699 1,
2700 hyperactor_telemetry::kv_pairs!(
2701 "transport" => "mock",
2702 "reason" => "network_flakiness",
2703 ),
2704 );
2705
2706 let mut w = prev_disconnected_at.write().unwrap();
2707 *w = tokio::time::Instant::now();
2708 break;
2709 }
2710 let data = queue.pop_front().unwrap().0;
2711 let is_sampled = debug_log_sampling_rate.is_some_and(|sample_rate| send_count % sample_rate == 1);
2712 if is_sampled {
2713 if is_from_client {
2714 if let Ok((Frame::Message(_seq, _msg), _)) = bincode::serde::decode_from_slice::<Frame<M>, _>(&data, bincode::config::legacy()) {
2715 tracing::debug!("MockLink relays a msg from client. msg type: {}", std::any::type_name::<M>());
2716 }
2717 } else {
2718 let result = deserialize_response(data.clone());
2719 if let Ok(NetRxResponse::Ack(seq)) = result {
2720 tracing::debug!("MockLink relays an ack from server. seq: {}", seq);
2721 }
2722 }
2723 }
2724 let mut fw = FrameWrite::new(writer, data, hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH), 0).unwrap();
2725 if fw.send().await.is_err() {
2726 break;
2727 }
2728 writer = fw.complete();
2729 send_count += 1;
2730 }
2731 _ = task_coordination_token.cancelled() => break,
2732
2733 changed = disconnect_signal.changed() => {
2734 tracing::debug!("MockLink disconnects per disconnect_signal {:?}", changed);
2735 break;
2736 }
2737 }
2738 }
2739
2740 task_coordination_token.cancel();
2741 }
2742
2743 let (server, mut server_relay) = tokio::io::duplex(self.buffer_size);
2744 let (client, client_relay) = tokio::io::duplex(self.buffer_size);
2745
2746 write_link_init(&mut server_relay, session_id, 0)
2750 .await
2751 .map_err(|err| ClientError::Io(self.dest(), err))?;
2752
2753 let (server_r, server_writer) = tokio::io::split(server_relay);
2754 let (client_r, client_writer) = tokio::io::split(client_relay);
2755
2756 let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
2757 let server_reader = FrameReader::new(server_r, max_len);
2758 let client_reader = FrameReader::new(client_r, max_len);
2759
2760 let task_coordination_token = CancellationToken::new();
2761 let _server_relay_task_handle = tokio::spawn(relay_message::<M>(
2762 self.disconnect_signal.subscribe(),
2763 self.network_flakiness.clone(),
2764 self.disconnected_count.clone(),
2765 self.prev_disconnected_at.clone(),
2766 server_reader,
2767 client_writer,
2768 task_coordination_token.clone(),
2769 self.debug_log_sampling_rate,
2770 false,
2771 ));
2772 let _client_relay_task_handle = tokio::spawn(relay_message::<M>(
2773 self.disconnect_signal.subscribe(),
2774 self.network_flakiness.clone(),
2775 self.disconnected_count.clone(),
2776 self.prev_disconnected_at.clone(),
2777 client_reader,
2778 server_writer,
2779 task_coordination_token,
2780 self.debug_log_sampling_rate,
2781 true,
2782 ));
2783
2784 self.receiver_storage.put(server).await;
2785 Ok(client)
2786 }
2787 }
2788
2789 struct MockLinkListener {
2790 receiver_storage: Arc<MVar<DuplexStream>>,
2791 channel_addr: ChannelAddr,
2792 }
2793
2794 impl MockLinkListener {
2795 fn new(receiver_storage: Arc<MVar<DuplexStream>>, channel_addr: ChannelAddr) -> Self {
2796 Self {
2797 receiver_storage,
2798 channel_addr,
2799 }
2800 }
2801 }
2802
2803 impl fmt::Debug for MockLinkListener {
2804 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2805 f.debug_struct("MockLinkListener")
2806 .field("channel_addr", &self.channel_addr)
2807 .finish()
2808 }
2809 }
2810
2811 #[async_trait]
2812 impl super::Listener for MockLinkListener {
2813 type Stream = DuplexStream;
2814
2815 async fn accept(&mut self) -> Result<(Self::Stream, ChannelAddr), ServerError> {
2816 let stream = self.receiver_storage.take().await;
2817 Ok((stream, self.channel_addr.clone()))
2818 }
2819 }
2820
2821 fn serve_acceptor_test<M: RemoteMessage>(
2825 session_id: SessionId,
2826 ) -> (
2827 JoinHandle<()>,
2828 mpsc::UnboundedSender<DuplexStream>,
2829 mpsc::Receiver<M>,
2830 CancellationToken,
2831 ) {
2832 let (acceptor_tx, acceptor_rx) = mpsc::unbounded_channel::<DuplexStream>();
2833 let cancel_token = CancellationToken::new();
2834 let link = AcceptorLink {
2835 dest: ChannelAddr::Local(u64::MAX),
2836 session_id,
2837 stream: acceptor_rx,
2838 cancel: cancel_token.clone(),
2839 };
2840 let (tx, rx) = mpsc::channel::<M>(1024);
2841 let ct = cancel_token.clone();
2842 let handle = tokio::spawn(async move {
2843 let mut session = Session::new(link);
2844 let mut next = session::Next { seq: 0, ack: 0 };
2845
2846 loop {
2847 let connected = match session.connect().await {
2848 Ok(s) => s,
2849 Err(_) => break,
2850 };
2851
2852 let result = {
2853 let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2854 tokio::select! {
2855 r = session::recv_connected::<M, _, _>(&stream, &tx, &mut next) => r,
2856 _ = ct.cancelled() => Err(session::RecvLoopError::Cancelled),
2857 }
2858 };
2859
2860 if next.ack < next.seq {
2862 let ack = serialize_response(NetRxResponse::Ack(next.seq - 1)).unwrap();
2863 let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2864 let mut completion = stream.write(ack);
2865 match completion.drive().await {
2866 Ok(()) => {
2867 next.ack = next.seq;
2868 }
2869 Err(e) => {
2870 tracing::debug!(
2871 error = %e,
2872 "failed to flush acks during cleanup"
2873 );
2874 }
2875 }
2876 }
2877
2878 let terminal_response = match &result {
2880 Err(session::RecvLoopError::SequenceError(reason)) => {
2881 Some(NetRxResponse::Reject(reason.clone()))
2882 }
2883 Err(session::RecvLoopError::Cancelled) => Some(NetRxResponse::Closed),
2884 _ => None,
2885 };
2886 if let Some(rsp) = terminal_response {
2887 let data = serialize_response(rsp).unwrap();
2888 let stream = connected.stream(INITIATOR_TO_ACCEPTOR);
2889 let mut completion = stream.write(data);
2890 let _ = completion.drive().await;
2891 }
2892
2893 let recoverable = matches!(&result, Ok(()) | Err(session::RecvLoopError::Io(_)));
2894 session = connected.release();
2895 if recoverable {
2896 continue;
2897 }
2898 break;
2899 }
2900 });
2901 (handle, acceptor_tx, rx, cancel_token)
2902 }
2903
2904 async fn write_stream<M, W>(
2905 mut writer: W,
2906 _session_id: u64,
2907 messages: &[(u64, M)],
2908 _init: bool,
2909 ) -> W
2910 where
2911 M: RemoteMessage + PartialEq + Clone,
2912 W: AsyncWrite + Unpin,
2913 {
2914 for (seq, message) in messages {
2915 let message =
2916 serde_multipart::serialize_bincode(&Frame::<M>::Message(*seq, message.clone()))
2917 .unwrap();
2918 let mut fw = FrameWrite::new(
2919 writer,
2920 message.framed(),
2921 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2922 0,
2923 )
2924 .map_err(|(_w, e)| e)
2925 .unwrap();
2926 fw.send().await.unwrap();
2927 writer = fw.complete();
2928 }
2929
2930 writer
2931 }
2932
2933 #[async_timed_test(timeout_secs = 60)]
2934 async fn test_persistent_server_session() {
2935 let config = hyperactor_config::global::lock();
2936 let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
2937
2938 async fn verify_ack(reader: &mut FrameReader<ReadHalf<DuplexStream>>, expected_last: u64) {
2939 let mut last_acked: i128 = -1;
2940 loop {
2941 let (_, bytes) = reader.next().await.unwrap().unwrap();
2942 let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
2943 assert!(
2944 acked as i128 > last_acked,
2945 "acks should be delivered in ascending order"
2946 );
2947 last_acked = acked as i128;
2948 assert!(acked <= expected_last);
2949 if acked == expected_last {
2950 break;
2951 }
2952 }
2953 }
2954
2955 let session_id = SessionId(123);
2956 let (_handle, acceptor_tx, mut rx, cancel_token) = serve_acceptor_test::<u64>(session_id);
2957
2958 {
2960 let (sender, receiver) = tokio::io::duplex(5000);
2961 acceptor_tx.send(receiver).unwrap();
2962
2963 let (r, writer) = tokio::io::split(sender);
2964 let mut reader = FrameReader::new(
2965 r,
2966 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
2967 );
2968
2969 let _writer = write_stream(
2970 writer,
2971 123,
2972 &[
2973 (0u64, 100u64),
2974 (1u64, 101u64),
2975 (2u64, 102u64),
2976 (3u64, 103u64),
2977 ],
2978 true,
2979 )
2980 .await;
2981
2982 assert_eq!(rx.recv().await, Some(100));
2983 assert_eq!(rx.recv().await, Some(101));
2984 assert_eq!(rx.recv().await, Some(102));
2985 assert_eq!(rx.recv().await, Some(103));
2986
2987 verify_ack(&mut reader, 3).await;
2988 }
2990
2991 {
2993 let (sender2, receiver2) = tokio::io::duplex(5000);
2994 acceptor_tx.send(receiver2).unwrap();
2995
2996 let (r2, writer2) = tokio::io::split(sender2);
2997 let mut reader2 = FrameReader::new(
2998 r2,
2999 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3000 );
3001
3002 let _ = write_stream(
3003 writer2,
3004 123,
3005 &[
3006 (2u64, 102u64),
3007 (3u64, 103u64),
3008 (4u64, 104u64),
3009 (5u64, 105u64),
3010 ],
3011 true,
3012 )
3013 .await;
3014
3015 assert_eq!(rx.recv().await, Some(104));
3017 assert_eq!(rx.recv().await, Some(105));
3018
3019 verify_ack(&mut reader2, 5).await;
3020
3021 cancel_token.cancel();
3022 }
3023 }
3024
3025 #[async_timed_test(timeout_secs = 60)]
3026 async fn test_ack_from_server_session() {
3027 let config = hyperactor_config::global::lock();
3028 let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
3029 let session_id = SessionId(123);
3030 let (_handle, acceptor_tx, mut rx, cancel_token) = serve_acceptor_test::<u64>(session_id);
3031
3032 let (sender, receiver) = tokio::io::duplex(5000);
3033 acceptor_tx.send(receiver).unwrap();
3034 let (r, mut writer) = tokio::io::split(sender);
3035 let mut reader = FrameReader::new(
3036 r,
3037 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3038 );
3039
3040 for i in 0u64..100u64 {
3041 writer = write_stream(writer, 123, &[(i, 100u64 + i)], i == 0u64).await;
3042 assert_eq!(rx.recv().await, Some(100u64 + i));
3043 let (_, bytes) = reader.next().await.unwrap().unwrap();
3044 let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
3045 assert_eq!(acked, i);
3046 }
3047
3048 tokio::time::sleep(Duration::from_secs(5)).await;
3050
3051 cancel_token.cancel();
3052
3053 let (_, bytes) = reader.next().await.unwrap().unwrap();
3055 assert!(deserialize_response(bytes).unwrap().is_closed());
3056 }
3057
3058 #[tracing_test::traced_test]
3059 async fn verify_tx_closed(tx_status: &mut watch::Receiver<TxStatus>, expected_log: &str) {
3060 match tokio::time::timeout(Duration::from_secs(5), tx_status.changed()).await {
3061 Ok(Ok(())) => {
3062 let current_status = tx_status.borrow().clone();
3063 assert!(current_status.is_closed());
3064 logs_assert_unscoped(|logs| {
3065 if logs.iter().any(|log| log.contains(expected_log)) {
3066 Ok(())
3067 } else {
3068 Err("expected log not found".to_string())
3069 }
3070 });
3071 }
3072 Ok(Err(_)) => panic!("watch::Receiver::changed() failed because sender is dropped."),
3073 Err(_) => panic!("timeout before tx_status changed"),
3074 }
3075 }
3076
3077 #[tracing_test::traced_test]
3078 #[tokio::test]
3079 #[cfg_attr(not(fbcode_build), ignore)]
3081 async fn test_tcp_tx_delivery_timeout() {
3082 let link = MockLink::<u64>::fail_connects();
3084 let tx = spawn::<u64>(link);
3085 let config = hyperactor_config::global::lock();
3087 let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1));
3088 let mut tx_receiver = tx.status().clone();
3089 let (return_channel, _return_receiver) = oneshot::channel();
3090 tx.try_post(123, return_channel);
3091 verify_tx_closed(&mut tx_receiver, "failed to deliver message within timeout").await;
3092 }
3093
3094 async fn take_receiver(
3095 receiver_storage: &MVar<DuplexStream>,
3096 ) -> (FrameReader<ReadHalf<DuplexStream>>, WriteHalf<DuplexStream>) {
3097 let mut receiver = receiver_storage.take().await;
3098 let _link_init = read_link_init(&mut receiver).await.expect("read LinkInit");
3100 let (r, writer) = tokio::io::split(receiver);
3101 let reader = FrameReader::new(
3102 r,
3103 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3104 );
3105 (reader, writer)
3106 }
3107
3108 async fn verify_message<M: RemoteMessage + PartialEq + std::fmt::Debug>(
3109 reader: &mut FrameReader<ReadHalf<DuplexStream>>,
3110 expect: (u64, M),
3111 loc: u32,
3112 ) {
3113 let expected = Frame::Message(expect.0, expect.1);
3114 let (_, bytes) = reader.next().await.unwrap().expect("unexpected EOF");
3115 let message = serde_multipart::Message::from_framed(bytes).unwrap();
3116 let frame: Frame<M> = serde_multipart::deserialize_bincode(message).unwrap();
3117
3118 assert_eq!(frame, expected, "from ln={loc}");
3119 }
3120
3121 async fn verify_stream<M: RemoteMessage + PartialEq + std::fmt::Debug + Clone>(
3122 reader: &mut FrameReader<ReadHalf<DuplexStream>>,
3123 expects: &[(u64, M)],
3124 _expect_session_id: Option<u64>,
3125 loc: u32,
3126 ) {
3127 for expect in expects {
3128 verify_message(reader, expect.clone(), loc).await;
3129 }
3130 }
3131
3132 async fn net_tx_send(tx: &NetTx<u64>, msgs: &[u64]) {
3133 for msg in msgs {
3134 tx.post(*msg);
3135 }
3136 }
3137
3138 #[async_timed_test(timeout_secs = 30)]
3140 async fn test_ack_in_net_tx_basic() {
3141 let link = MockLink::<u64>::new();
3142 let receiver_storage = link.receiver_storage();
3143 let tx = spawn::<u64>(link);
3144
3145 net_tx_send(&tx, &[100, 101, 102, 103, 104]).await;
3147 {
3148 let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
3149 verify_stream(
3150 &mut reader,
3151 &[
3152 (0u64, 100u64),
3153 (1u64, 101u64),
3154 (2u64, 102u64),
3155 (3u64, 103u64),
3156 (4u64, 104u64),
3157 ],
3158 None,
3159 line!(),
3160 )
3161 .await;
3162
3163 for i in 0u64..5u64 {
3164 writer = FrameWrite::write_frame(
3165 writer,
3166 serialize_response(NetRxResponse::Ack(i)).unwrap(),
3167 1024,
3168 0,
3169 )
3170 .await
3171 .map_err(|(_, e)| e)
3172 .unwrap();
3173 }
3174 tokio::time::sleep(Duration::from_secs(3)).await;
3176 drop(reader);
3178 drop(writer);
3179 };
3180
3181 net_tx_send(&tx, &[105u64]).await;
3183 {
3184 let (mut reader, _writer) = take_receiver(&receiver_storage).await;
3185 verify_stream(&mut reader, &[(5u64, 105u64)], None, line!()).await;
3186 };
3188 }
3189
3190 #[async_timed_test(timeout_secs = 60)]
3192 async fn test_persistent_net_tx() {
3193 let link = MockLink::<u64>::new();
3194 let receiver_storage = link.receiver_storage();
3195
3196 let tx = spawn::<u64>(link);
3197
3198 net_tx_send(&tx, &[100, 101, 102, 103, 104]).await;
3200
3201 let n = 3;
3205
3206 for i in 0..n {
3209 {
3210 let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
3211 verify_stream(
3212 &mut reader,
3213 &[
3214 (0u64, 100u64),
3215 (1u64, 101u64),
3216 (2u64, 102u64),
3217 (3u64, 103u64),
3218 (4u64, 104u64),
3219 ],
3220 None,
3221 line!(),
3222 )
3223 .await;
3224
3225 if i == n - 1 {
3228 writer = FrameWrite::write_frame(
3229 writer,
3230 serialize_response(NetRxResponse::Ack(1)).unwrap(),
3231 1024,
3232 0,
3233 )
3234 .await
3235 .map_err(|(_, e)| e)
3236 .unwrap();
3237 tokio::time::sleep(Duration::from_secs(3)).await;
3239 }
3240 drop(reader);
3242 drop(writer);
3243 };
3244 }
3245
3246 for _ in 0..n {
3248 {
3249 let (mut reader, mut _writer) = take_receiver(&receiver_storage).await;
3250 verify_stream(
3251 &mut reader,
3252 &[(2u64, 102u64), (3u64, 103u64), (4u64, 104u64)],
3253 None,
3254 line!(),
3255 )
3256 .await;
3257 };
3259 }
3260
3261 net_tx_send(&tx, &[105u64, 106u64, 107u64, 108u64, 109u64]).await;
3263 for i in 0..n {
3266 {
3267 let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
3268 verify_stream(
3269 &mut reader,
3270 &[
3271 (2u64, 102u64),
3273 (3u64, 103u64),
3274 (4u64, 104u64),
3275 (5u64, 105u64),
3277 (6u64, 106u64),
3278 (7u64, 107u64),
3279 (8u64, 108u64),
3280 (9u64, 109u64),
3281 ],
3282 None,
3283 line!(),
3284 )
3285 .await;
3286
3287 if i == n - 1 {
3290 writer = FrameWrite::write_frame(
3293 writer,
3294 serialize_response(NetRxResponse::Ack(1)).unwrap(),
3295 1024,
3296 0,
3297 )
3298 .await
3299 .map_err(|(_, e)| e)
3300 .unwrap();
3301 writer = FrameWrite::write_frame(
3302 writer,
3303 serialize_response(NetRxResponse::Ack(2)).unwrap(),
3304 1024,
3305 0,
3306 )
3307 .await
3308 .map_err(|(_, e)| e)
3309 .unwrap();
3310 writer = FrameWrite::write_frame(
3311 writer,
3312 serialize_response(NetRxResponse::Ack(3)).unwrap(),
3313 1024,
3314 0,
3315 )
3316 .await
3317 .map_err(|(_, e)| e)
3318 .unwrap();
3319 tokio::time::sleep(Duration::from_secs(3)).await;
3321 }
3322 drop(reader);
3324 drop(writer);
3325 };
3326 }
3327
3328 for i in 0..n {
3329 {
3330 let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
3331 verify_stream(
3332 &mut reader,
3333 &[
3334 (4u64, 104),
3336 (5u64, 105u64),
3338 (6u64, 106u64),
3339 (7u64, 107u64),
3340 (8u64, 108u64),
3341 (9u64, 109u64),
3342 ],
3343 None,
3344 line!(),
3345 )
3346 .await;
3347
3348 if i == n - 1 {
3350 writer = FrameWrite::write_frame(
3351 writer,
3352 serialize_response(NetRxResponse::Ack(7)).unwrap(),
3353 1024,
3354 0,
3355 )
3356 .await
3357 .map_err(|(_, e)| e)
3358 .unwrap();
3359 tokio::time::sleep(Duration::from_secs(3)).await;
3361 }
3362 drop(reader);
3364 drop(writer);
3365 };
3366 }
3367
3368 for _ in 0..n {
3369 {
3370 let (mut reader, writer) = take_receiver(&receiver_storage).await;
3371 verify_stream(
3372 &mut reader,
3373 &[
3374 (8u64, 108u64),
3376 (9u64, 109u64),
3377 ],
3378 None,
3379 line!(),
3380 )
3381 .await;
3382 drop(reader);
3384 drop(writer);
3385 };
3386 }
3387 }
3388
3389 #[async_timed_test(timeout_secs = 15)]
3390 async fn test_ack_before_redelivery_in_net_tx() {
3391 let link = MockLink::<u64>::new();
3392 let receiver_storage = link.receiver_storage();
3393 let net_tx = spawn::<u64>(link);
3394
3395 let (return_channel_tx, return_channel_rx) = oneshot::channel();
3398 net_tx.try_post(100, return_channel_tx);
3399 let (mut reader, mut writer) = take_receiver(&receiver_storage).await;
3400 verify_stream(&mut reader, &[(0u64, 100u64)], None, line!()).await;
3401 writer = FrameWrite::write_frame(
3403 writer,
3404 serialize_response(NetRxResponse::Ack(0)).unwrap(),
3405 1024,
3406 0,
3407 )
3408 .await
3409 .map_err(|(_, e)| e)
3410 .unwrap();
3411 assert!(return_channel_rx.await.is_err());
3416
3417 let _ = FrameWrite::write_frame(
3422 writer,
3423 serialize_response(NetRxResponse::Ack(1)).unwrap(),
3424 1024,
3425 0,
3426 )
3427 .await
3428 .map_err(|(_, e)| e)
3429 .unwrap();
3430
3431 let (return_channel_tx, return_channel_rx) = oneshot::channel();
3432 net_tx.try_post(101, return_channel_tx);
3433 verify_message(&mut reader, (1u64, 101u64), line!()).await;
3435 assert!(return_channel_rx.await.is_err());
3442 }
3443
3444 async fn verify_ack_exceeded_limit(disconnect_before_ack: bool) {
3445 let config = hyperactor_config::global::lock();
3447 let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(2));
3448
3449 let link: MockLink<u64> = MockLink::<u64>::new();
3450 let disconnect_signal = link.disconnect_signal().clone();
3451 let fail_connect_switch = link.fail_connects_switch();
3452 let receiver_storage = link.receiver_storage();
3453 let tx = spawn::<u64>(link);
3454 let mut tx_status = tx.status().clone();
3455 tx.post(100);
3457 let (mut reader, writer) = take_receiver(&receiver_storage).await;
3458 verify_stream(&mut reader, &[(0u64, 100u64)], None, line!()).await;
3460 let _ = FrameWrite::write_frame(
3462 writer,
3463 serialize_response(NetRxResponse::Ack(0)).unwrap(),
3464 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3465 0,
3466 )
3467 .await
3468 .map_err(|(_, e)| e)
3469 .unwrap();
3470 tokio::time::sleep(Duration::from_secs(3)).await;
3471 assert!(!tx_status.has_changed().unwrap());
3473 assert_eq!(*tx_status.borrow(), TxStatus::Active);
3474
3475 tx.post(101);
3476 verify_message(&mut reader, (1u64, 101u64), line!()).await;
3478
3479 if disconnect_before_ack {
3480 fail_connect_switch.store(true, Ordering::Release);
3482 disconnect_signal.send(()).unwrap();
3484 }
3485
3486 let expected_log: &str = if disconnect_before_ack {
3488 "failed to receive ack within timeout 2s; link is currently broken"
3489 } else {
3490 "failed to receive ack within timeout 2s; link is currently connected"
3491 };
3492
3493 verify_tx_closed(&mut tx_status, expected_log).await;
3494 }
3495
3496 #[tracing_test::traced_test]
3497 #[async_timed_test(timeout_secs = 30)]
3498 #[cfg_attr(not(fbcode_build), ignore)]
3500 async fn test_ack_exceeded_limit_with_connected_link() {
3501 verify_ack_exceeded_limit(false).await;
3502 }
3503
3504 #[tracing_test::traced_test]
3505 #[async_timed_test(timeout_secs = 30)]
3506 #[cfg_attr(not(fbcode_build), ignore)]
3508 async fn test_ack_exceeded_limit_with_broken_link() {
3509 verify_ack_exceeded_limit(true).await;
3510 }
3511
3512 #[async_timed_test(timeout_secs = 60)]
3515 async fn test_network_flakiness_in_channel() {
3516 hyperactor_telemetry::initialize_logging_for_test();
3517
3518 let sampling_rate = 100;
3519 let mut link = MockLink::<u64>::with_network_flakiness(NetworkFlakiness {
3520 disconnect_params: Some((0.001, 15, Duration::from_millis(400))),
3521 latency_range: Some((Duration::from_millis(100), Duration::from_millis(200))),
3522 });
3523 link.set_sampling_rate(sampling_rate);
3524 link.set_buffer_size(1024000);
3526 let disconnected_count = link.disconnected_count();
3527 let receiver_storage = link.receiver_storage();
3528 let listener = MockLinkListener::new(receiver_storage.clone(), link.dest());
3529 let local_addr = listener.channel_addr.clone();
3530 let (_, mut nx): (ChannelAddr, NetRx<u64>) =
3531 super::server::serve_with_listener(listener, local_addr).unwrap();
3532 let tx = spawn::<u64>(link);
3533 let messages: Vec<_> = (0..10001).collect();
3534 let messages_clone = messages.clone();
3535 let send_task_handle = tokio::spawn(async move {
3538 for message in messages_clone {
3539 tokio::time::sleep(Duration::from_micros(rand::random::<u64>() % 100)).await;
3551 tx.post(message);
3552 }
3553 tracing::debug!("NetTx sent all messages");
3554 tx
3557 });
3558
3559 for message in &messages {
3560 if message % sampling_rate == 0 {
3561 tracing::debug!("NetRx received a message: {message}");
3562 }
3563 assert_eq!(nx.recv().await.unwrap(), *message);
3564 }
3565 tracing::debug!("NetRx received all messages");
3566
3567 let send_result = send_task_handle.await;
3568 assert!(send_result.is_ok());
3569
3570 tracing::debug!(
3571 "MockLink disconnected {} times.",
3572 disconnected_count.load(Ordering::SeqCst)
3573 );
3574 }
3577
3578 #[async_timed_test(timeout_secs = 60)]
3579 async fn test_ack_every_n_messages() {
3580 let config = hyperactor_config::global::lock();
3581 let _guard_message_ack = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 600);
3582 let _guard_time_interval =
3583 config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(1000));
3584 sparse_ack().await;
3585 }
3586
3587 #[async_timed_test(timeout_secs = 60)]
3588 async fn test_ack_every_time_interval() {
3589 let config = hyperactor_config::global::lock();
3590 let _guard_message_ack =
3591 config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 100000000);
3592 let _guard_time_interval = config.override_key(
3593 config::MESSAGE_ACK_TIME_INTERVAL,
3594 Duration::from_millis(500),
3595 );
3596 sparse_ack().await;
3597 }
3598
3599 async fn sparse_ack() {
3600 let mut link = MockLink::<u64>::new();
3601 link.set_buffer_size(1024000);
3603 let disconnected_count = link.disconnected_count();
3604 let receiver_storage = link.receiver_storage();
3605 let listener = MockLinkListener::new(receiver_storage.clone(), link.dest());
3606 let local_addr = listener.channel_addr.clone();
3607 let (_, mut nx): (ChannelAddr, NetRx<u64>) =
3608 super::server::serve_with_listener(listener, local_addr).unwrap();
3609 let tx = spawn::<u64>(link);
3610 let messages: Vec<_> = (0..20001).collect();
3611 let messages_clone = messages.clone();
3612 let send_task_handle = tokio::spawn(async move {
3615 for message in messages_clone {
3616 tokio::time::sleep(Duration::from_micros(rand::random::<u64>() % 100)).await;
3617 tx.post(message);
3618 }
3619 tokio::time::sleep(Duration::from_secs(5)).await;
3620 tracing::debug!("NetTx sent all messages");
3621 tx
3622 });
3623
3624 for message in &messages {
3625 assert_eq!(nx.recv().await.unwrap(), *message);
3626 }
3627 tracing::debug!("NetRx received all messages");
3628
3629 let send_result = send_task_handle.await;
3630 assert!(send_result.is_ok());
3631
3632 tracing::debug!(
3633 "MockLink disconnected {} times.",
3634 disconnected_count.load(Ordering::SeqCst)
3635 );
3636 }
3637
3638 #[test]
3639 fn test_metatls_parsing() {
3640 let channel: ChannelAddr = "metatls!localhost:1234".parse().unwrap();
3642 assert_eq!(
3643 channel,
3644 ChannelAddr::MetaTls(TlsAddr::new("localhost", 1234))
3645 );
3646 let channel: ChannelAddr = "metatls!1.2.3.4:1234".parse().unwrap();
3648 assert_eq!(channel, ChannelAddr::MetaTls(TlsAddr::new("1.2.3.4", 1234)));
3649 let channel: ChannelAddr = "metatls!2401:db00:33c:6902:face:0:2a2:0:1234"
3651 .parse()
3652 .unwrap();
3653 assert_eq!(
3654 channel,
3655 ChannelAddr::MetaTls(TlsAddr::new("2401:db00:33c:6902:face:0:2a2:0", 1234))
3656 );
3657
3658 let channel: ChannelAddr = "metatls![::]:1234".parse().unwrap();
3659 assert_eq!(channel, ChannelAddr::MetaTls(TlsAddr::new("::", 1234)));
3660 }
3661
3662 #[async_timed_test(timeout_secs = 300)]
3663 #[cfg_attr(not(fbcode_build), ignore)]
3665 async fn test_tcp_throughput() {
3666 let config = hyperactor_config::global::lock();
3667 let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_mins(5));
3668
3669 let socket_addr: SocketAddr = "[::1]:0".parse().unwrap();
3670 let (local_addr, mut rx) =
3671 server::serve::<String>(ChannelAddr::Tcp(socket_addr), None).unwrap();
3672
3673 let total_num_msgs = 500000;
3675
3676 let receive_handle = tokio::spawn(async move {
3677 let mut num = 0;
3678 for _ in 0..10 * total_num_msgs {
3679 rx.recv().await.unwrap();
3680 num += 1;
3681
3682 if num % 100000 == 0 {
3683 tracing::info!("total number of received messages: {}", num);
3684 }
3685 }
3686 });
3687
3688 let mut tx_handles = vec![];
3689 let mut txs = vec![];
3690 for _ in 0..10 {
3691 let server_addr = local_addr.clone();
3692 let tx = Arc::new(channel::dial::<String>(server_addr).unwrap());
3693 let tx2 = Arc::clone(&tx);
3694 txs.push(tx);
3695 tx_handles.push(tokio::spawn(async move {
3696 let random_string = rand::rng()
3697 .sample_iter(&Alphanumeric)
3698 .take(2048)
3699 .map(char::from)
3700 .collect::<String>();
3701 for _ in 0..total_num_msgs {
3702 tx2.post(random_string.clone());
3703 }
3704 }));
3705 }
3706
3707 receive_handle.await.unwrap();
3708 for handle in tx_handles {
3709 handle.await.unwrap();
3710 }
3711 }
3712
3713 #[tracing_test::traced_test]
3714 #[async_timed_test(timeout_secs = 60)]
3715 #[cfg_attr(not(fbcode_build), ignore)]
3717 async fn test_net_tx_closed_on_server_reject() {
3718 let link = MockLink::<u64>::new();
3719 let receiver_storage = link.receiver_storage();
3720 let mut tx = spawn::<u64>(link);
3721 net_tx_send(&tx, &[100]).await;
3722
3723 {
3724 let (_reader, writer) = take_receiver(&receiver_storage).await;
3725 let _ = FrameWrite::write_frame(
3726 writer,
3727 serialize_response(NetRxResponse::Reject("testing".to_string())).unwrap(),
3728 1024,
3729 0,
3730 )
3731 .await
3732 .map_err(|(_, e)| e);
3733
3734 tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
3737 }
3738
3739 verify_tx_closed(&mut tx.status, "server rejected connection").await;
3740 }
3741
3742 #[async_timed_test(timeout_secs = 60)]
3743 async fn test_server_rejects_conn_on_out_of_sequence_message() {
3744 let config = hyperactor_config::global::lock();
3745 let _guard = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1);
3746 let session_id = SessionId(123);
3747 let (_handle, acceptor_tx, mut rx, _cancel_token) = serve_acceptor_test::<u64>(session_id);
3748
3749 let (sender, receiver) = tokio::io::duplex(5000);
3750 acceptor_tx.send(receiver).unwrap();
3751 let (r, writer) = tokio::io::split(sender);
3752 let mut reader = FrameReader::new(
3753 r,
3754 hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH),
3755 );
3756
3757 let _ = write_stream(writer, 123, &[(0, 100u64), (1, 101u64), (3, 103u64)], true).await;
3758 assert_eq!(rx.recv().await, Some(100u64));
3759 assert_eq!(rx.recv().await, Some(101u64));
3760 let (_, bytes) = reader.next().await.unwrap().unwrap();
3761 let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
3762 assert_eq!(acked, 0);
3763 let (_, bytes) = reader.next().await.unwrap().unwrap();
3764 let acked = deserialize_response(bytes).unwrap().into_ack().unwrap();
3765 assert_eq!(acked, 1);
3766 let (_, bytes) = reader.next().await.unwrap().unwrap();
3767 assert!(deserialize_response(bytes).unwrap().is_reject());
3768 }
3769
3770 #[async_timed_test(timeout_secs = 60)]
3771 #[cfg_attr(not(fbcode_build), ignore)]
3773 async fn test_stop_net_tx_after_stopping_net_rx() {
3774 hyperactor_telemetry::initialize_logging_for_test();
3775
3776 let config = hyperactor_config::global::lock();
3777 let _guard = config.override_key(config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_mins(5));
3778 let (addr, mut rx) =
3779 server::serve::<u64>(ChannelAddr::Tcp("[::1]:0".parse().unwrap()), None).unwrap();
3780 let socket_addr = match addr {
3781 ChannelAddr::Tcp(a) => a,
3782 _ => panic!("unexpected channel type"),
3783 };
3784 let tx: NetTx<u64> = spawn(tcp::link(socket_addr, SessionId::random(), 0));
3785 tx.send(100).await.unwrap();
3790 assert_eq!(rx.recv().await.unwrap(), 100);
3791 rx.2.stop("testing");
3793 assert!(rx.recv().await.is_err());
3794
3795 tx.post(101);
3798 let mut watcher = tx.status().clone();
3799 let _ = watcher.wait_for(|val| val.is_closed()).await;
3801 assert!(watcher.borrow().is_closed());
3805 }
3806
3807 struct QueueListener {
3811 streams: std::collections::VecDeque<DuplexStream>,
3812 addr: ChannelAddr,
3813 }
3814
3815 #[async_trait]
3816 impl super::Listener for QueueListener {
3817 type Stream = DuplexStream;
3818
3819 async fn accept(&mut self) -> Result<(DuplexStream, ChannelAddr), ServerError> {
3820 match self.streams.pop_front() {
3821 Some(s) => Ok((s, self.addr.clone())),
3822 None => std::future::pending().await,
3823 }
3824 }
3825 }
3826
3827 struct PreparedConnection {
3830 server_side: DuplexStream,
3831 _client_w: tokio::io::WriteHalf<DuplexStream>,
3835 client_r: ReadHalf<DuplexStream>,
3836 }
3837
3838 async fn prepare_connection(
3841 session_id: SessionId,
3842 stream_id: u8,
3843 messages: &[(u64, u64)],
3844 ) -> PreparedConnection {
3845 let (client_side, server_side) = tokio::io::duplex(8192);
3846 let (client_r, mut client_w) = tokio::io::split(client_side);
3847
3848 super::write_link_init(&mut client_w, session_id, stream_id)
3849 .await
3850 .unwrap();
3851 let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
3852 for (seq, value) in messages {
3853 let payload =
3854 serde_multipart::serialize_bincode(&Frame::<u64>::Message(*seq, *value)).unwrap();
3855 let mut fw = FrameWrite::new(client_w, payload.framed(), max_len, 0)
3856 .map_err(|(_w, e)| e)
3857 .unwrap();
3858 fw.send().await.unwrap();
3859 client_w = fw.complete();
3860 }
3861
3862 PreparedConnection {
3863 server_side,
3864 _client_w: client_w,
3865 client_r,
3866 }
3867 }
3868
3869 struct SeparateSessionPlan {
3872 session_id: SessionId,
3873 stream_id: u8,
3874 messages: Vec<(u64, u64)>,
3875 }
3876
3877 async fn run_separate_sessions_flush_test(
3888 plans: Vec<SeparateSessionPlan>,
3889 stream_id_label: &str,
3890 ) {
3891 let config = hyperactor_config::global::lock();
3892 let _g_msg = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1_000_000);
3893 let _g_time =
3894 config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(3600));
3895
3896 let mut conns: Vec<PreparedConnection> = Vec::with_capacity(plans.len());
3898 let mut expected_messages: std::collections::HashSet<u64> =
3899 std::collections::HashSet::new();
3900 let mut expected_acks: Vec<u64> = Vec::with_capacity(plans.len());
3901 for plan in &plans {
3902 for (_seq, value) in &plan.messages {
3903 expected_messages.insert(*value);
3904 }
3905 expected_acks.push(plan.messages.iter().map(|(s, _)| *s).max().unwrap());
3906 conns.push(prepare_connection(plan.session_id, plan.stream_id, &plan.messages).await);
3907 }
3908
3909 let addr = ChannelAddr::Local(u64::MAX);
3910 let listener = QueueListener {
3911 streams: conns
3912 .iter_mut()
3913 .map(|c| {
3914 std::mem::replace(&mut c.server_side, tokio::io::duplex(1).0)
3916 })
3917 .collect(),
3918 addr: addr.clone(),
3919 };
3920 let (_addr, mut rx) = super::server::serve_with_listener::<u64, _>(listener, addr).unwrap();
3921
3922 let mut received: std::collections::HashSet<u64> = std::collections::HashSet::new();
3924 for _ in 0..expected_messages.len() {
3925 received.insert(rx.recv().await.unwrap());
3926 }
3927 assert_eq!(
3928 received, expected_messages,
3929 "{stream_id_label}: every produced message should reach the application"
3930 );
3931
3932 tokio::time::sleep(Duration::from_millis(100)).await;
3939
3940 let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
3941 let mut readers: Vec<FrameReader<ReadHalf<DuplexStream>>> = conns
3942 .into_iter()
3943 .map(|c| FrameReader::new(c.client_r, max_len))
3944 .collect();
3945 for (idx, reader) in readers.iter_mut().enumerate() {
3946 match tokio::time::timeout(Duration::from_millis(10), reader.next()).await {
3947 Err(_) => {} Ok(Err(e)) => panic!(
3949 "{stream_id_label}: connection {idx} frame reader error before rx.join: {e}"
3950 ),
3951 Ok(Ok(None)) => {
3952 panic!("{stream_id_label}: connection {idx} closed before rx.join()")
3953 }
3954 Ok(Ok(Some((_, bytes)))) => {
3955 let resp = super::deserialize_response(bytes).unwrap();
3956 panic!(
3957 "{stream_id_label}: connection {idx} unexpectedly received {resp:?} \
3958 before rx.join()"
3959 );
3960 }
3961 }
3962 }
3963
3964 rx.join().await;
3965
3966 for (idx, (reader, expected_ack)) in readers.iter_mut().zip(&expected_acks).enumerate() {
3970 let bytes = tokio::time::timeout(Duration::from_millis(50), reader.next())
3971 .await
3972 .unwrap_or_else(|_| {
3973 panic!(
3974 "{stream_id_label}: connection {idx} produced no Ack frame within 50ms \
3975 after rx.join()"
3976 )
3977 })
3978 .expect("frame reader error")
3979 .expect("frame reader returned None");
3980 let acked = super::deserialize_response(bytes.1)
3981 .unwrap()
3982 .into_ack()
3983 .unwrap_or_else(|other| {
3984 panic!("{stream_id_label}: connection {idx} expected Ack, got {other:?}")
3985 });
3986 assert_eq!(
3987 acked, *expected_ack,
3988 "{stream_id_label}: connection {idx} ack mismatch"
3989 );
3990
3991 let bytes = tokio::time::timeout(Duration::from_millis(50), reader.next())
3992 .await
3993 .unwrap_or_else(|_| {
3994 panic!(
3995 "{stream_id_label}: connection {idx} produced no Closed frame within 50ms"
3996 )
3997 })
3998 .expect("frame reader error")
3999 .expect("frame reader returned None");
4000 assert!(
4001 super::deserialize_response(bytes.1).unwrap().is_closed(),
4002 "{stream_id_label}: connection {idx} expected Closed terminal frame"
4003 );
4004 }
4005 }
4006
4007 #[async_timed_test(timeout_secs = 30)]
4008 async fn rx_join_flushes_pending_ack_single_stream() {
4009 let plans = (1u64..=3)
4013 .map(|sid| SeparateSessionPlan {
4014 session_id: SessionId(sid),
4015 stream_id: 0,
4016 messages: (0u64..3).map(|seq| (seq, sid * 100 + seq)).collect(),
4017 })
4018 .collect();
4019 run_separate_sessions_flush_test(plans, "single-stream").await;
4020 }
4021
4022 #[async_timed_test(timeout_secs = 30)]
4023 async fn rx_join_flushes_pending_ack_multi_stream() {
4024 let plans = (1u64..=3)
4029 .map(|sid| SeparateSessionPlan {
4030 session_id: SessionId(sid),
4031 stream_id: 1,
4032 messages: (0u64..3).map(|seq| (seq, sid * 100 + seq)).collect(),
4033 })
4034 .collect();
4035 run_separate_sessions_flush_test(plans, "multi-stream").await;
4036 }
4037
4038 #[async_timed_test(timeout_secs = 30)]
4046 async fn rx_join_flushes_pending_ack_shared_multi_stream_session() {
4047 let config = hyperactor_config::global::lock();
4048 let _g_msg = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1_000_000);
4049 let _g_time =
4050 config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(3600));
4051
4052 let session_id = SessionId(99);
4053 let num_streams = 3u8;
4054 let msgs_per_stream = 3u64;
4055 let mut conns: Vec<PreparedConnection> = Vec::with_capacity(num_streams as usize);
4056 let mut expected_messages: std::collections::HashSet<u64> =
4057 std::collections::HashSet::new();
4058 for stream_id in 1..=num_streams {
4059 let messages: Vec<(u64, u64)> = (0u64..msgs_per_stream)
4060 .map(|i| {
4061 let seq = (stream_id as u64 - 1) * msgs_per_stream + i;
4062 (seq, 1000 + seq)
4063 })
4064 .collect();
4065 for (_, v) in &messages {
4066 expected_messages.insert(*v);
4067 }
4068 conns.push(prepare_connection(session_id, stream_id, &messages).await);
4069 }
4070 let highest_seq = num_streams as u64 * msgs_per_stream - 1;
4071
4072 let addr = ChannelAddr::Local(u64::MAX);
4073 let listener = QueueListener {
4074 streams: conns
4075 .iter_mut()
4076 .map(|c| std::mem::replace(&mut c.server_side, tokio::io::duplex(1).0))
4077 .collect(),
4078 addr: addr.clone(),
4079 };
4080 let (_addr, mut rx) = super::server::serve_with_listener::<u64, _>(listener, addr).unwrap();
4081
4082 let mut received: std::collections::HashSet<u64> = std::collections::HashSet::new();
4083 for _ in 0..expected_messages.len() {
4084 received.insert(rx.recv().await.unwrap());
4085 }
4086 assert_eq!(
4087 received, expected_messages,
4088 "shared-session multi-stream: every message reaches the application"
4089 );
4090
4091 tokio::time::sleep(Duration::from_millis(100)).await;
4092
4093 let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
4094 let mut readers: Vec<FrameReader<ReadHalf<DuplexStream>>> = conns
4095 .into_iter()
4096 .map(|c| FrameReader::new(c.client_r, max_len))
4097 .collect();
4098 for (idx, reader) in readers.iter_mut().enumerate() {
4099 match tokio::time::timeout(Duration::from_millis(10), reader.next()).await {
4100 Err(_) | Ok(Ok(None)) => {} Ok(Err(e)) => {
4102 panic!("shared-session multi-stream: stream {idx} frame reader error: {e}")
4103 }
4104 Ok(Ok(Some((_, bytes)))) => {
4105 let resp = super::deserialize_response(bytes).unwrap();
4106 panic!(
4107 "shared-session multi-stream: stream {idx} unexpectedly received \
4108 {resp:?} before rx.join()"
4109 );
4110 }
4111 }
4112 }
4113
4114 rx.join().await;
4115
4116 let mut ack_count = 0;
4121 let mut closed_count = 0;
4122 for (idx, reader) in readers.iter_mut().enumerate() {
4123 loop {
4124 match tokio::time::timeout(Duration::from_millis(50), reader.next()).await {
4125 Err(_) => panic!(
4126 "shared-session multi-stream: stream {idx} did not yield expected \
4127 frames within 50ms after rx.join()"
4128 ),
4129 Ok(Err(e)) => panic!("frame reader error: {e}"),
4130 Ok(Ok(None)) => break,
4131 Ok(Ok(Some((_, bytes)))) => {
4132 let resp = super::deserialize_response(bytes).unwrap();
4133 match resp {
4134 NetRxResponse::Ack(seq) => {
4135 assert_eq!(
4136 seq, highest_seq,
4137 "shared-session multi-stream: ack should cover the full \
4138 contiguous range 0..={highest_seq}"
4139 );
4140 ack_count += 1;
4141 }
4142 NetRxResponse::Closed => {
4143 closed_count += 1;
4144 break;
4145 }
4146 other => panic!(
4147 "shared-session multi-stream: stream {idx} unexpected {other:?}"
4148 ),
4149 }
4150 }
4151 }
4152 }
4153 }
4154 assert!(
4155 ack_count >= 1,
4156 "shared-session multi-stream: expected at least one Ack({highest_seq}); \
4157 got {ack_count}"
4158 );
4159 assert!(
4160 ack_count <= num_streams as usize,
4161 "shared-session multi-stream: expected at most {num_streams} Ack({highest_seq}) \
4162 frames; got {ack_count}"
4163 );
4164 assert_eq!(
4165 closed_count, num_streams as usize,
4166 "shared-session multi-stream: every stream should emit its own Closed frame"
4167 );
4168 }
4169
4170 #[async_timed_test(timeout_secs = 30)]
4177 async fn server_join_flushes_pending_ack_duplex_session() {
4178 let config = hyperactor_config::global::lock();
4179 let _g_msg = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1_000_000);
4180 let _g_time =
4181 config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(3600));
4182
4183 let session_count = 3u64;
4184 let msgs_per_session = 3u64;
4185 let mut conns: Vec<PreparedConnection> = Vec::with_capacity(session_count as usize);
4186 let mut expected_messages: std::collections::HashSet<u64> =
4187 std::collections::HashSet::new();
4188 let mut expected_acks: Vec<u64> = Vec::with_capacity(session_count as usize);
4189 for sid in 1..=session_count {
4190 let messages: Vec<(u64, u64)> = (0u64..msgs_per_session)
4191 .map(|seq| (seq, sid * 100 + seq))
4192 .collect();
4193 for (_, v) in &messages {
4194 expected_messages.insert(*v);
4195 }
4196 expected_acks.push(messages.iter().map(|(s, _)| *s).max().unwrap());
4197 conns.push(prepare_connection(SessionId(sid), 0, &messages).await);
4198 }
4199
4200 let addr = ChannelAddr::Local(u64::MAX);
4201 let listener = QueueListener {
4202 streams: conns
4203 .iter_mut()
4204 .map(|c| std::mem::replace(&mut c.server_side, tokio::io::duplex(1).0))
4205 .collect(),
4206 addr: addr.clone(),
4207 };
4208 let mut server = super::duplex::serve_with_listener::<u64, u64, _>(listener, addr).unwrap();
4209
4210 let mut all_rx: Vec<super::duplex::DuplexRx<u64>> =
4216 Vec::with_capacity(session_count as usize);
4217 let mut all_tx: Vec<super::duplex::DuplexTx<u64>> =
4218 Vec::with_capacity(session_count as usize);
4219 for _ in 0..session_count {
4220 let (rx, tx) = server.accept().await.unwrap();
4221 all_rx.push(rx);
4222 all_tx.push(tx);
4223 }
4224
4225 let mut received: std::collections::HashSet<u64> = std::collections::HashSet::new();
4230 for rx in all_rx.iter_mut() {
4231 for _ in 0..msgs_per_session {
4232 received.insert(rx.recv().await.unwrap());
4233 }
4234 }
4235 assert_eq!(
4236 received, expected_messages,
4237 "duplex: every produced message should reach the application"
4238 );
4239
4240 tokio::time::sleep(Duration::from_millis(100)).await;
4244
4245 let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
4246 let mut readers: Vec<FrameReader<ReadHalf<DuplexStream>>> = conns
4247 .into_iter()
4248 .map(|c| FrameReader::new(c.client_r, max_len))
4249 .collect();
4250 for (idx, reader) in readers.iter_mut().enumerate() {
4251 match tokio::time::timeout(Duration::from_millis(10), reader.next()).await {
4252 Err(_) => {} Ok(Err(e)) => {
4254 panic!("duplex: connection {idx} frame reader error before join: {e}")
4255 }
4256 Ok(Ok(None)) => {
4257 panic!("duplex: connection {idx} closed before server.join()")
4258 }
4259 Ok(Ok(Some((_, bytes)))) => {
4260 let resp = super::deserialize_response(bytes).unwrap();
4261 panic!(
4262 "duplex: connection {idx} unexpectedly received {resp:?} \
4263 before server.join()"
4264 );
4265 }
4266 }
4267 }
4268
4269 server.join().await;
4275
4276 for (idx, (reader, expected_ack)) in readers.iter_mut().zip(&expected_acks).enumerate() {
4281 let bytes = tokio::time::timeout(Duration::from_millis(50), reader.next())
4282 .await
4283 .unwrap_or_else(|_| {
4284 panic!(
4285 "duplex: connection {idx} produced no Ack frame within 50ms after \
4286 server.join()"
4287 )
4288 })
4289 .expect("frame reader error")
4290 .expect("frame reader returned None");
4291 let acked = super::deserialize_response(bytes.1)
4292 .unwrap()
4293 .into_ack()
4294 .unwrap_or_else(|other| {
4295 panic!("duplex: connection {idx} expected Ack, got {other:?}")
4296 });
4297 assert_eq!(
4298 acked, *expected_ack,
4299 "duplex: connection {idx} ack mismatch"
4300 );
4301
4302 let bytes = tokio::time::timeout(Duration::from_millis(50), reader.next())
4303 .await
4304 .unwrap_or_else(|_| {
4305 panic!("duplex: connection {idx} produced no Closed frame within 50ms")
4306 })
4307 .expect("frame reader error")
4308 .expect("frame reader returned None");
4309 assert!(
4310 super::deserialize_response(bytes.1).unwrap().is_closed(),
4311 "duplex: connection {idx} expected Closed terminal frame"
4312 );
4313 }
4314 }
4315
4316 struct DuplexDialMockLink {
4321 session_id: SessionId,
4322 streams: std::collections::VecDeque<DuplexStream>,
4323 }
4324
4325 impl fmt::Debug for DuplexDialMockLink {
4326 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4327 f.debug_struct("DuplexDialMockLink")
4328 .field("session_id", &self.session_id)
4329 .field("remaining_streams", &self.streams.len())
4330 .finish()
4331 }
4332 }
4333
4334 #[async_trait]
4335 impl super::Link for DuplexDialMockLink {
4336 type Stream = DuplexStream;
4337
4338 fn dest(&self) -> ChannelAddr {
4339 ChannelAddr::Local(u64::MAX)
4340 }
4341
4342 fn link_id(&self) -> SessionId {
4343 self.session_id
4344 }
4345
4346 async fn next(&mut self) -> Result<DuplexStream, ClientError> {
4347 match self.streams.pop_front() {
4348 Some(mut stream) => {
4349 super::write_link_init(&mut stream, self.session_id, 0)
4350 .await
4351 .map_err(|err| ClientError::Io(self.dest(), err))?;
4352 Ok(stream)
4353 }
4354 None => Err(ClientError::Connect(
4355 self.dest(),
4356 std::io::Error::other("mock link exhausted"),
4357 "no more streams".into(),
4358 )),
4359 }
4360 }
4361 }
4362
4363 #[async_timed_test(timeout_secs = 30)]
4372 async fn duplex_serve_flushes_pending_ack_on_app_closed() {
4373 let config = hyperactor_config::global::lock();
4374 let _g_msg = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1_000_000);
4375 let _g_time =
4376 config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(3600));
4377
4378 let session_id = SessionId(1);
4379 let messages: Vec<(u64, u64)> = vec![(0, 100), (1, 200), (2, 300)];
4380 let expected_ack: u64 = 2;
4381 let conn = prepare_connection(session_id, 0, &messages).await;
4382
4383 let addr = ChannelAddr::Local(u64::MAX);
4384 let listener = QueueListener {
4385 streams: std::collections::VecDeque::from([conn.server_side]),
4386 addr: addr.clone(),
4387 };
4388
4389 let mut server = super::duplex::serve_with_listener::<u64, u64, _>(listener, addr).unwrap();
4390
4391 let (mut server_rx, server_tx) = server.accept().await.unwrap();
4392
4393 let mut received: Vec<u64> = Vec::with_capacity(messages.len());
4395 for _ in &messages {
4396 received.push(server_rx.recv().await.unwrap());
4397 }
4398 let expected_values: Vec<u64> = messages.iter().map(|(_, v)| *v).collect();
4399 assert_eq!(
4400 received, expected_values,
4401 "duplex serve: every message should reach the application"
4402 );
4403
4404 let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
4407 tokio::time::sleep(Duration::from_millis(100)).await;
4408 let mut reader = FrameReader::new(conn.client_r, max_len);
4409 match tokio::time::timeout(Duration::from_millis(10), reader.next()).await {
4410 Err(_) => {} Ok(Err(e)) => panic!("duplex serve: frame reader error before app close: {e}"),
4412 Ok(Ok(None)) => panic!("duplex serve: wire closed before app close"),
4413 Ok(Ok(Some((_, bytes)))) => {
4414 let resp = super::deserialize_response(bytes).unwrap();
4415 panic!("duplex serve: unexpectedly received {resp:?} before app close");
4416 }
4417 }
4418
4419 drop(server_tx);
4425
4426 let bytes = tokio::time::timeout(Duration::from_millis(100), reader.next())
4427 .await
4428 .unwrap_or_else(|_| panic!("duplex serve: produced no Ack frame within 100ms"))
4429 .expect("frame reader error")
4430 .expect("frame reader returned None");
4431 let acked = super::deserialize_response(bytes.1)
4432 .unwrap()
4433 .into_ack()
4434 .unwrap_or_else(|other| panic!("duplex serve: expected Ack, got {other:?}"));
4435 assert_eq!(
4436 acked, expected_ack,
4437 "duplex serve: ack should cover the highest seq received"
4438 );
4439
4440 let bytes = tokio::time::timeout(Duration::from_millis(100), reader.next())
4441 .await
4442 .unwrap_or_else(|_| panic!("duplex serve: produced no Closed frame within 100ms"))
4443 .expect("frame reader error")
4444 .expect("frame reader returned None");
4445 assert!(
4446 super::deserialize_response(bytes.1).unwrap().is_closed(),
4447 "duplex serve: expected Closed terminal frame after Ack"
4448 );
4449
4450 drop(server_rx);
4451 server.join().await;
4456 }
4457
4458 #[async_timed_test(timeout_secs = 30)]
4464 async fn duplex_client_join_flushes_pending_ack() {
4465 let config = hyperactor_config::global::lock();
4466 let _g_msg = config.override_key(config::MESSAGE_ACK_EVERY_N_MESSAGES, 1_000_000);
4467 let _g_time =
4468 config.override_key(config::MESSAGE_ACK_TIME_INTERVAL, Duration::from_secs(3600));
4469
4470 let session_id = SessionId(123);
4471 let (client_side, server_side) = tokio::io::duplex(8192);
4472 let (mut test_r, mut test_w) = tokio::io::split(server_side);
4473
4474 let link = DuplexDialMockLink {
4475 session_id,
4476 streams: std::collections::VecDeque::from([client_side]),
4477 };
4478
4479 let mut dial_client = super::duplex::spawn::<u64, u64>(link);
4480 let _dial_tx = dial_client.tx();
4481 let mut dial_rx = dial_client.take_rx().unwrap();
4482
4483 super::read_link_init(&mut test_r).await.unwrap();
4485
4486 let max_len = hyperactor_config::global::get(config::CODEC_MAX_FRAME_LENGTH);
4491 let messages: Vec<(u64, u64)> = vec![(0, 100), (1, 200), (2, 300)];
4492 let expected_ack: u64 = 2;
4493 for (seq, value) in &messages {
4494 let payload =
4495 serde_multipart::serialize_bincode(&Frame::<u64>::Message(*seq, *value)).unwrap();
4496 let mut fw = FrameWrite::new(
4497 test_w,
4498 payload.framed(),
4499 max_len,
4500 super::ACCEPTOR_TO_INITIATOR,
4501 )
4502 .map_err(|(_w, e)| e)
4503 .unwrap();
4504 fw.send().await.unwrap();
4505 test_w = fw.complete();
4506 }
4507
4508 let mut received: Vec<u64> = Vec::with_capacity(messages.len());
4511 for _ in &messages {
4512 received.push(dial_rx.recv().await.unwrap());
4513 }
4514 let expected_values: Vec<u64> = messages.iter().map(|(_, v)| *v).collect();
4515 assert_eq!(
4516 received, expected_values,
4517 "dial: every message should reach the application"
4518 );
4519
4520 tokio::time::sleep(Duration::from_millis(100)).await;
4523 let mut reader = FrameReader::new(test_r, max_len);
4524 match tokio::time::timeout(Duration::from_millis(10), reader.next()).await {
4525 Err(_) => {} Ok(Err(e)) => panic!("dial: frame reader error before join: {e}"),
4527 Ok(Ok(None)) => panic!("dial: wire closed before join"),
4528 Ok(Ok(Some((_, bytes)))) => {
4529 let resp = super::deserialize_response(bytes).unwrap();
4530 panic!("dial: unexpectedly received {resp:?} before join");
4531 }
4532 }
4533
4534 dial_client.join().await;
4540
4541 let bytes = tokio::time::timeout(Duration::from_millis(100), reader.next())
4542 .await
4543 .unwrap_or_else(|_| panic!("dial: produced no Ack frame within 100ms after join"))
4544 .expect("frame reader error")
4545 .expect("frame reader returned None");
4546 let acked = super::deserialize_response(bytes.1)
4547 .unwrap()
4548 .into_ack()
4549 .unwrap_or_else(|other| panic!("dial: expected Ack, got {other:?}"));
4550 assert_eq!(
4551 acked, expected_ack,
4552 "dial: ack should cover the highest seq received"
4553 );
4554
4555 let bytes = tokio::time::timeout(Duration::from_millis(100), reader.next())
4559 .await
4560 .unwrap_or_else(|_| panic!("dial: produced no Closed frame within 100ms after join"))
4561 .expect("frame reader error")
4562 .expect("frame reader returned None");
4563 assert!(
4564 super::deserialize_response(bytes.1).unwrap().is_closed(),
4565 "dial: expected Closed terminal frame after Ack"
4566 );
4567
4568 drop(dial_rx);
4569 drop(test_w);
4570 }
4571
4572 #[async_timed_test(timeout_secs = 30)]
4579 async fn duplex_client_join_terminates_in_progress_recv() {
4580 let session_id = SessionId(123);
4581 let (client_side, server_side) = tokio::io::duplex(8192);
4582 let (mut test_r, _test_w) = tokio::io::split(server_side);
4583
4584 let link = DuplexDialMockLink {
4585 session_id,
4586 streams: std::collections::VecDeque::from([client_side]),
4587 };
4588
4589 let mut dial_client = super::duplex::spawn::<u64, u64>(link);
4590 let mut dial_rx = dial_client.take_rx().unwrap();
4591
4592 super::read_link_init(&mut test_r).await.unwrap();
4595
4596 let recv_handle: tokio::task::JoinHandle<Result<u64, ChannelError>> =
4601 tokio::spawn(async move { dial_rx.recv().await });
4602
4603 dial_client.join().await;
4607
4608 let result = tokio::time::timeout(Duration::from_millis(100), recv_handle)
4609 .await
4610 .expect("parked recv should resolve within 100ms after join")
4611 .expect("recv task should not panic");
4612 assert!(
4613 matches!(result, Err(ChannelError::Closed)),
4614 "in-progress recv should resolve with ChannelError::Closed after join, got {result:?}"
4615 );
4616 }
4617}