1use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::OnceLock;
18use std::time::Duration;
19use std::time::Instant;
20
21use anyhow::Result;
22use async_trait::async_trait;
23use bytes::Bytes;
24use bytes::BytesMut;
25use dashmap::DashMap;
26use hyperactor::Actor;
27use hyperactor::ActorHandle;
28use hyperactor::Context;
29use hyperactor::HandleClient;
30use hyperactor::Handler;
31use hyperactor::Instance;
32use hyperactor::OncePortHandle;
33use hyperactor::PortHandle;
34use hyperactor::RefClient;
35use hyperactor::actor::ActorError;
36use hyperactor::channel;
37use hyperactor::channel::ChannelAddr;
38use hyperactor::channel::ChannelTx;
39use hyperactor::channel::Rx;
40use hyperactor::channel::Tx;
41use hyperactor::context;
42use hyperactor::context::Actor as _;
43use hyperactor::reference;
44use hyperactor::reference::OncePortRef;
45use hyperactor_mesh::transport::default_transport;
46use serde::Deserialize;
47use serde::Serialize;
48use serde_multipart::Part;
49use tokio::time::timeout as tokio_timeout;
50use tokio_util::sync::CancellationToken;
51use typeuri::Named;
52
53use super::TcpOp;
54use crate::RdmaLocalMemory;
55use crate::RdmaOp;
56use crate::RdmaOpType;
57use crate::RdmaTransportLevel;
58use crate::backend::RdmaBackend;
59use crate::rdma_manager_actor::GetTcpActorRefClient;
60use crate::rdma_manager_actor::RdmaManagerActor;
61use crate::rdma_manager_actor::RdmaManagerMessageClient;
62
63#[derive(Debug, Clone, Serialize, Deserialize, Named)]
68pub struct TcpChunk(Part);
69wirevalue::register_type!(TcpChunk);
70
71#[derive(Debug, Clone, Serialize, Deserialize, Named)]
73struct TcpDataChunk {
74 transfer_id: usize,
76 offset: usize,
78 data: Part,
79}
80wirevalue::register_type!(TcpDataChunk);
81
82#[derive(Debug)]
87struct TransferState {
88 local_memory: Arc<dyn RdmaLocalMemory>,
90
91 chunks_received: usize,
93
94 total_chunks: usize,
96
97 done: OncePortRef<Result<(), String>>,
100}
101
102impl TransferState {
103 fn new(
104 total_chunks: usize,
105 local_memory: Arc<dyn RdmaLocalMemory>,
106 done: OncePortRef<Result<(), String>>,
107 ) -> Self {
108 Self {
109 local_memory,
110 chunks_received: 0,
111 total_chunks,
112 done,
113 }
114 }
115}
116
117#[derive(Debug, Serialize, Deserialize, Named)]
127struct SendTransferResult {
128 done: OncePortRef<Result<(), String>>,
129 result: Result<(), String>,
130}
131
132#[derive(Debug, Serialize, Deserialize, Named)]
137struct TransferError {
138 message: String,
139}
140
141#[derive(Debug)]
144struct RegisterTransferLocal {
145 local_memory: Arc<dyn RdmaLocalMemory>,
146 total_chunks: usize,
147 done: OncePortRef<Result<(), String>>,
148 reply: OncePortHandle<usize>,
150}
151
152#[derive(Debug)]
155struct ExecuteTransferLocal {
156 transfer_id: usize,
157 local_memory: Arc<dyn RdmaLocalMemory>,
158 chunk_size: usize,
159 dest_addr: ChannelAddr,
160}
161
162#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
167enum TcpManagerMessage {
168 WriteChunk {
170 buf_id: usize,
171 offset: usize,
172 data: Part,
173 #[reply]
174 reply: OncePortRef<Result<(), String>>,
175 },
176 ReadChunk {
178 buf_id: usize,
179 offset: usize,
180 size: usize,
181 #[reply]
182 reply: OncePortRef<Result<TcpChunk, String>>,
183 },
184 GetChannelAddress {
187 #[reply]
188 reply: OncePortRef<Option<ChannelAddr>>,
189 },
190 RegisterTransferRemote {
193 buf_id: usize,
194 total_chunks: usize,
195 done: OncePortRef<Result<(), String>>,
196 #[reply]
197 reply: OncePortRef<Result<usize, String>>,
198 },
199 ExecuteTransferRemote {
202 transfer_id: usize,
203 buf_id: usize,
204 chunk_size: usize,
205 dest_addr: ChannelAddr,
206 #[reply]
207 reply: OncePortRef<Result<(), String>>,
208 },
209}
210wirevalue::register_type!(TcpManagerMessage);
211
212#[derive(Debug)]
217#[hyperactor::export(
218 handlers = [TcpManagerMessage],
219)]
220pub struct TcpManagerActor {
221 owner: OnceLock<ActorHandle<RdmaManagerActor>>,
222 next_transfer_id: usize,
223 transfers: Arc<DashMap<usize, TransferState>>,
224 channel_addr: Option<ChannelAddr>,
227 outbound: HashMap<ChannelAddr, Vec<Arc<ChannelTx<TcpDataChunk>>>>,
229 cancel: CancellationToken,
231 receiver_done: Option<tokio::sync::oneshot::Receiver<()>>,
233}
234
235impl TcpManagerActor {
236 pub fn new() -> Self {
237 Self {
238 owner: OnceLock::new(),
239 next_transfer_id: 0,
240 transfers: Arc::new(DashMap::new()),
241 channel_addr: None,
242 outbound: HashMap::new(),
243 cancel: CancellationToken::new(),
244 receiver_done: None,
245 }
246 }
247
248 fn register_transfer(
249 &mut self,
250 local_memory: Arc<dyn RdmaLocalMemory>,
251 total_chunks: usize,
252 done: OncePortRef<Result<(), String>>,
253 ) -> usize {
254 let transfer_id = self.next_transfer_id;
255 self.next_transfer_id += 1;
256 self.transfers.insert(
257 transfer_id,
258 TransferState::new(total_chunks, local_memory, done),
259 );
260 transfer_id
261 }
262
263 fn execute_transfer(
264 &mut self,
265 cx: &Context<Self>,
266 transfer_id: usize,
267 local_memory: Arc<dyn RdmaLocalMemory>,
268 chunk_size: usize,
269 dest_addr: ChannelAddr,
270 ) -> Result<()> {
271 let parallelism =
272 hyperactor_config::global::get(crate::config::RDMA_TCP_FALLBACK_PARALLELISM);
273
274 if !self.outbound.contains_key(&dest_addr) {
275 let conns = (0..parallelism)
276 .map(|_| {
277 channel::dial::<TcpDataChunk>(dest_addr.clone())
278 .map(Arc::new)
279 .map_err(anyhow::Error::from)
280 })
281 .collect::<Result<Vec<_>>>()?;
282 self.outbound.insert(dest_addr.clone(), conns);
283 }
284 let conns = self.outbound.get(&dest_addr).unwrap();
285
286 let size = local_memory.size();
287 let total_chunks = size.div_ceil(chunk_size);
288
289 let chunk_index = Arc::new(std::sync::atomic::AtomicUsize::new(0));
290 let error_port: PortHandle<TransferError> = cx.port();
291 let proc = cx.instance().proc().clone();
292 let cancel = self.cancel.clone();
293
294 for conn in conns.clone() {
295 let mem = local_memory.clone();
296 let transfer_id = transfer_id;
297 let chunk_index = chunk_index.clone();
298 let error_port = error_port.clone();
299 let proc = proc.clone();
300 let cancel = cancel.clone();
301
302 tokio::spawn(async move {
303 let sender_name = reference::name::Name::generate(
304 reference::name::Ident::new("tcp_manager_actor".into()).unwrap(),
305 reference::name::Ident::new("tcp_chunk_sender".into()).unwrap(),
306 );
307 let (instance, _handle) = proc
308 .instance(&sender_name.to_string())
309 .expect("failed to create sender instance");
310
311 loop {
312 if cancel.is_cancelled() {
313 return;
314 }
315
316 let idx = chunk_index.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
317 if idx >= total_chunks {
318 break;
319 }
320
321 let offset = idx * chunk_size;
322 let len = std::cmp::min(chunk_size, size - offset);
323 let mut buf = BytesMut::zeroed(len);
324 if let Err(e) = mem.read_at(offset, &mut buf) {
325 error_port
326 .send(
327 &instance,
328 TransferError {
329 message: format!("read_at failed at offset {offset}: {e}"),
330 },
331 )
332 .unwrap();
333 return;
334 }
335
336 let chunk = TcpDataChunk {
337 transfer_id,
338 offset,
339 data: Part::from(buf.freeze()),
340 };
341
342 if let Err(e) = conn.send(chunk).await {
343 error_port
344 .send(
345 &instance,
346 TransferError {
347 message: format!(
348 "failed to send chunk at offset {offset}: {e}"
349 ),
350 },
351 )
352 .unwrap();
353 return;
354 }
355 }
356 });
357 }
358
359 Ok(())
360 }
361
362 pub async fn local_handle(
365 client: &(impl context::Actor + Send + Sync),
366 ) -> Result<ActorHandle<Self>, anyhow::Error> {
367 let rdma_handle = RdmaManagerActor::local_handle(client);
368 let tcp_ref = rdma_handle.get_tcp_actor_ref(client).await?;
369 tcp_ref
370 .downcast_handle(client)
371 .ok_or_else(|| anyhow::anyhow!("TcpManagerActor is not in the local process"))
372 }
373}
374
375#[async_trait]
376impl Actor for TcpManagerActor {
377 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
378 let owner = this.parent_handle().ok_or_else(|| {
379 anyhow::anyhow!("RdmaManagerActor not found as parent of TcpManagerActor")
380 })?;
381 self.owner
382 .set(owner)
383 .map_err(|_| anyhow::anyhow!("TcpManagerActor owner already set"))?;
384
385 let parallelism =
386 hyperactor_config::global::get(crate::config::RDMA_TCP_FALLBACK_PARALLELISM);
387 if parallelism > 1 {
388 let addr = ChannelAddr::any(default_transport());
389 let (bound_addr, mut rx) = channel::serve::<TcpDataChunk>(addr)?;
390 self.channel_addr = Some(bound_addr);
391
392 let transfers = self.transfers.clone();
393 let proc = this.proc().clone();
394 let result_port: PortHandle<SendTransferResult> = this.port();
395 let error_port: PortHandle<TransferError> = this.port();
396 let cancel = self.cancel.clone();
397 let receiver_name = reference::name::Name::generate(
398 reference::name::Ident::new("tcp_manager_actor".into()).unwrap(),
399 reference::name::Ident::new("tcp_chunk_receiver".into()).unwrap(),
400 );
401
402 let (done_tx, done_rx) = tokio::sync::oneshot::channel::<()>();
403 self.receiver_done = Some(done_rx);
404
405 tokio::spawn(async move {
406 let (instance, _handle) = proc
407 .instance(&receiver_name.to_string())
408 .expect("failed to create receiver instance");
409
410 loop {
411 let chunk = tokio::select! {
412 _ = cancel.cancelled() => break,
413 result = rx.recv() => match result {
414 Ok(chunk) => chunk,
415 Err(e) => {
416 error_port
417 .send(
418 &instance,
419 TransferError {
420 message: format!(
421 "parallel channel receive error: {e}"
422 ),
423 },
424 )
425 .unwrap();
426 break;
427 }
428 },
429 };
430
431 let mut entry = match transfers.get_mut(&chunk.transfer_id) {
432 Some(entry) => entry,
433 None => {
434 tracing::warn!(
435 "received chunk for unknown transfer {:?}",
436 chunk.transfer_id,
437 );
438 continue;
439 }
440 };
441
442 let mut write_offset = chunk.offset;
443 let fragments = chunk.data.into_inner();
444 let write_err = fragments.iter().find_map(|fragment| {
445 let result = entry.local_memory.write_at(write_offset, fragment);
446 write_offset += fragment.len();
447 result.err()
448 });
449 if let Some(e) = write_err {
450 let transfer_id = chunk.transfer_id;
451 drop(entry);
452 let (_, state) = transfers.remove(&transfer_id).unwrap();
453 result_port
454 .send(
455 &instance,
456 SendTransferResult {
457 done: state.done,
458 result: Err(e.to_string()),
459 },
460 )
461 .unwrap();
462 continue;
463 }
464
465 entry.chunks_received += 1;
466 if entry.chunks_received == entry.total_chunks {
467 let transfer_id = chunk.transfer_id;
468 drop(entry);
469 let (_, state) = transfers.remove(&transfer_id).unwrap();
470 result_port
471 .send(
472 &instance,
473 SendTransferResult {
474 done: state.done,
475 result: Ok(()),
476 },
477 )
478 .unwrap();
479 }
480 }
481 rx.join().await;
482 done_tx.send(()).unwrap();
483 });
484 }
485
486 Ok(())
487 }
488
489 async fn cleanup(
490 &mut self,
491 _this: &Instance<Self>,
492 _err: Option<&ActorError>,
493 ) -> Result<(), anyhow::Error> {
494 self.cancel.cancel();
495 if let Some(done_rx) = self.receiver_done.take() {
496 done_rx.await?;
497 }
498 Ok(())
499 }
500}
501
502#[async_trait]
503#[hyperactor::handle(TcpManagerMessage)]
504impl TcpManagerMessageHandler for TcpManagerActor {
505 async fn write_chunk(
506 &mut self,
507 cx: &Context<Self>,
508 buf_id: usize,
509 offset: usize,
510 data: Part,
511 ) -> Result<Result<(), String>, anyhow::Error> {
512 let owner = self.owner.get().expect("TcpManagerActor owner not set");
513 let mem = match owner.request_local_memory(cx, buf_id).await {
514 Ok(Some(mem)) => mem,
515 Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
516 Err(e) => return Ok(Err(e.to_string())),
517 };
518
519 let bytes = data.into_bytes();
520 if let Err(e) = mem.write_at(offset, &bytes) {
521 return Ok(Err(e.to_string()));
522 }
523
524 Ok(Ok(()))
525 }
526
527 async fn read_chunk(
528 &mut self,
529 cx: &Context<Self>,
530 buf_id: usize,
531 offset: usize,
532 size: usize,
533 ) -> Result<Result<TcpChunk, String>, anyhow::Error> {
534 let owner = self.owner.get().expect("TcpManagerActor owner not set");
535 let mem = match owner.request_local_memory(cx, buf_id).await {
536 Ok(Some(mem)) => mem,
537 Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
538 Err(e) => return Ok(Err(e.to_string())),
539 };
540
541 let mut buf = BytesMut::zeroed(size);
542 if let Err(e) = mem.read_at(offset, &mut buf) {
543 return Ok(Err(e.to_string()));
544 }
545 Ok(Ok(TcpChunk(Part::from(buf.freeze()))))
546 }
547
548 async fn get_channel_address(
549 &mut self,
550 _cx: &Context<Self>,
551 ) -> Result<Option<ChannelAddr>, anyhow::Error> {
552 Ok(self.channel_addr.clone())
553 }
554
555 async fn register_transfer_remote(
556 &mut self,
557 cx: &Context<Self>,
558 buf_id: usize,
559 total_chunks: usize,
560 done: OncePortRef<Result<(), String>>,
561 ) -> Result<Result<usize, String>, anyhow::Error> {
562 let owner = self.owner.get().expect("TcpManagerActor owner not set");
563 let mem = match owner.request_local_memory(cx, buf_id).await {
564 Ok(Some(mem)) => mem,
565 Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
566 Err(e) => return Ok(Err(e.to_string())),
567 };
568 let transfer_id = self.register_transfer(mem, total_chunks, done);
569 Ok(Ok(transfer_id))
570 }
571
572 async fn execute_transfer_remote(
573 &mut self,
574 cx: &Context<Self>,
575 transfer_id: usize,
576 buf_id: usize,
577 chunk_size: usize,
578 dest_addr: ChannelAddr,
579 ) -> Result<Result<(), String>, anyhow::Error> {
580 let owner = self.owner.get().expect("TcpManagerActor owner not set");
581 let mem = match owner.request_local_memory(cx, buf_id).await {
582 Ok(Some(mem)) => mem,
583 Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
584 Err(e) => return Ok(Err(e.to_string())),
585 };
586 self.execute_transfer(cx, transfer_id, mem, chunk_size, dest_addr)?;
587 Ok(Ok(()))
588 }
589}
590
591#[async_trait]
592impl Handler<RegisterTransferLocal> for TcpManagerActor {
593 async fn handle(
594 &mut self,
595 cx: &Context<Self>,
596 message: RegisterTransferLocal,
597 ) -> Result<(), anyhow::Error> {
598 let transfer_id =
599 self.register_transfer(message.local_memory, message.total_chunks, message.done);
600 message.reply.send(cx, transfer_id)?;
601 Ok(())
602 }
603}
604
605#[async_trait]
606impl Handler<ExecuteTransferLocal> for TcpManagerActor {
607 async fn handle(
608 &mut self,
609 cx: &Context<Self>,
610 message: ExecuteTransferLocal,
611 ) -> Result<(), anyhow::Error> {
612 self.execute_transfer(
613 cx,
614 message.transfer_id,
615 message.local_memory,
616 message.chunk_size,
617 message.dest_addr,
618 )
619 }
620}
621
622#[async_trait]
623impl Handler<SendTransferResult> for TcpManagerActor {
624 async fn handle(
625 &mut self,
626 cx: &Context<Self>,
627 message: SendTransferResult,
628 ) -> Result<(), anyhow::Error> {
629 Ok(message.done.send(cx, message.result)?)
630 }
631}
632
633#[async_trait]
634impl Handler<TransferError> for TcpManagerActor {
635 async fn handle(
636 &mut self,
637 _cx: &Context<Self>,
638 message: TransferError,
639 ) -> Result<(), anyhow::Error> {
640 tracing::error!("fatal transfer error: {}", message.message);
641 Err(anyhow::anyhow!(message.message))
642 }
643}
644
645#[derive(Debug, Clone)]
653pub struct TcpBackend(pub ActorHandle<TcpManagerActor>);
654
655impl std::ops::Deref for TcpBackend {
656 type Target = ActorHandle<TcpManagerActor>;
657 fn deref(&self) -> &Self::Target {
658 &self.0
659 }
660}
661
662impl TcpBackend {
663 async fn execute_parallel_write(
666 &self,
667 cx: &(impl context::Actor + Send + Sync),
668 op: &TcpOp,
669 chunk_size: usize,
670 deadline: Instant,
671 ) -> Result<()> {
672 let size = op.local_memory.size();
673 let total_chunks = size.div_ceil(chunk_size);
674
675 let (done_handle, done_rx) = hyperactor::mailbox::open_once_port::<Result<(), String>>(cx);
676 let done_ref = done_handle.bind();
677
678 let remaining = deadline.saturating_duration_since(Instant::now());
679 let transfer_id = tokio_timeout(
680 remaining,
681 op.remote_tcp_manager.register_transfer_remote(
682 cx,
683 op.remote_buf_id,
684 total_chunks,
685 done_ref,
686 ),
687 )
688 .await
689 .map_err(|_| anyhow::anyhow!("register_transfer_remote timed out"))??
690 .map_err(|e| anyhow::anyhow!(e))?;
691
692 let dest_addr = tokio_timeout(
693 deadline.saturating_duration_since(Instant::now()),
694 op.remote_tcp_manager.get_channel_address(cx),
695 )
696 .await
697 .map_err(|_| anyhow::anyhow!("get_channel_address timed out"))??
698 .ok_or_else(|| anyhow::anyhow!("remote does not have parallel channels enabled"))?;
699
700 self.0.send(
701 cx,
702 ExecuteTransferLocal {
703 transfer_id,
704 local_memory: op.local_memory.clone(),
705 chunk_size,
706 dest_addr,
707 },
708 )?;
709
710 let remaining = deadline.saturating_duration_since(Instant::now());
711 let result = tokio_timeout(remaining, done_rx.recv())
712 .await
713 .map_err(|_| anyhow::anyhow!("parallel write timed out"))?
714 .map_err(|e| anyhow::anyhow!(e))?;
715 result.map_err(|e| anyhow::anyhow!(e))
716 }
717
718 async fn execute_parallel_read(
721 &self,
722 cx: &(impl context::Actor + Send + Sync),
723 op: &TcpOp,
724 chunk_size: usize,
725 deadline: Instant,
726 ) -> Result<()> {
727 let size = op.local_memory.size();
728 let total_chunks = size.div_ceil(chunk_size);
729
730 let (done_handle, done_rx) = hyperactor::mailbox::open_once_port::<Result<(), String>>(cx);
731 let done_ref = done_handle.bind();
732
733 let (id_handle, id_rx) = hyperactor::mailbox::open_once_port::<usize>(cx);
734
735 self.0.send(
736 cx,
737 RegisterTransferLocal {
738 local_memory: op.local_memory.clone(),
739 total_chunks,
740 done: done_ref,
741 reply: id_handle,
742 },
743 )?;
744
745 let transfer_id = id_rx
746 .recv()
747 .await
748 .map_err(|e| anyhow::anyhow!("failed to receive transfer id: {e}"))?;
749
750 let my_channel_addr = self
751 .0
752 .get_channel_address(cx)
753 .await?
754 .ok_or_else(|| anyhow::anyhow!("local parallel channels not enabled"))?;
755
756 let remaining = deadline.saturating_duration_since(Instant::now());
757 tokio_timeout(
758 remaining,
759 op.remote_tcp_manager.execute_transfer_remote(
760 cx,
761 transfer_id,
762 op.remote_buf_id,
763 chunk_size,
764 my_channel_addr,
765 ),
766 )
767 .await
768 .map_err(|_| anyhow::anyhow!("execute_transfer_remote timed out"))??
769 .map_err(|e| anyhow::anyhow!(e))?;
770
771 let remaining = deadline.saturating_duration_since(Instant::now());
772 let result = tokio_timeout(remaining, done_rx.recv())
773 .await
774 .map_err(|_| anyhow::anyhow!("parallel read timed out"))?
775 .map_err(|e| anyhow::anyhow!(e))?;
776 result.map_err(|e| anyhow::anyhow!(e))
777 }
778
779 async fn execute_write(
782 &self,
783 cx: &(impl context::Actor + Send + Sync),
784 op: &TcpOp,
785 chunk_size: usize,
786 deadline: Instant,
787 ) -> Result<()> {
788 let size = op.local_memory.size();
789 let mut offset = 0;
790
791 while offset < size {
792 let remaining = deadline.saturating_duration_since(Instant::now());
793 if remaining.is_zero() {
794 anyhow::bail!("tcp write timed out");
795 }
796
797 let len = std::cmp::min(chunk_size, size - offset);
798
799 let mut buf = vec![0u8; len];
800 op.local_memory.read_at(offset, &mut buf)?;
801 let data = Part::from(Bytes::from(buf));
802
803 tokio_timeout(
804 remaining,
805 op.remote_tcp_manager
806 .write_chunk(cx, op.remote_buf_id, offset, data),
807 )
808 .await
809 .map_err(|_| anyhow::anyhow!("tcp write chunk timed out"))??
810 .map_err(|e| anyhow::anyhow!(e))?;
811
812 offset += len;
813 }
814
815 Ok(())
816 }
817
818 async fn execute_read(
821 &self,
822 cx: &(impl context::Actor + Send + Sync),
823 op: &TcpOp,
824 chunk_size: usize,
825 deadline: Instant,
826 ) -> Result<()> {
827 let size = op.local_memory.size();
828 let mut offset = 0;
829
830 while offset < size {
831 let remaining = deadline.saturating_duration_since(Instant::now());
832 if remaining.is_zero() {
833 anyhow::bail!("tcp read timed out");
834 }
835
836 let len = std::cmp::min(chunk_size, size - offset);
837
838 let chunk = tokio_timeout(
839 remaining,
840 op.remote_tcp_manager
841 .read_chunk(cx, op.remote_buf_id, offset, len),
842 )
843 .await
844 .map_err(|_| anyhow::anyhow!("tcp read chunk timed out"))??
845 .map_err(|e| anyhow::anyhow!(e))?;
846 let data = chunk.0.into_bytes();
847
848 anyhow::ensure!(
849 data.len() == len,
850 "tcp read chunk size mismatch: expected {len}, got {}",
851 data.len()
852 );
853
854 op.local_memory.write_at(offset, &data)?;
855
856 offset += len;
857 }
858
859 Ok(())
860 }
861}
862
863#[async_trait]
864impl RdmaBackend for TcpBackend {
865 type TransportInfo = ();
866
867 async fn submit(
873 &mut self,
874 cx: &(impl context::Actor + Send + Sync),
875 ops: Vec<RdmaOp>,
876 timeout: Duration,
877 ) -> Result<()> {
878 let chunk_size =
879 hyperactor_config::global::get(crate::config::RDMA_MAX_CHUNK_SIZE_MB) * 1024 * 1024;
880 let parallelism =
881 hyperactor_config::global::get(crate::config::RDMA_TCP_FALLBACK_PARALLELISM);
882 let deadline = Instant::now() + timeout;
883
884 for op in ops {
885 let remaining = deadline.saturating_duration_since(Instant::now());
886 if remaining.is_zero() {
887 anyhow::bail!("tcp submit timed out");
888 }
889
890 let (remote_tcp_mgr, remote_buf_id) = op.remote.resolve_tcp()?;
891 let tcp_op = TcpOp {
892 op_type: op.op_type.clone(),
893 local_memory: op.local,
894 remote_tcp_manager: remote_tcp_mgr,
895 remote_buf_id,
896 };
897
898 if parallelism > 1 {
899 match tcp_op.op_type {
900 RdmaOpType::WriteFromLocal => {
901 self.execute_parallel_write(cx, &tcp_op, chunk_size, deadline)
902 .await?;
903 }
904 RdmaOpType::ReadIntoLocal => {
905 self.execute_parallel_read(cx, &tcp_op, chunk_size, deadline)
906 .await?;
907 }
908 }
909 } else {
910 match tcp_op.op_type {
911 RdmaOpType::WriteFromLocal => {
912 self.execute_write(cx, &tcp_op, chunk_size, deadline)
913 .await?;
914 }
915 RdmaOpType::ReadIntoLocal => {
916 self.execute_read(cx, &tcp_op, chunk_size, deadline).await?;
917 }
918 }
919 }
920 }
921
922 Ok(())
923 }
924
925 fn transport_level(&self) -> RdmaTransportLevel {
926 RdmaTransportLevel::Tcp
927 }
928
929 fn transport_info(&self) -> Option<Self::TransportInfo> {
930 None
931 }
932}
933
934#[cfg(test)]
935mod tests {
936 use std::sync::Arc;
937 use std::sync::atomic::AtomicUsize;
938 use std::sync::atomic::Ordering;
939 use std::time::Duration;
940
941 use hyperactor::ActorHandle;
942 use hyperactor::Proc;
943 use hyperactor::RemoteSpawn;
944 use hyperactor::channel::ChannelAddr;
945 use hyperactor_config::Flattrs;
946
947 use super::TcpBackend;
948 use super::TcpManagerActor;
949 use crate::RdmaManagerMessageClient;
950 use crate::RdmaOp;
951 use crate::RdmaOpType;
952 use crate::backend::RdmaBackend;
953 use crate::local_memory::Keepalive;
954 use crate::local_memory::KeepaliveLocalMemory;
955 use crate::local_memory::RdmaLocalMemory;
956 use crate::rdma_manager_actor::GetTcpActorRefClient;
957 use crate::rdma_manager_actor::RdmaManagerActor;
958
959 impl Keepalive for Box<[u8]> {}
960
961 static COUNTER: AtomicUsize = AtomicUsize::new(0);
962
963 struct TcpTestProcEnv {
964 proc: Proc,
965 rdma_handle: ActorHandle<RdmaManagerActor>,
966 instance: hyperactor::Instance<()>,
967 tcp_backend: TcpBackend,
968 rdma_remote_buf: crate::RdmaRemoteBuffer,
969 local_memory: Arc<dyn RdmaLocalMemory>,
970 }
971
972 impl Drop for TcpTestProcEnv {
973 fn drop(&mut self) {
974 use crate::rdma_manager_actor::ReleaseBufferClient;
975 tokio::task::block_in_place(|| {
978 tokio::runtime::Handle::current()
979 .block_on(
980 self.rdma_remote_buf
981 .owner
982 .release_buffer(&self.instance, self.rdma_remote_buf.id),
983 )
984 .expect("failed to release buffer in TcpTestProcEnv drop");
985 });
986 }
987 }
988
989 impl TcpTestProcEnv {
990 async fn new(buffer_size: usize) -> anyhow::Result<Self> {
992 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
993 let proc = Proc::direct(
994 ChannelAddr::any(hyperactor::channel::ChannelTransport::Unix),
995 format!("tcp_test_{id}"),
996 )?;
997 let (instance, _) = proc.instance("client")?;
998
999 let rdma_actor = RdmaManagerActor::new(None, Flattrs::default()).await?;
1000 let rdma_handle = proc.spawn("rdma_manager", rdma_actor)?;
1001
1002 let tcp_ref = rdma_handle.get_tcp_actor_ref(&instance).await?;
1003 let tcp_backend = TcpBackend(
1004 tcp_ref
1005 .downcast_handle(&instance)
1006 .ok_or_else(|| anyhow::anyhow!("tcp actor not local"))?,
1007 );
1008
1009 let (local_memory, rdma_remote_buf) =
1010 Self::alloc_cpu_buffer(&instance, &rdma_handle, buffer_size).await?;
1011
1012 Ok(Self {
1013 proc,
1014 rdma_handle,
1015 instance,
1016 tcp_backend,
1017 rdma_remote_buf,
1018 local_memory,
1019 })
1020 }
1021
1022 async fn on_proc(
1024 proc: &Proc,
1025 rdma_handle: &ActorHandle<RdmaManagerActor>,
1026 tcp_backend: TcpBackend,
1027 buffer_size: usize,
1028 ) -> anyhow::Result<Self> {
1029 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
1030 let (instance, _) = proc.instance(&format!("client_{id}"))?;
1031
1032 let (local_memory, rdma_remote_buf) =
1033 Self::alloc_cpu_buffer(&instance, rdma_handle, buffer_size).await?;
1034
1035 Ok(Self {
1036 proc: proc.clone(),
1037 rdma_handle: rdma_handle.clone(),
1038 instance,
1039 tcp_backend,
1040 rdma_remote_buf,
1041 local_memory,
1042 })
1043 }
1044
1045 async fn alloc_cpu_buffer(
1046 instance: &hyperactor::Instance<()>,
1047 rdma_handle: &ActorHandle<RdmaManagerActor>,
1048 buffer_size: usize,
1049 ) -> anyhow::Result<(Arc<dyn RdmaLocalMemory>, crate::RdmaRemoteBuffer)> {
1050 let cpu_buf = vec![0u8; buffer_size].into_boxed_slice();
1051 let ptr = cpu_buf.as_ptr() as usize;
1052 let local_memory: Arc<dyn RdmaLocalMemory> = Arc::new(KeepaliveLocalMemory::new(
1053 ptr,
1054 buffer_size,
1055 Arc::new(cpu_buf),
1056 ));
1057 let rdma_remote_buf = rdma_handle
1058 .request_buffer(instance, local_memory.clone())
1059 .await?;
1060 Ok((local_memory, rdma_remote_buf))
1061 }
1062 }
1063
1064 async fn setup_tcp_env(buf_size: usize) -> anyhow::Result<Vec<TcpTestProcEnv>> {
1066 Ok(vec![
1067 TcpTestProcEnv::new(buf_size).await?,
1068 TcpTestProcEnv::new(buf_size).await?,
1069 ])
1070 }
1071
1072 async fn setup_same_proc_tcp_env(buf_size: usize) -> anyhow::Result<Vec<TcpTestProcEnv>> {
1074 let first = TcpTestProcEnv::new(buf_size).await?;
1075 let second = TcpTestProcEnv::on_proc(
1076 &first.proc,
1077 &first.rdma_handle,
1078 first.tcp_backend.clone(),
1079 buf_size,
1080 )
1081 .await?;
1082 Ok(vec![first, second])
1083 }
1084
1085 async fn setup_tcp_env_pairs(buf_size: usize) -> anyhow::Result<Vec<TcpTestProcEnv>> {
1088 let e0 = TcpTestProcEnv::new(buf_size).await?;
1089 let e1 = TcpTestProcEnv::new(buf_size).await?;
1090 let e2 =
1091 TcpTestProcEnv::on_proc(&e0.proc, &e0.rdma_handle, e0.tcp_backend.clone(), buf_size)
1092 .await?;
1093 let e3 =
1094 TcpTestProcEnv::on_proc(&e1.proc, &e1.rdma_handle, e1.tcp_backend.clone(), buf_size)
1095 .await?;
1096 Ok(vec![e0, e1, e2, e3])
1097 }
1098
1099 async fn do_write_test(
1103 envs: &mut [TcpTestProcEnv],
1104 buf_size: usize,
1105 timeout: Duration,
1106 ) -> anyhow::Result<()> {
1107 let mut src = vec![0u8; buf_size];
1108 for (i, byte) in src.iter_mut().enumerate() {
1109 *byte = (i % 256) as u8;
1110 }
1111 envs[0].local_memory.write_at(0, &src)?;
1112
1113 let remote = envs[1].rdma_remote_buf.clone();
1114 let env = &mut envs[0];
1115 env.tcp_backend
1116 .submit(
1117 &env.instance,
1118 vec![RdmaOp {
1119 op_type: RdmaOpType::WriteFromLocal,
1120 local: env.local_memory.clone(),
1121 remote,
1122 }],
1123 timeout,
1124 )
1125 .await?;
1126
1127 let mut dst = vec![0u8; buf_size];
1128 envs[1].local_memory.read_at(0, &mut dst)?;
1129 for (i, byte) in dst.iter().enumerate() {
1130 assert_eq!(*byte, (i % 256) as u8, "mismatch at offset {i} after write");
1131 }
1132 Ok(())
1133 }
1134
1135 async fn do_read_test(
1137 envs: &mut [TcpTestProcEnv],
1138 buf_size: usize,
1139 timeout: Duration,
1140 ) -> anyhow::Result<()> {
1141 let mut src = vec![0u8; buf_size];
1142 for (i, byte) in src.iter_mut().enumerate() {
1143 *byte = ((i * 7 + 3) % 256) as u8;
1144 }
1145 envs[1].local_memory.write_at(0, &src)?;
1146
1147 let remote = envs[1].rdma_remote_buf.clone();
1148 let env = &mut envs[0];
1149 env.tcp_backend
1150 .submit(
1151 &env.instance,
1152 vec![RdmaOp {
1153 op_type: RdmaOpType::ReadIntoLocal,
1154 local: env.local_memory.clone(),
1155 remote,
1156 }],
1157 timeout,
1158 )
1159 .await?;
1160
1161 let mut dst = vec![0u8; buf_size];
1162 envs[0].local_memory.read_at(0, &mut dst)?;
1163 for (i, byte) in dst.iter().enumerate() {
1164 assert_eq!(
1165 *byte,
1166 ((i * 7 + 3) % 256) as u8,
1167 "mismatch at offset {i} after read"
1168 );
1169 }
1170 Ok(())
1171 }
1172
1173 async fn do_round_trip_test(
1175 envs: &mut [TcpTestProcEnv],
1176 buf_size: usize,
1177 timeout: Duration,
1178 ) -> anyhow::Result<()> {
1179 let mut src = vec![0u8; buf_size];
1180 for (i, byte) in src.iter_mut().enumerate() {
1181 *byte = ((i * 13 + 5) % 256) as u8;
1182 }
1183 envs[0].local_memory.write_at(0, &src)?;
1184
1185 let remote = envs[1].rdma_remote_buf.clone();
1186 let env = &mut envs[0];
1187 env.tcp_backend
1188 .submit(
1189 &env.instance,
1190 vec![RdmaOp {
1191 op_type: RdmaOpType::WriteFromLocal,
1192 local: env.local_memory.clone(),
1193 remote: remote.clone(),
1194 }],
1195 timeout,
1196 )
1197 .await?;
1198
1199 envs[0].local_memory.write_at(0, &vec![0u8; buf_size])?;
1200
1201 let env = &mut envs[0];
1202 env.tcp_backend
1203 .submit(
1204 &env.instance,
1205 vec![RdmaOp {
1206 op_type: RdmaOpType::ReadIntoLocal,
1207 local: env.local_memory.clone(),
1208 remote,
1209 }],
1210 timeout,
1211 )
1212 .await?;
1213
1214 let mut dst = vec![0u8; buf_size];
1215 envs[0].local_memory.read_at(0, &mut dst)?;
1216 for (i, byte) in dst.iter().enumerate() {
1217 assert_eq!(
1218 *byte,
1219 ((i * 13 + 5) % 256) as u8,
1220 "mismatch at offset {i} after round-trip"
1221 );
1222 }
1223 Ok(())
1224 }
1225
1226 #[timed_test::async_timed_test(timeout_secs = 30)]
1230 async fn test_tcp_write_from_local() -> anyhow::Result<()> {
1231 let config = hyperactor_config::global::lock();
1232 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1233
1234 let mut envs = setup_tcp_env(4096).await?;
1235 do_write_test(&mut envs, 4096, Duration::from_secs(30)).await
1236 }
1237
1238 #[timed_test::async_timed_test(timeout_secs = 30)]
1240 async fn test_tcp_read_into_local() -> anyhow::Result<()> {
1241 let config = hyperactor_config::global::lock();
1242 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1243
1244 let mut envs = setup_tcp_env(2048).await?;
1245 do_read_test(&mut envs, 2048, Duration::from_secs(30)).await
1246 }
1247
1248 #[timed_test::async_timed_test(timeout_secs = 30)]
1250 async fn test_tcp_write_then_read_back() -> anyhow::Result<()> {
1251 let config = hyperactor_config::global::lock();
1252 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1253
1254 let mut envs = setup_tcp_env(4096).await?;
1255 do_round_trip_test(&mut envs, 4096, Duration::from_secs(30)).await
1256 }
1257
1258 #[timed_test::async_timed_test(timeout_secs = 30)]
1260 async fn test_tcp_multi_chunk_write() -> anyhow::Result<()> {
1261 let config = hyperactor_config::global::lock();
1262 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1263 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1264
1265 let buf_size = 3 * 1024 * 512;
1266 let mut envs = setup_tcp_env(buf_size).await?;
1267
1268 let mut src = vec![0u8; buf_size];
1269 for (i, byte) in src.iter_mut().enumerate() {
1270 *byte = (i % 251) as u8;
1271 }
1272 envs[0].local_memory.write_at(0, &src)?;
1273
1274 let remote = envs[1].rdma_remote_buf.clone();
1275 let env = &mut envs[0];
1276 env.tcp_backend
1277 .submit(
1278 &env.instance,
1279 vec![RdmaOp {
1280 op_type: RdmaOpType::WriteFromLocal,
1281 local: env.local_memory.clone(),
1282 remote,
1283 }],
1284 Duration::from_secs(30),
1285 )
1286 .await?;
1287
1288 let mut dst = vec![0u8; buf_size];
1289 envs[1].local_memory.read_at(0, &mut dst)?;
1290 for (i, byte) in dst.iter().enumerate() {
1291 assert_eq!(*byte, (i % 251) as u8, "mismatch at offset {i}");
1292 }
1293
1294 Ok(())
1295 }
1296
1297 #[timed_test::async_timed_test(timeout_secs = 30)]
1299 async fn test_tcp_multi_chunk_read() -> anyhow::Result<()> {
1300 let config = hyperactor_config::global::lock();
1301 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1302 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1303
1304 let buf_size = 3 * 1024 * 512;
1305 let mut envs = setup_tcp_env(buf_size).await?;
1306
1307 let mut src = vec![0u8; buf_size];
1308 for (i, byte) in src.iter_mut().enumerate() {
1309 *byte = ((i * 3 + 17) % 256) as u8;
1310 }
1311 envs[1].local_memory.write_at(0, &src)?;
1312
1313 let remote = envs[1].rdma_remote_buf.clone();
1314 let env = &mut envs[0];
1315 env.tcp_backend
1316 .submit(
1317 &env.instance,
1318 vec![RdmaOp {
1319 op_type: RdmaOpType::ReadIntoLocal,
1320 local: env.local_memory.clone(),
1321 remote,
1322 }],
1323 Duration::from_secs(30),
1324 )
1325 .await?;
1326
1327 let mut dst = vec![0u8; buf_size];
1328 envs[0].local_memory.read_at(0, &mut dst)?;
1329 for (i, byte) in dst.iter().enumerate() {
1330 assert_eq!(*byte, ((i * 3 + 17) % 256) as u8, "mismatch at offset {i}");
1331 }
1332
1333 Ok(())
1334 }
1335
1336 #[timed_test::async_timed_test(timeout_secs = 30)]
1338 async fn test_tcp_multi_chunk_round_trip() -> anyhow::Result<()> {
1339 let config = hyperactor_config::global::lock();
1340 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1341 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1342
1343 let buf_size = 5 * 1024 * 512;
1344 let mut envs = setup_tcp_env(buf_size).await?;
1345
1346 let mut src = vec![0u8; buf_size];
1347 for (i, byte) in src.iter_mut().enumerate() {
1348 *byte = ((i * 41 + 7) % 256) as u8;
1349 }
1350 envs[0].local_memory.write_at(0, &src)?;
1351
1352 let remote = envs[1].rdma_remote_buf.clone();
1353 let env = &mut envs[0];
1354 env.tcp_backend
1355 .submit(
1356 &env.instance,
1357 vec![RdmaOp {
1358 op_type: RdmaOpType::WriteFromLocal,
1359 local: env.local_memory.clone(),
1360 remote: remote.clone(),
1361 }],
1362 Duration::from_secs(30),
1363 )
1364 .await?;
1365
1366 envs[0].local_memory.write_at(0, &vec![0u8; buf_size])?;
1367
1368 let env = &mut envs[0];
1369 env.tcp_backend
1370 .submit(
1371 &env.instance,
1372 vec![RdmaOp {
1373 op_type: RdmaOpType::ReadIntoLocal,
1374 local: env.local_memory.clone(),
1375 remote,
1376 }],
1377 Duration::from_secs(30),
1378 )
1379 .await?;
1380
1381 let mut dst = vec![0u8; buf_size];
1382 envs[0].local_memory.read_at(0, &mut dst)?;
1383 for (i, byte) in dst.iter().enumerate() {
1384 assert_eq!(
1385 *byte,
1386 ((i * 41 + 7) % 256) as u8,
1387 "mismatch at offset {i} after multi-chunk round-trip"
1388 );
1389 }
1390
1391 Ok(())
1392 }
1393
1394 #[timed_test::async_timed_test(timeout_secs = 30)]
1396 async fn test_tcp_resolve_tcp() -> anyhow::Result<()> {
1397 let config = hyperactor_config::global::lock();
1398 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1399
1400 let envs = setup_tcp_env(64).await?;
1401
1402 for (i, env) in envs.iter().enumerate() {
1403 let (tcp_ref, id) = env.rdma_remote_buf.resolve_tcp()?;
1404 assert_eq!(id, env.rdma_remote_buf.id, "buf id mismatch for env {i}");
1405 let expected: hyperactor::ActorRef<TcpManagerActor> = env.tcp_backend.bind();
1406 assert_eq!(tcp_ref.actor_id(), expected.actor_id());
1407 }
1408
1409 Ok(())
1410 }
1411
1412 #[timed_test::async_timed_test(timeout_secs = 30)]
1414 async fn test_tcp_write_to_released_buffer() -> anyhow::Result<()> {
1415 let config = hyperactor_config::global::lock();
1416 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1417
1418 let buf_size = 64;
1419 let mut envs = setup_tcp_env(buf_size).await?;
1420
1421 let mut src = vec![0u8; buf_size];
1422 for (i, byte) in src.iter_mut().enumerate() {
1423 *byte = (i % 256) as u8;
1424 }
1425 envs[0].local_memory.write_at(0, &src)?;
1426
1427 let remote = envs[1].rdma_remote_buf.clone();
1429 let env = &mut envs[0];
1430 env.tcp_backend
1431 .submit(
1432 &env.instance,
1433 vec![RdmaOp {
1434 op_type: RdmaOpType::WriteFromLocal,
1435 local: env.local_memory.clone(),
1436 remote: remote.clone(),
1437 }],
1438 Duration::from_secs(10),
1439 )
1440 .await?;
1441
1442 use crate::rdma_manager_actor::ReleaseBufferClient;
1444 let owner_ref = envs[1].rdma_remote_buf.owner.clone();
1445 owner_ref
1446 .release_buffer(&envs[0].instance, envs[1].rdma_remote_buf.id)
1447 .await?;
1448
1449 let env = &mut envs[0];
1451 let result = env
1452 .tcp_backend
1453 .submit(
1454 &env.instance,
1455 vec![RdmaOp {
1456 op_type: RdmaOpType::WriteFromLocal,
1457 local: env.local_memory.clone(),
1458 remote: remote.clone(),
1459 }],
1460 Duration::from_secs(10),
1461 )
1462 .await;
1463 assert!(result.is_err(), "expected error writing to released buffer");
1464
1465 Ok(())
1466 }
1467
1468 #[timed_test::async_timed_test(timeout_secs = 30)]
1470 async fn test_tcp_read_from_released_buffer() -> anyhow::Result<()> {
1471 let config = hyperactor_config::global::lock();
1472 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1473
1474 let buf_size = 64;
1475 let mut envs = setup_tcp_env(buf_size).await?;
1476
1477 use crate::rdma_manager_actor::ReleaseBufferClient;
1479 let owner_ref = envs[1].rdma_remote_buf.owner.clone();
1480 owner_ref
1481 .release_buffer(&envs[0].instance, envs[1].rdma_remote_buf.id)
1482 .await?;
1483
1484 let remote = envs[1].rdma_remote_buf.clone();
1486 let env = &mut envs[0];
1487 let result = env
1488 .tcp_backend
1489 .submit(
1490 &env.instance,
1491 vec![RdmaOp {
1492 op_type: RdmaOpType::ReadIntoLocal,
1493 local: env.local_memory.clone(),
1494 remote,
1495 }],
1496 Duration::from_secs(10),
1497 )
1498 .await;
1499 assert!(
1500 result.is_err(),
1501 "expected error reading from released buffer"
1502 );
1503
1504 Ok(())
1505 }
1506
1507 #[timed_test::async_timed_test(timeout_secs = 30)]
1511 async fn test_tcp_same_process_write() -> anyhow::Result<()> {
1512 let config = hyperactor_config::global::lock();
1513 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1514
1515 let mut envs = setup_same_proc_tcp_env(4096).await?;
1516 do_write_test(&mut envs, 4096, Duration::from_secs(10)).await
1517 }
1518
1519 #[timed_test::async_timed_test(timeout_secs = 30)]
1521 async fn test_tcp_same_process_read() -> anyhow::Result<()> {
1522 let config = hyperactor_config::global::lock();
1523 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1524
1525 let mut envs = setup_same_proc_tcp_env(2048).await?;
1526 do_read_test(&mut envs, 2048, Duration::from_secs(10)).await
1527 }
1528
1529 #[timed_test::async_timed_test(timeout_secs = 30)]
1531 async fn test_tcp_same_process_round_trip() -> anyhow::Result<()> {
1532 let config = hyperactor_config::global::lock();
1533 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1534
1535 let mut envs = setup_same_proc_tcp_env(4096).await?;
1536 do_round_trip_test(&mut envs, 4096, Duration::from_secs(10)).await
1537 }
1538
1539 #[timed_test::async_timed_test(timeout_secs = 30)]
1542 async fn test_tcp_fallback_disabled_fails() -> anyhow::Result<()> {
1543 let config = hyperactor_config::global::lock();
1544 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, false);
1545
1546 let result = RdmaManagerActor::new(None, Flattrs::default()).await;
1547 if crate::ibverbs_supported() {
1548 assert!(result.is_ok());
1549 } else {
1550 assert!(result.is_err());
1551 }
1552
1553 Ok(())
1554 }
1555
1556 use crate::backend::cuda_test_utils::CudaAllocator;
1559 use crate::backend::cuda_test_utils::cuda_device_count;
1560
1561 impl TcpTestProcEnv {
1562 async fn new_gpu(device: i32, buffer_size: usize) -> anyhow::Result<Self> {
1564 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
1565 let proc = Proc::direct(
1566 ChannelAddr::any(hyperactor::channel::ChannelTransport::Unix),
1567 format!("tcp_gpu_test_{id}"),
1568 )?;
1569 let (instance, _) = proc.instance("client")?;
1570
1571 let rdma_actor = RdmaManagerActor::new(None, Flattrs::default()).await?;
1572 let rdma_handle = proc.spawn("rdma_manager", rdma_actor)?;
1573
1574 let tcp_ref = rdma_handle.get_tcp_actor_ref(&instance).await?;
1575 let tcp_backend = TcpBackend(
1576 tcp_ref
1577 .downcast_handle(&instance)
1578 .ok_or_else(|| anyhow::anyhow!("tcp actor not local"))?,
1579 );
1580
1581 let alloc = CudaAllocator::get().allocate(device, buffer_size);
1582 let local_memory: Arc<dyn RdmaLocalMemory> = Arc::new(KeepaliveLocalMemory::new(
1583 alloc.ptr(),
1584 buffer_size,
1585 Arc::new(alloc),
1586 ));
1587 let rdma_remote_buf = rdma_handle
1588 .request_buffer(&instance, local_memory.clone())
1589 .await?;
1590
1591 Ok(Self {
1592 proc,
1593 rdma_handle,
1594 instance,
1595 tcp_backend,
1596 rdma_remote_buf,
1597 local_memory,
1598 })
1599 }
1600 }
1601
1602 #[timed_test::async_timed_test(timeout_secs = 60)]
1604 async fn test_tcp_write_multi_gpu() -> anyhow::Result<()> {
1605 if cuda_device_count() < 2 {
1606 println!("Skipping: need at least 2 CUDA devices");
1607 return Ok(());
1608 }
1609
1610 let config = hyperactor_config::global::lock();
1611 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1612
1613 let buf_size = 2 * 1024 * 1024;
1614 let mut envs = vec![
1615 TcpTestProcEnv::new_gpu(0, buf_size).await?,
1616 TcpTestProcEnv::new_gpu(1, buf_size).await?,
1617 ];
1618 do_write_test(&mut envs, buf_size, Duration::from_secs(30)).await
1619 }
1620
1621 #[timed_test::async_timed_test(timeout_secs = 60)]
1623 async fn test_tcp_read_multi_gpu() -> anyhow::Result<()> {
1624 if cuda_device_count() < 2 {
1625 println!("Skipping: need at least 2 CUDA devices");
1626 return Ok(());
1627 }
1628
1629 let config = hyperactor_config::global::lock();
1630 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1631
1632 let buf_size = 2 * 1024 * 1024;
1633 let mut envs = vec![
1634 TcpTestProcEnv::new_gpu(0, buf_size).await?,
1635 TcpTestProcEnv::new_gpu(1, buf_size).await?,
1636 ];
1637 do_read_test(&mut envs, buf_size, Duration::from_secs(30)).await
1638 }
1639
1640 #[timed_test::async_timed_test(timeout_secs = 60)]
1642 async fn test_tcp_round_trip_multi_gpu() -> anyhow::Result<()> {
1643 if cuda_device_count() < 2 {
1644 println!("Skipping: need at least 2 CUDA devices");
1645 return Ok(());
1646 }
1647
1648 let config = hyperactor_config::global::lock();
1649 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1650
1651 let buf_size = 2 * 1024 * 1024;
1652 let mut envs = vec![
1653 TcpTestProcEnv::new_gpu(0, buf_size).await?,
1654 TcpTestProcEnv::new_gpu(1, buf_size).await?,
1655 ];
1656 do_round_trip_test(&mut envs, buf_size, Duration::from_secs(30)).await
1657 }
1658
1659 #[timed_test::async_timed_test(timeout_secs = 30)]
1662 async fn test_tcp_parallel_clean_shutdown() -> anyhow::Result<()> {
1663 let config = hyperactor_config::global::lock();
1664 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1665 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1666 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1667
1668 let buf_size = 3 * 1024 * 1024;
1669 let mut envs = setup_tcp_env(buf_size).await?;
1670
1671 let mut src = vec![0u8; buf_size];
1673 for (i, byte) in src.iter_mut().enumerate() {
1674 *byte = (i % 256) as u8;
1675 }
1676 envs[0].local_memory.write_at(0, &src)?;
1677 let remote = envs[1].rdma_remote_buf.clone();
1678 let env = &mut envs[0];
1679 env.tcp_backend
1680 .submit(
1681 &env.instance,
1682 vec![RdmaOp {
1683 op_type: RdmaOpType::WriteFromLocal,
1684 local: env.local_memory.clone(),
1685 remote: remote.clone(),
1686 }],
1687 Duration::from_secs(30),
1688 )
1689 .await?;
1690
1691 envs[0].rdma_handle.drain_and_stop("test")?;
1694 envs[0].rdma_handle.clone().await;
1695 envs[1].rdma_handle.drain_and_stop("test")?;
1696 envs[1].rdma_handle.clone().await;
1697
1698 Ok(())
1699 }
1700
1701 #[timed_test::async_timed_test(timeout_secs = 30)]
1705 async fn test_tcp_parallel_write() -> anyhow::Result<()> {
1706 let config = hyperactor_config::global::lock();
1707 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1708 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1709 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1710
1711 let buf_size = 3 * 1024 * 1024;
1713 let mut envs = setup_tcp_env(buf_size).await?;
1714 do_write_test(&mut envs, buf_size, Duration::from_secs(30)).await
1715 }
1716
1717 #[timed_test::async_timed_test(timeout_secs = 30)]
1719 async fn test_tcp_parallel_read() -> anyhow::Result<()> {
1720 let config = hyperactor_config::global::lock();
1721 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1722 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1723 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1724
1725 let buf_size = 3 * 1024 * 1024;
1726 let mut envs = setup_tcp_env(buf_size).await?;
1727 do_read_test(&mut envs, buf_size, Duration::from_secs(30)).await
1728 }
1729
1730 #[timed_test::async_timed_test(timeout_secs = 30)]
1732 async fn test_tcp_parallel_round_trip() -> anyhow::Result<()> {
1733 let config = hyperactor_config::global::lock();
1734 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1735 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1736 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1737
1738 let buf_size = 3 * 1024 * 1024;
1739 let mut envs = setup_tcp_env(buf_size).await?;
1740 do_round_trip_test(&mut envs, buf_size, Duration::from_secs(30)).await
1741 }
1742
1743 #[timed_test::async_timed_test(timeout_secs = 30)]
1745 async fn test_tcp_parallel_same_process_write() -> anyhow::Result<()> {
1746 let config = hyperactor_config::global::lock();
1747 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1748 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1749 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1750
1751 let buf_size = 3 * 1024 * 1024;
1752 let mut envs = setup_same_proc_tcp_env(buf_size).await?;
1753 do_write_test(&mut envs, buf_size, Duration::from_secs(10)).await
1754 }
1755
1756 #[timed_test::async_timed_test(timeout_secs = 30)]
1758 async fn test_tcp_parallel_same_process_read() -> anyhow::Result<()> {
1759 let config = hyperactor_config::global::lock();
1760 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1761 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1762 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1763
1764 let buf_size = 3 * 1024 * 1024;
1765 let mut envs = setup_same_proc_tcp_env(buf_size).await?;
1766 do_read_test(&mut envs, buf_size, Duration::from_secs(10)).await
1767 }
1768
1769 #[timed_test::async_timed_test(timeout_secs = 30)]
1773 async fn test_tcp_parallel_concurrent_writes() -> anyhow::Result<()> {
1774 let config = hyperactor_config::global::lock();
1775 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1776 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1777 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1778
1779 let buf_size = 3 * 1024 * 1024;
1780 let envs = setup_tcp_env_pairs(buf_size).await?;
1781
1782 let mut src0 = vec![0u8; buf_size];
1784 for (i, byte) in src0.iter_mut().enumerate() {
1785 *byte = (i % 256) as u8;
1786 }
1787 envs[0].local_memory.write_at(0, &src0)?;
1788 let mut src2 = vec![0u8; buf_size];
1789 for (i, byte) in src2.iter_mut().enumerate() {
1790 *byte = ((i * 3 + 7) % 256) as u8;
1791 }
1792 envs[2].local_memory.write_at(0, &src2)?;
1793
1794 let remote_1 = envs[1].rdma_remote_buf.clone();
1796 let remote_3 = envs[3].rdma_remote_buf.clone();
1797 let mut h0 = envs[0].tcp_backend.clone();
1798 let mut h2 = envs[2].tcp_backend.clone();
1799 let inst_0 = &envs[0].instance;
1800 let inst_2 = &envs[2].instance;
1801 let mem_0 = envs[0].local_memory.clone();
1802 let mem_2 = envs[2].local_memory.clone();
1803 let (r1, r2) = tokio::join!(
1804 h0.submit(
1805 inst_0,
1806 vec![RdmaOp {
1807 op_type: RdmaOpType::WriteFromLocal,
1808 local: mem_0,
1809 remote: remote_1,
1810 }],
1811 Duration::from_secs(30),
1812 ),
1813 h2.submit(
1814 inst_2,
1815 vec![RdmaOp {
1816 op_type: RdmaOpType::WriteFromLocal,
1817 local: mem_2,
1818 remote: remote_3,
1819 }],
1820 Duration::from_secs(30),
1821 ),
1822 );
1823 r1?;
1824 r2?;
1825
1826 let mut dst1 = vec![0u8; buf_size];
1827 envs[1].local_memory.read_at(0, &mut dst1)?;
1828 for (i, byte) in dst1.iter().enumerate() {
1829 assert_eq!(*byte, (i % 256) as u8, "pair 1 mismatch at offset {i}");
1830 }
1831 let mut dst3 = vec![0u8; buf_size];
1832 envs[3].local_memory.read_at(0, &mut dst3)?;
1833 for (i, byte) in dst3.iter().enumerate() {
1834 assert_eq!(
1835 *byte,
1836 ((i * 3 + 7) % 256) as u8,
1837 "pair 2 mismatch at offset {i}"
1838 );
1839 }
1840
1841 Ok(())
1842 }
1843
1844 #[timed_test::async_timed_test(timeout_secs = 30)]
1846 async fn test_tcp_parallel_concurrent_reads() -> anyhow::Result<()> {
1847 let config = hyperactor_config::global::lock();
1848 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1849 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1850 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1851
1852 let buf_size = 3 * 1024 * 1024;
1853 let envs = setup_tcp_env_pairs(buf_size).await?;
1854
1855 let mut src1 = vec![0u8; buf_size];
1857 for (i, byte) in src1.iter_mut().enumerate() {
1858 *byte = ((i * 11 + 3) % 256) as u8;
1859 }
1860 envs[1].local_memory.write_at(0, &src1)?;
1861 let mut src3 = vec![0u8; buf_size];
1862 for (i, byte) in src3.iter_mut().enumerate() {
1863 *byte = ((i * 5 + 13) % 256) as u8;
1864 }
1865 envs[3].local_memory.write_at(0, &src3)?;
1866
1867 let remote_1 = envs[1].rdma_remote_buf.clone();
1869 let remote_3 = envs[3].rdma_remote_buf.clone();
1870 let mut h0 = envs[0].tcp_backend.clone();
1871 let mut h2 = envs[2].tcp_backend.clone();
1872 let inst_0 = &envs[0].instance;
1873 let inst_2 = &envs[2].instance;
1874 let mem_0 = envs[0].local_memory.clone();
1875 let mem_2 = envs[2].local_memory.clone();
1876 let (r1, r2) = tokio::join!(
1877 h0.submit(
1878 inst_0,
1879 vec![RdmaOp {
1880 op_type: RdmaOpType::ReadIntoLocal,
1881 local: mem_0,
1882 remote: remote_1,
1883 }],
1884 Duration::from_secs(30),
1885 ),
1886 h2.submit(
1887 inst_2,
1888 vec![RdmaOp {
1889 op_type: RdmaOpType::ReadIntoLocal,
1890 local: mem_2,
1891 remote: remote_3,
1892 }],
1893 Duration::from_secs(30),
1894 ),
1895 );
1896 r1?;
1897 r2?;
1898
1899 let mut dst0 = vec![0u8; buf_size];
1900 envs[0].local_memory.read_at(0, &mut dst0)?;
1901 for (i, byte) in dst0.iter().enumerate() {
1902 assert_eq!(
1903 *byte,
1904 ((i * 11 + 3) % 256) as u8,
1905 "pair 1 mismatch at offset {i}"
1906 );
1907 }
1908 let mut dst2 = vec![0u8; buf_size];
1909 envs[2].local_memory.read_at(0, &mut dst2)?;
1910 for (i, byte) in dst2.iter().enumerate() {
1911 assert_eq!(
1912 *byte,
1913 ((i * 5 + 13) % 256) as u8,
1914 "pair 2 mismatch at offset {i}"
1915 );
1916 }
1917
1918 Ok(())
1919 }
1920
1921 #[timed_test::async_timed_test(timeout_secs = 30)]
1923 async fn test_tcp_parallel_concurrent_write_and_read() -> anyhow::Result<()> {
1924 let config = hyperactor_config::global::lock();
1925 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1926 let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1927 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1928
1929 let buf_size = 3 * 1024 * 1024;
1930 let envs = setup_tcp_env_pairs(buf_size).await?;
1931
1932 let mut src0 = vec![0u8; buf_size];
1934 for (i, byte) in src0.iter_mut().enumerate() {
1935 *byte = (i % 256) as u8;
1936 }
1937 envs[0].local_memory.write_at(0, &src0)?;
1938 let mut src3 = vec![0u8; buf_size];
1939 for (i, byte) in src3.iter_mut().enumerate() {
1940 *byte = ((i * 7 + 13) % 256) as u8;
1941 }
1942 envs[3].local_memory.write_at(0, &src3)?;
1943
1944 let remote_1 = envs[1].rdma_remote_buf.clone();
1946 let remote_3 = envs[3].rdma_remote_buf.clone();
1947 let mut h0 = envs[0].tcp_backend.clone();
1948 let mut h2 = envs[2].tcp_backend.clone();
1949 let inst_0 = &envs[0].instance;
1950 let inst_2 = &envs[2].instance;
1951 let mem_0 = envs[0].local_memory.clone();
1952 let mem_2 = envs[2].local_memory.clone();
1953 let (write_result, read_result) = tokio::join!(
1954 h0.submit(
1955 inst_0,
1956 vec![RdmaOp {
1957 op_type: RdmaOpType::WriteFromLocal,
1958 local: mem_0,
1959 remote: remote_1,
1960 }],
1961 Duration::from_secs(30),
1962 ),
1963 h2.submit(
1964 inst_2,
1965 vec![RdmaOp {
1966 op_type: RdmaOpType::ReadIntoLocal,
1967 local: mem_2,
1968 remote: remote_3,
1969 }],
1970 Duration::from_secs(30),
1971 ),
1972 );
1973 write_result?;
1974 read_result?;
1975
1976 let mut dst1 = vec![0u8; buf_size];
1977 envs[1].local_memory.read_at(0, &mut dst1)?;
1978 for (i, byte) in dst1.iter().enumerate() {
1979 assert_eq!(*byte, (i % 256) as u8, "write mismatch at offset {i}");
1980 }
1981 let mut dst2 = vec![0u8; buf_size];
1982 envs[2].local_memory.read_at(0, &mut dst2)?;
1983 for (i, byte) in dst2.iter().enumerate() {
1984 assert_eq!(
1985 *byte,
1986 ((i * 7 + 13) % 256) as u8,
1987 "read mismatch at offset {i}"
1988 );
1989 }
1990
1991 Ok(())
1992 }
1993}