1use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::OnceLock;
18use std::time::Duration;
19use std::time::Instant;
20
21use anyhow::Result;
22use async_trait::async_trait;
23use bytes::Bytes;
24use bytes::BytesMut;
25use dashmap::DashMap;
26use hyperactor::Actor;
27use hyperactor::ActorHandle;
28use hyperactor::ActorRef;
29use hyperactor::Context;
30use hyperactor::Endpoint as _;
31use hyperactor::HandleClient;
32use hyperactor::Handler;
33use hyperactor::Instance;
34use hyperactor::OncePortHandle;
35use hyperactor::OncePortRef;
36use hyperactor::PortHandle;
37use hyperactor::RefClient;
38use hyperactor::actor::ActorError;
39use hyperactor::channel;
40use hyperactor::channel::ChannelAddr;
41use hyperactor::channel::ChannelTx;
42use hyperactor::channel::Rx;
43use hyperactor::channel::Tx;
44use hyperactor::context;
45use hyperactor::context::Actor as _;
46use hyperactor_mesh::transport::default_bind_spec;
47use serde::Deserialize;
48use serde::Serialize;
49use serde_multipart::Part;
50use tokio::time::timeout as tokio_timeout;
51use tokio_util::sync::CancellationToken;
52use typeuri::Named;
53
54use super::TcpOp;
55use crate::RdmaOp;
56use crate::RdmaOpType;
57use crate::RdmaTransportLevel;
58use crate::backend::RdmaBackend;
59use crate::local_memory::KeepaliveLocalMemory;
60use crate::rdma_manager_actor::GetTcpActorRefClient;
61use crate::rdma_manager_actor::RdmaManagerActor;
62use crate::rdma_manager_actor::RdmaManagerMessageClient;
63
64#[derive(Debug, Clone, Serialize, Deserialize, Named)]
69pub struct TcpChunk(Part);
70wirevalue::register_type!(TcpChunk);
71
72#[derive(Debug, Clone, Serialize, Deserialize, Named)]
74struct TcpDataChunk {
75 transfer_id: usize,
77 offset: usize,
79 data: Part,
80}
81wirevalue::register_type!(TcpDataChunk);
82
83#[derive(Debug)]
88struct TransferState {
89 local_memory: Arc<KeepaliveLocalMemory>,
91
92 chunks_received: usize,
94
95 total_chunks: usize,
97
98 done: OncePortRef<Result<(), String>>,
101}
102
103impl TransferState {
104 fn new(
105 total_chunks: usize,
106 local_memory: Arc<KeepaliveLocalMemory>,
107 done: OncePortRef<Result<(), String>>,
108 ) -> Self {
109 Self {
110 local_memory,
111 chunks_received: 0,
112 total_chunks,
113 done,
114 }
115 }
116}
117
118#[derive(Debug, Serialize, Deserialize, Named)]
128struct SendTransferResult {
129 done: OncePortRef<Result<(), String>>,
130 result: Result<(), String>,
131}
132
133#[derive(Debug, Serialize, Deserialize, Named)]
138struct TransferError {
139 message: String,
140}
141
142#[derive(Debug)]
145struct RegisterTransferLocal {
146 local_memory: Arc<KeepaliveLocalMemory>,
147 total_chunks: usize,
148 done: OncePortRef<Result<(), String>>,
149 reply: OncePortHandle<usize>,
151}
152
153#[derive(Debug)]
156struct ExecuteTransferLocal {
157 transfer_id: usize,
158 local_memory: Arc<KeepaliveLocalMemory>,
159 chunk_size: usize,
160 dest_addr: ChannelAddr,
161}
162
163#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
168enum TcpManagerMessage {
169 WriteChunk {
171 buf_id: usize,
172 offset: usize,
173 data: Part,
174 #[reply]
175 reply: OncePortRef<Result<(), String>>,
176 },
177 ReadChunk {
179 buf_id: usize,
180 offset: usize,
181 size: usize,
182 #[reply]
183 reply: OncePortRef<Result<TcpChunk, String>>,
184 },
185 GetChannelAddress {
188 #[reply]
189 reply: OncePortRef<Option<ChannelAddr>>,
190 },
191 RegisterTransferRemote {
194 buf_id: usize,
195 total_chunks: usize,
196 done: OncePortRef<Result<(), String>>,
197 #[reply]
198 reply: OncePortRef<Result<usize, String>>,
199 },
200 ExecuteTransferRemote {
203 transfer_id: usize,
204 buf_id: usize,
205 chunk_size: usize,
206 dest_addr: ChannelAddr,
207 #[reply]
208 reply: OncePortRef<Result<(), String>>,
209 },
210}
211wirevalue::register_type!(TcpManagerMessage);
212
213#[derive(Debug)]
218#[hyperactor::export(
219 handlers = [TcpManagerMessage],
220)]
221pub struct TcpManagerActor {
222 owner: OnceLock<ActorHandle<RdmaManagerActor>>,
223 next_transfer_id: usize,
224 transfers: Arc<DashMap<usize, TransferState>>,
225 channel_addr: Option<ChannelAddr>,
228 outbound: HashMap<ChannelAddr, Vec<Arc<ChannelTx<TcpDataChunk>>>>,
230 cancel: CancellationToken,
232 receiver_done: Option<tokio::sync::oneshot::Receiver<()>>,
234}
235
236impl TcpManagerActor {
237 pub fn new() -> Self {
238 Self {
239 owner: OnceLock::new(),
240 next_transfer_id: 0,
241 transfers: Arc::new(DashMap::new()),
242 channel_addr: None,
243 outbound: HashMap::new(),
244 cancel: CancellationToken::new(),
245 receiver_done: None,
246 }
247 }
248
249 fn register_transfer(
250 &mut self,
251 local_memory: Arc<KeepaliveLocalMemory>,
252 total_chunks: usize,
253 done: OncePortRef<Result<(), String>>,
254 ) -> usize {
255 let transfer_id = self.next_transfer_id;
256 self.next_transfer_id += 1;
257 self.transfers.insert(
258 transfer_id,
259 TransferState::new(total_chunks, local_memory, done),
260 );
261 transfer_id
262 }
263
264 fn execute_transfer(
265 &mut self,
266 cx: &Context<Self>,
267 transfer_id: usize,
268 local_memory: Arc<KeepaliveLocalMemory>,
269 chunk_size: usize,
270 dest_addr: ChannelAddr,
271 ) -> Result<()> {
272 let parallelism =
273 hyperactor_config::global::get(crate::config::RDMA_TCP_FALLBACK_PARALLELISM);
274
275 if !self.outbound.contains_key(&dest_addr) {
276 let conns = (0..parallelism)
277 .map(|_| {
278 channel::dial::<TcpDataChunk>(dest_addr.clone())
279 .map(Arc::new)
280 .map_err(anyhow::Error::from)
281 })
282 .collect::<Result<Vec<_>>>()?;
283 self.outbound.insert(dest_addr.clone(), conns);
284 }
285 let conns = self.outbound.get(&dest_addr).unwrap();
286
287 let size = local_memory.size();
288 let total_chunks = size.div_ceil(chunk_size);
289
290 let chunk_index = Arc::new(std::sync::atomic::AtomicUsize::new(0));
291 let error_port: PortHandle<TransferError> = cx.port();
292 let proc = cx.instance().proc().clone();
293 let cancel = self.cancel.clone();
294
295 for conn in conns.clone() {
296 let mem = local_memory.clone();
297 let chunk_index = chunk_index.clone();
298 let error_port = error_port.clone();
299 let proc = proc.clone();
300 let cancel = cancel.clone();
301
302 tokio::spawn(async move {
303 let sender_name = format!(
304 "tcp_chunk_sender_{}",
305 hyperactor_mesh::shortuuid::ShortUuid::generate()
306 );
307 let (instance, _handle) = proc
308 .client(&sender_name)
309 .expect("failed to create sender instance");
310
311 loop {
312 if cancel.is_cancelled() {
313 return;
314 }
315
316 let idx = chunk_index.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
317 if idx >= total_chunks {
318 break;
319 }
320
321 let offset = idx * chunk_size;
322 let len = std::cmp::min(chunk_size, size - offset);
323 let mut buf = BytesMut::zeroed(len);
324 if let Err(e) = unsafe { mem.read_at(offset, &mut buf) } {
327 error_port.post(
328 &instance,
329 TransferError {
330 message: format!("read_at failed at offset {offset}: {e}"),
331 },
332 );
333 return;
334 }
335
336 let chunk = TcpDataChunk {
337 transfer_id,
338 offset,
339 data: Part::from(buf.freeze()),
340 };
341
342 if let Err(e) = conn.send(chunk).await {
343 error_port.post(
344 &instance,
345 TransferError {
346 message: format!("failed to send chunk at offset {offset}: {e}"),
347 },
348 );
349 return;
350 }
351 }
352 });
353 }
354
355 Ok(())
356 }
357
358 pub async fn local_handle(
361 client: &(impl context::Actor + Send + Sync),
362 ) -> Result<ActorHandle<Self>, anyhow::Error> {
363 let rdma_handle = RdmaManagerActor::local_handle(client);
364 let tcp_ref: ActorRef<TcpManagerActor> = rdma_handle.get_tcp_actor_ref(client).await?;
365 tcp_ref
366 .downcast_handle(client)
367 .ok_or_else(|| anyhow::anyhow!("TcpManagerActor is not in the local process"))
368 }
369}
370
371#[async_trait]
372impl Actor for TcpManagerActor {
373 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
374 let owner = this.parent_handle().ok_or_else(|| {
375 anyhow::anyhow!("RdmaManagerActor not found as parent of TcpManagerActor")
376 })?;
377 self.owner
378 .set(owner)
379 .map_err(|_| anyhow::anyhow!("TcpManagerActor owner already set"))?;
380
381 let parallelism =
382 hyperactor_config::global::get(crate::config::RDMA_TCP_FALLBACK_PARALLELISM);
383 if parallelism > 1 {
384 let addr = match default_bind_spec() {
385 channel::BindSpec::Any(transport) => ChannelAddr::any(transport),
386 channel::BindSpec::Addr(addr) => addr,
387 };
388 let (bound_addr, mut rx) = channel::serve::<TcpDataChunk>(addr)?;
389 self.channel_addr = Some(bound_addr);
390
391 let transfers = self.transfers.clone();
392 let proc = this.proc().clone();
393 let result_port: PortHandle<SendTransferResult> = this.port();
394 let error_port: PortHandle<TransferError> = this.port();
395 let cancel = self.cancel.clone();
396 let receiver_name = format!(
397 "tcp_chunk_receiver_{}",
398 hyperactor_mesh::shortuuid::ShortUuid::generate()
399 );
400
401 let (done_tx, done_rx) = tokio::sync::oneshot::channel::<()>();
402 self.receiver_done = Some(done_rx);
403
404 tokio::spawn(async move {
405 let (instance, _handle) = proc
406 .client(&receiver_name.to_string())
407 .expect("failed to create receiver instance");
408
409 loop {
410 let chunk = tokio::select! {
411 _ = cancel.cancelled() => break,
412 result = rx.recv() => match result {
413 Ok(chunk) => chunk,
414 Err(e) => {
415 error_port
416 .post(
417 &instance,
418 TransferError {
419 message: format!(
420 "parallel channel receive error: {e}"
421 ),
422 },
423 );
424 break;
425 }
426 },
427 };
428
429 let mut entry = match transfers.get_mut(&chunk.transfer_id) {
430 Some(entry) => entry,
431 None => {
432 tracing::warn!(
433 "received chunk for unknown transfer {:?}",
434 chunk.transfer_id,
435 );
436 continue;
437 }
438 };
439
440 let mut write_offset = chunk.offset;
441 let fragments = chunk.data.into_inner();
442 let write_err = fragments.iter().find_map(|fragment| {
443 let result = unsafe { entry.local_memory.write_at(write_offset, fragment) };
446 write_offset += fragment.len();
447 result.err()
448 });
449 if let Some(e) = write_err {
450 let transfer_id = chunk.transfer_id;
451 drop(entry);
452 let (_, state) = transfers.remove(&transfer_id).unwrap();
453 result_port.post(
454 &instance,
455 SendTransferResult {
456 done: state.done,
457 result: Err(e.to_string()),
458 },
459 );
460 continue;
461 }
462
463 entry.chunks_received += 1;
464 if entry.chunks_received == entry.total_chunks {
465 let transfer_id = chunk.transfer_id;
466 drop(entry);
467 let (_, state) = transfers.remove(&transfer_id).unwrap();
468 result_port.post(
469 &instance,
470 SendTransferResult {
471 done: state.done,
472 result: Ok(()),
473 },
474 );
475 }
476 }
477 rx.join().await;
478 done_tx.send(()).unwrap();
479 });
480 }
481
482 Ok(())
483 }
484
485 async fn cleanup(
486 &mut self,
487 _this: &Instance<Self>,
488 _err: Option<&ActorError>,
489 ) -> Result<(), anyhow::Error> {
490 self.cancel.cancel();
491 if let Some(done_rx) = self.receiver_done.take() {
492 done_rx.await?;
493 }
494 Ok(())
495 }
496}
497
498#[async_trait]
499#[hyperactor::handle(TcpManagerMessage)]
500impl TcpManagerMessageHandler for TcpManagerActor {
501 async fn write_chunk(
502 &mut self,
503 cx: &Context<Self>,
504 buf_id: usize,
505 offset: usize,
506 data: Part,
507 ) -> Result<Result<(), String>, anyhow::Error> {
508 let owner = self.owner.get().expect("TcpManagerActor owner not set");
509 let mem = match owner.request_local_memory(cx, buf_id).await {
510 Ok(Some(mem)) => mem,
511 Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
512 Err(e) => return Ok(Err(e.to_string())),
513 };
514
515 let bytes = data.into_bytes();
516 if let Err(e) = unsafe { mem.write_at(offset, &bytes) } {
521 return Ok(Err(e.to_string()));
522 }
523
524 Ok(Ok(()))
525 }
526
527 async fn read_chunk(
528 &mut self,
529 cx: &Context<Self>,
530 buf_id: usize,
531 offset: usize,
532 size: usize,
533 ) -> Result<Result<TcpChunk, String>, anyhow::Error> {
534 let owner = self.owner.get().expect("TcpManagerActor owner not set");
535 let mem = match owner.request_local_memory(cx, buf_id).await {
536 Ok(Some(mem)) => mem,
537 Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
538 Err(e) => return Ok(Err(e.to_string())),
539 };
540
541 let mut buf = BytesMut::zeroed(size);
542 if let Err(e) = unsafe { mem.read_at(offset, &mut buf) } {
547 return Ok(Err(e.to_string()));
548 }
549 Ok(Ok(TcpChunk(Part::from(buf.freeze()))))
550 }
551
552 async fn get_channel_address(
553 &mut self,
554 _cx: &Context<Self>,
555 ) -> Result<Option<ChannelAddr>, anyhow::Error> {
556 Ok(self.channel_addr.clone())
557 }
558
559 async fn register_transfer_remote(
560 &mut self,
561 cx: &Context<Self>,
562 buf_id: usize,
563 total_chunks: usize,
564 done: OncePortRef<Result<(), String>>,
565 ) -> Result<Result<usize, String>, anyhow::Error> {
566 let owner = self.owner.get().expect("TcpManagerActor owner not set");
567 let mem = match owner.request_local_memory(cx, buf_id).await {
568 Ok(Some(mem)) => mem,
569 Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
570 Err(e) => return Ok(Err(e.to_string())),
571 };
572 let transfer_id = self.register_transfer(mem, total_chunks, done);
573 Ok(Ok(transfer_id))
574 }
575
576 async fn execute_transfer_remote(
577 &mut self,
578 cx: &Context<Self>,
579 transfer_id: usize,
580 buf_id: usize,
581 chunk_size: usize,
582 dest_addr: ChannelAddr,
583 ) -> Result<Result<(), String>, anyhow::Error> {
584 let owner = self.owner.get().expect("TcpManagerActor owner not set");
585 let mem = match owner.request_local_memory(cx, buf_id).await {
586 Ok(Some(mem)) => mem,
587 Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
588 Err(e) => return Ok(Err(e.to_string())),
589 };
590 self.execute_transfer(cx, transfer_id, mem, chunk_size, dest_addr)?;
591 Ok(Ok(()))
592 }
593}
594
595#[async_trait]
596impl Handler<RegisterTransferLocal> for TcpManagerActor {
597 async fn handle(
598 &mut self,
599 cx: &Context<Self>,
600 message: RegisterTransferLocal,
601 ) -> Result<(), anyhow::Error> {
602 let transfer_id =
603 self.register_transfer(message.local_memory, message.total_chunks, message.done);
604 message.reply.post(cx, transfer_id);
605 Ok(())
606 }
607}
608
609#[async_trait]
610impl Handler<ExecuteTransferLocal> for TcpManagerActor {
611 async fn handle(
612 &mut self,
613 cx: &Context<Self>,
614 message: ExecuteTransferLocal,
615 ) -> Result<(), anyhow::Error> {
616 self.execute_transfer(
617 cx,
618 message.transfer_id,
619 message.local_memory,
620 message.chunk_size,
621 message.dest_addr,
622 )
623 }
624}
625
626#[async_trait]
627impl Handler<SendTransferResult> for TcpManagerActor {
628 async fn handle(
629 &mut self,
630 cx: &Context<Self>,
631 message: SendTransferResult,
632 ) -> Result<(), anyhow::Error> {
633 message.done.post(cx, message.result);
634 Ok(())
635 }
636}
637
638#[async_trait]
639impl Handler<TransferError> for TcpManagerActor {
640 async fn handle(
641 &mut self,
642 _cx: &Context<Self>,
643 message: TransferError,
644 ) -> Result<(), anyhow::Error> {
645 tracing::error!("fatal transfer error: {}", message.message);
646 Err(anyhow::anyhow!(message.message))
647 }
648}
649
650#[derive(Debug, Clone)]
658pub struct TcpBackend(pub ActorHandle<TcpManagerActor>);
659
660impl std::ops::Deref for TcpBackend {
661 type Target = ActorHandle<TcpManagerActor>;
662 fn deref(&self) -> &Self::Target {
663 &self.0
664 }
665}
666
667impl TcpBackend {
668 async fn execute_parallel_write(
671 &self,
672 cx: &(impl context::Actor + Send + Sync),
673 op: &TcpOp,
674 chunk_size: usize,
675 deadline: Instant,
676 ) -> Result<()> {
677 let size = op.local_memory.size();
678 let total_chunks = size.div_ceil(chunk_size);
679
680 let (done_handle, done_rx) = hyperactor::mailbox::open_once_port::<Result<(), String>>(cx);
681 let done_ref = done_handle.bind();
682
683 let remaining = deadline.saturating_duration_since(Instant::now());
684 let transfer_id = tokio_timeout(
685 remaining,
686 op.remote_tcp_manager.register_transfer_remote(
687 cx,
688 op.remote_buf_id,
689 total_chunks,
690 done_ref,
691 ),
692 )
693 .await
694 .map_err(|_| anyhow::anyhow!("register_transfer_remote timed out"))??
695 .map_err(|e| anyhow::anyhow!(e))?;
696
697 let dest_addr = tokio_timeout(
698 deadline.saturating_duration_since(Instant::now()),
699 op.remote_tcp_manager.get_channel_address(cx),
700 )
701 .await
702 .map_err(|_| anyhow::anyhow!("get_channel_address timed out"))??
703 .ok_or_else(|| anyhow::anyhow!("remote does not have parallel channels enabled"))?;
704
705 self.0.post(
706 cx,
707 ExecuteTransferLocal {
708 transfer_id,
709 local_memory: op.local_memory.clone(),
710 chunk_size,
711 dest_addr,
712 },
713 );
714
715 let remaining = deadline.saturating_duration_since(Instant::now());
716 let result = tokio_timeout(remaining, done_rx.recv())
717 .await
718 .map_err(|_| anyhow::anyhow!("parallel write timed out"))?
719 .map_err(|e| anyhow::anyhow!(e))?;
720 result.map_err(|e| anyhow::anyhow!(e))
721 }
722
723 async fn execute_parallel_read(
726 &self,
727 cx: &(impl context::Actor + Send + Sync),
728 op: &TcpOp,
729 chunk_size: usize,
730 deadline: Instant,
731 ) -> Result<()> {
732 let size = op.local_memory.size();
733 let total_chunks = size.div_ceil(chunk_size);
734
735 let (done_handle, done_rx) = hyperactor::mailbox::open_once_port::<Result<(), String>>(cx);
736 let done_ref = done_handle.bind();
737
738 let (id_handle, id_rx) = hyperactor::mailbox::open_once_port::<usize>(cx);
739
740 self.0.post(
741 cx,
742 RegisterTransferLocal {
743 local_memory: op.local_memory.clone(),
744 total_chunks,
745 done: done_ref,
746 reply: id_handle,
747 },
748 );
749
750 let transfer_id = id_rx
751 .recv()
752 .await
753 .map_err(|e| anyhow::anyhow!("failed to receive transfer id: {e}"))?;
754
755 let my_channel_addr = self
756 .0
757 .get_channel_address(cx)
758 .await?
759 .ok_or_else(|| anyhow::anyhow!("local parallel channels not enabled"))?;
760
761 let remaining = deadline.saturating_duration_since(Instant::now());
762 tokio_timeout(
763 remaining,
764 op.remote_tcp_manager.execute_transfer_remote(
765 cx,
766 transfer_id,
767 op.remote_buf_id,
768 chunk_size,
769 my_channel_addr,
770 ),
771 )
772 .await
773 .map_err(|_| anyhow::anyhow!("execute_transfer_remote timed out"))??
774 .map_err(|e| anyhow::anyhow!(e))?;
775
776 let remaining = deadline.saturating_duration_since(Instant::now());
777 let result = tokio_timeout(remaining, done_rx.recv())
778 .await
779 .map_err(|_| anyhow::anyhow!("parallel read timed out"))?
780 .map_err(|e| anyhow::anyhow!(e))?;
781 result.map_err(|e| anyhow::anyhow!(e))
782 }
783
784 async fn execute_write(
787 &self,
788 cx: &(impl context::Actor + Send + Sync),
789 op: &TcpOp,
790 chunk_size: usize,
791 deadline: Instant,
792 ) -> Result<()> {
793 let size = op.local_memory.size();
794 let mut offset = 0;
795
796 while offset < size {
797 let remaining = deadline.saturating_duration_since(Instant::now());
798 if remaining.is_zero() {
799 anyhow::bail!("tcp write timed out");
800 }
801
802 let len = std::cmp::min(chunk_size, size - offset);
803
804 let mut buf = vec![0u8; len];
805 unsafe { op.local_memory.read_at(offset, &mut buf) }?;
809 let data = Part::from(Bytes::from(buf));
810
811 tokio_timeout(
812 remaining,
813 op.remote_tcp_manager
814 .write_chunk(cx, op.remote_buf_id, offset, data),
815 )
816 .await
817 .map_err(|_| anyhow::anyhow!("tcp write chunk timed out"))??
818 .map_err(|e| anyhow::anyhow!(e))?;
819
820 offset += len;
821 }
822
823 Ok(())
824 }
825
826 async fn execute_read(
829 &self,
830 cx: &(impl context::Actor + Send + Sync),
831 op: &TcpOp,
832 chunk_size: usize,
833 deadline: Instant,
834 ) -> Result<()> {
835 let size = op.local_memory.size();
836 let mut offset = 0;
837
838 while offset < size {
839 let remaining = deadline.saturating_duration_since(Instant::now());
840 if remaining.is_zero() {
841 anyhow::bail!("tcp read timed out");
842 }
843
844 let len = std::cmp::min(chunk_size, size - offset);
845
846 let chunk = tokio_timeout(
847 remaining,
848 op.remote_tcp_manager
849 .read_chunk(cx, op.remote_buf_id, offset, len),
850 )
851 .await
852 .map_err(|_| anyhow::anyhow!("tcp read chunk timed out"))??
853 .map_err(|e| anyhow::anyhow!(e))?;
854 let data = chunk.0.into_bytes();
855
856 anyhow::ensure!(
857 data.len() == len,
858 "tcp read chunk size mismatch: expected {len}, got {}",
859 data.len()
860 );
861
862 unsafe { op.local_memory.write_at(offset, &data) }?;
867
868 offset += len;
869 }
870
871 Ok(())
872 }
873}
874
875#[async_trait]
876impl RdmaBackend for TcpBackend {
877 type TransportInfo = ();
878
879 async fn submit(
885 &mut self,
886 cx: &(impl context::Actor + Send + Sync),
887 ops: Vec<RdmaOp>,
888 timeout: Duration,
889 ) -> Result<()> {
890 let chunk_size =
891 hyperactor_config::global::get(crate::config::RDMA_MAX_CHUNK_SIZE_MB) * 1024 * 1024;
892 let parallelism =
893 hyperactor_config::global::get(crate::config::RDMA_TCP_FALLBACK_PARALLELISM);
894 let deadline = Instant::now() + timeout;
895
896 for op in ops {
897 let remaining = deadline.saturating_duration_since(Instant::now());
898 if remaining.is_zero() {
899 anyhow::bail!("tcp submit timed out");
900 }
901
902 let (remote_tcp_mgr, remote_buf_id) = op.remote.resolve_tcp()?;
903 let tcp_op = TcpOp {
904 op_type: op.op_type.clone(),
905 local_memory: op.local,
906 remote_tcp_manager: remote_tcp_mgr,
907 remote_buf_id,
908 };
909
910 if parallelism > 1 {
911 match tcp_op.op_type {
912 RdmaOpType::WriteFromLocal => {
913 self.execute_parallel_write(cx, &tcp_op, chunk_size, deadline)
914 .await?;
915 }
916 RdmaOpType::ReadIntoLocal => {
917 self.execute_parallel_read(cx, &tcp_op, chunk_size, deadline)
918 .await?;
919 }
920 }
921 } else {
922 match tcp_op.op_type {
923 RdmaOpType::WriteFromLocal => {
924 self.execute_write(cx, &tcp_op, chunk_size, deadline)
925 .await?;
926 }
927 RdmaOpType::ReadIntoLocal => {
928 self.execute_read(cx, &tcp_op, chunk_size, deadline).await?;
929 }
930 }
931 }
932 }
933
934 Ok(())
935 }
936
937 fn transport_level(&self) -> RdmaTransportLevel {
938 RdmaTransportLevel::Tcp
939 }
940
941 fn transport_info(&self) -> Option<Self::TransportInfo> {
942 None
943 }
944}
945
946#[cfg(test)]
947mod tests {
948 use std::sync::Arc;
949 use std::sync::atomic::AtomicUsize;
950 use std::sync::atomic::Ordering;
951 use std::time::Duration;
952
953 use hyperactor::ActorHandle;
954 use hyperactor::Proc;
955 use hyperactor::RemoteSpawn;
956 use hyperactor::channel::ChannelAddr;
957 use hyperactor_config::Flattrs;
958
959 use super::TcpBackend;
960 use super::TcpManagerActor;
961 use crate::RdmaManagerMessageClient;
962 use crate::RdmaOp;
963 use crate::RdmaOpType;
964 use crate::backend::RdmaBackend;
965 use crate::local_memory::KeepaliveLocalMemory;
966 use crate::rdma_manager_actor::GetTcpActorRefClient;
967 use crate::rdma_manager_actor::RdmaManagerActor;
968
969 static COUNTER: AtomicUsize = AtomicUsize::new(0);
970
971 struct TcpTestProcEnv {
972 proc: Proc,
973 rdma_handle: ActorHandle<RdmaManagerActor>,
974 instance: hyperactor::Instance<()>,
975 tcp_backend: TcpBackend,
976 rdma_remote_buf: crate::RdmaRemoteBuffer,
977 local_memory: Arc<KeepaliveLocalMemory>,
978 }
979
980 impl Drop for TcpTestProcEnv {
981 fn drop(&mut self) {
982 use crate::rdma_manager_actor::ReleaseBufferClient;
983 tokio::task::block_in_place(|| {
986 tokio::runtime::Handle::current()
987 .block_on(
988 self.rdma_remote_buf
989 .owner
990 .release_buffer(&self.instance, self.rdma_remote_buf.id),
991 )
992 .expect("failed to release buffer in TcpTestProcEnv drop");
993 });
994 }
995 }
996
997 impl TcpTestProcEnv {
998 async fn new(buffer_size: usize) -> anyhow::Result<Self> {
1000 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
1001 let proc = Proc::direct(
1002 ChannelAddr::any(hyperactor::channel::ChannelTransport::Unix),
1003 format!("tcp_test_{id}"),
1004 )?;
1005 let (instance, _) = proc.client("client")?;
1006
1007 let rdma_actor = RdmaManagerActor::new(None, Flattrs::default()).await?;
1008 let rdma_handle = proc.spawn("rdma_manager", rdma_actor)?;
1009
1010 let tcp_ref = rdma_handle.get_tcp_actor_ref(&instance).await?;
1011 let tcp_backend = TcpBackend(
1012 tcp_ref
1013 .downcast_handle(&instance)
1014 .ok_or_else(|| anyhow::anyhow!("tcp actor not local"))?,
1015 );
1016
1017 let (local_memory, rdma_remote_buf) =
1018 Self::alloc_cpu_buffer(&instance, &rdma_handle, buffer_size).await?;
1019
1020 Ok(Self {
1021 proc,
1022 rdma_handle,
1023 instance,
1024 tcp_backend,
1025 rdma_remote_buf,
1026 local_memory,
1027 })
1028 }
1029
1030 async fn on_proc(
1032 proc: &Proc,
1033 rdma_handle: &ActorHandle<RdmaManagerActor>,
1034 tcp_backend: TcpBackend,
1035 buffer_size: usize,
1036 ) -> anyhow::Result<Self> {
1037 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
1038 let (instance, _) = proc.client(&format!("client_{id}"))?;
1039
1040 let (local_memory, rdma_remote_buf) =
1041 Self::alloc_cpu_buffer(&instance, rdma_handle, buffer_size).await?;
1042
1043 Ok(Self {
1044 proc: proc.clone(),
1045 rdma_handle: rdma_handle.clone(),
1046 instance,
1047 tcp_backend,
1048 rdma_remote_buf,
1049 local_memory,
1050 })
1051 }
1052
1053 async fn alloc_cpu_buffer(
1054 instance: &hyperactor::Instance<()>,
1055 rdma_handle: &ActorHandle<RdmaManagerActor>,
1056 buffer_size: usize,
1057 ) -> anyhow::Result<(Arc<KeepaliveLocalMemory>, crate::RdmaRemoteBuffer)> {
1058 let cpu_buf = vec![0u8; buffer_size].into_boxed_slice();
1059 let ptr = cpu_buf.as_ptr() as usize;
1060 let local_memory: Arc<KeepaliveLocalMemory> = Arc::new(KeepaliveLocalMemory::new(
1061 ptr,
1062 buffer_size,
1063 Arc::new(cpu_buf),
1064 ));
1065 let rdma_remote_buf = rdma_handle
1066 .request_buffer(instance, local_memory.clone())
1067 .await?;
1068 Ok((local_memory, rdma_remote_buf))
1069 }
1070 }
1071
1072 async fn setup_tcp_env(buf_size: usize) -> anyhow::Result<Vec<TcpTestProcEnv>> {
1074 Ok(vec![
1075 TcpTestProcEnv::new(buf_size).await?,
1076 TcpTestProcEnv::new(buf_size).await?,
1077 ])
1078 }
1079
1080 async fn setup_same_proc_tcp_env(buf_size: usize) -> anyhow::Result<Vec<TcpTestProcEnv>> {
1082 let first = TcpTestProcEnv::new(buf_size).await?;
1083 let second = TcpTestProcEnv::on_proc(
1084 &first.proc,
1085 &first.rdma_handle,
1086 first.tcp_backend.clone(),
1087 buf_size,
1088 )
1089 .await?;
1090 Ok(vec![first, second])
1091 }
1092
1093 async fn setup_tcp_env_pairs(buf_size: usize) -> anyhow::Result<Vec<TcpTestProcEnv>> {
1096 let e0 = TcpTestProcEnv::new(buf_size).await?;
1097 let e1 = TcpTestProcEnv::new(buf_size).await?;
1098 let e2 =
1099 TcpTestProcEnv::on_proc(&e0.proc, &e0.rdma_handle, e0.tcp_backend.clone(), buf_size)
1100 .await?;
1101 let e3 =
1102 TcpTestProcEnv::on_proc(&e1.proc, &e1.rdma_handle, e1.tcp_backend.clone(), buf_size)
1103 .await?;
1104 Ok(vec![e0, e1, e2, e3])
1105 }
1106
1107 fn test_write(mem: &KeepaliveLocalMemory, offset: usize, src: &[u8]) -> anyhow::Result<()> {
1116 unsafe { mem.write_at(offset, src) }
1118 }
1119
1120 fn test_read(mem: &KeepaliveLocalMemory, offset: usize, dst: &mut [u8]) -> anyhow::Result<()> {
1123 unsafe { mem.read_at(offset, dst) }
1125 }
1126
1127 async fn do_write_test(
1129 envs: &mut [TcpTestProcEnv],
1130 buf_size: usize,
1131 timeout: Duration,
1132 ) -> anyhow::Result<()> {
1133 let mut src = vec![0u8; buf_size];
1134 for (i, byte) in src.iter_mut().enumerate() {
1135 *byte = (i % 256) as u8;
1136 }
1137 test_write(&envs[0].local_memory, 0, &src)?;
1138
1139 let remote = envs[1].rdma_remote_buf.clone();
1140 let env = &mut envs[0];
1141 env.tcp_backend
1142 .submit(
1143 &env.instance,
1144 vec![RdmaOp {
1145 op_type: RdmaOpType::WriteFromLocal,
1146 local: env.local_memory.clone(),
1147 remote,
1148 }],
1149 timeout,
1150 )
1151 .await?;
1152
1153 let mut dst = vec![0u8; buf_size];
1154 test_read(&envs[1].local_memory, 0, &mut dst)?;
1155 for (i, byte) in dst.iter().enumerate() {
1156 assert_eq!(*byte, (i % 256) as u8, "mismatch at offset {i} after write");
1157 }
1158 Ok(())
1159 }
1160
1161 async fn do_read_test(
1163 envs: &mut [TcpTestProcEnv],
1164 buf_size: usize,
1165 timeout: Duration,
1166 ) -> anyhow::Result<()> {
1167 let mut src = vec![0u8; buf_size];
1168 for (i, byte) in src.iter_mut().enumerate() {
1169 *byte = ((i * 7 + 3) % 256) as u8;
1170 }
1171 test_write(&envs[1].local_memory, 0, &src)?;
1172
1173 let remote = envs[1].rdma_remote_buf.clone();
1174 let env = &mut envs[0];
1175 env.tcp_backend
1176 .submit(
1177 &env.instance,
1178 vec![RdmaOp {
1179 op_type: RdmaOpType::ReadIntoLocal,
1180 local: env.local_memory.clone(),
1181 remote,
1182 }],
1183 timeout,
1184 )
1185 .await?;
1186
1187 let mut dst = vec![0u8; buf_size];
1188 test_read(&envs[0].local_memory, 0, &mut dst)?;
1189 for (i, byte) in dst.iter().enumerate() {
1190 assert_eq!(
1191 *byte,
1192 ((i * 7 + 3) % 256) as u8,
1193 "mismatch at offset {i} after read"
1194 );
1195 }
1196 Ok(())
1197 }
1198
1199 async fn do_round_trip_test(
1201 envs: &mut [TcpTestProcEnv],
1202 buf_size: usize,
1203 timeout: Duration,
1204 ) -> anyhow::Result<()> {
1205 let mut src = vec![0u8; buf_size];
1206 for (i, byte) in src.iter_mut().enumerate() {
1207 *byte = ((i * 13 + 5) % 256) as u8;
1208 }
1209 test_write(&envs[0].local_memory, 0, &src)?;
1210
1211 let remote = envs[1].rdma_remote_buf.clone();
1212 let env = &mut envs[0];
1213 env.tcp_backend
1214 .submit(
1215 &env.instance,
1216 vec![RdmaOp {
1217 op_type: RdmaOpType::WriteFromLocal,
1218 local: env.local_memory.clone(),
1219 remote: remote.clone(),
1220 }],
1221 timeout,
1222 )
1223 .await?;
1224
1225 test_write(&envs[0].local_memory, 0, &vec![0u8; buf_size])?;
1226
1227 let env = &mut envs[0];
1228 env.tcp_backend
1229 .submit(
1230 &env.instance,
1231 vec![RdmaOp {
1232 op_type: RdmaOpType::ReadIntoLocal,
1233 local: env.local_memory.clone(),
1234 remote,
1235 }],
1236 timeout,
1237 )
1238 .await?;
1239
1240 let mut dst = vec![0u8; buf_size];
1241 test_read(&envs[0].local_memory, 0, &mut dst)?;
1242 for (i, byte) in dst.iter().enumerate() {
1243 assert_eq!(
1244 *byte,
1245 ((i * 13 + 5) % 256) as u8,
1246 "mismatch at offset {i} after round-trip"
1247 );
1248 }
1249 Ok(())
1250 }
1251
1252 #[timed_test::async_timed_test(timeout_secs = 30)]
1256 async fn test_tcp_write_from_local() -> anyhow::Result<()> {
1257 let config = hyperactor_config::global::lock();
1258 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1259
1260 let mut envs = setup_tcp_env(4096).await?;
1261 do_write_test(&mut envs, 4096, Duration::from_secs(30)).await
1262 }
1263
1264 #[timed_test::async_timed_test(timeout_secs = 30)]
1266 async fn test_tcp_read_into_local() -> anyhow::Result<()> {
1267 let config = hyperactor_config::global::lock();
1268 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1269
1270 let mut envs = setup_tcp_env(2048).await?;
1271 do_read_test(&mut envs, 2048, Duration::from_secs(30)).await
1272 }
1273
1274 #[timed_test::async_timed_test(timeout_secs = 30)]
1276 async fn test_tcp_write_then_read_back() -> anyhow::Result<()> {
1277 let config = hyperactor_config::global::lock();
1278 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1279
1280 let mut envs = setup_tcp_env(4096).await?;
1281 do_round_trip_test(&mut envs, 4096, Duration::from_secs(30)).await
1282 }
1283
1284 #[timed_test::async_timed_test(timeout_secs = 30)]
1286 async fn test_tcp_multi_chunk_write() -> anyhow::Result<()> {
1287 let config = hyperactor_config::global::lock();
1288 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1289 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1290
1291 let buf_size = 3 * 1024 * 512;
1292 let mut envs = setup_tcp_env(buf_size).await?;
1293
1294 let mut src = vec![0u8; buf_size];
1295 for (i, byte) in src.iter_mut().enumerate() {
1296 *byte = (i % 251) as u8;
1297 }
1298 test_write(&envs[0].local_memory, 0, &src)?;
1299
1300 let remote = envs[1].rdma_remote_buf.clone();
1301 let env = &mut envs[0];
1302 env.tcp_backend
1303 .submit(
1304 &env.instance,
1305 vec![RdmaOp {
1306 op_type: RdmaOpType::WriteFromLocal,
1307 local: env.local_memory.clone(),
1308 remote,
1309 }],
1310 Duration::from_secs(30),
1311 )
1312 .await?;
1313
1314 let mut dst = vec![0u8; buf_size];
1315 test_read(&envs[1].local_memory, 0, &mut dst)?;
1316 for (i, byte) in dst.iter().enumerate() {
1317 assert_eq!(*byte, (i % 251) as u8, "mismatch at offset {i}");
1318 }
1319
1320 Ok(())
1321 }
1322
1323 #[timed_test::async_timed_test(timeout_secs = 30)]
1325 async fn test_tcp_multi_chunk_read() -> anyhow::Result<()> {
1326 let config = hyperactor_config::global::lock();
1327 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1328 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1329
1330 let buf_size = 3 * 1024 * 512;
1331 let mut envs = setup_tcp_env(buf_size).await?;
1332
1333 let mut src = vec![0u8; buf_size];
1334 for (i, byte) in src.iter_mut().enumerate() {
1335 *byte = ((i * 3 + 17) % 256) as u8;
1336 }
1337 test_write(&envs[1].local_memory, 0, &src)?;
1338
1339 let remote = envs[1].rdma_remote_buf.clone();
1340 let env = &mut envs[0];
1341 env.tcp_backend
1342 .submit(
1343 &env.instance,
1344 vec![RdmaOp {
1345 op_type: RdmaOpType::ReadIntoLocal,
1346 local: env.local_memory.clone(),
1347 remote,
1348 }],
1349 Duration::from_secs(30),
1350 )
1351 .await?;
1352
1353 let mut dst = vec![0u8; buf_size];
1354 test_read(&envs[0].local_memory, 0, &mut dst)?;
1355 for (i, byte) in dst.iter().enumerate() {
1356 assert_eq!(*byte, ((i * 3 + 17) % 256) as u8, "mismatch at offset {i}");
1357 }
1358
1359 Ok(())
1360 }
1361
1362 #[timed_test::async_timed_test(timeout_secs = 30)]
1364 async fn test_tcp_multi_chunk_round_trip() -> anyhow::Result<()> {
1365 let config = hyperactor_config::global::lock();
1366 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1367 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1368
1369 let buf_size = 5 * 1024 * 512;
1370 let mut envs = setup_tcp_env(buf_size).await?;
1371
1372 let mut src = vec![0u8; buf_size];
1373 for (i, byte) in src.iter_mut().enumerate() {
1374 *byte = ((i * 41 + 7) % 256) as u8;
1375 }
1376 test_write(&envs[0].local_memory, 0, &src)?;
1377
1378 let remote = envs[1].rdma_remote_buf.clone();
1379 let env = &mut envs[0];
1380 env.tcp_backend
1381 .submit(
1382 &env.instance,
1383 vec![RdmaOp {
1384 op_type: RdmaOpType::WriteFromLocal,
1385 local: env.local_memory.clone(),
1386 remote: remote.clone(),
1387 }],
1388 Duration::from_secs(30),
1389 )
1390 .await?;
1391
1392 test_write(&envs[0].local_memory, 0, &vec![0u8; buf_size])?;
1393
1394 let env = &mut envs[0];
1395 env.tcp_backend
1396 .submit(
1397 &env.instance,
1398 vec![RdmaOp {
1399 op_type: RdmaOpType::ReadIntoLocal,
1400 local: env.local_memory.clone(),
1401 remote,
1402 }],
1403 Duration::from_secs(30),
1404 )
1405 .await?;
1406
1407 let mut dst = vec![0u8; buf_size];
1408 test_read(&envs[0].local_memory, 0, &mut dst)?;
1409 for (i, byte) in dst.iter().enumerate() {
1410 assert_eq!(
1411 *byte,
1412 ((i * 41 + 7) % 256) as u8,
1413 "mismatch at offset {i} after multi-chunk round-trip"
1414 );
1415 }
1416
1417 Ok(())
1418 }
1419
1420 #[timed_test::async_timed_test(timeout_secs = 30)]
1422 async fn test_tcp_resolve_tcp() -> anyhow::Result<()> {
1423 let config = hyperactor_config::global::lock();
1424 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1425
1426 let envs = setup_tcp_env(64).await?;
1427
1428 for (i, env) in envs.iter().enumerate() {
1429 let (tcp_ref, id) = env.rdma_remote_buf.resolve_tcp()?;
1430 assert_eq!(id, env.rdma_remote_buf.id, "buf id mismatch for env {i}");
1431 let expected: hyperactor::ActorRef<TcpManagerActor> = env.tcp_backend.bind();
1432 assert_eq!(tcp_ref.actor_addr(), expected.actor_addr());
1433 }
1434
1435 Ok(())
1436 }
1437
1438 #[timed_test::async_timed_test(timeout_secs = 30)]
1440 async fn test_tcp_write_to_released_buffer() -> anyhow::Result<()> {
1441 let config = hyperactor_config::global::lock();
1442 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1443
1444 let buf_size = 64;
1445 let mut envs = setup_tcp_env(buf_size).await?;
1446
1447 let mut src = vec![0u8; buf_size];
1448 for (i, byte) in src.iter_mut().enumerate() {
1449 *byte = (i % 256) as u8;
1450 }
1451 test_write(&envs[0].local_memory, 0, &src)?;
1452
1453 let remote = envs[1].rdma_remote_buf.clone();
1455 let env = &mut envs[0];
1456 env.tcp_backend
1457 .submit(
1458 &env.instance,
1459 vec![RdmaOp {
1460 op_type: RdmaOpType::WriteFromLocal,
1461 local: env.local_memory.clone(),
1462 remote: remote.clone(),
1463 }],
1464 Duration::from_secs(10),
1465 )
1466 .await?;
1467
1468 use crate::rdma_manager_actor::ReleaseBufferClient;
1470 let owner_ref = envs[1].rdma_remote_buf.owner.clone();
1471 owner_ref
1472 .release_buffer(&envs[0].instance, envs[1].rdma_remote_buf.id)
1473 .await?;
1474
1475 let env = &mut envs[0];
1477 let result = env
1478 .tcp_backend
1479 .submit(
1480 &env.instance,
1481 vec![RdmaOp {
1482 op_type: RdmaOpType::WriteFromLocal,
1483 local: env.local_memory.clone(),
1484 remote: remote.clone(),
1485 }],
1486 Duration::from_secs(10),
1487 )
1488 .await;
1489 assert!(result.is_err(), "expected error writing to released buffer");
1490
1491 Ok(())
1492 }
1493
1494 #[timed_test::async_timed_test(timeout_secs = 30)]
1496 async fn test_tcp_read_from_released_buffer() -> anyhow::Result<()> {
1497 let config = hyperactor_config::global::lock();
1498 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1499
1500 let buf_size = 64;
1501 let mut envs = setup_tcp_env(buf_size).await?;
1502
1503 use crate::rdma_manager_actor::ReleaseBufferClient;
1505 let owner_ref = envs[1].rdma_remote_buf.owner.clone();
1506 owner_ref
1507 .release_buffer(&envs[0].instance, envs[1].rdma_remote_buf.id)
1508 .await?;
1509
1510 let remote = envs[1].rdma_remote_buf.clone();
1512 let env = &mut envs[0];
1513 let result = env
1514 .tcp_backend
1515 .submit(
1516 &env.instance,
1517 vec![RdmaOp {
1518 op_type: RdmaOpType::ReadIntoLocal,
1519 local: env.local_memory.clone(),
1520 remote,
1521 }],
1522 Duration::from_secs(10),
1523 )
1524 .await;
1525 assert!(
1526 result.is_err(),
1527 "expected error reading from released buffer"
1528 );
1529
1530 Ok(())
1531 }
1532
1533 #[timed_test::async_timed_test(timeout_secs = 30)]
1537 async fn test_tcp_same_process_write() -> anyhow::Result<()> {
1538 let config = hyperactor_config::global::lock();
1539 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1540
1541 let mut envs = setup_same_proc_tcp_env(4096).await?;
1542 do_write_test(&mut envs, 4096, Duration::from_secs(10)).await
1543 }
1544
1545 #[timed_test::async_timed_test(timeout_secs = 30)]
1547 async fn test_tcp_same_process_read() -> anyhow::Result<()> {
1548 let config = hyperactor_config::global::lock();
1549 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1550
1551 let mut envs = setup_same_proc_tcp_env(2048).await?;
1552 do_read_test(&mut envs, 2048, Duration::from_secs(10)).await
1553 }
1554
1555 #[timed_test::async_timed_test(timeout_secs = 30)]
1557 async fn test_tcp_same_process_round_trip() -> anyhow::Result<()> {
1558 let config = hyperactor_config::global::lock();
1559 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1560
1561 let mut envs = setup_same_proc_tcp_env(4096).await?;
1562 do_round_trip_test(&mut envs, 4096, Duration::from_secs(10)).await
1563 }
1564
1565 #[timed_test::async_timed_test(timeout_secs = 30)]
1568 async fn test_tcp_fallback_disabled_fails() -> anyhow::Result<()> {
1569 let config = hyperactor_config::global::lock();
1570 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, false);
1571
1572 let result = RdmaManagerActor::new(None, Flattrs::default()).await;
1573 if crate::ibverbs_supported() {
1574 assert!(result.is_ok());
1575 } else {
1576 assert!(result.is_err());
1577 }
1578
1579 Ok(())
1580 }
1581
1582 use crate::backend::cuda_test_utils::CudaAllocator;
1585 use crate::backend::cuda_test_utils::cuda_device_count;
1586
1587 impl TcpTestProcEnv {
1588 async fn new_gpu(device: i32, buffer_size: usize) -> anyhow::Result<Self> {
1590 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
1591 let proc = Proc::direct(
1592 ChannelAddr::any(hyperactor::channel::ChannelTransport::Unix),
1593 format!("tcp_gpu_test_{id}"),
1594 )?;
1595 let (instance, _) = proc.client("client")?;
1596
1597 let rdma_actor = RdmaManagerActor::new(None, Flattrs::default()).await?;
1598 let rdma_handle = proc.spawn("rdma_manager", rdma_actor)?;
1599
1600 let tcp_ref = rdma_handle.get_tcp_actor_ref(&instance).await?;
1601 let tcp_backend = TcpBackend(
1602 tcp_ref
1603 .downcast_handle(&instance)
1604 .ok_or_else(|| anyhow::anyhow!("tcp actor not local"))?,
1605 );
1606
1607 let alloc = CudaAllocator::get().allocate(device, buffer_size, buffer_size);
1608 let local_memory: Arc<KeepaliveLocalMemory> = Arc::new(KeepaliveLocalMemory::new(
1609 alloc.ptr(),
1610 buffer_size,
1611 Arc::new(alloc),
1612 ));
1613 let rdma_remote_buf = rdma_handle
1614 .request_buffer(&instance, local_memory.clone())
1615 .await?;
1616
1617 Ok(Self {
1618 proc,
1619 rdma_handle,
1620 instance,
1621 tcp_backend,
1622 rdma_remote_buf,
1623 local_memory,
1624 })
1625 }
1626 }
1627
1628 #[timed_test::async_timed_test(timeout_secs = 60)]
1630 async fn test_tcp_write_multi_gpu() -> anyhow::Result<()> {
1631 if cuda_device_count() < 2 {
1632 println!("Skipping: need at least 2 CUDA devices");
1633 return Ok(());
1634 }
1635
1636 let config = hyperactor_config::global::lock();
1637 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1638
1639 let buf_size = 2 * 1024 * 1024;
1640 let mut envs = vec![
1641 TcpTestProcEnv::new_gpu(0, buf_size).await?,
1642 TcpTestProcEnv::new_gpu(1, buf_size).await?,
1643 ];
1644 do_write_test(&mut envs, buf_size, Duration::from_secs(30)).await
1645 }
1646
1647 #[timed_test::async_timed_test(timeout_secs = 60)]
1649 async fn test_tcp_read_multi_gpu() -> anyhow::Result<()> {
1650 if cuda_device_count() < 2 {
1651 println!("Skipping: need at least 2 CUDA devices");
1652 return Ok(());
1653 }
1654
1655 let config = hyperactor_config::global::lock();
1656 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1657
1658 let buf_size = 2 * 1024 * 1024;
1659 let mut envs = vec![
1660 TcpTestProcEnv::new_gpu(0, buf_size).await?,
1661 TcpTestProcEnv::new_gpu(1, buf_size).await?,
1662 ];
1663 do_read_test(&mut envs, buf_size, Duration::from_secs(30)).await
1664 }
1665
1666 #[timed_test::async_timed_test(timeout_secs = 60)]
1668 async fn test_tcp_round_trip_multi_gpu() -> anyhow::Result<()> {
1669 if cuda_device_count() < 2 {
1670 println!("Skipping: need at least 2 CUDA devices");
1671 return Ok(());
1672 }
1673
1674 let config = hyperactor_config::global::lock();
1675 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1676
1677 let buf_size = 2 * 1024 * 1024;
1678 let mut envs = vec![
1679 TcpTestProcEnv::new_gpu(0, buf_size).await?,
1680 TcpTestProcEnv::new_gpu(1, buf_size).await?,
1681 ];
1682 do_round_trip_test(&mut envs, buf_size, Duration::from_secs(30)).await
1683 }
1684
1685 #[timed_test::async_timed_test(timeout_secs = 30)]
1688 async fn test_tcp_parallel_clean_shutdown() -> anyhow::Result<()> {
1689 let config = hyperactor_config::global::lock();
1690 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1691 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1692 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1693
1694 let buf_size = 3 * 1024 * 1024;
1695 let mut envs = setup_tcp_env(buf_size).await?;
1696
1697 let mut src = vec![0u8; buf_size];
1699 for (i, byte) in src.iter_mut().enumerate() {
1700 *byte = (i % 256) as u8;
1701 }
1702 test_write(&envs[0].local_memory, 0, &src)?;
1703 let remote = envs[1].rdma_remote_buf.clone();
1704 let env = &mut envs[0];
1705 env.tcp_backend
1706 .submit(
1707 &env.instance,
1708 vec![RdmaOp {
1709 op_type: RdmaOpType::WriteFromLocal,
1710 local: env.local_memory.clone(),
1711 remote: remote.clone(),
1712 }],
1713 Duration::from_secs(30),
1714 )
1715 .await?;
1716
1717 envs[0].rdma_handle.drain_and_stop("test")?;
1720 envs[0].rdma_handle.clone().await;
1721 envs[1].rdma_handle.drain_and_stop("test")?;
1722 envs[1].rdma_handle.clone().await;
1723
1724 Ok(())
1725 }
1726
1727 #[timed_test::async_timed_test(timeout_secs = 30)]
1731 async fn test_tcp_parallel_write() -> anyhow::Result<()> {
1732 let config = hyperactor_config::global::lock();
1733 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1734 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1735 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1736
1737 let buf_size = 3 * 1024 * 1024;
1739 let mut envs = setup_tcp_env(buf_size).await?;
1740 do_write_test(&mut envs, buf_size, Duration::from_secs(30)).await
1741 }
1742
1743 #[timed_test::async_timed_test(timeout_secs = 30)]
1745 async fn test_tcp_parallel_read() -> anyhow::Result<()> {
1746 let config = hyperactor_config::global::lock();
1747 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1748 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1749 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1750
1751 let buf_size = 3 * 1024 * 1024;
1752 let mut envs = setup_tcp_env(buf_size).await?;
1753 do_read_test(&mut envs, buf_size, Duration::from_secs(30)).await
1754 }
1755
1756 #[timed_test::async_timed_test(timeout_secs = 30)]
1758 async fn test_tcp_parallel_round_trip() -> anyhow::Result<()> {
1759 let config = hyperactor_config::global::lock();
1760 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1761 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1762 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1763
1764 let buf_size = 3 * 1024 * 1024;
1765 let mut envs = setup_tcp_env(buf_size).await?;
1766 do_round_trip_test(&mut envs, buf_size, Duration::from_secs(30)).await
1767 }
1768
1769 #[timed_test::async_timed_test(timeout_secs = 30)]
1771 async fn test_tcp_parallel_same_process_write() -> anyhow::Result<()> {
1772 let config = hyperactor_config::global::lock();
1773 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1774 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1775 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1776
1777 let buf_size = 3 * 1024 * 1024;
1778 let mut envs = setup_same_proc_tcp_env(buf_size).await?;
1779 do_write_test(&mut envs, buf_size, Duration::from_secs(10)).await
1780 }
1781
1782 #[timed_test::async_timed_test(timeout_secs = 30)]
1784 async fn test_tcp_parallel_same_process_read() -> anyhow::Result<()> {
1785 let config = hyperactor_config::global::lock();
1786 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1787 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1788 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1789
1790 let buf_size = 3 * 1024 * 1024;
1791 let mut envs = setup_same_proc_tcp_env(buf_size).await?;
1792 do_read_test(&mut envs, buf_size, Duration::from_secs(10)).await
1793 }
1794
1795 #[timed_test::async_timed_test(timeout_secs = 30)]
1799 async fn test_tcp_parallel_concurrent_writes() -> anyhow::Result<()> {
1800 let config = hyperactor_config::global::lock();
1801 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1802 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1803 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1804
1805 let buf_size = 3 * 1024 * 1024;
1806 let envs = setup_tcp_env_pairs(buf_size).await?;
1807
1808 let mut src0 = vec![0u8; buf_size];
1810 for (i, byte) in src0.iter_mut().enumerate() {
1811 *byte = (i % 256) as u8;
1812 }
1813 test_write(&envs[0].local_memory, 0, &src0)?;
1814 let mut src2 = vec![0u8; buf_size];
1815 for (i, byte) in src2.iter_mut().enumerate() {
1816 *byte = ((i * 3 + 7) % 256) as u8;
1817 }
1818 test_write(&envs[2].local_memory, 0, &src2)?;
1819
1820 let remote_1 = envs[1].rdma_remote_buf.clone();
1822 let remote_3 = envs[3].rdma_remote_buf.clone();
1823 let mut h0 = envs[0].tcp_backend.clone();
1824 let mut h2 = envs[2].tcp_backend.clone();
1825 let inst_0 = &envs[0].instance;
1826 let inst_2 = &envs[2].instance;
1827 let mem_0 = envs[0].local_memory.clone();
1828 let mem_2 = envs[2].local_memory.clone();
1829 let (r1, r2) = tokio::join!(
1830 h0.submit(
1831 inst_0,
1832 vec![RdmaOp {
1833 op_type: RdmaOpType::WriteFromLocal,
1834 local: mem_0,
1835 remote: remote_1,
1836 }],
1837 Duration::from_secs(30),
1838 ),
1839 h2.submit(
1840 inst_2,
1841 vec![RdmaOp {
1842 op_type: RdmaOpType::WriteFromLocal,
1843 local: mem_2,
1844 remote: remote_3,
1845 }],
1846 Duration::from_secs(30),
1847 ),
1848 );
1849 r1?;
1850 r2?;
1851
1852 let mut dst1 = vec![0u8; buf_size];
1853 test_read(&envs[1].local_memory, 0, &mut dst1)?;
1854 for (i, byte) in dst1.iter().enumerate() {
1855 assert_eq!(*byte, (i % 256) as u8, "pair 1 mismatch at offset {i}");
1856 }
1857 let mut dst3 = vec![0u8; buf_size];
1858 test_read(&envs[3].local_memory, 0, &mut dst3)?;
1859 for (i, byte) in dst3.iter().enumerate() {
1860 assert_eq!(
1861 *byte,
1862 ((i * 3 + 7) % 256) as u8,
1863 "pair 2 mismatch at offset {i}"
1864 );
1865 }
1866
1867 Ok(())
1868 }
1869
1870 #[timed_test::async_timed_test(timeout_secs = 30)]
1872 async fn test_tcp_parallel_concurrent_reads() -> anyhow::Result<()> {
1873 let config = hyperactor_config::global::lock();
1874 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1875 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1876 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1877
1878 let buf_size = 3 * 1024 * 1024;
1879 let envs = setup_tcp_env_pairs(buf_size).await?;
1880
1881 let mut src1 = vec![0u8; buf_size];
1883 for (i, byte) in src1.iter_mut().enumerate() {
1884 *byte = ((i * 11 + 3) % 256) as u8;
1885 }
1886 test_write(&envs[1].local_memory, 0, &src1)?;
1887 let mut src3 = vec![0u8; buf_size];
1888 for (i, byte) in src3.iter_mut().enumerate() {
1889 *byte = ((i * 5 + 13) % 256) as u8;
1890 }
1891 test_write(&envs[3].local_memory, 0, &src3)?;
1892
1893 let remote_1 = envs[1].rdma_remote_buf.clone();
1895 let remote_3 = envs[3].rdma_remote_buf.clone();
1896 let mut h0 = envs[0].tcp_backend.clone();
1897 let mut h2 = envs[2].tcp_backend.clone();
1898 let inst_0 = &envs[0].instance;
1899 let inst_2 = &envs[2].instance;
1900 let mem_0 = envs[0].local_memory.clone();
1901 let mem_2 = envs[2].local_memory.clone();
1902 let (r1, r2) = tokio::join!(
1903 h0.submit(
1904 inst_0,
1905 vec![RdmaOp {
1906 op_type: RdmaOpType::ReadIntoLocal,
1907 local: mem_0,
1908 remote: remote_1,
1909 }],
1910 Duration::from_secs(30),
1911 ),
1912 h2.submit(
1913 inst_2,
1914 vec![RdmaOp {
1915 op_type: RdmaOpType::ReadIntoLocal,
1916 local: mem_2,
1917 remote: remote_3,
1918 }],
1919 Duration::from_secs(30),
1920 ),
1921 );
1922 r1?;
1923 r2?;
1924
1925 let mut dst0 = vec![0u8; buf_size];
1926 test_read(&envs[0].local_memory, 0, &mut dst0)?;
1927 for (i, byte) in dst0.iter().enumerate() {
1928 assert_eq!(
1929 *byte,
1930 ((i * 11 + 3) % 256) as u8,
1931 "pair 1 mismatch at offset {i}"
1932 );
1933 }
1934 let mut dst2 = vec![0u8; buf_size];
1935 test_read(&envs[2].local_memory, 0, &mut dst2)?;
1936 for (i, byte) in dst2.iter().enumerate() {
1937 assert_eq!(
1938 *byte,
1939 ((i * 5 + 13) % 256) as u8,
1940 "pair 2 mismatch at offset {i}"
1941 );
1942 }
1943
1944 Ok(())
1945 }
1946
1947 #[timed_test::async_timed_test(timeout_secs = 30)]
1949 async fn test_tcp_parallel_concurrent_write_and_read() -> anyhow::Result<()> {
1950 let config = hyperactor_config::global::lock();
1951 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1952 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1953 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1954
1955 let buf_size = 3 * 1024 * 1024;
1956 let envs = setup_tcp_env_pairs(buf_size).await?;
1957
1958 let mut src0 = vec![0u8; buf_size];
1960 for (i, byte) in src0.iter_mut().enumerate() {
1961 *byte = (i % 256) as u8;
1962 }
1963 test_write(&envs[0].local_memory, 0, &src0)?;
1964 let mut src3 = vec![0u8; buf_size];
1965 for (i, byte) in src3.iter_mut().enumerate() {
1966 *byte = ((i * 7 + 13) % 256) as u8;
1967 }
1968 test_write(&envs[3].local_memory, 0, &src3)?;
1969
1970 let remote_1 = envs[1].rdma_remote_buf.clone();
1972 let remote_3 = envs[3].rdma_remote_buf.clone();
1973 let mut h0 = envs[0].tcp_backend.clone();
1974 let mut h2 = envs[2].tcp_backend.clone();
1975 let inst_0 = &envs[0].instance;
1976 let inst_2 = &envs[2].instance;
1977 let mem_0 = envs[0].local_memory.clone();
1978 let mem_2 = envs[2].local_memory.clone();
1979 let (write_result, read_result) = tokio::join!(
1980 h0.submit(
1981 inst_0,
1982 vec![RdmaOp {
1983 op_type: RdmaOpType::WriteFromLocal,
1984 local: mem_0,
1985 remote: remote_1,
1986 }],
1987 Duration::from_secs(30),
1988 ),
1989 h2.submit(
1990 inst_2,
1991 vec![RdmaOp {
1992 op_type: RdmaOpType::ReadIntoLocal,
1993 local: mem_2,
1994 remote: remote_3,
1995 }],
1996 Duration::from_secs(30),
1997 ),
1998 );
1999 write_result?;
2000 read_result?;
2001
2002 let mut dst1 = vec![0u8; buf_size];
2003 test_read(&envs[1].local_memory, 0, &mut dst1)?;
2004 for (i, byte) in dst1.iter().enumerate() {
2005 assert_eq!(*byte, (i % 256) as u8, "write mismatch at offset {i}");
2006 }
2007 let mut dst2 = vec![0u8; buf_size];
2008 test_read(&envs[2].local_memory, 0, &mut dst2)?;
2009 for (i, byte) in dst2.iter().enumerate() {
2010 assert_eq!(
2011 *byte,
2012 ((i * 7 + 13) % 256) as u8,
2013 "read mismatch at offset {i}"
2014 );
2015 }
2016
2017 Ok(())
2018 }
2019}