1use std::sync::OnceLock;
16use std::time::Duration;
17use std::time::Instant;
18
19use anyhow::Result;
20use async_trait::async_trait;
21use bytes::Bytes;
22use bytes::BytesMut;
23use hyperactor::Actor;
24use hyperactor::ActorHandle;
25use hyperactor::Context;
26use hyperactor::HandleClient;
27use hyperactor::Handler;
28use hyperactor::Instance;
29use hyperactor::OncePortHandle;
30use hyperactor::RefClient;
31use hyperactor::context;
32use hyperactor::context::Mailbox;
33use hyperactor::reference::OncePortRef;
34use serde::Deserialize;
35use serde::Serialize;
36use serde_multipart::Part;
37use tokio::time::timeout as tokio_timeout;
38use typeuri::Named;
39
40use super::TcpOp;
41use crate::RdmaOp;
42use crate::RdmaOpType;
43use crate::RdmaTransportLevel;
44use crate::backend::RdmaBackend;
45use crate::rdma_manager_actor::GetTcpActorRefClient;
46use crate::rdma_manager_actor::RdmaManagerActor;
47use crate::rdma_manager_actor::RdmaManagerMessageClient;
48
49#[derive(Debug, Clone, Serialize, Deserialize, Named)]
54pub struct TcpChunk(pub Part);
55wirevalue::register_type!(TcpChunk);
56
57#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
62pub enum TcpManagerMessage {
63 WriteChunk {
65 buf_id: usize,
66 offset: usize,
67 data: Part,
68 #[reply]
69 reply: OncePortRef<Result<(), String>>,
70 },
71 ReadChunk {
73 buf_id: usize,
74 offset: usize,
75 size: usize,
76 #[reply]
77 reply: OncePortRef<Result<TcpChunk, String>>,
78 },
79}
80wirevalue::register_type!(TcpManagerMessage);
81
82#[derive(Handler, HandleClient, Debug)]
86pub struct TcpSubmit {
87 pub ops: Vec<TcpOp>,
88 pub timeout: Duration,
89 #[reply]
90 pub reply: OncePortHandle<Result<(), String>>,
91}
92
93#[derive(Debug)]
98#[hyperactor::export(
99 handlers = [TcpManagerMessage],
100)]
101pub struct TcpManagerActor {
102 owner: OnceLock<ActorHandle<RdmaManagerActor>>,
103}
104
105impl TcpManagerActor {
106 pub fn new() -> Self {
107 Self {
108 owner: OnceLock::new(),
109 }
110 }
111
112 pub async fn local_handle(
115 client: &(impl context::Actor + Send + Sync),
116 ) -> Result<ActorHandle<Self>, anyhow::Error> {
117 let rdma_handle = RdmaManagerActor::local_handle(client);
118 let tcp_ref = rdma_handle.get_tcp_actor_ref(client).await?;
119 tcp_ref
120 .downcast_handle(client)
121 .ok_or_else(|| anyhow::anyhow!("TcpManagerActor is not in the local process"))
122 }
123
124 fn is_same_actor(cx: &Context<'_, Self>, op: &TcpOp) -> bool {
126 cx.mailbox().actor_id() == op.remote_tcp_manager.actor_id()
127 }
128
129 async fn execute_write(
135 &mut self,
136 cx: &Context<'_, Self>,
137 op: &TcpOp,
138 chunk_size: usize,
139 deadline: Instant,
140 ) -> Result<()> {
141 let same_process = Self::is_same_actor(cx, op);
142 let size = op.local_memory.size();
143 let mut offset = 0;
144
145 while offset < size {
146 let remaining = deadline.saturating_duration_since(Instant::now());
147 if remaining.is_zero() {
148 anyhow::bail!("tcp write timed out");
149 }
150
151 let len = std::cmp::min(chunk_size, size - offset);
152
153 let mut buf = vec![0u8; len];
154 op.local_memory.read_at(offset, &mut buf)?;
155 let data = Part::from(Bytes::from(buf));
156
157 if same_process {
158 tokio_timeout(
159 remaining,
160 self.write_chunk(cx, op.remote_buf_id, offset, data),
161 )
162 .await
163 .map_err(|_| anyhow::anyhow!("tcp write chunk timed out"))??
164 .map_err(|e| anyhow::anyhow!(e))?;
165 } else {
166 tokio_timeout(
167 remaining,
168 op.remote_tcp_manager
169 .write_chunk(cx, op.remote_buf_id, offset, data),
170 )
171 .await
172 .map_err(|_| anyhow::anyhow!("tcp write chunk timed out"))??
173 .map_err(|e| anyhow::anyhow!(e))?;
174 }
175
176 offset += len;
177 }
178
179 Ok(())
180 }
181
182 async fn execute_read(
188 &mut self,
189 cx: &Context<'_, Self>,
190 op: &TcpOp,
191 chunk_size: usize,
192 deadline: Instant,
193 ) -> Result<()> {
194 let same_process = Self::is_same_actor(cx, op);
195 let size = op.local_memory.size();
196 let mut offset = 0;
197
198 while offset < size {
199 let remaining = deadline.saturating_duration_since(Instant::now());
200 if remaining.is_zero() {
201 anyhow::bail!("tcp read timed out");
202 }
203
204 let len = std::cmp::min(chunk_size, size - offset);
205
206 let chunk = if same_process {
207 tokio_timeout(
208 remaining,
209 self.read_chunk(cx, op.remote_buf_id, offset, len),
210 )
211 .await
212 .map_err(|_| anyhow::anyhow!("tcp read chunk timed out"))??
213 .map_err(|e| anyhow::anyhow!(e))?
214 } else {
215 tokio_timeout(
216 remaining,
217 op.remote_tcp_manager
218 .read_chunk(cx, op.remote_buf_id, offset, len),
219 )
220 .await
221 .map_err(|_| anyhow::anyhow!("tcp read chunk timed out"))??
222 .map_err(|e| anyhow::anyhow!(e))?
223 };
224 let data = chunk.0.into_bytes();
225
226 anyhow::ensure!(
227 data.len() == len,
228 "tcp read chunk size mismatch: expected {len}, got {}",
229 data.len()
230 );
231
232 op.local_memory.write_at(offset, &data)?;
233
234 offset += len;
235 }
236
237 Ok(())
238 }
239}
240
241#[async_trait]
242impl Actor for TcpManagerActor {
243 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
244 let owner = this.parent_handle().ok_or_else(|| {
245 anyhow::anyhow!("RdmaManagerActor not found as parent of TcpManagerActor")
246 })?;
247 self.owner
248 .set(owner)
249 .map_err(|_| anyhow::anyhow!("TcpManagerActor owner already set"))?;
250 Ok(())
251 }
252}
253
254#[async_trait]
255#[hyperactor::handle(TcpManagerMessage)]
256impl TcpManagerMessageHandler for TcpManagerActor {
257 async fn write_chunk(
258 &mut self,
259 cx: &Context<Self>,
260 buf_id: usize,
261 offset: usize,
262 data: Part,
263 ) -> Result<Result<(), String>, anyhow::Error> {
264 let owner = self.owner.get().expect("TcpManagerActor owner not set");
265 let mem = match owner.request_local_memory(cx, buf_id).await {
266 Ok(Some(mem)) => mem,
267 Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
268 Err(e) => return Ok(Err(e.to_string())),
269 };
270
271 let bytes = data.into_bytes();
272 if let Err(e) = mem.write_at(offset, &bytes) {
273 return Ok(Err(e.to_string()));
274 }
275
276 Ok(Ok(()))
277 }
278
279 async fn read_chunk(
280 &mut self,
281 cx: &Context<Self>,
282 buf_id: usize,
283 offset: usize,
284 size: usize,
285 ) -> Result<Result<TcpChunk, String>, anyhow::Error> {
286 let owner = self.owner.get().expect("TcpManagerActor owner not set");
287 let mem = match owner.request_local_memory(cx, buf_id).await {
288 Ok(Some(mem)) => mem,
289 Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
290 Err(e) => return Ok(Err(e.to_string())),
291 };
292
293 let mut buf = BytesMut::zeroed(size);
294 if let Err(e) = mem.read_at(offset, &mut buf) {
295 return Ok(Err(e.to_string()));
296 }
297 Ok(Ok(TcpChunk(Part::from(buf.freeze()))))
298 }
299}
300
301#[async_trait]
302#[hyperactor::handle(TcpSubmit)]
303impl TcpSubmitHandler for TcpManagerActor {
304 async fn tcp_submit(
305 &mut self,
306 cx: &Context<Self>,
307 ops: Vec<TcpOp>,
308 timeout: Duration,
309 ) -> Result<Result<(), String>, anyhow::Error> {
310 let chunk_size =
311 hyperactor_config::global::get(crate::config::RDMA_MAX_CHUNK_SIZE_MB) * 1024 * 1024;
312 let deadline = Instant::now() + timeout;
313 let mut result = Ok(());
314
315 for op in &ops {
316 let remaining = deadline.saturating_duration_since(Instant::now());
317 if remaining.is_zero() {
318 result = Err("tcp submit timed out".to_string());
319 break;
320 }
321
322 let op_result = match op.op_type {
323 RdmaOpType::WriteFromLocal => {
324 self.execute_write(cx, op, chunk_size, deadline).await
325 }
326 RdmaOpType::ReadIntoLocal => self.execute_read(cx, op, chunk_size, deadline).await,
327 };
328
329 if let Err(e) = op_result {
330 result = Err(e.to_string());
331 break;
332 }
333 }
334
335 Ok(result)
336 }
337}
338
339#[async_trait]
340impl RdmaBackend for ActorHandle<TcpManagerActor> {
341 type TransportInfo = ();
342
343 async fn submit(
349 &mut self,
350 cx: &(impl context::Actor + Send + Sync),
351 ops: Vec<RdmaOp>,
352 timeout: Duration,
353 ) -> Result<()> {
354 let mut tcp_ops = Vec::with_capacity(ops.len());
355
356 for op in ops {
357 let (remote_tcp_mgr, remote_buf_id) = op.remote.resolve_tcp()?;
358 tcp_ops.push(TcpOp {
359 op_type: op.op_type,
360 local_memory: op.local,
361 remote_tcp_manager: remote_tcp_mgr,
362 remote_buf_id,
363 });
364 }
365
366 <Self as TcpSubmitClient>::tcp_submit(self, cx, tcp_ops, timeout)
367 .await?
368 .map_err(|e| anyhow::anyhow!(e))
369 }
370
371 fn transport_level(&self) -> RdmaTransportLevel {
372 RdmaTransportLevel::Tcp
373 }
374
375 fn transport_info(&self) -> Option<Self::TransportInfo> {
376 None
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use std::sync::Arc;
383 use std::sync::atomic::AtomicUsize;
384 use std::sync::atomic::Ordering;
385
386 use hyperactor::ActorHandle;
387 use hyperactor::Proc;
388 use hyperactor::RemoteSpawn;
389 use hyperactor::channel::ChannelAddr;
390 use hyperactor_config::Flattrs;
391
392 use super::TcpManagerActor;
393 use crate::RdmaManagerMessageClient;
394 use crate::RdmaOp;
395 use crate::RdmaOpType;
396 use crate::backend::RdmaBackend;
397 use crate::local_memory::RdmaLocalMemory;
398 use crate::local_memory::UnsafeLocalMemory;
399 use crate::rdma_manager_actor::GetTcpActorRefClient;
400 use crate::rdma_manager_actor::RdmaManagerActor;
401
402 static COUNTER: AtomicUsize = AtomicUsize::new(0);
403
404 struct TcpTestEnv {
405 _proc_1: Proc,
406 _proc_2: Proc,
407 instance_1: hyperactor::Instance<()>,
408 _instance_2: hyperactor::Instance<()>,
409 tcp_handle_1: ActorHandle<TcpManagerActor>,
410 tcp_handle_2: ActorHandle<TcpManagerActor>,
411 rdma_buf_1: crate::RdmaRemoteBuffer,
412 rdma_buf_2: crate::RdmaRemoteBuffer,
413 local_mem_1: Arc<dyn RdmaLocalMemory>,
414 _local_mem_2: Arc<dyn RdmaLocalMemory>,
415 cpu_buf_1: Box<[u8]>,
416 cpu_buf_2: Box<[u8]>,
417 }
418
419 async fn setup_tcp_env(buffer_size: usize) -> anyhow::Result<TcpTestEnv> {
425 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
426
427 let proc_1 = Proc::direct(
428 ChannelAddr::any(hyperactor::channel::ChannelTransport::Unix),
429 format!("tcp_test_{id}_a"),
430 )?;
431 let proc_2 = Proc::direct(
432 ChannelAddr::any(hyperactor::channel::ChannelTransport::Unix),
433 format!("tcp_test_{id}_b"),
434 )?;
435
436 let (instance_1, _ch1) = proc_1.instance("client")?;
437 let (instance_2, _ch2) = proc_2.instance("client")?;
438
439 let rdma_actor_1 = RdmaManagerActor::new(None, Flattrs::default()).await?;
440 let rdma_handle_1 = proc_1.spawn("rdma_manager", rdma_actor_1)?;
441
442 let rdma_actor_2 = RdmaManagerActor::new(None, Flattrs::default()).await?;
443 let rdma_handle_2 = proc_2.spawn("rdma_manager", rdma_actor_2)?;
444
445 let tcp_ref_1 = rdma_handle_1.get_tcp_actor_ref(&instance_1).await?;
447 let tcp_handle_1 = tcp_ref_1
448 .downcast_handle(&instance_1)
449 .ok_or_else(|| anyhow::anyhow!("tcp actor 1 not local"))?;
450
451 let tcp_ref_2 = rdma_handle_2.get_tcp_actor_ref(&instance_2).await?;
452 let tcp_handle_2 = tcp_ref_2
453 .downcast_handle(&instance_2)
454 .ok_or_else(|| anyhow::anyhow!("tcp actor 2 not local"))?;
455
456 let mut cpu_buf_1 = vec![0u8; buffer_size].into_boxed_slice();
458 let ptr_1 = cpu_buf_1.as_mut_ptr() as usize;
459 let local_mem_1: Arc<dyn RdmaLocalMemory> =
460 Arc::new(UnsafeLocalMemory::new(ptr_1, buffer_size));
461
462 let mut cpu_buf_2 = vec![0u8; buffer_size].into_boxed_slice();
463 let ptr_2 = cpu_buf_2.as_mut_ptr() as usize;
464 let local_mem_2: Arc<dyn RdmaLocalMemory> =
465 Arc::new(UnsafeLocalMemory::new(ptr_2, buffer_size));
466
467 let rdma_buf_1 = rdma_handle_1
469 .request_buffer(&instance_1, local_mem_1.clone())
470 .await?;
471
472 let rdma_buf_2 = rdma_handle_2
473 .request_buffer(&instance_2, local_mem_2.clone())
474 .await?;
475
476 Ok(TcpTestEnv {
477 _proc_1: proc_1,
478 _proc_2: proc_2,
479 instance_1,
480 _instance_2: instance_2,
481 tcp_handle_1,
482 tcp_handle_2,
483 rdma_buf_1,
484 rdma_buf_2,
485 local_mem_1,
486 _local_mem_2: local_mem_2,
487 cpu_buf_1,
488 cpu_buf_2,
489 })
490 }
491
492 #[timed_test::async_timed_test(timeout_secs = 30)]
494 async fn test_tcp_write_from_local() -> anyhow::Result<()> {
495 let config = hyperactor_config::global::lock();
496 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
497
498 let buf_size = 4096;
499 let mut env = setup_tcp_env(buf_size).await?;
500
501 for (i, byte) in env.cpu_buf_1.iter_mut().enumerate() {
502 *byte = (i % 256) as u8;
503 }
504
505 env.tcp_handle_1
506 .submit(
507 &env.instance_1,
508 vec![RdmaOp {
509 op_type: RdmaOpType::WriteFromLocal,
510 local: env.local_mem_1.clone(),
511 remote: env.rdma_buf_2.clone(),
512 }],
513 Duration::from_secs(30),
514 )
515 .await?;
516
517 for (i, byte) in env.cpu_buf_2.iter().enumerate() {
518 assert_eq!(*byte, (i % 256) as u8, "mismatch at offset {i} after write");
519 }
520
521 Ok(())
522 }
523
524 #[timed_test::async_timed_test(timeout_secs = 30)]
526 async fn test_tcp_read_into_local() -> anyhow::Result<()> {
527 let config = hyperactor_config::global::lock();
528 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
529
530 let buf_size = 2048;
531 let mut env = setup_tcp_env(buf_size).await?;
532
533 for (i, byte) in env.cpu_buf_2.iter_mut().enumerate() {
534 *byte = ((i * 7 + 3) % 256) as u8;
535 }
536
537 env.tcp_handle_1
538 .submit(
539 &env.instance_1,
540 vec![RdmaOp {
541 op_type: RdmaOpType::ReadIntoLocal,
542 local: env.local_mem_1.clone(),
543 remote: env.rdma_buf_2.clone(),
544 }],
545 Duration::from_secs(30),
546 )
547 .await?;
548
549 for (i, byte) in env.cpu_buf_1.iter().enumerate() {
550 assert_eq!(
551 *byte,
552 ((i * 7 + 3) % 256) as u8,
553 "mismatch at offset {i} after read"
554 );
555 }
556
557 Ok(())
558 }
559
560 #[timed_test::async_timed_test(timeout_secs = 30)]
562 async fn test_tcp_write_then_read_back() -> anyhow::Result<()> {
563 let config = hyperactor_config::global::lock();
564 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
565
566 let buf_size = 4096;
567 let mut env = setup_tcp_env(buf_size).await?;
568
569 for (i, byte) in env.cpu_buf_1.iter_mut().enumerate() {
570 *byte = ((i * 13 + 5) % 256) as u8;
571 }
572
573 env.tcp_handle_1
575 .submit(
576 &env.instance_1,
577 vec![RdmaOp {
578 op_type: RdmaOpType::WriteFromLocal,
579 local: env.local_mem_1.clone(),
580 remote: env.rdma_buf_2.clone(),
581 }],
582 Duration::from_secs(30),
583 )
584 .await?;
585
586 for byte in env.cpu_buf_1.iter_mut() {
588 *byte = 0;
589 }
590
591 env.tcp_handle_1
593 .submit(
594 &env.instance_1,
595 vec![RdmaOp {
596 op_type: RdmaOpType::ReadIntoLocal,
597 local: env.local_mem_1.clone(),
598 remote: env.rdma_buf_2.clone(),
599 }],
600 Duration::from_secs(30),
601 )
602 .await?;
603
604 for (i, byte) in env.cpu_buf_1.iter().enumerate() {
605 assert_eq!(
606 *byte,
607 ((i * 13 + 5) % 256) as u8,
608 "mismatch at offset {i} after read-back"
609 );
610 }
611
612 Ok(())
613 }
614
615 #[timed_test::async_timed_test(timeout_secs = 30)]
619 async fn test_tcp_multi_chunk_write() -> anyhow::Result<()> {
620 let config = hyperactor_config::global::lock();
621 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
622 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
623
624 let buf_size = 3 * 1024 * 512;
626 let mut env = setup_tcp_env(buf_size).await?;
627
628 for (i, byte) in env.cpu_buf_1.iter_mut().enumerate() {
629 *byte = (i % 251) as u8;
630 }
631
632 env.tcp_handle_1
633 .submit(
634 &env.instance_1,
635 vec![RdmaOp {
636 op_type: RdmaOpType::WriteFromLocal,
637 local: env.local_mem_1.clone(),
638 remote: env.rdma_buf_2.clone(),
639 }],
640 Duration::from_secs(30),
641 )
642 .await?;
643
644 for (i, byte) in env.cpu_buf_2.iter().enumerate() {
645 assert_eq!(*byte, (i % 251) as u8, "mismatch at offset {i}");
646 }
647
648 Ok(())
649 }
650
651 #[timed_test::async_timed_test(timeout_secs = 30)]
653 async fn test_tcp_multi_chunk_read() -> anyhow::Result<()> {
654 let config = hyperactor_config::global::lock();
655 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
656 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
657
658 let buf_size = 3 * 1024 * 512; let mut env = setup_tcp_env(buf_size).await?;
660
661 for (i, byte) in env.cpu_buf_2.iter_mut().enumerate() {
662 *byte = ((i * 3 + 17) % 256) as u8;
663 }
664
665 env.tcp_handle_1
666 .submit(
667 &env.instance_1,
668 vec![RdmaOp {
669 op_type: RdmaOpType::ReadIntoLocal,
670 local: env.local_mem_1.clone(),
671 remote: env.rdma_buf_2.clone(),
672 }],
673 Duration::from_secs(30),
674 )
675 .await?;
676
677 for (i, byte) in env.cpu_buf_1.iter().enumerate() {
678 assert_eq!(*byte, ((i * 3 + 17) % 256) as u8, "mismatch at offset {i}");
679 }
680
681 Ok(())
682 }
683
684 #[timed_test::async_timed_test(timeout_secs = 30)]
687 async fn test_tcp_multi_chunk_round_trip() -> anyhow::Result<()> {
688 let config = hyperactor_config::global::lock();
689 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
690 let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
691
692 let buf_size = 5 * 1024 * 512;
694 let mut env = setup_tcp_env(buf_size).await?;
695
696 for (i, byte) in env.cpu_buf_1.iter_mut().enumerate() {
697 *byte = ((i * 41 + 7) % 256) as u8;
698 }
699
700 env.tcp_handle_1
702 .submit(
703 &env.instance_1,
704 vec![RdmaOp {
705 op_type: RdmaOpType::WriteFromLocal,
706 local: env.local_mem_1.clone(),
707 remote: env.rdma_buf_2.clone(),
708 }],
709 Duration::from_secs(30),
710 )
711 .await?;
712
713 for byte in env.cpu_buf_1.iter_mut() {
715 *byte = 0;
716 }
717
718 env.tcp_handle_1
720 .submit(
721 &env.instance_1,
722 vec![RdmaOp {
723 op_type: RdmaOpType::ReadIntoLocal,
724 local: env.local_mem_1.clone(),
725 remote: env.rdma_buf_2.clone(),
726 }],
727 Duration::from_secs(30),
728 )
729 .await?;
730
731 for (i, byte) in env.cpu_buf_1.iter().enumerate() {
732 assert_eq!(
733 *byte,
734 ((i * 41 + 7) % 256) as u8,
735 "mismatch at offset {i} after multi-chunk round-trip"
736 );
737 }
738
739 Ok(())
740 }
741
742 #[timed_test::async_timed_test(timeout_secs = 30)]
745 async fn test_tcp_resolve_tcp() -> anyhow::Result<()> {
746 let config = hyperactor_config::global::lock();
747 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
748
749 let env = setup_tcp_env(64).await?;
750
751 let (tcp_ref_1, id_1) = env.rdma_buf_1.resolve_tcp()?;
752 assert_eq!(id_1, env.rdma_buf_1.id);
753 let expected_1: hyperactor::ActorRef<TcpManagerActor> = env.tcp_handle_1.bind();
754 assert_eq!(tcp_ref_1.actor_id(), expected_1.actor_id());
755
756 let (tcp_ref_2, id_2) = env.rdma_buf_2.resolve_tcp()?;
757 assert_eq!(id_2, env.rdma_buf_2.id);
758 let expected_2: hyperactor::ActorRef<TcpManagerActor> = env.tcp_handle_2.bind();
759 assert_eq!(tcp_ref_2.actor_id(), expected_2.actor_id());
760
761 Ok(())
762 }
763
764 #[timed_test::async_timed_test(timeout_secs = 30)]
766 async fn test_tcp_write_to_released_buffer() -> anyhow::Result<()> {
767 let config = hyperactor_config::global::lock();
768 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
769
770 let buf_size = 64;
771 let mut env = setup_tcp_env(buf_size).await?;
772
773 for (i, byte) in env.cpu_buf_1.iter_mut().enumerate() {
775 *byte = (i % 256) as u8;
776 }
777
778 env.tcp_handle_1
780 .submit(
781 &env.instance_1,
782 vec![RdmaOp {
783 op_type: RdmaOpType::WriteFromLocal,
784 local: env.local_mem_1.clone(),
785 remote: env.rdma_buf_2.clone(),
786 }],
787 Duration::from_secs(10),
788 )
789 .await?;
790
791 use crate::rdma_manager_actor::ReleaseBufferClient;
793 let owner_ref = env.rdma_buf_2.owner.clone();
794 owner_ref
795 .release_buffer(&env.instance_1, env.rdma_buf_2.id)
796 .await?;
797
798 let result = env
800 .tcp_handle_1
801 .submit(
802 &env.instance_1,
803 vec![RdmaOp {
804 op_type: RdmaOpType::WriteFromLocal,
805 local: env.local_mem_1.clone(),
806 remote: env.rdma_buf_2.clone(),
807 }],
808 Duration::from_secs(10),
809 )
810 .await;
811 assert!(result.is_err(), "expected error writing to released buffer");
812
813 let result = env
816 .tcp_handle_2
817 .submit(
818 &env.instance_1,
819 vec![RdmaOp {
820 op_type: RdmaOpType::ReadIntoLocal,
821 local: env.local_mem_1.clone(),
822 remote: env.rdma_buf_1.clone(),
823 }],
824 Duration::from_secs(10),
825 )
826 .await;
827 assert!(
828 result.is_ok(),
829 "TCP actor should still be alive after error"
830 );
831
832 Ok(())
833 }
834
835 #[timed_test::async_timed_test(timeout_secs = 30)]
837 async fn test_tcp_read_from_released_buffer() -> anyhow::Result<()> {
838 let config = hyperactor_config::global::lock();
839 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
840
841 let buf_size = 64;
842 let mut env = setup_tcp_env(buf_size).await?;
843
844 use crate::rdma_manager_actor::ReleaseBufferClient;
846 let owner_ref = env.rdma_buf_2.owner.clone();
847 owner_ref
848 .release_buffer(&env.instance_1, env.rdma_buf_2.id)
849 .await?;
850
851 let result = env
853 .tcp_handle_1
854 .submit(
855 &env.instance_1,
856 vec![RdmaOp {
857 op_type: RdmaOpType::ReadIntoLocal,
858 local: env.local_mem_1.clone(),
859 remote: env.rdma_buf_2.clone(),
860 }],
861 Duration::from_secs(10),
862 )
863 .await;
864 assert!(
865 result.is_err(),
866 "expected error reading from released buffer"
867 );
868
869 let result = env
871 .tcp_handle_2
872 .submit(
873 &env.instance_1,
874 vec![RdmaOp {
875 op_type: RdmaOpType::ReadIntoLocal,
876 local: env.local_mem_1.clone(),
877 remote: env.rdma_buf_1.clone(),
878 }],
879 Duration::from_secs(10),
880 )
881 .await;
882 assert!(
883 result.is_ok(),
884 "TCP actor should still be alive after error"
885 );
886
887 Ok(())
888 }
889
890 struct SameProcTcpTestEnv {
893 _proc: Proc,
894 instance: hyperactor::Instance<()>,
895 tcp_handle: ActorHandle<TcpManagerActor>,
896 _rdma_buf_1: crate::RdmaRemoteBuffer,
897 rdma_buf_2: crate::RdmaRemoteBuffer,
898 local_mem_1: Arc<dyn RdmaLocalMemory>,
899 _local_mem_2: Arc<dyn RdmaLocalMemory>,
900 cpu_buf_1: Box<[u8]>,
901 cpu_buf_2: Box<[u8]>,
902 }
903
904 async fn setup_same_proc_tcp_env(buffer_size: usize) -> anyhow::Result<SameProcTcpTestEnv> {
905 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
906
907 let proc = Proc::direct(
908 ChannelAddr::any(hyperactor::channel::ChannelTransport::Unix),
909 format!("tcp_same_proc_test_{id}"),
910 )?;
911 let (instance, _ch) = proc.instance("client")?;
912
913 let rdma_actor = RdmaManagerActor::new(None, Flattrs::default()).await?;
914 let rdma_handle = proc.spawn("rdma_manager", rdma_actor)?;
915
916 let tcp_ref = rdma_handle.get_tcp_actor_ref(&instance).await?;
917 let tcp_handle = tcp_ref
918 .downcast_handle(&instance)
919 .ok_or_else(|| anyhow::anyhow!("tcp actor not local"))?;
920
921 let mut cpu_buf_1 = vec![0u8; buffer_size].into_boxed_slice();
922 let ptr_1 = cpu_buf_1.as_mut_ptr() as usize;
923 let local_mem_1: Arc<dyn RdmaLocalMemory> =
924 Arc::new(UnsafeLocalMemory::new(ptr_1, buffer_size));
925
926 let mut cpu_buf_2 = vec![0u8; buffer_size].into_boxed_slice();
927 let ptr_2 = cpu_buf_2.as_mut_ptr() as usize;
928 let local_mem_2: Arc<dyn RdmaLocalMemory> =
929 Arc::new(UnsafeLocalMemory::new(ptr_2, buffer_size));
930
931 let rdma_buf_1 = rdma_handle
932 .request_buffer(&instance, local_mem_1.clone())
933 .await?;
934 let rdma_buf_2 = rdma_handle
935 .request_buffer(&instance, local_mem_2.clone())
936 .await?;
937
938 Ok(SameProcTcpTestEnv {
939 _proc: proc,
940 instance,
941 tcp_handle,
942 _rdma_buf_1: rdma_buf_1,
943 rdma_buf_2,
944 local_mem_1,
945 _local_mem_2: local_mem_2,
946 cpu_buf_1,
947 cpu_buf_2,
948 })
949 }
950
951 #[timed_test::async_timed_test(timeout_secs = 30)]
953 async fn test_tcp_same_process_write() -> anyhow::Result<()> {
954 let config = hyperactor_config::global::lock();
955 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
956
957 let buf_size = 4096;
958 let mut env = setup_same_proc_tcp_env(buf_size).await?;
959
960 for (i, byte) in env.cpu_buf_1.iter_mut().enumerate() {
961 *byte = (i % 256) as u8;
962 }
963
964 env.tcp_handle
965 .submit(
966 &env.instance,
967 vec![RdmaOp {
968 op_type: RdmaOpType::WriteFromLocal,
969 local: env.local_mem_1.clone(),
970 remote: env.rdma_buf_2.clone(),
971 }],
972 Duration::from_secs(10),
973 )
974 .await?;
975
976 for (i, byte) in env.cpu_buf_2.iter().enumerate() {
977 assert_eq!(*byte, (i % 256) as u8, "mismatch at offset {i}");
978 }
979
980 Ok(())
981 }
982
983 #[timed_test::async_timed_test(timeout_secs = 30)]
985 async fn test_tcp_same_process_read() -> anyhow::Result<()> {
986 let config = hyperactor_config::global::lock();
987 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
988
989 let buf_size = 2048;
990 let mut env = setup_same_proc_tcp_env(buf_size).await?;
991
992 for (i, byte) in env.cpu_buf_2.iter_mut().enumerate() {
993 *byte = ((i * 7 + 3) % 256) as u8;
994 }
995
996 env.tcp_handle
997 .submit(
998 &env.instance,
999 vec![RdmaOp {
1000 op_type: RdmaOpType::ReadIntoLocal,
1001 local: env.local_mem_1.clone(),
1002 remote: env.rdma_buf_2.clone(),
1003 }],
1004 Duration::from_secs(10),
1005 )
1006 .await?;
1007
1008 for (i, byte) in env.cpu_buf_1.iter().enumerate() {
1009 assert_eq!(*byte, ((i * 7 + 3) % 256) as u8, "mismatch at offset {i}");
1010 }
1011
1012 Ok(())
1013 }
1014
1015 #[timed_test::async_timed_test(timeout_secs = 30)]
1017 async fn test_tcp_same_process_round_trip() -> anyhow::Result<()> {
1018 let config = hyperactor_config::global::lock();
1019 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1020
1021 let buf_size = 4096;
1022 let mut env = setup_same_proc_tcp_env(buf_size).await?;
1023
1024 for (i, byte) in env.cpu_buf_1.iter_mut().enumerate() {
1025 *byte = ((i * 13 + 5) % 256) as u8;
1026 }
1027
1028 env.tcp_handle
1030 .submit(
1031 &env.instance,
1032 vec![RdmaOp {
1033 op_type: RdmaOpType::WriteFromLocal,
1034 local: env.local_mem_1.clone(),
1035 remote: env.rdma_buf_2.clone(),
1036 }],
1037 Duration::from_secs(10),
1038 )
1039 .await?;
1040
1041 for byte in env.cpu_buf_1.iter_mut() {
1043 *byte = 0;
1044 }
1045
1046 env.tcp_handle
1048 .submit(
1049 &env.instance,
1050 vec![RdmaOp {
1051 op_type: RdmaOpType::ReadIntoLocal,
1052 local: env.local_mem_1.clone(),
1053 remote: env.rdma_buf_2.clone(),
1054 }],
1055 Duration::from_secs(10),
1056 )
1057 .await?;
1058
1059 for (i, byte) in env.cpu_buf_1.iter().enumerate() {
1060 assert_eq!(
1061 *byte,
1062 ((i * 13 + 5) % 256) as u8,
1063 "mismatch at offset {i} after round-trip"
1064 );
1065 }
1066
1067 Ok(())
1068 }
1069
1070 #[timed_test::async_timed_test(timeout_secs = 30)]
1073 async fn test_tcp_fallback_disabled_fails() -> anyhow::Result<()> {
1074 let config = hyperactor_config::global::lock();
1075 let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, false);
1076
1077 let result = RdmaManagerActor::new(None, Flattrs::default()).await;
1078 if crate::ibverbs_supported() {
1079 assert!(result.is_ok());
1080 } else {
1081 assert!(result.is_err());
1082 }
1083
1084 Ok(())
1085 }
1086}