Skip to main content

monarch_rdma/backend/tcp/
manager_actor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! TCP manager actor for RDMA fallback transport.
10//!
11//! Transfers buffer data over the default hyperactor channel transport
12//! in chunks controlled by
13//! [`RDMA_MAX_CHUNK_SIZE_MB`](crate::config::RDMA_MAX_CHUNK_SIZE_MB).
14
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::OnceLock;
18use std::time::Duration;
19use std::time::Instant;
20
21use anyhow::Result;
22use async_trait::async_trait;
23use bytes::Bytes;
24use bytes::BytesMut;
25use dashmap::DashMap;
26use hyperactor::Actor;
27use hyperactor::ActorHandle;
28use hyperactor::ActorRef;
29use hyperactor::Context;
30use hyperactor::Endpoint as _;
31use hyperactor::HandleClient;
32use hyperactor::Handler;
33use hyperactor::Instance;
34use hyperactor::OncePortHandle;
35use hyperactor::OncePortRef;
36use hyperactor::PortHandle;
37use hyperactor::RefClient;
38use hyperactor::actor::ActorError;
39use hyperactor::channel;
40use hyperactor::channel::ChannelAddr;
41use hyperactor::channel::ChannelTx;
42use hyperactor::channel::Rx;
43use hyperactor::channel::Tx;
44use hyperactor::context;
45use hyperactor::context::Actor as _;
46use hyperactor_mesh::transport::default_bind_spec;
47use serde::Deserialize;
48use serde::Serialize;
49use serde_multipart::Part;
50use tokio::time::timeout as tokio_timeout;
51use tokio_util::sync::CancellationToken;
52use typeuri::Named;
53
54use super::TcpOp;
55use crate::RdmaOp;
56use crate::RdmaOpType;
57use crate::RdmaTransportLevel;
58use crate::backend::RdmaBackend;
59use crate::local_memory::KeepaliveLocalMemory;
60use crate::rdma_manager_actor::GetTcpActorRefClient;
61use crate::rdma_manager_actor::RdmaManagerActor;
62use crate::rdma_manager_actor::RdmaManagerMessageClient;
63
64/// [`Named`] wrapper around [`Part`] for use as a reply type.
65///
66/// [`Part`] itself does not implement [`Named`], which is required by
67/// [`OncePortRef`]. This newtype adds the missing trait.
68#[derive(Debug, Clone, Serialize, Deserialize, Named)]
69pub struct TcpChunk(Part);
70wirevalue::register_type!(TcpChunk);
71
72/// Data chunk sent over direct parallel channels.
73#[derive(Debug, Clone, Serialize, Deserialize, Named)]
74struct TcpDataChunk {
75    // Which specific transfer this chunk is associated with.
76    transfer_id: usize,
77    // Offset into the buffer for this chunk.
78    offset: usize,
79    data: Part,
80}
81wirevalue::register_type!(TcpDataChunk);
82
83/// Tracks the progress of a single parallel transfer (read or write).
84///
85/// Shared between channel receive loops and actor message handlers
86/// via a [`DashMap`].
87#[derive(Debug)]
88struct TransferState {
89    /// Buffer backing this transfer, provided at construction.
90    local_memory: Arc<KeepaliveLocalMemory>,
91
92    /// Number of chunks received so far.
93    chunks_received: usize,
94
95    /// Total chunks expected for this transfer.
96    total_chunks: usize,
97
98    /// Completion reply port. Fired when all chunks arrive
99    /// or an error occurs.
100    done: OncePortRef<Result<(), String>>,
101}
102
103impl TransferState {
104    fn new(
105        total_chunks: usize,
106        local_memory: Arc<KeepaliveLocalMemory>,
107        done: OncePortRef<Result<(), String>>,
108    ) -> Self {
109        Self {
110            local_memory,
111            chunks_received: 0,
112            total_chunks,
113            done,
114        }
115    }
116}
117
118/// Sends the result of a completed transfer to the caller's reply port.
119///
120/// Sending an actor message from the spawned receiver task requires the
121/// loop to own a dummy [`context::Actor`] impl. If the task sent directly
122/// to a remote [`OncePortRef`], the message would appear to come from
123/// this dummy context, and undeliverable messages wouldn't be handled
124/// properly. This intermediate message lets us use a [`PortHandle`]
125/// whose message cannot be undeliverable; the handler then forwards the
126/// result using the real actor's context.
127#[derive(Debug, Serialize, Deserialize, Named)]
128struct SendTransferResult {
129    done: OncePortRef<Result<(), String>>,
130    result: Result<(), String>,
131}
132
133/// Fatal error from the receive loop.
134///
135/// The handler logs the error and returns `Err`, which triggers a
136/// supervision event and crashes the actor.
137#[derive(Debug, Serialize, Deserialize, Named)]
138struct TransferError {
139    message: String,
140}
141
142/// Set up the local TcpManagerActor to receive a parallel transfer from
143/// a remote TcpManagerActor.
144#[derive(Debug)]
145struct RegisterTransferLocal {
146    local_memory: Arc<KeepaliveLocalMemory>,
147    total_chunks: usize,
148    done: OncePortRef<Result<(), String>>,
149    // The transfer ID
150    reply: OncePortHandle<usize>,
151}
152
153/// Tell the local TcpManagerActor to read local memory and push
154/// chunks to `dest_addr`.
155#[derive(Debug)]
156struct ExecuteTransferLocal {
157    transfer_id: usize,
158    local_memory: Arc<KeepaliveLocalMemory>,
159    chunk_size: usize,
160    dest_addr: ChannelAddr,
161}
162
163/// Serializable messages for the [`TcpManagerActor`].
164///
165/// These travel over the wire between processes. The [`Part`] payload
166/// is transferred via the multipart codec without an extra copy.
167#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
168enum TcpManagerMessage {
169    /// Write a chunk of data into a registered buffer at the given offset.
170    WriteChunk {
171        buf_id: usize,
172        offset: usize,
173        data: Part,
174        #[reply]
175        reply: OncePortRef<Result<(), String>>,
176    },
177    /// Read a chunk of data from a registered buffer at the given offset.
178    ReadChunk {
179        buf_id: usize,
180        offset: usize,
181        size: usize,
182        #[reply]
183        reply: OncePortRef<Result<TcpChunk, String>>,
184    },
185    /// Return the channel address served by this actor for parallel transfers.
186    /// `None` when parallelism is 1.
187    GetChannelAddress {
188        #[reply]
189        reply: OncePortRef<Option<ChannelAddr>>,
190    },
191    /// Set up a remote TcpManagerActor to receive a parallel transfer from
192    /// the sender.
193    RegisterTransferRemote {
194        buf_id: usize,
195        total_chunks: usize,
196        done: OncePortRef<Result<(), String>>,
197        #[reply]
198        reply: OncePortRef<Result<usize, String>>,
199    },
200    /// Tell the remote TcpManagerActor to read its local memory and push
201    /// chunks to the dest_addr provided by the sender.
202    ExecuteTransferRemote {
203        transfer_id: usize,
204        buf_id: usize,
205        chunk_size: usize,
206        dest_addr: ChannelAddr,
207        #[reply]
208        reply: OncePortRef<Result<(), String>>,
209    },
210}
211wirevalue::register_type!(TcpManagerMessage);
212
213/// TCP fallback RDMA backend actor.
214///
215/// Spawned as a child of [`RdmaManagerActor`]. Transfers buffer data
216/// over the default hyperactor channel transport in chunks.
217#[derive(Debug)]
218#[hyperactor::export(
219    handlers = [TcpManagerMessage],
220)]
221pub struct TcpManagerActor {
222    owner: OnceLock<ActorHandle<RdmaManagerActor>>,
223    next_transfer_id: usize,
224    transfers: Arc<DashMap<usize, TransferState>>,
225    /// Address of the direct channel served for parallel transfers.
226    /// `None` when parallelism is 1 (default).
227    channel_addr: Option<ChannelAddr>,
228    /// Cached outbound connections keyed by remote channel address.
229    outbound: HashMap<ChannelAddr, Vec<Arc<ChannelTx<TcpDataChunk>>>>,
230    /// Cancellation token for spawned tasks.
231    cancel: CancellationToken,
232    /// Signaled when the parallel receive loop exits.
233    receiver_done: Option<tokio::sync::oneshot::Receiver<()>>,
234}
235
236impl TcpManagerActor {
237    pub fn new() -> Self {
238        Self {
239            owner: OnceLock::new(),
240            next_transfer_id: 0,
241            transfers: Arc::new(DashMap::new()),
242            channel_addr: None,
243            outbound: HashMap::new(),
244            cancel: CancellationToken::new(),
245            receiver_done: None,
246        }
247    }
248
249    fn register_transfer(
250        &mut self,
251        local_memory: Arc<KeepaliveLocalMemory>,
252        total_chunks: usize,
253        done: OncePortRef<Result<(), String>>,
254    ) -> usize {
255        let transfer_id = self.next_transfer_id;
256        self.next_transfer_id += 1;
257        self.transfers.insert(
258            transfer_id,
259            TransferState::new(total_chunks, local_memory, done),
260        );
261        transfer_id
262    }
263
264    fn execute_transfer(
265        &mut self,
266        cx: &Context<Self>,
267        transfer_id: usize,
268        local_memory: Arc<KeepaliveLocalMemory>,
269        chunk_size: usize,
270        dest_addr: ChannelAddr,
271    ) -> Result<()> {
272        let parallelism =
273            hyperactor_config::global::get(crate::config::RDMA_TCP_FALLBACK_PARALLELISM);
274
275        if !self.outbound.contains_key(&dest_addr) {
276            let conns = (0..parallelism)
277                .map(|_| {
278                    channel::dial::<TcpDataChunk>(dest_addr.clone())
279                        .map(Arc::new)
280                        .map_err(anyhow::Error::from)
281                })
282                .collect::<Result<Vec<_>>>()?;
283            self.outbound.insert(dest_addr.clone(), conns);
284        }
285        let conns = self.outbound.get(&dest_addr).unwrap();
286
287        let size = local_memory.size();
288        let total_chunks = size.div_ceil(chunk_size);
289
290        let chunk_index = Arc::new(std::sync::atomic::AtomicUsize::new(0));
291        let error_port: PortHandle<TransferError> = cx.port();
292        let proc = cx.instance().proc().clone();
293        let cancel = self.cancel.clone();
294
295        for conn in conns.clone() {
296            let mem = local_memory.clone();
297            let chunk_index = chunk_index.clone();
298            let error_port = error_port.clone();
299            let proc = proc.clone();
300            let cancel = cancel.clone();
301
302            tokio::spawn(async move {
303                let sender_name = format!(
304                    "tcp_chunk_sender_{}",
305                    hyperactor_mesh::shortuuid::ShortUuid::generate()
306                );
307                let (instance, _handle) = proc
308                    .client(&sender_name)
309                    .expect("failed to create sender instance");
310
311                loop {
312                    if cancel.is_cancelled() {
313                        return;
314                    }
315
316                    let idx = chunk_index.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
317                    if idx >= total_chunks {
318                        break;
319                    }
320
321                    let offset = idx * chunk_size;
322                    let len = std::cmp::min(chunk_size, size - offset);
323                    let mut buf = BytesMut::zeroed(len);
324                    // SAFETY: the caller is responsible for ensuring that no other
325                    // component writes the target byte range concurrently.
326                    if let Err(e) = unsafe { mem.read_at(offset, &mut buf) } {
327                        error_port.post(
328                            &instance,
329                            TransferError {
330                                message: format!("read_at failed at offset {offset}: {e}"),
331                            },
332                        );
333                        return;
334                    }
335
336                    let chunk = TcpDataChunk {
337                        transfer_id,
338                        offset,
339                        data: Part::from(buf.freeze()),
340                    };
341
342                    if let Err(e) = conn.send(chunk).await {
343                        error_port.post(
344                            &instance,
345                            TransferError {
346                                message: format!("failed to send chunk at offset {offset}: {e}"),
347                            },
348                        );
349                        return;
350                    }
351                }
352            });
353        }
354
355        Ok(())
356    }
357
358    /// Construct an [`ActorHandle`] for the local [`TcpManagerActor`]
359    /// by querying the local [`RdmaManagerActor`].
360    pub async fn local_handle(
361        client: &(impl context::Actor + Send + Sync),
362    ) -> Result<ActorHandle<Self>, anyhow::Error> {
363        let rdma_handle = RdmaManagerActor::local_handle(client);
364        let tcp_ref: ActorRef<TcpManagerActor> = rdma_handle.get_tcp_actor_ref(client).await?;
365        tcp_ref
366            .downcast_handle(client)
367            .ok_or_else(|| anyhow::anyhow!("TcpManagerActor is not in the local process"))
368    }
369}
370
371#[async_trait]
372impl Actor for TcpManagerActor {
373    async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
374        let owner = this.parent_handle().ok_or_else(|| {
375            anyhow::anyhow!("RdmaManagerActor not found as parent of TcpManagerActor")
376        })?;
377        self.owner
378            .set(owner)
379            .map_err(|_| anyhow::anyhow!("TcpManagerActor owner already set"))?;
380
381        let parallelism =
382            hyperactor_config::global::get(crate::config::RDMA_TCP_FALLBACK_PARALLELISM);
383        if parallelism > 1 {
384            let addr = match default_bind_spec() {
385                channel::BindSpec::Any(transport) => ChannelAddr::any(transport),
386                channel::BindSpec::Addr(addr) => addr,
387            };
388            let (bound_addr, mut rx) = channel::serve::<TcpDataChunk>(addr)?;
389            self.channel_addr = Some(bound_addr);
390
391            let transfers = self.transfers.clone();
392            let proc = this.proc().clone();
393            let result_port: PortHandle<SendTransferResult> = this.port();
394            let error_port: PortHandle<TransferError> = this.port();
395            let cancel = self.cancel.clone();
396            let receiver_name = format!(
397                "tcp_chunk_receiver_{}",
398                hyperactor_mesh::shortuuid::ShortUuid::generate()
399            );
400
401            let (done_tx, done_rx) = tokio::sync::oneshot::channel::<()>();
402            self.receiver_done = Some(done_rx);
403
404            tokio::spawn(async move {
405                let (instance, _handle) = proc
406                    .client(&receiver_name.to_string())
407                    .expect("failed to create receiver instance");
408
409                loop {
410                    let chunk = tokio::select! {
411                        _ = cancel.cancelled() => break,
412                        result = rx.recv() => match result {
413                            Ok(chunk) => chunk,
414                            Err(e) => {
415                                error_port
416                                    .post(
417                                        &instance,
418                                        TransferError {
419                                            message: format!(
420                                                "parallel channel receive error: {e}"
421                                            ),
422                                        },
423                                    );
424                                break;
425                            }
426                        },
427                    };
428
429                    let mut entry = match transfers.get_mut(&chunk.transfer_id) {
430                        Some(entry) => entry,
431                        None => {
432                            tracing::warn!(
433                                "received chunk for unknown transfer {:?}",
434                                chunk.transfer_id,
435                            );
436                            continue;
437                        }
438                    };
439
440                    let mut write_offset = chunk.offset;
441                    let fragments = chunk.data.into_inner();
442                    let write_err = fragments.iter().find_map(|fragment| {
443                        // SAFETY: the caller is responsible for ensuring that no other
444                        // component reads or writes the target byte range concurrently.
445                        let result = unsafe { entry.local_memory.write_at(write_offset, fragment) };
446                        write_offset += fragment.len();
447                        result.err()
448                    });
449                    if let Some(e) = write_err {
450                        let transfer_id = chunk.transfer_id;
451                        drop(entry);
452                        let (_, state) = transfers.remove(&transfer_id).unwrap();
453                        result_port.post(
454                            &instance,
455                            SendTransferResult {
456                                done: state.done,
457                                result: Err(e.to_string()),
458                            },
459                        );
460                        continue;
461                    }
462
463                    entry.chunks_received += 1;
464                    if entry.chunks_received == entry.total_chunks {
465                        let transfer_id = chunk.transfer_id;
466                        drop(entry);
467                        let (_, state) = transfers.remove(&transfer_id).unwrap();
468                        result_port.post(
469                            &instance,
470                            SendTransferResult {
471                                done: state.done,
472                                result: Ok(()),
473                            },
474                        );
475                    }
476                }
477                rx.join().await;
478                done_tx.send(()).unwrap();
479            });
480        }
481
482        Ok(())
483    }
484
485    async fn cleanup(
486        &mut self,
487        _this: &Instance<Self>,
488        _err: Option<&ActorError>,
489    ) -> Result<(), anyhow::Error> {
490        self.cancel.cancel();
491        if let Some(done_rx) = self.receiver_done.take() {
492            done_rx.await?;
493        }
494        Ok(())
495    }
496}
497
498#[async_trait]
499#[hyperactor::handle(TcpManagerMessage)]
500impl TcpManagerMessageHandler for TcpManagerActor {
501    async fn write_chunk(
502        &mut self,
503        cx: &Context<Self>,
504        buf_id: usize,
505        offset: usize,
506        data: Part,
507    ) -> Result<Result<(), String>, anyhow::Error> {
508        let owner = self.owner.get().expect("TcpManagerActor owner not set");
509        let mem = match owner.request_local_memory(cx, buf_id).await {
510            Ok(Some(mem)) => mem,
511            Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
512            Err(e) => return Ok(Err(e.to_string())),
513        };
514
515        let bytes = data.into_bytes();
516        // SAFETY: the remote peer that issued this `WriteChunk` had to
517        // register the buffer locally first; its caller is responsible
518        // for ensuring no other component reads or writes the target byte
519        // range concurrently.
520        if let Err(e) = unsafe { mem.write_at(offset, &bytes) } {
521            return Ok(Err(e.to_string()));
522        }
523
524        Ok(Ok(()))
525    }
526
527    async fn read_chunk(
528        &mut self,
529        cx: &Context<Self>,
530        buf_id: usize,
531        offset: usize,
532        size: usize,
533    ) -> Result<Result<TcpChunk, String>, anyhow::Error> {
534        let owner = self.owner.get().expect("TcpManagerActor owner not set");
535        let mem = match owner.request_local_memory(cx, buf_id).await {
536            Ok(Some(mem)) => mem,
537            Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
538            Err(e) => return Ok(Err(e.to_string())),
539        };
540
541        let mut buf = BytesMut::zeroed(size);
542        // SAFETY: the remote peer that issued this `ReadChunk` had to
543        // register the buffer locally first; its caller is responsible
544        // for ensuring no other component writes the target byte range
545        // concurrently.
546        if let Err(e) = unsafe { mem.read_at(offset, &mut buf) } {
547            return Ok(Err(e.to_string()));
548        }
549        Ok(Ok(TcpChunk(Part::from(buf.freeze()))))
550    }
551
552    async fn get_channel_address(
553        &mut self,
554        _cx: &Context<Self>,
555    ) -> Result<Option<ChannelAddr>, anyhow::Error> {
556        Ok(self.channel_addr.clone())
557    }
558
559    async fn register_transfer_remote(
560        &mut self,
561        cx: &Context<Self>,
562        buf_id: usize,
563        total_chunks: usize,
564        done: OncePortRef<Result<(), String>>,
565    ) -> Result<Result<usize, String>, anyhow::Error> {
566        let owner = self.owner.get().expect("TcpManagerActor owner not set");
567        let mem = match owner.request_local_memory(cx, buf_id).await {
568            Ok(Some(mem)) => mem,
569            Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
570            Err(e) => return Ok(Err(e.to_string())),
571        };
572        let transfer_id = self.register_transfer(mem, total_chunks, done);
573        Ok(Ok(transfer_id))
574    }
575
576    async fn execute_transfer_remote(
577        &mut self,
578        cx: &Context<Self>,
579        transfer_id: usize,
580        buf_id: usize,
581        chunk_size: usize,
582        dest_addr: ChannelAddr,
583    ) -> Result<Result<(), String>, anyhow::Error> {
584        let owner = self.owner.get().expect("TcpManagerActor owner not set");
585        let mem = match owner.request_local_memory(cx, buf_id).await {
586            Ok(Some(mem)) => mem,
587            Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
588            Err(e) => return Ok(Err(e.to_string())),
589        };
590        self.execute_transfer(cx, transfer_id, mem, chunk_size, dest_addr)?;
591        Ok(Ok(()))
592    }
593}
594
595#[async_trait]
596impl Handler<RegisterTransferLocal> for TcpManagerActor {
597    async fn handle(
598        &mut self,
599        cx: &Context<Self>,
600        message: RegisterTransferLocal,
601    ) -> Result<(), anyhow::Error> {
602        let transfer_id =
603            self.register_transfer(message.local_memory, message.total_chunks, message.done);
604        message.reply.post(cx, transfer_id);
605        Ok(())
606    }
607}
608
609#[async_trait]
610impl Handler<ExecuteTransferLocal> for TcpManagerActor {
611    async fn handle(
612        &mut self,
613        cx: &Context<Self>,
614        message: ExecuteTransferLocal,
615    ) -> Result<(), anyhow::Error> {
616        self.execute_transfer(
617            cx,
618            message.transfer_id,
619            message.local_memory,
620            message.chunk_size,
621            message.dest_addr,
622        )
623    }
624}
625
626#[async_trait]
627impl Handler<SendTransferResult> for TcpManagerActor {
628    async fn handle(
629        &mut self,
630        cx: &Context<Self>,
631        message: SendTransferResult,
632    ) -> Result<(), anyhow::Error> {
633        message.done.post(cx, message.result);
634        Ok(())
635    }
636}
637
638#[async_trait]
639impl Handler<TransferError> for TcpManagerActor {
640    async fn handle(
641        &mut self,
642        _cx: &Context<Self>,
643        message: TransferError,
644    ) -> Result<(), anyhow::Error> {
645        tracing::error!("fatal transfer error: {}", message.message);
646        Err(anyhow::anyhow!(message.message))
647    }
648}
649
650/// Wrapper around [`ActorHandle<TcpManagerActor>`] that moves the TCP
651/// data-plane (chunked reads/writes) off the actor loop while keeping
652/// buffer resolution serialized through actor messages.
653///
654/// Because submit logic now runs outside the actor loop, same-process
655/// messages no longer deadlock — the actor loop is free to handle
656/// `WriteChunk`/`ReadChunk` messages.
657#[derive(Debug, Clone)]
658pub struct TcpBackend(pub ActorHandle<TcpManagerActor>);
659
660impl std::ops::Deref for TcpBackend {
661    type Target = ActorHandle<TcpManagerActor>;
662    fn deref(&self) -> &Self::Target {
663        &self.0
664    }
665}
666
667impl TcpBackend {
668    /// Execute a parallel write: register the transfer on the remote
669    /// side, then execute locally to push chunks over direct channels.
670    async fn execute_parallel_write(
671        &self,
672        cx: &(impl context::Actor + Send + Sync),
673        op: &TcpOp,
674        chunk_size: usize,
675        deadline: Instant,
676    ) -> Result<()> {
677        let size = op.local_memory.size();
678        let total_chunks = size.div_ceil(chunk_size);
679
680        let (done_handle, done_rx) = hyperactor::mailbox::open_once_port::<Result<(), String>>(cx);
681        let done_ref = done_handle.bind();
682
683        let remaining = deadline.saturating_duration_since(Instant::now());
684        let transfer_id = tokio_timeout(
685            remaining,
686            op.remote_tcp_manager.register_transfer_remote(
687                cx,
688                op.remote_buf_id,
689                total_chunks,
690                done_ref,
691            ),
692        )
693        .await
694        .map_err(|_| anyhow::anyhow!("register_transfer_remote timed out"))??
695        .map_err(|e| anyhow::anyhow!(e))?;
696
697        let dest_addr = tokio_timeout(
698            deadline.saturating_duration_since(Instant::now()),
699            op.remote_tcp_manager.get_channel_address(cx),
700        )
701        .await
702        .map_err(|_| anyhow::anyhow!("get_channel_address timed out"))??
703        .ok_or_else(|| anyhow::anyhow!("remote does not have parallel channels enabled"))?;
704
705        self.0.post(
706            cx,
707            ExecuteTransferLocal {
708                transfer_id,
709                local_memory: op.local_memory.clone(),
710                chunk_size,
711                dest_addr,
712            },
713        );
714
715        let remaining = deadline.saturating_duration_since(Instant::now());
716        let result = tokio_timeout(remaining, done_rx.recv())
717            .await
718            .map_err(|_| anyhow::anyhow!("parallel write timed out"))?
719            .map_err(|e| anyhow::anyhow!(e))?;
720        result.map_err(|e| anyhow::anyhow!(e))
721    }
722
723    /// Execute a parallel read: register the transfer locally, then
724    /// ask the remote side to push chunks to our channel.
725    async fn execute_parallel_read(
726        &self,
727        cx: &(impl context::Actor + Send + Sync),
728        op: &TcpOp,
729        chunk_size: usize,
730        deadline: Instant,
731    ) -> Result<()> {
732        let size = op.local_memory.size();
733        let total_chunks = size.div_ceil(chunk_size);
734
735        let (done_handle, done_rx) = hyperactor::mailbox::open_once_port::<Result<(), String>>(cx);
736        let done_ref = done_handle.bind();
737
738        let (id_handle, id_rx) = hyperactor::mailbox::open_once_port::<usize>(cx);
739
740        self.0.post(
741            cx,
742            RegisterTransferLocal {
743                local_memory: op.local_memory.clone(),
744                total_chunks,
745                done: done_ref,
746                reply: id_handle,
747            },
748        );
749
750        let transfer_id = id_rx
751            .recv()
752            .await
753            .map_err(|e| anyhow::anyhow!("failed to receive transfer id: {e}"))?;
754
755        let my_channel_addr = self
756            .0
757            .get_channel_address(cx)
758            .await?
759            .ok_or_else(|| anyhow::anyhow!("local parallel channels not enabled"))?;
760
761        let remaining = deadline.saturating_duration_since(Instant::now());
762        tokio_timeout(
763            remaining,
764            op.remote_tcp_manager.execute_transfer_remote(
765                cx,
766                transfer_id,
767                op.remote_buf_id,
768                chunk_size,
769                my_channel_addr,
770            ),
771        )
772        .await
773        .map_err(|_| anyhow::anyhow!("execute_transfer_remote timed out"))??
774        .map_err(|e| anyhow::anyhow!(e))?;
775
776        let remaining = deadline.saturating_duration_since(Instant::now());
777        let result = tokio_timeout(remaining, done_rx.recv())
778            .await
779            .map_err(|_| anyhow::anyhow!("parallel read timed out"))?
780            .map_err(|e| anyhow::anyhow!(e))?;
781        result.map_err(|e| anyhow::anyhow!(e))
782    }
783
784    /// Execute a write operation: read local memory in chunks and write
785    /// them into the remote buffer via actor messages.
786    async fn execute_write(
787        &self,
788        cx: &(impl context::Actor + Send + Sync),
789        op: &TcpOp,
790        chunk_size: usize,
791        deadline: Instant,
792    ) -> Result<()> {
793        let size = op.local_memory.size();
794        let mut offset = 0;
795
796        while offset < size {
797            let remaining = deadline.saturating_duration_since(Instant::now());
798            if remaining.is_zero() {
799                anyhow::bail!("tcp write timed out");
800            }
801
802            let len = std::cmp::min(chunk_size, size - offset);
803
804            let mut buf = vec![0u8; len];
805            // SAFETY: `op.local_memory` is the caller's buffer; that
806            // caller is responsible for excluding external writers
807            // while the `write_from_local` operation is in flight.
808            unsafe { op.local_memory.read_at(offset, &mut buf) }?;
809            let data = Part::from(Bytes::from(buf));
810
811            tokio_timeout(
812                remaining,
813                op.remote_tcp_manager
814                    .write_chunk(cx, op.remote_buf_id, offset, data),
815            )
816            .await
817            .map_err(|_| anyhow::anyhow!("tcp write chunk timed out"))??
818            .map_err(|e| anyhow::anyhow!(e))?;
819
820            offset += len;
821        }
822
823        Ok(())
824    }
825
826    /// Execute a read operation: request chunks from the remote buffer
827    /// and write them into local memory via actor messages.
828    async fn execute_read(
829        &self,
830        cx: &(impl context::Actor + Send + Sync),
831        op: &TcpOp,
832        chunk_size: usize,
833        deadline: Instant,
834    ) -> Result<()> {
835        let size = op.local_memory.size();
836        let mut offset = 0;
837
838        while offset < size {
839            let remaining = deadline.saturating_duration_since(Instant::now());
840            if remaining.is_zero() {
841                anyhow::bail!("tcp read timed out");
842            }
843
844            let len = std::cmp::min(chunk_size, size - offset);
845
846            let chunk = tokio_timeout(
847                remaining,
848                op.remote_tcp_manager
849                    .read_chunk(cx, op.remote_buf_id, offset, len),
850            )
851            .await
852            .map_err(|_| anyhow::anyhow!("tcp read chunk timed out"))??
853            .map_err(|e| anyhow::anyhow!(e))?;
854            let data = chunk.0.into_bytes();
855
856            anyhow::ensure!(
857                data.len() == len,
858                "tcp read chunk size mismatch: expected {len}, got {}",
859                data.len()
860            );
861
862            // SAFETY: `op.local_memory` is the caller's buffer; that
863            // caller is responsible for excluding external readers
864            // and writers while the `read_into_local` operation is in
865            // flight.
866            unsafe { op.local_memory.write_at(offset, &data) }?;
867
868            offset += len;
869        }
870
871        Ok(())
872    }
873}
874
875#[async_trait]
876impl RdmaBackend for TcpBackend {
877    type TransportInfo = ();
878
879    /// Submit a batch of RDMA operations over TCP.
880    ///
881    /// Each operation's remote buffer is resolved to its TCP backend
882    /// context, then executed directly — sending chunked write/read
883    /// messages to the remote [`TcpManagerActor`].
884    async fn submit(
885        &mut self,
886        cx: &(impl context::Actor + Send + Sync),
887        ops: Vec<RdmaOp>,
888        timeout: Duration,
889    ) -> Result<()> {
890        let chunk_size =
891            hyperactor_config::global::get(crate::config::RDMA_MAX_CHUNK_SIZE_MB) * 1024 * 1024;
892        let parallelism =
893            hyperactor_config::global::get(crate::config::RDMA_TCP_FALLBACK_PARALLELISM);
894        let deadline = Instant::now() + timeout;
895
896        for op in ops {
897            let remaining = deadline.saturating_duration_since(Instant::now());
898            if remaining.is_zero() {
899                anyhow::bail!("tcp submit timed out");
900            }
901
902            let (remote_tcp_mgr, remote_buf_id) = op.remote.resolve_tcp()?;
903            let tcp_op = TcpOp {
904                op_type: op.op_type.clone(),
905                local_memory: op.local,
906                remote_tcp_manager: remote_tcp_mgr,
907                remote_buf_id,
908            };
909
910            if parallelism > 1 {
911                match tcp_op.op_type {
912                    RdmaOpType::WriteFromLocal => {
913                        self.execute_parallel_write(cx, &tcp_op, chunk_size, deadline)
914                            .await?;
915                    }
916                    RdmaOpType::ReadIntoLocal => {
917                        self.execute_parallel_read(cx, &tcp_op, chunk_size, deadline)
918                            .await?;
919                    }
920                }
921            } else {
922                match tcp_op.op_type {
923                    RdmaOpType::WriteFromLocal => {
924                        self.execute_write(cx, &tcp_op, chunk_size, deadline)
925                            .await?;
926                    }
927                    RdmaOpType::ReadIntoLocal => {
928                        self.execute_read(cx, &tcp_op, chunk_size, deadline).await?;
929                    }
930                }
931            }
932        }
933
934        Ok(())
935    }
936
937    fn transport_level(&self) -> RdmaTransportLevel {
938        RdmaTransportLevel::Tcp
939    }
940
941    fn transport_info(&self) -> Option<Self::TransportInfo> {
942        None
943    }
944}
945
946#[cfg(test)]
947mod tests {
948    use std::sync::Arc;
949    use std::sync::atomic::AtomicUsize;
950    use std::sync::atomic::Ordering;
951    use std::time::Duration;
952
953    use hyperactor::ActorHandle;
954    use hyperactor::Proc;
955    use hyperactor::RemoteSpawn;
956    use hyperactor::channel::ChannelAddr;
957    use hyperactor_config::Flattrs;
958
959    use super::TcpBackend;
960    use super::TcpManagerActor;
961    use crate::RdmaManagerMessageClient;
962    use crate::RdmaOp;
963    use crate::RdmaOpType;
964    use crate::backend::RdmaBackend;
965    use crate::local_memory::KeepaliveLocalMemory;
966    use crate::rdma_manager_actor::GetTcpActorRefClient;
967    use crate::rdma_manager_actor::RdmaManagerActor;
968
969    static COUNTER: AtomicUsize = AtomicUsize::new(0);
970
971    struct TcpTestProcEnv {
972        proc: Proc,
973        rdma_handle: ActorHandle<RdmaManagerActor>,
974        instance: hyperactor::Instance<()>,
975        tcp_backend: TcpBackend,
976        rdma_remote_buf: crate::RdmaRemoteBuffer,
977        local_memory: Arc<KeepaliveLocalMemory>,
978    }
979
980    impl Drop for TcpTestProcEnv {
981        fn drop(&mut self) {
982            use crate::rdma_manager_actor::ReleaseBufferClient;
983            // Release the buffer so the actor drops its local_memory
984            // clone while the CUDA runtime is still alive.
985            tokio::task::block_in_place(|| {
986                tokio::runtime::Handle::current()
987                    .block_on(
988                        self.rdma_remote_buf
989                            .owner
990                            .release_buffer(&self.instance, self.rdma_remote_buf.id),
991                    )
992                    .expect("failed to release buffer in TcpTestProcEnv drop");
993            });
994        }
995    }
996
997    impl TcpTestProcEnv {
998        /// Create a standalone test environment with its own proc and rdma manager.
999        async fn new(buffer_size: usize) -> anyhow::Result<Self> {
1000            let id = COUNTER.fetch_add(1, Ordering::Relaxed);
1001            let proc = Proc::direct(
1002                ChannelAddr::any(hyperactor::channel::ChannelTransport::Unix),
1003                format!("tcp_test_{id}"),
1004            )?;
1005            let (instance, _) = proc.client("client")?;
1006
1007            let rdma_actor = RdmaManagerActor::new(None, Flattrs::default()).await?;
1008            let rdma_handle = proc.spawn("rdma_manager", rdma_actor)?;
1009
1010            let tcp_ref = rdma_handle.get_tcp_actor_ref(&instance).await?;
1011            let tcp_backend = TcpBackend(
1012                tcp_ref
1013                    .downcast_handle(&instance)
1014                    .ok_or_else(|| anyhow::anyhow!("tcp actor not local"))?,
1015            );
1016
1017            let (local_memory, rdma_remote_buf) =
1018                Self::alloc_cpu_buffer(&instance, &rdma_handle, buffer_size).await?;
1019
1020            Ok(Self {
1021                proc,
1022                rdma_handle,
1023                instance,
1024                tcp_backend,
1025                rdma_remote_buf,
1026                local_memory,
1027            })
1028        }
1029
1030        /// Create a buffer on an existing proc's rdma manager.
1031        async fn on_proc(
1032            proc: &Proc,
1033            rdma_handle: &ActorHandle<RdmaManagerActor>,
1034            tcp_backend: TcpBackend,
1035            buffer_size: usize,
1036        ) -> anyhow::Result<Self> {
1037            let id = COUNTER.fetch_add(1, Ordering::Relaxed);
1038            let (instance, _) = proc.client(&format!("client_{id}"))?;
1039
1040            let (local_memory, rdma_remote_buf) =
1041                Self::alloc_cpu_buffer(&instance, rdma_handle, buffer_size).await?;
1042
1043            Ok(Self {
1044                proc: proc.clone(),
1045                rdma_handle: rdma_handle.clone(),
1046                instance,
1047                tcp_backend,
1048                rdma_remote_buf,
1049                local_memory,
1050            })
1051        }
1052
1053        async fn alloc_cpu_buffer(
1054            instance: &hyperactor::Instance<()>,
1055            rdma_handle: &ActorHandle<RdmaManagerActor>,
1056            buffer_size: usize,
1057        ) -> anyhow::Result<(Arc<KeepaliveLocalMemory>, crate::RdmaRemoteBuffer)> {
1058            let cpu_buf = vec![0u8; buffer_size].into_boxed_slice();
1059            let ptr = cpu_buf.as_ptr() as usize;
1060            let local_memory: Arc<KeepaliveLocalMemory> = Arc::new(KeepaliveLocalMemory::new(
1061                ptr,
1062                buffer_size,
1063                Arc::new(cpu_buf),
1064            ));
1065            let rdma_remote_buf = rdma_handle
1066                .request_buffer(instance, local_memory.clone())
1067                .await?;
1068            Ok((local_memory, rdma_remote_buf))
1069        }
1070    }
1071
1072    /// Two separate procs, one buffer each.
1073    async fn setup_tcp_env(buf_size: usize) -> anyhow::Result<Vec<TcpTestProcEnv>> {
1074        Ok(vec![
1075            TcpTestProcEnv::new(buf_size).await?,
1076            TcpTestProcEnv::new(buf_size).await?,
1077        ])
1078    }
1079
1080    /// Single proc, two buffers.
1081    async fn setup_same_proc_tcp_env(buf_size: usize) -> anyhow::Result<Vec<TcpTestProcEnv>> {
1082        let first = TcpTestProcEnv::new(buf_size).await?;
1083        let second = TcpTestProcEnv::on_proc(
1084            &first.proc,
1085            &first.rdma_handle,
1086            first.tcp_backend.clone(),
1087            buf_size,
1088        )
1089        .await?;
1090        Ok(vec![first, second])
1091    }
1092
1093    /// Two procs, two buffers each (4 total). For concurrent tests that
1094    /// need independent source/dest pairs.
1095    async fn setup_tcp_env_pairs(buf_size: usize) -> anyhow::Result<Vec<TcpTestProcEnv>> {
1096        let e0 = TcpTestProcEnv::new(buf_size).await?;
1097        let e1 = TcpTestProcEnv::new(buf_size).await?;
1098        let e2 =
1099            TcpTestProcEnv::on_proc(&e0.proc, &e0.rdma_handle, e0.tcp_backend.clone(), buf_size)
1100                .await?;
1101        let e3 =
1102            TcpTestProcEnv::on_proc(&e1.proc, &e1.rdma_handle, e1.tcp_backend.clone(), buf_size)
1103                .await?;
1104        Ok(vec![e0, e1, e2, e3])
1105    }
1106
1107    // --- Shared test helpers ---
1108
1109    /// Test-only wrapper around [`KeepaliveLocalMemory::write_at`].
1110    ///
1111    /// Every [`TcpTestProcEnv`] owns a distinct CPU buffer that no
1112    /// other thread accesses outside of explicit, serialized test
1113    /// operations, so the safety obligation of `write_at` is trivially
1114    /// satisfied across the whole module.
1115    fn test_write(mem: &KeepaliveLocalMemory, offset: usize, src: &[u8]) -> anyhow::Result<()> {
1116        // SAFETY: see the function-level comment.
1117        unsafe { mem.write_at(offset, src) }
1118    }
1119
1120    /// Test-only wrapper around [`KeepaliveLocalMemory::read_at`]. See
1121    /// [`test_write`] for the safety rationale.
1122    fn test_read(mem: &KeepaliveLocalMemory, offset: usize, dst: &mut [u8]) -> anyhow::Result<()> {
1123        // SAFETY: see the function-level comment.
1124        unsafe { mem.read_at(offset, dst) }
1125    }
1126
1127    /// Fill envs[0], write to envs[1], verify.
1128    async fn do_write_test(
1129        envs: &mut [TcpTestProcEnv],
1130        buf_size: usize,
1131        timeout: Duration,
1132    ) -> anyhow::Result<()> {
1133        let mut src = vec![0u8; buf_size];
1134        for (i, byte) in src.iter_mut().enumerate() {
1135            *byte = (i % 256) as u8;
1136        }
1137        test_write(&envs[0].local_memory, 0, &src)?;
1138
1139        let remote = envs[1].rdma_remote_buf.clone();
1140        let env = &mut envs[0];
1141        env.tcp_backend
1142            .submit(
1143                &env.instance,
1144                vec![RdmaOp {
1145                    op_type: RdmaOpType::WriteFromLocal,
1146                    local: env.local_memory.clone(),
1147                    remote,
1148                }],
1149                timeout,
1150            )
1151            .await?;
1152
1153        let mut dst = vec![0u8; buf_size];
1154        test_read(&envs[1].local_memory, 0, &mut dst)?;
1155        for (i, byte) in dst.iter().enumerate() {
1156            assert_eq!(*byte, (i % 256) as u8, "mismatch at offset {i} after write");
1157        }
1158        Ok(())
1159    }
1160
1161    /// Fill envs[1], read into envs[0], verify.
1162    async fn do_read_test(
1163        envs: &mut [TcpTestProcEnv],
1164        buf_size: usize,
1165        timeout: Duration,
1166    ) -> anyhow::Result<()> {
1167        let mut src = vec![0u8; buf_size];
1168        for (i, byte) in src.iter_mut().enumerate() {
1169            *byte = ((i * 7 + 3) % 256) as u8;
1170        }
1171        test_write(&envs[1].local_memory, 0, &src)?;
1172
1173        let remote = envs[1].rdma_remote_buf.clone();
1174        let env = &mut envs[0];
1175        env.tcp_backend
1176            .submit(
1177                &env.instance,
1178                vec![RdmaOp {
1179                    op_type: RdmaOpType::ReadIntoLocal,
1180                    local: env.local_memory.clone(),
1181                    remote,
1182                }],
1183                timeout,
1184            )
1185            .await?;
1186
1187        let mut dst = vec![0u8; buf_size];
1188        test_read(&envs[0].local_memory, 0, &mut dst)?;
1189        for (i, byte) in dst.iter().enumerate() {
1190            assert_eq!(
1191                *byte,
1192                ((i * 7 + 3) % 256) as u8,
1193                "mismatch at offset {i} after read"
1194            );
1195        }
1196        Ok(())
1197    }
1198
1199    /// Write, clear, read-back, verify round-trip.
1200    async fn do_round_trip_test(
1201        envs: &mut [TcpTestProcEnv],
1202        buf_size: usize,
1203        timeout: Duration,
1204    ) -> anyhow::Result<()> {
1205        let mut src = vec![0u8; buf_size];
1206        for (i, byte) in src.iter_mut().enumerate() {
1207            *byte = ((i * 13 + 5) % 256) as u8;
1208        }
1209        test_write(&envs[0].local_memory, 0, &src)?;
1210
1211        let remote = envs[1].rdma_remote_buf.clone();
1212        let env = &mut envs[0];
1213        env.tcp_backend
1214            .submit(
1215                &env.instance,
1216                vec![RdmaOp {
1217                    op_type: RdmaOpType::WriteFromLocal,
1218                    local: env.local_memory.clone(),
1219                    remote: remote.clone(),
1220                }],
1221                timeout,
1222            )
1223            .await?;
1224
1225        test_write(&envs[0].local_memory, 0, &vec![0u8; buf_size])?;
1226
1227        let env = &mut envs[0];
1228        env.tcp_backend
1229            .submit(
1230                &env.instance,
1231                vec![RdmaOp {
1232                    op_type: RdmaOpType::ReadIntoLocal,
1233                    local: env.local_memory.clone(),
1234                    remote,
1235                }],
1236                timeout,
1237            )
1238            .await?;
1239
1240        let mut dst = vec![0u8; buf_size];
1241        test_read(&envs[0].local_memory, 0, &mut dst)?;
1242        for (i, byte) in dst.iter().enumerate() {
1243            assert_eq!(
1244                *byte,
1245                ((i * 13 + 5) % 256) as u8,
1246                "mismatch at offset {i} after round-trip"
1247            );
1248        }
1249        Ok(())
1250    }
1251
1252    // --- Non-parallel two-proc tests ---
1253
1254    /// Write from local buffer 0 into remote buffer 1.
1255    #[timed_test::async_timed_test(timeout_secs = 30)]
1256    async fn test_tcp_write_from_local() -> anyhow::Result<()> {
1257        let config = hyperactor_config::global::lock();
1258        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1259
1260        let mut envs = setup_tcp_env(4096).await?;
1261        do_write_test(&mut envs, 4096, Duration::from_secs(30)).await
1262    }
1263
1264    /// Read from remote buffer 1 into local buffer 0.
1265    #[timed_test::async_timed_test(timeout_secs = 30)]
1266    async fn test_tcp_read_into_local() -> anyhow::Result<()> {
1267        let config = hyperactor_config::global::lock();
1268        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1269
1270        let mut envs = setup_tcp_env(2048).await?;
1271        do_read_test(&mut envs, 2048, Duration::from_secs(30)).await
1272    }
1273
1274    /// Write, clear, read-back, verify round-trip.
1275    #[timed_test::async_timed_test(timeout_secs = 30)]
1276    async fn test_tcp_write_then_read_back() -> anyhow::Result<()> {
1277        let config = hyperactor_config::global::lock();
1278        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1279
1280        let mut envs = setup_tcp_env(4096).await?;
1281        do_round_trip_test(&mut envs, 4096, Duration::from_secs(30)).await
1282    }
1283
1284    /// Multi-chunk write (1 MiB chunks, 1.5 MiB buffer).
1285    #[timed_test::async_timed_test(timeout_secs = 30)]
1286    async fn test_tcp_multi_chunk_write() -> anyhow::Result<()> {
1287        let config = hyperactor_config::global::lock();
1288        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1289        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1290
1291        let buf_size = 3 * 1024 * 512;
1292        let mut envs = setup_tcp_env(buf_size).await?;
1293
1294        let mut src = vec![0u8; buf_size];
1295        for (i, byte) in src.iter_mut().enumerate() {
1296            *byte = (i % 251) as u8;
1297        }
1298        test_write(&envs[0].local_memory, 0, &src)?;
1299
1300        let remote = envs[1].rdma_remote_buf.clone();
1301        let env = &mut envs[0];
1302        env.tcp_backend
1303            .submit(
1304                &env.instance,
1305                vec![RdmaOp {
1306                    op_type: RdmaOpType::WriteFromLocal,
1307                    local: env.local_memory.clone(),
1308                    remote,
1309                }],
1310                Duration::from_secs(30),
1311            )
1312            .await?;
1313
1314        let mut dst = vec![0u8; buf_size];
1315        test_read(&envs[1].local_memory, 0, &mut dst)?;
1316        for (i, byte) in dst.iter().enumerate() {
1317            assert_eq!(*byte, (i % 251) as u8, "mismatch at offset {i}");
1318        }
1319
1320        Ok(())
1321    }
1322
1323    /// Multi-chunk read.
1324    #[timed_test::async_timed_test(timeout_secs = 30)]
1325    async fn test_tcp_multi_chunk_read() -> anyhow::Result<()> {
1326        let config = hyperactor_config::global::lock();
1327        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1328        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1329
1330        let buf_size = 3 * 1024 * 512;
1331        let mut envs = setup_tcp_env(buf_size).await?;
1332
1333        let mut src = vec![0u8; buf_size];
1334        for (i, byte) in src.iter_mut().enumerate() {
1335            *byte = ((i * 3 + 17) % 256) as u8;
1336        }
1337        test_write(&envs[1].local_memory, 0, &src)?;
1338
1339        let remote = envs[1].rdma_remote_buf.clone();
1340        let env = &mut envs[0];
1341        env.tcp_backend
1342            .submit(
1343                &env.instance,
1344                vec![RdmaOp {
1345                    op_type: RdmaOpType::ReadIntoLocal,
1346                    local: env.local_memory.clone(),
1347                    remote,
1348                }],
1349                Duration::from_secs(30),
1350            )
1351            .await?;
1352
1353        let mut dst = vec![0u8; buf_size];
1354        test_read(&envs[0].local_memory, 0, &mut dst)?;
1355        for (i, byte) in dst.iter().enumerate() {
1356            assert_eq!(*byte, ((i * 3 + 17) % 256) as u8, "mismatch at offset {i}");
1357        }
1358
1359        Ok(())
1360    }
1361
1362    /// Multi-chunk write-then-read round-trip.
1363    #[timed_test::async_timed_test(timeout_secs = 30)]
1364    async fn test_tcp_multi_chunk_round_trip() -> anyhow::Result<()> {
1365        let config = hyperactor_config::global::lock();
1366        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1367        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1368
1369        let buf_size = 5 * 1024 * 512;
1370        let mut envs = setup_tcp_env(buf_size).await?;
1371
1372        let mut src = vec![0u8; buf_size];
1373        for (i, byte) in src.iter_mut().enumerate() {
1374            *byte = ((i * 41 + 7) % 256) as u8;
1375        }
1376        test_write(&envs[0].local_memory, 0, &src)?;
1377
1378        let remote = envs[1].rdma_remote_buf.clone();
1379        let env = &mut envs[0];
1380        env.tcp_backend
1381            .submit(
1382                &env.instance,
1383                vec![RdmaOp {
1384                    op_type: RdmaOpType::WriteFromLocal,
1385                    local: env.local_memory.clone(),
1386                    remote: remote.clone(),
1387                }],
1388                Duration::from_secs(30),
1389            )
1390            .await?;
1391
1392        test_write(&envs[0].local_memory, 0, &vec![0u8; buf_size])?;
1393
1394        let env = &mut envs[0];
1395        env.tcp_backend
1396            .submit(
1397                &env.instance,
1398                vec![RdmaOp {
1399                    op_type: RdmaOpType::ReadIntoLocal,
1400                    local: env.local_memory.clone(),
1401                    remote,
1402                }],
1403                Duration::from_secs(30),
1404            )
1405            .await?;
1406
1407        let mut dst = vec![0u8; buf_size];
1408        test_read(&envs[0].local_memory, 0, &mut dst)?;
1409        for (i, byte) in dst.iter().enumerate() {
1410            assert_eq!(
1411                *byte,
1412                ((i * 41 + 7) % 256) as u8,
1413                "mismatch at offset {i} after multi-chunk round-trip"
1414            );
1415        }
1416
1417        Ok(())
1418    }
1419
1420    /// resolve_tcp finds the Tcp backend context in a buffer.
1421    #[timed_test::async_timed_test(timeout_secs = 30)]
1422    async fn test_tcp_resolve_tcp() -> anyhow::Result<()> {
1423        let config = hyperactor_config::global::lock();
1424        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1425
1426        let envs = setup_tcp_env(64).await?;
1427
1428        for (i, env) in envs.iter().enumerate() {
1429            let (tcp_ref, id) = env.rdma_remote_buf.resolve_tcp()?;
1430            assert_eq!(id, env.rdma_remote_buf.id, "buf id mismatch for env {i}");
1431            let expected: hyperactor::ActorRef<TcpManagerActor> = env.tcp_backend.bind();
1432            assert_eq!(tcp_ref.actor_addr(), expected.actor_addr());
1433        }
1434
1435        Ok(())
1436    }
1437
1438    /// Write to a released buffer returns an error without crashing.
1439    #[timed_test::async_timed_test(timeout_secs = 30)]
1440    async fn test_tcp_write_to_released_buffer() -> anyhow::Result<()> {
1441        let config = hyperactor_config::global::lock();
1442        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1443
1444        let buf_size = 64;
1445        let mut envs = setup_tcp_env(buf_size).await?;
1446
1447        let mut src = vec![0u8; buf_size];
1448        for (i, byte) in src.iter_mut().enumerate() {
1449            *byte = (i % 256) as u8;
1450        }
1451        test_write(&envs[0].local_memory, 0, &src)?;
1452
1453        // Normal write should succeed.
1454        let remote = envs[1].rdma_remote_buf.clone();
1455        let env = &mut envs[0];
1456        env.tcp_backend
1457            .submit(
1458                &env.instance,
1459                vec![RdmaOp {
1460                    op_type: RdmaOpType::WriteFromLocal,
1461                    local: env.local_memory.clone(),
1462                    remote: remote.clone(),
1463                }],
1464                Duration::from_secs(10),
1465            )
1466            .await?;
1467
1468        // Release the remote buffer.
1469        use crate::rdma_manager_actor::ReleaseBufferClient;
1470        let owner_ref = envs[1].rdma_remote_buf.owner.clone();
1471        owner_ref
1472            .release_buffer(&envs[0].instance, envs[1].rdma_remote_buf.id)
1473            .await?;
1474
1475        // Writing to the released buffer should fail.
1476        let env = &mut envs[0];
1477        let result = env
1478            .tcp_backend
1479            .submit(
1480                &env.instance,
1481                vec![RdmaOp {
1482                    op_type: RdmaOpType::WriteFromLocal,
1483                    local: env.local_memory.clone(),
1484                    remote: remote.clone(),
1485                }],
1486                Duration::from_secs(10),
1487            )
1488            .await;
1489        assert!(result.is_err(), "expected error writing to released buffer");
1490
1491        Ok(())
1492    }
1493
1494    /// Read from a released buffer returns an error without crashing.
1495    #[timed_test::async_timed_test(timeout_secs = 30)]
1496    async fn test_tcp_read_from_released_buffer() -> anyhow::Result<()> {
1497        let config = hyperactor_config::global::lock();
1498        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1499
1500        let buf_size = 64;
1501        let mut envs = setup_tcp_env(buf_size).await?;
1502
1503        // Release the remote buffer.
1504        use crate::rdma_manager_actor::ReleaseBufferClient;
1505        let owner_ref = envs[1].rdma_remote_buf.owner.clone();
1506        owner_ref
1507            .release_buffer(&envs[0].instance, envs[1].rdma_remote_buf.id)
1508            .await?;
1509
1510        // Reading from the released buffer should fail.
1511        let remote = envs[1].rdma_remote_buf.clone();
1512        let env = &mut envs[0];
1513        let result = env
1514            .tcp_backend
1515            .submit(
1516                &env.instance,
1517                vec![RdmaOp {
1518                    op_type: RdmaOpType::ReadIntoLocal,
1519                    local: env.local_memory.clone(),
1520                    remote,
1521                }],
1522                Duration::from_secs(10),
1523            )
1524            .await;
1525        assert!(
1526            result.is_err(),
1527            "expected error reading from released buffer"
1528        );
1529
1530        Ok(())
1531    }
1532
1533    // --- Non-parallel same-proc tests ---
1534
1535    /// Same-process write.
1536    #[timed_test::async_timed_test(timeout_secs = 30)]
1537    async fn test_tcp_same_process_write() -> anyhow::Result<()> {
1538        let config = hyperactor_config::global::lock();
1539        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1540
1541        let mut envs = setup_same_proc_tcp_env(4096).await?;
1542        do_write_test(&mut envs, 4096, Duration::from_secs(10)).await
1543    }
1544
1545    /// Same-process read.
1546    #[timed_test::async_timed_test(timeout_secs = 30)]
1547    async fn test_tcp_same_process_read() -> anyhow::Result<()> {
1548        let config = hyperactor_config::global::lock();
1549        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1550
1551        let mut envs = setup_same_proc_tcp_env(2048).await?;
1552        do_read_test(&mut envs, 2048, Duration::from_secs(10)).await
1553    }
1554
1555    /// Same-process write-then-read round-trip.
1556    #[timed_test::async_timed_test(timeout_secs = 30)]
1557    async fn test_tcp_same_process_round_trip() -> anyhow::Result<()> {
1558        let config = hyperactor_config::global::lock();
1559        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1560
1561        let mut envs = setup_same_proc_tcp_env(4096).await?;
1562        do_round_trip_test(&mut envs, 4096, Duration::from_secs(10)).await
1563    }
1564
1565    /// When TCP fallback is disabled and ibverbs is unavailable,
1566    /// RdmaManagerActor::new returns an error.
1567    #[timed_test::async_timed_test(timeout_secs = 30)]
1568    async fn test_tcp_fallback_disabled_fails() -> anyhow::Result<()> {
1569        let config = hyperactor_config::global::lock();
1570        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, false);
1571
1572        let result = RdmaManagerActor::new(None, Flattrs::default()).await;
1573        if crate::ibverbs_supported() {
1574            assert!(result.is_ok());
1575        } else {
1576            assert!(result.is_err());
1577        }
1578
1579        Ok(())
1580    }
1581
1582    // --- Multi-GPU TCP fallback tests ---
1583
1584    use crate::backend::cuda_test_utils::CudaAllocator;
1585    use crate::backend::cuda_test_utils::cuda_device_count;
1586
1587    impl TcpTestProcEnv {
1588        /// Create a test environment backed by CUDA device memory.
1589        async fn new_gpu(device: i32, buffer_size: usize) -> anyhow::Result<Self> {
1590            let id = COUNTER.fetch_add(1, Ordering::Relaxed);
1591            let proc = Proc::direct(
1592                ChannelAddr::any(hyperactor::channel::ChannelTransport::Unix),
1593                format!("tcp_gpu_test_{id}"),
1594            )?;
1595            let (instance, _) = proc.client("client")?;
1596
1597            let rdma_actor = RdmaManagerActor::new(None, Flattrs::default()).await?;
1598            let rdma_handle = proc.spawn("rdma_manager", rdma_actor)?;
1599
1600            let tcp_ref = rdma_handle.get_tcp_actor_ref(&instance).await?;
1601            let tcp_backend = TcpBackend(
1602                tcp_ref
1603                    .downcast_handle(&instance)
1604                    .ok_or_else(|| anyhow::anyhow!("tcp actor not local"))?,
1605            );
1606
1607            let alloc = CudaAllocator::get().allocate(device, buffer_size, buffer_size);
1608            let local_memory: Arc<KeepaliveLocalMemory> = Arc::new(KeepaliveLocalMemory::new(
1609                alloc.ptr(),
1610                buffer_size,
1611                Arc::new(alloc),
1612            ));
1613            let rdma_remote_buf = rdma_handle
1614                .request_buffer(&instance, local_memory.clone())
1615                .await?;
1616
1617            Ok(Self {
1618                proc,
1619                rdma_handle,
1620                instance,
1621                tcp_backend,
1622                rdma_remote_buf,
1623                local_memory,
1624            })
1625        }
1626    }
1627
1628    /// TCP write from GPU on cuda:0 to GPU on cuda:1.
1629    #[timed_test::async_timed_test(timeout_secs = 60)]
1630    async fn test_tcp_write_multi_gpu() -> anyhow::Result<()> {
1631        if cuda_device_count() < 2 {
1632            println!("Skipping: need at least 2 CUDA devices");
1633            return Ok(());
1634        }
1635
1636        let config = hyperactor_config::global::lock();
1637        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1638
1639        let buf_size = 2 * 1024 * 1024;
1640        let mut envs = vec![
1641            TcpTestProcEnv::new_gpu(0, buf_size).await?,
1642            TcpTestProcEnv::new_gpu(1, buf_size).await?,
1643        ];
1644        do_write_test(&mut envs, buf_size, Duration::from_secs(30)).await
1645    }
1646
1647    /// TCP read from GPU on cuda:1 into GPU on cuda:0.
1648    #[timed_test::async_timed_test(timeout_secs = 60)]
1649    async fn test_tcp_read_multi_gpu() -> anyhow::Result<()> {
1650        if cuda_device_count() < 2 {
1651            println!("Skipping: need at least 2 CUDA devices");
1652            return Ok(());
1653        }
1654
1655        let config = hyperactor_config::global::lock();
1656        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1657
1658        let buf_size = 2 * 1024 * 1024;
1659        let mut envs = vec![
1660            TcpTestProcEnv::new_gpu(0, buf_size).await?,
1661            TcpTestProcEnv::new_gpu(1, buf_size).await?,
1662        ];
1663        do_read_test(&mut envs, buf_size, Duration::from_secs(30)).await
1664    }
1665
1666    /// TCP write-then-read round-trip between cuda:0 and cuda:1.
1667    #[timed_test::async_timed_test(timeout_secs = 60)]
1668    async fn test_tcp_round_trip_multi_gpu() -> anyhow::Result<()> {
1669        if cuda_device_count() < 2 {
1670            println!("Skipping: need at least 2 CUDA devices");
1671            return Ok(());
1672        }
1673
1674        let config = hyperactor_config::global::lock();
1675        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1676
1677        let buf_size = 2 * 1024 * 1024;
1678        let mut envs = vec![
1679            TcpTestProcEnv::new_gpu(0, buf_size).await?,
1680            TcpTestProcEnv::new_gpu(1, buf_size).await?,
1681        ];
1682        do_round_trip_test(&mut envs, buf_size, Duration::from_secs(30)).await
1683    }
1684
1685    /// Stopping the RdmaManagerActor with parallelism enabled cleanly
1686    /// shuts down the TcpManagerActor's receive loop without hanging.
1687    #[timed_test::async_timed_test(timeout_secs = 30)]
1688    async fn test_tcp_parallel_clean_shutdown() -> anyhow::Result<()> {
1689        let config = hyperactor_config::global::lock();
1690        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1691        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1692        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1693
1694        let buf_size = 3 * 1024 * 1024;
1695        let mut envs = setup_tcp_env(buf_size).await?;
1696
1697        // Do a transfer so the receive loop and outbound connections are live.
1698        let mut src = vec![0u8; buf_size];
1699        for (i, byte) in src.iter_mut().enumerate() {
1700            *byte = (i % 256) as u8;
1701        }
1702        test_write(&envs[0].local_memory, 0, &src)?;
1703        let remote = envs[1].rdma_remote_buf.clone();
1704        let env = &mut envs[0];
1705        env.tcp_backend
1706            .submit(
1707                &env.instance,
1708                vec![RdmaOp {
1709                    op_type: RdmaOpType::WriteFromLocal,
1710                    local: env.local_memory.clone(),
1711                    remote: remote.clone(),
1712                }],
1713                Duration::from_secs(30),
1714            )
1715            .await?;
1716
1717        // Stop the RdmaManagerActor, which cascades to TcpManagerActor.
1718        // The test timeout ensures we detect hangs in the cleanup path.
1719        envs[0].rdma_handle.drain_and_stop("test")?;
1720        envs[0].rdma_handle.clone().await;
1721        envs[1].rdma_handle.drain_and_stop("test")?;
1722        envs[1].rdma_handle.clone().await;
1723
1724        Ok(())
1725    }
1726
1727    // --- Parallel transfer tests ---
1728
1729    /// Parallel write via direct channels.
1730    #[timed_test::async_timed_test(timeout_secs = 30)]
1731    async fn test_tcp_parallel_write() -> anyhow::Result<()> {
1732        let config = hyperactor_config::global::lock();
1733        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1734        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1735        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1736
1737        // 3 MiB, 3 chunks spread across 2 workers.
1738        let buf_size = 3 * 1024 * 1024;
1739        let mut envs = setup_tcp_env(buf_size).await?;
1740        do_write_test(&mut envs, buf_size, Duration::from_secs(30)).await
1741    }
1742
1743    /// Parallel read via direct channels.
1744    #[timed_test::async_timed_test(timeout_secs = 30)]
1745    async fn test_tcp_parallel_read() -> anyhow::Result<()> {
1746        let config = hyperactor_config::global::lock();
1747        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1748        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1749        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1750
1751        let buf_size = 3 * 1024 * 1024;
1752        let mut envs = setup_tcp_env(buf_size).await?;
1753        do_read_test(&mut envs, buf_size, Duration::from_secs(30)).await
1754    }
1755
1756    /// Parallel write-then-read round-trip.
1757    #[timed_test::async_timed_test(timeout_secs = 30)]
1758    async fn test_tcp_parallel_round_trip() -> anyhow::Result<()> {
1759        let config = hyperactor_config::global::lock();
1760        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1761        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1762        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1763
1764        let buf_size = 3 * 1024 * 1024;
1765        let mut envs = setup_tcp_env(buf_size).await?;
1766        do_round_trip_test(&mut envs, buf_size, Duration::from_secs(30)).await
1767    }
1768
1769    /// Same-process parallel write.
1770    #[timed_test::async_timed_test(timeout_secs = 30)]
1771    async fn test_tcp_parallel_same_process_write() -> anyhow::Result<()> {
1772        let config = hyperactor_config::global::lock();
1773        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1774        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1775        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1776
1777        let buf_size = 3 * 1024 * 1024;
1778        let mut envs = setup_same_proc_tcp_env(buf_size).await?;
1779        do_write_test(&mut envs, buf_size, Duration::from_secs(10)).await
1780    }
1781
1782    /// Same-process parallel read.
1783    #[timed_test::async_timed_test(timeout_secs = 30)]
1784    async fn test_tcp_parallel_same_process_read() -> anyhow::Result<()> {
1785        let config = hyperactor_config::global::lock();
1786        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1787        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1788        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1789
1790        let buf_size = 3 * 1024 * 1024;
1791        let mut envs = setup_same_proc_tcp_env(buf_size).await?;
1792        do_read_test(&mut envs, buf_size, Duration::from_secs(10)).await
1793    }
1794
1795    // --- Concurrent parallel tests (4 envs, 2 independent pairs) ---
1796
1797    /// Two concurrent parallel writes to independent buffer pairs.
1798    #[timed_test::async_timed_test(timeout_secs = 30)]
1799    async fn test_tcp_parallel_concurrent_writes() -> anyhow::Result<()> {
1800        let config = hyperactor_config::global::lock();
1801        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1802        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1803        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1804
1805        let buf_size = 3 * 1024 * 1024;
1806        let envs = setup_tcp_env_pairs(buf_size).await?;
1807
1808        // Fill source buffers with distinct patterns.
1809        let mut src0 = vec![0u8; buf_size];
1810        for (i, byte) in src0.iter_mut().enumerate() {
1811            *byte = (i % 256) as u8;
1812        }
1813        test_write(&envs[0].local_memory, 0, &src0)?;
1814        let mut src2 = vec![0u8; buf_size];
1815        for (i, byte) in src2.iter_mut().enumerate() {
1816            *byte = ((i * 3 + 7) % 256) as u8;
1817        }
1818        test_write(&envs[2].local_memory, 0, &src2)?;
1819
1820        // Pair 1: envs[0] -> envs[1], Pair 2: envs[2] -> envs[3].
1821        let remote_1 = envs[1].rdma_remote_buf.clone();
1822        let remote_3 = envs[3].rdma_remote_buf.clone();
1823        let mut h0 = envs[0].tcp_backend.clone();
1824        let mut h2 = envs[2].tcp_backend.clone();
1825        let inst_0 = &envs[0].instance;
1826        let inst_2 = &envs[2].instance;
1827        let mem_0 = envs[0].local_memory.clone();
1828        let mem_2 = envs[2].local_memory.clone();
1829        let (r1, r2) = tokio::join!(
1830            h0.submit(
1831                inst_0,
1832                vec![RdmaOp {
1833                    op_type: RdmaOpType::WriteFromLocal,
1834                    local: mem_0,
1835                    remote: remote_1,
1836                }],
1837                Duration::from_secs(30),
1838            ),
1839            h2.submit(
1840                inst_2,
1841                vec![RdmaOp {
1842                    op_type: RdmaOpType::WriteFromLocal,
1843                    local: mem_2,
1844                    remote: remote_3,
1845                }],
1846                Duration::from_secs(30),
1847            ),
1848        );
1849        r1?;
1850        r2?;
1851
1852        let mut dst1 = vec![0u8; buf_size];
1853        test_read(&envs[1].local_memory, 0, &mut dst1)?;
1854        for (i, byte) in dst1.iter().enumerate() {
1855            assert_eq!(*byte, (i % 256) as u8, "pair 1 mismatch at offset {i}");
1856        }
1857        let mut dst3 = vec![0u8; buf_size];
1858        test_read(&envs[3].local_memory, 0, &mut dst3)?;
1859        for (i, byte) in dst3.iter().enumerate() {
1860            assert_eq!(
1861                *byte,
1862                ((i * 3 + 7) % 256) as u8,
1863                "pair 2 mismatch at offset {i}"
1864            );
1865        }
1866
1867        Ok(())
1868    }
1869
1870    /// Two concurrent parallel reads from independent buffer pairs.
1871    #[timed_test::async_timed_test(timeout_secs = 30)]
1872    async fn test_tcp_parallel_concurrent_reads() -> anyhow::Result<()> {
1873        let config = hyperactor_config::global::lock();
1874        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1875        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1876        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1877
1878        let buf_size = 3 * 1024 * 1024;
1879        let envs = setup_tcp_env_pairs(buf_size).await?;
1880
1881        // Fill remote buffers with distinct patterns.
1882        let mut src1 = vec![0u8; buf_size];
1883        for (i, byte) in src1.iter_mut().enumerate() {
1884            *byte = ((i * 11 + 3) % 256) as u8;
1885        }
1886        test_write(&envs[1].local_memory, 0, &src1)?;
1887        let mut src3 = vec![0u8; buf_size];
1888        for (i, byte) in src3.iter_mut().enumerate() {
1889            *byte = ((i * 5 + 13) % 256) as u8;
1890        }
1891        test_write(&envs[3].local_memory, 0, &src3)?;
1892
1893        // Pair 1: envs[0] <- envs[1], Pair 2: envs[2] <- envs[3].
1894        let remote_1 = envs[1].rdma_remote_buf.clone();
1895        let remote_3 = envs[3].rdma_remote_buf.clone();
1896        let mut h0 = envs[0].tcp_backend.clone();
1897        let mut h2 = envs[2].tcp_backend.clone();
1898        let inst_0 = &envs[0].instance;
1899        let inst_2 = &envs[2].instance;
1900        let mem_0 = envs[0].local_memory.clone();
1901        let mem_2 = envs[2].local_memory.clone();
1902        let (r1, r2) = tokio::join!(
1903            h0.submit(
1904                inst_0,
1905                vec![RdmaOp {
1906                    op_type: RdmaOpType::ReadIntoLocal,
1907                    local: mem_0,
1908                    remote: remote_1,
1909                }],
1910                Duration::from_secs(30),
1911            ),
1912            h2.submit(
1913                inst_2,
1914                vec![RdmaOp {
1915                    op_type: RdmaOpType::ReadIntoLocal,
1916                    local: mem_2,
1917                    remote: remote_3,
1918                }],
1919                Duration::from_secs(30),
1920            ),
1921        );
1922        r1?;
1923        r2?;
1924
1925        let mut dst0 = vec![0u8; buf_size];
1926        test_read(&envs[0].local_memory, 0, &mut dst0)?;
1927        for (i, byte) in dst0.iter().enumerate() {
1928            assert_eq!(
1929                *byte,
1930                ((i * 11 + 3) % 256) as u8,
1931                "pair 1 mismatch at offset {i}"
1932            );
1933        }
1934        let mut dst2 = vec![0u8; buf_size];
1935        test_read(&envs[2].local_memory, 0, &mut dst2)?;
1936        for (i, byte) in dst2.iter().enumerate() {
1937            assert_eq!(
1938                *byte,
1939                ((i * 5 + 13) % 256) as u8,
1940                "pair 2 mismatch at offset {i}"
1941            );
1942        }
1943
1944        Ok(())
1945    }
1946
1947    /// Concurrent parallel write and read on independent buffer pairs.
1948    #[timed_test::async_timed_test(timeout_secs = 30)]
1949    async fn test_tcp_parallel_concurrent_write_and_read() -> anyhow::Result<()> {
1950        let config = hyperactor_config::global::lock();
1951        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1952        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1953        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1954
1955        let buf_size = 3 * 1024 * 1024;
1956        let envs = setup_tcp_env_pairs(buf_size).await?;
1957
1958        // Fill source buffers.
1959        let mut src0 = vec![0u8; buf_size];
1960        for (i, byte) in src0.iter_mut().enumerate() {
1961            *byte = (i % 256) as u8;
1962        }
1963        test_write(&envs[0].local_memory, 0, &src0)?;
1964        let mut src3 = vec![0u8; buf_size];
1965        for (i, byte) in src3.iter_mut().enumerate() {
1966            *byte = ((i * 7 + 13) % 256) as u8;
1967        }
1968        test_write(&envs[3].local_memory, 0, &src3)?;
1969
1970        // Write envs[0] -> envs[1], read envs[2] <- envs[3] concurrently.
1971        let remote_1 = envs[1].rdma_remote_buf.clone();
1972        let remote_3 = envs[3].rdma_remote_buf.clone();
1973        let mut h0 = envs[0].tcp_backend.clone();
1974        let mut h2 = envs[2].tcp_backend.clone();
1975        let inst_0 = &envs[0].instance;
1976        let inst_2 = &envs[2].instance;
1977        let mem_0 = envs[0].local_memory.clone();
1978        let mem_2 = envs[2].local_memory.clone();
1979        let (write_result, read_result) = tokio::join!(
1980            h0.submit(
1981                inst_0,
1982                vec![RdmaOp {
1983                    op_type: RdmaOpType::WriteFromLocal,
1984                    local: mem_0,
1985                    remote: remote_1,
1986                }],
1987                Duration::from_secs(30),
1988            ),
1989            h2.submit(
1990                inst_2,
1991                vec![RdmaOp {
1992                    op_type: RdmaOpType::ReadIntoLocal,
1993                    local: mem_2,
1994                    remote: remote_3,
1995                }],
1996                Duration::from_secs(30),
1997            ),
1998        );
1999        write_result?;
2000        read_result?;
2001
2002        let mut dst1 = vec![0u8; buf_size];
2003        test_read(&envs[1].local_memory, 0, &mut dst1)?;
2004        for (i, byte) in dst1.iter().enumerate() {
2005            assert_eq!(*byte, (i % 256) as u8, "write mismatch at offset {i}");
2006        }
2007        let mut dst2 = vec![0u8; buf_size];
2008        test_read(&envs[2].local_memory, 0, &mut dst2)?;
2009        for (i, byte) in dst2.iter().enumerate() {
2010            assert_eq!(
2011                *byte,
2012                ((i * 7 + 13) % 256) as u8,
2013                "read mismatch at offset {i}"
2014            );
2015        }
2016
2017        Ok(())
2018    }
2019}