monarch_rdma/backend/tcp/
manager_actor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! TCP manager actor for RDMA fallback transport.
10//!
11//! Transfers buffer data over the default hyperactor channel transport
12//! in chunks controlled by
13//! [`RDMA_MAX_CHUNK_SIZE_MB`](crate::config::RDMA_MAX_CHUNK_SIZE_MB).
14
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::OnceLock;
18use std::time::Duration;
19use std::time::Instant;
20
21use anyhow::Result;
22use async_trait::async_trait;
23use bytes::Bytes;
24use bytes::BytesMut;
25use dashmap::DashMap;
26use hyperactor::Actor;
27use hyperactor::ActorHandle;
28use hyperactor::Context;
29use hyperactor::HandleClient;
30use hyperactor::Handler;
31use hyperactor::Instance;
32use hyperactor::OncePortHandle;
33use hyperactor::PortHandle;
34use hyperactor::RefClient;
35use hyperactor::actor::ActorError;
36use hyperactor::channel;
37use hyperactor::channel::ChannelAddr;
38use hyperactor::channel::ChannelTx;
39use hyperactor::channel::Rx;
40use hyperactor::channel::Tx;
41use hyperactor::context;
42use hyperactor::context::Actor as _;
43use hyperactor::reference;
44use hyperactor::reference::OncePortRef;
45use hyperactor_mesh::transport::default_transport;
46use serde::Deserialize;
47use serde::Serialize;
48use serde_multipart::Part;
49use tokio::time::timeout as tokio_timeout;
50use tokio_util::sync::CancellationToken;
51use typeuri::Named;
52
53use super::TcpOp;
54use crate::RdmaLocalMemory;
55use crate::RdmaOp;
56use crate::RdmaOpType;
57use crate::RdmaTransportLevel;
58use crate::backend::RdmaBackend;
59use crate::rdma_manager_actor::GetTcpActorRefClient;
60use crate::rdma_manager_actor::RdmaManagerActor;
61use crate::rdma_manager_actor::RdmaManagerMessageClient;
62
63/// [`Named`] wrapper around [`Part`] for use as a reply type.
64///
65/// [`Part`] itself does not implement [`Named`], which is required by
66/// [`OncePortRef`]. This newtype adds the missing trait.
67#[derive(Debug, Clone, Serialize, Deserialize, Named)]
68pub struct TcpChunk(Part);
69wirevalue::register_type!(TcpChunk);
70
71/// Data chunk sent over direct parallel channels.
72#[derive(Debug, Clone, Serialize, Deserialize, Named)]
73struct TcpDataChunk {
74    // Which specific transfer this chunk is associated with.
75    transfer_id: usize,
76    // Offset into the buffer for this chunk.
77    offset: usize,
78    data: Part,
79}
80wirevalue::register_type!(TcpDataChunk);
81
82/// Tracks the progress of a single parallel transfer (read or write).
83///
84/// Shared between channel receive loops and actor message handlers
85/// via a [`DashMap`].
86#[derive(Debug)]
87struct TransferState {
88    /// Buffer backing this transfer, provided at construction.
89    local_memory: Arc<dyn RdmaLocalMemory>,
90
91    /// Number of chunks received so far.
92    chunks_received: usize,
93
94    /// Total chunks expected for this transfer.
95    total_chunks: usize,
96
97    /// Completion reply port. Fired when all chunks arrive
98    /// or an error occurs.
99    done: OncePortRef<Result<(), String>>,
100}
101
102impl TransferState {
103    fn new(
104        total_chunks: usize,
105        local_memory: Arc<dyn RdmaLocalMemory>,
106        done: OncePortRef<Result<(), String>>,
107    ) -> Self {
108        Self {
109            local_memory,
110            chunks_received: 0,
111            total_chunks,
112            done,
113        }
114    }
115}
116
117/// Sends the result of a completed transfer to the caller's reply port.
118///
119/// Sending an actor message from the spawned receiver task requires the
120/// loop to own a dummy [`context::Actor`] impl. If the task sent directly
121/// to a remote [`OncePortRef`], the message would appear to come from
122/// this dummy context, and undeliverable messages wouldn't be handled
123/// properly. This intermediate message lets us use a [`PortHandle`]
124/// whose message cannot be undeliverable; the handler then forwards the
125/// result using the real actor's context.
126#[derive(Debug, Serialize, Deserialize, Named)]
127struct SendTransferResult {
128    done: OncePortRef<Result<(), String>>,
129    result: Result<(), String>,
130}
131
132/// Fatal error from the receive loop.
133///
134/// The handler logs the error and returns `Err`, which triggers a
135/// supervision event and crashes the actor.
136#[derive(Debug, Serialize, Deserialize, Named)]
137struct TransferError {
138    message: String,
139}
140
141/// Set up the local TcpManagerActor to receive a parallel transfer from
142/// a remote TcpManagerActor.
143#[derive(Debug)]
144struct RegisterTransferLocal {
145    local_memory: Arc<dyn RdmaLocalMemory>,
146    total_chunks: usize,
147    done: OncePortRef<Result<(), String>>,
148    // The transfer ID
149    reply: OncePortHandle<usize>,
150}
151
152/// Tell the local TcpManagerActor to read local memory and push
153/// chunks to `dest_addr`.
154#[derive(Debug)]
155struct ExecuteTransferLocal {
156    transfer_id: usize,
157    local_memory: Arc<dyn RdmaLocalMemory>,
158    chunk_size: usize,
159    dest_addr: ChannelAddr,
160}
161
162/// Serializable messages for the [`TcpManagerActor`].
163///
164/// These travel over the wire between processes. The [`Part`] payload
165/// is transferred via the multipart codec without an extra copy.
166#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
167enum TcpManagerMessage {
168    /// Write a chunk of data into a registered buffer at the given offset.
169    WriteChunk {
170        buf_id: usize,
171        offset: usize,
172        data: Part,
173        #[reply]
174        reply: OncePortRef<Result<(), String>>,
175    },
176    /// Read a chunk of data from a registered buffer at the given offset.
177    ReadChunk {
178        buf_id: usize,
179        offset: usize,
180        size: usize,
181        #[reply]
182        reply: OncePortRef<Result<TcpChunk, String>>,
183    },
184    /// Return the channel address served by this actor for parallel transfers.
185    /// `None` when parallelism is 1.
186    GetChannelAddress {
187        #[reply]
188        reply: OncePortRef<Option<ChannelAddr>>,
189    },
190    /// Set up a remote TcpManagerActor to receive a parallel transfer from
191    /// the sender.
192    RegisterTransferRemote {
193        buf_id: usize,
194        total_chunks: usize,
195        done: OncePortRef<Result<(), String>>,
196        #[reply]
197        reply: OncePortRef<Result<usize, String>>,
198    },
199    /// Tell the remote TcpManagerActor to read its local memory and push
200    /// chunks to the dest_addr provided by the sender.
201    ExecuteTransferRemote {
202        transfer_id: usize,
203        buf_id: usize,
204        chunk_size: usize,
205        dest_addr: ChannelAddr,
206        #[reply]
207        reply: OncePortRef<Result<(), String>>,
208    },
209}
210wirevalue::register_type!(TcpManagerMessage);
211
212/// TCP fallback RDMA backend actor.
213///
214/// Spawned as a child of [`RdmaManagerActor`]. Transfers buffer data
215/// over the default hyperactor channel transport in chunks.
216#[derive(Debug)]
217#[hyperactor::export(
218    handlers = [TcpManagerMessage],
219)]
220pub struct TcpManagerActor {
221    owner: OnceLock<ActorHandle<RdmaManagerActor>>,
222    next_transfer_id: usize,
223    transfers: Arc<DashMap<usize, TransferState>>,
224    /// Address of the direct channel served for parallel transfers.
225    /// `None` when parallelism is 1 (default).
226    channel_addr: Option<ChannelAddr>,
227    /// Cached outbound connections keyed by remote channel address.
228    outbound: HashMap<ChannelAddr, Vec<Arc<ChannelTx<TcpDataChunk>>>>,
229    /// Cancellation token for spawned tasks.
230    cancel: CancellationToken,
231    /// Signaled when the parallel receive loop exits.
232    receiver_done: Option<tokio::sync::oneshot::Receiver<()>>,
233}
234
235impl TcpManagerActor {
236    pub fn new() -> Self {
237        Self {
238            owner: OnceLock::new(),
239            next_transfer_id: 0,
240            transfers: Arc::new(DashMap::new()),
241            channel_addr: None,
242            outbound: HashMap::new(),
243            cancel: CancellationToken::new(),
244            receiver_done: None,
245        }
246    }
247
248    fn register_transfer(
249        &mut self,
250        local_memory: Arc<dyn RdmaLocalMemory>,
251        total_chunks: usize,
252        done: OncePortRef<Result<(), String>>,
253    ) -> usize {
254        let transfer_id = self.next_transfer_id;
255        self.next_transfer_id += 1;
256        self.transfers.insert(
257            transfer_id,
258            TransferState::new(total_chunks, local_memory, done),
259        );
260        transfer_id
261    }
262
263    fn execute_transfer(
264        &mut self,
265        cx: &Context<Self>,
266        transfer_id: usize,
267        local_memory: Arc<dyn RdmaLocalMemory>,
268        chunk_size: usize,
269        dest_addr: ChannelAddr,
270    ) -> Result<()> {
271        let parallelism =
272            hyperactor_config::global::get(crate::config::RDMA_TCP_FALLBACK_PARALLELISM);
273
274        if !self.outbound.contains_key(&dest_addr) {
275            let conns = (0..parallelism)
276                .map(|_| {
277                    channel::dial::<TcpDataChunk>(dest_addr.clone())
278                        .map(Arc::new)
279                        .map_err(anyhow::Error::from)
280                })
281                .collect::<Result<Vec<_>>>()?;
282            self.outbound.insert(dest_addr.clone(), conns);
283        }
284        let conns = self.outbound.get(&dest_addr).unwrap();
285
286        let size = local_memory.size();
287        let total_chunks = size.div_ceil(chunk_size);
288
289        let chunk_index = Arc::new(std::sync::atomic::AtomicUsize::new(0));
290        let error_port: PortHandle<TransferError> = cx.port();
291        let proc = cx.instance().proc().clone();
292        let cancel = self.cancel.clone();
293
294        for conn in conns.clone() {
295            let mem = local_memory.clone();
296            let transfer_id = transfer_id;
297            let chunk_index = chunk_index.clone();
298            let error_port = error_port.clone();
299            let proc = proc.clone();
300            let cancel = cancel.clone();
301
302            tokio::spawn(async move {
303                let sender_name = reference::name::Name::generate(
304                    reference::name::Ident::new("tcp_manager_actor".into()).unwrap(),
305                    reference::name::Ident::new("tcp_chunk_sender".into()).unwrap(),
306                );
307                let (instance, _handle) = proc
308                    .instance(&sender_name.to_string())
309                    .expect("failed to create sender instance");
310
311                loop {
312                    if cancel.is_cancelled() {
313                        return;
314                    }
315
316                    let idx = chunk_index.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
317                    if idx >= total_chunks {
318                        break;
319                    }
320
321                    let offset = idx * chunk_size;
322                    let len = std::cmp::min(chunk_size, size - offset);
323                    let mut buf = BytesMut::zeroed(len);
324                    if let Err(e) = mem.read_at(offset, &mut buf) {
325                        error_port
326                            .send(
327                                &instance,
328                                TransferError {
329                                    message: format!("read_at failed at offset {offset}: {e}"),
330                                },
331                            )
332                            .unwrap();
333                        return;
334                    }
335
336                    let chunk = TcpDataChunk {
337                        transfer_id,
338                        offset,
339                        data: Part::from(buf.freeze()),
340                    };
341
342                    if let Err(e) = conn.send(chunk).await {
343                        error_port
344                            .send(
345                                &instance,
346                                TransferError {
347                                    message: format!(
348                                        "failed to send chunk at offset {offset}: {e}"
349                                    ),
350                                },
351                            )
352                            .unwrap();
353                        return;
354                    }
355                }
356            });
357        }
358
359        Ok(())
360    }
361
362    /// Construct an [`ActorHandle`] for the local [`TcpManagerActor`]
363    /// by querying the local [`RdmaManagerActor`].
364    pub async fn local_handle(
365        client: &(impl context::Actor + Send + Sync),
366    ) -> Result<ActorHandle<Self>, anyhow::Error> {
367        let rdma_handle = RdmaManagerActor::local_handle(client);
368        let tcp_ref = rdma_handle.get_tcp_actor_ref(client).await?;
369        tcp_ref
370            .downcast_handle(client)
371            .ok_or_else(|| anyhow::anyhow!("TcpManagerActor is not in the local process"))
372    }
373}
374
375#[async_trait]
376impl Actor for TcpManagerActor {
377    async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
378        let owner = this.parent_handle().ok_or_else(|| {
379            anyhow::anyhow!("RdmaManagerActor not found as parent of TcpManagerActor")
380        })?;
381        self.owner
382            .set(owner)
383            .map_err(|_| anyhow::anyhow!("TcpManagerActor owner already set"))?;
384
385        let parallelism =
386            hyperactor_config::global::get(crate::config::RDMA_TCP_FALLBACK_PARALLELISM);
387        if parallelism > 1 {
388            let addr = ChannelAddr::any(default_transport());
389            let (bound_addr, mut rx) = channel::serve::<TcpDataChunk>(addr)?;
390            self.channel_addr = Some(bound_addr);
391
392            let transfers = self.transfers.clone();
393            let proc = this.proc().clone();
394            let result_port: PortHandle<SendTransferResult> = this.port();
395            let error_port: PortHandle<TransferError> = this.port();
396            let cancel = self.cancel.clone();
397            let receiver_name = reference::name::Name::generate(
398                reference::name::Ident::new("tcp_manager_actor".into()).unwrap(),
399                reference::name::Ident::new("tcp_chunk_receiver".into()).unwrap(),
400            );
401
402            let (done_tx, done_rx) = tokio::sync::oneshot::channel::<()>();
403            self.receiver_done = Some(done_rx);
404
405            tokio::spawn(async move {
406                let (instance, _handle) = proc
407                    .instance(&receiver_name.to_string())
408                    .expect("failed to create receiver instance");
409
410                loop {
411                    let chunk = tokio::select! {
412                        _ = cancel.cancelled() => break,
413                        result = rx.recv() => match result {
414                            Ok(chunk) => chunk,
415                            Err(e) => {
416                                error_port
417                                    .send(
418                                        &instance,
419                                        TransferError {
420                                            message: format!(
421                                                "parallel channel receive error: {e}"
422                                            ),
423                                        },
424                                    )
425                                    .unwrap();
426                                break;
427                            }
428                        },
429                    };
430
431                    let mut entry = match transfers.get_mut(&chunk.transfer_id) {
432                        Some(entry) => entry,
433                        None => {
434                            tracing::warn!(
435                                "received chunk for unknown transfer {:?}",
436                                chunk.transfer_id,
437                            );
438                            continue;
439                        }
440                    };
441
442                    let mut write_offset = chunk.offset;
443                    let fragments = chunk.data.into_inner();
444                    let write_err = fragments.iter().find_map(|fragment| {
445                        let result = entry.local_memory.write_at(write_offset, fragment);
446                        write_offset += fragment.len();
447                        result.err()
448                    });
449                    if let Some(e) = write_err {
450                        let transfer_id = chunk.transfer_id;
451                        drop(entry);
452                        let (_, state) = transfers.remove(&transfer_id).unwrap();
453                        result_port
454                            .send(
455                                &instance,
456                                SendTransferResult {
457                                    done: state.done,
458                                    result: Err(e.to_string()),
459                                },
460                            )
461                            .unwrap();
462                        continue;
463                    }
464
465                    entry.chunks_received += 1;
466                    if entry.chunks_received == entry.total_chunks {
467                        let transfer_id = chunk.transfer_id;
468                        drop(entry);
469                        let (_, state) = transfers.remove(&transfer_id).unwrap();
470                        result_port
471                            .send(
472                                &instance,
473                                SendTransferResult {
474                                    done: state.done,
475                                    result: Ok(()),
476                                },
477                            )
478                            .unwrap();
479                    }
480                }
481                rx.join().await;
482                done_tx.send(()).unwrap();
483            });
484        }
485
486        Ok(())
487    }
488
489    async fn cleanup(
490        &mut self,
491        _this: &Instance<Self>,
492        _err: Option<&ActorError>,
493    ) -> Result<(), anyhow::Error> {
494        self.cancel.cancel();
495        if let Some(done_rx) = self.receiver_done.take() {
496            done_rx.await?;
497        }
498        Ok(())
499    }
500}
501
502#[async_trait]
503#[hyperactor::handle(TcpManagerMessage)]
504impl TcpManagerMessageHandler for TcpManagerActor {
505    async fn write_chunk(
506        &mut self,
507        cx: &Context<Self>,
508        buf_id: usize,
509        offset: usize,
510        data: Part,
511    ) -> Result<Result<(), String>, anyhow::Error> {
512        let owner = self.owner.get().expect("TcpManagerActor owner not set");
513        let mem = match owner.request_local_memory(cx, buf_id).await {
514            Ok(Some(mem)) => mem,
515            Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
516            Err(e) => return Ok(Err(e.to_string())),
517        };
518
519        let bytes = data.into_bytes();
520        if let Err(e) = mem.write_at(offset, &bytes) {
521            return Ok(Err(e.to_string()));
522        }
523
524        Ok(Ok(()))
525    }
526
527    async fn read_chunk(
528        &mut self,
529        cx: &Context<Self>,
530        buf_id: usize,
531        offset: usize,
532        size: usize,
533    ) -> Result<Result<TcpChunk, String>, anyhow::Error> {
534        let owner = self.owner.get().expect("TcpManagerActor owner not set");
535        let mem = match owner.request_local_memory(cx, buf_id).await {
536            Ok(Some(mem)) => mem,
537            Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
538            Err(e) => return Ok(Err(e.to_string())),
539        };
540
541        let mut buf = BytesMut::zeroed(size);
542        if let Err(e) = mem.read_at(offset, &mut buf) {
543            return Ok(Err(e.to_string()));
544        }
545        Ok(Ok(TcpChunk(Part::from(buf.freeze()))))
546    }
547
548    async fn get_channel_address(
549        &mut self,
550        _cx: &Context<Self>,
551    ) -> Result<Option<ChannelAddr>, anyhow::Error> {
552        Ok(self.channel_addr.clone())
553    }
554
555    async fn register_transfer_remote(
556        &mut self,
557        cx: &Context<Self>,
558        buf_id: usize,
559        total_chunks: usize,
560        done: OncePortRef<Result<(), String>>,
561    ) -> Result<Result<usize, String>, anyhow::Error> {
562        let owner = self.owner.get().expect("TcpManagerActor owner not set");
563        let mem = match owner.request_local_memory(cx, buf_id).await {
564            Ok(Some(mem)) => mem,
565            Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
566            Err(e) => return Ok(Err(e.to_string())),
567        };
568        let transfer_id = self.register_transfer(mem, total_chunks, done);
569        Ok(Ok(transfer_id))
570    }
571
572    async fn execute_transfer_remote(
573        &mut self,
574        cx: &Context<Self>,
575        transfer_id: usize,
576        buf_id: usize,
577        chunk_size: usize,
578        dest_addr: ChannelAddr,
579    ) -> Result<Result<(), String>, anyhow::Error> {
580        let owner = self.owner.get().expect("TcpManagerActor owner not set");
581        let mem = match owner.request_local_memory(cx, buf_id).await {
582            Ok(Some(mem)) => mem,
583            Ok(None) => return Ok(Err(format!("buffer {buf_id} not found"))),
584            Err(e) => return Ok(Err(e.to_string())),
585        };
586        self.execute_transfer(cx, transfer_id, mem, chunk_size, dest_addr)?;
587        Ok(Ok(()))
588    }
589}
590
591#[async_trait]
592impl Handler<RegisterTransferLocal> for TcpManagerActor {
593    async fn handle(
594        &mut self,
595        cx: &Context<Self>,
596        message: RegisterTransferLocal,
597    ) -> Result<(), anyhow::Error> {
598        let transfer_id =
599            self.register_transfer(message.local_memory, message.total_chunks, message.done);
600        message.reply.send(cx, transfer_id)?;
601        Ok(())
602    }
603}
604
605#[async_trait]
606impl Handler<ExecuteTransferLocal> for TcpManagerActor {
607    async fn handle(
608        &mut self,
609        cx: &Context<Self>,
610        message: ExecuteTransferLocal,
611    ) -> Result<(), anyhow::Error> {
612        self.execute_transfer(
613            cx,
614            message.transfer_id,
615            message.local_memory,
616            message.chunk_size,
617            message.dest_addr,
618        )
619    }
620}
621
622#[async_trait]
623impl Handler<SendTransferResult> for TcpManagerActor {
624    async fn handle(
625        &mut self,
626        cx: &Context<Self>,
627        message: SendTransferResult,
628    ) -> Result<(), anyhow::Error> {
629        Ok(message.done.send(cx, message.result)?)
630    }
631}
632
633#[async_trait]
634impl Handler<TransferError> for TcpManagerActor {
635    async fn handle(
636        &mut self,
637        _cx: &Context<Self>,
638        message: TransferError,
639    ) -> Result<(), anyhow::Error> {
640        tracing::error!("fatal transfer error: {}", message.message);
641        Err(anyhow::anyhow!(message.message))
642    }
643}
644
645/// Wrapper around [`ActorHandle<TcpManagerActor>`] that moves the TCP
646/// data-plane (chunked reads/writes) off the actor loop while keeping
647/// buffer resolution serialized through actor messages.
648///
649/// Because submit logic now runs outside the actor loop, same-process
650/// messages no longer deadlock — the actor loop is free to handle
651/// `WriteChunk`/`ReadChunk` messages.
652#[derive(Debug, Clone)]
653pub struct TcpBackend(pub ActorHandle<TcpManagerActor>);
654
655impl std::ops::Deref for TcpBackend {
656    type Target = ActorHandle<TcpManagerActor>;
657    fn deref(&self) -> &Self::Target {
658        &self.0
659    }
660}
661
662impl TcpBackend {
663    /// Execute a parallel write: register the transfer on the remote
664    /// side, then execute locally to push chunks over direct channels.
665    async fn execute_parallel_write(
666        &self,
667        cx: &(impl context::Actor + Send + Sync),
668        op: &TcpOp,
669        chunk_size: usize,
670        deadline: Instant,
671    ) -> Result<()> {
672        let size = op.local_memory.size();
673        let total_chunks = size.div_ceil(chunk_size);
674
675        let (done_handle, done_rx) = hyperactor::mailbox::open_once_port::<Result<(), String>>(cx);
676        let done_ref = done_handle.bind();
677
678        let remaining = deadline.saturating_duration_since(Instant::now());
679        let transfer_id = tokio_timeout(
680            remaining,
681            op.remote_tcp_manager.register_transfer_remote(
682                cx,
683                op.remote_buf_id,
684                total_chunks,
685                done_ref,
686            ),
687        )
688        .await
689        .map_err(|_| anyhow::anyhow!("register_transfer_remote timed out"))??
690        .map_err(|e| anyhow::anyhow!(e))?;
691
692        let dest_addr = tokio_timeout(
693            deadline.saturating_duration_since(Instant::now()),
694            op.remote_tcp_manager.get_channel_address(cx),
695        )
696        .await
697        .map_err(|_| anyhow::anyhow!("get_channel_address timed out"))??
698        .ok_or_else(|| anyhow::anyhow!("remote does not have parallel channels enabled"))?;
699
700        self.0.send(
701            cx,
702            ExecuteTransferLocal {
703                transfer_id,
704                local_memory: op.local_memory.clone(),
705                chunk_size,
706                dest_addr,
707            },
708        )?;
709
710        let remaining = deadline.saturating_duration_since(Instant::now());
711        let result = tokio_timeout(remaining, done_rx.recv())
712            .await
713            .map_err(|_| anyhow::anyhow!("parallel write timed out"))?
714            .map_err(|e| anyhow::anyhow!(e))?;
715        result.map_err(|e| anyhow::anyhow!(e))
716    }
717
718    /// Execute a parallel read: register the transfer locally, then
719    /// ask the remote side to push chunks to our channel.
720    async fn execute_parallel_read(
721        &self,
722        cx: &(impl context::Actor + Send + Sync),
723        op: &TcpOp,
724        chunk_size: usize,
725        deadline: Instant,
726    ) -> Result<()> {
727        let size = op.local_memory.size();
728        let total_chunks = size.div_ceil(chunk_size);
729
730        let (done_handle, done_rx) = hyperactor::mailbox::open_once_port::<Result<(), String>>(cx);
731        let done_ref = done_handle.bind();
732
733        let (id_handle, id_rx) = hyperactor::mailbox::open_once_port::<usize>(cx);
734
735        self.0.send(
736            cx,
737            RegisterTransferLocal {
738                local_memory: op.local_memory.clone(),
739                total_chunks,
740                done: done_ref,
741                reply: id_handle,
742            },
743        )?;
744
745        let transfer_id = id_rx
746            .recv()
747            .await
748            .map_err(|e| anyhow::anyhow!("failed to receive transfer id: {e}"))?;
749
750        let my_channel_addr = self
751            .0
752            .get_channel_address(cx)
753            .await?
754            .ok_or_else(|| anyhow::anyhow!("local parallel channels not enabled"))?;
755
756        let remaining = deadline.saturating_duration_since(Instant::now());
757        tokio_timeout(
758            remaining,
759            op.remote_tcp_manager.execute_transfer_remote(
760                cx,
761                transfer_id,
762                op.remote_buf_id,
763                chunk_size,
764                my_channel_addr,
765            ),
766        )
767        .await
768        .map_err(|_| anyhow::anyhow!("execute_transfer_remote timed out"))??
769        .map_err(|e| anyhow::anyhow!(e))?;
770
771        let remaining = deadline.saturating_duration_since(Instant::now());
772        let result = tokio_timeout(remaining, done_rx.recv())
773            .await
774            .map_err(|_| anyhow::anyhow!("parallel read timed out"))?
775            .map_err(|e| anyhow::anyhow!(e))?;
776        result.map_err(|e| anyhow::anyhow!(e))
777    }
778
779    /// Execute a write operation: read local memory in chunks and write
780    /// them into the remote buffer via actor messages.
781    async fn execute_write(
782        &self,
783        cx: &(impl context::Actor + Send + Sync),
784        op: &TcpOp,
785        chunk_size: usize,
786        deadline: Instant,
787    ) -> Result<()> {
788        let size = op.local_memory.size();
789        let mut offset = 0;
790
791        while offset < size {
792            let remaining = deadline.saturating_duration_since(Instant::now());
793            if remaining.is_zero() {
794                anyhow::bail!("tcp write timed out");
795            }
796
797            let len = std::cmp::min(chunk_size, size - offset);
798
799            let mut buf = vec![0u8; len];
800            op.local_memory.read_at(offset, &mut buf)?;
801            let data = Part::from(Bytes::from(buf));
802
803            tokio_timeout(
804                remaining,
805                op.remote_tcp_manager
806                    .write_chunk(cx, op.remote_buf_id, offset, data),
807            )
808            .await
809            .map_err(|_| anyhow::anyhow!("tcp write chunk timed out"))??
810            .map_err(|e| anyhow::anyhow!(e))?;
811
812            offset += len;
813        }
814
815        Ok(())
816    }
817
818    /// Execute a read operation: request chunks from the remote buffer
819    /// and write them into local memory via actor messages.
820    async fn execute_read(
821        &self,
822        cx: &(impl context::Actor + Send + Sync),
823        op: &TcpOp,
824        chunk_size: usize,
825        deadline: Instant,
826    ) -> Result<()> {
827        let size = op.local_memory.size();
828        let mut offset = 0;
829
830        while offset < size {
831            let remaining = deadline.saturating_duration_since(Instant::now());
832            if remaining.is_zero() {
833                anyhow::bail!("tcp read timed out");
834            }
835
836            let len = std::cmp::min(chunk_size, size - offset);
837
838            let chunk = tokio_timeout(
839                remaining,
840                op.remote_tcp_manager
841                    .read_chunk(cx, op.remote_buf_id, offset, len),
842            )
843            .await
844            .map_err(|_| anyhow::anyhow!("tcp read chunk timed out"))??
845            .map_err(|e| anyhow::anyhow!(e))?;
846            let data = chunk.0.into_bytes();
847
848            anyhow::ensure!(
849                data.len() == len,
850                "tcp read chunk size mismatch: expected {len}, got {}",
851                data.len()
852            );
853
854            op.local_memory.write_at(offset, &data)?;
855
856            offset += len;
857        }
858
859        Ok(())
860    }
861}
862
863#[async_trait]
864impl RdmaBackend for TcpBackend {
865    type TransportInfo = ();
866
867    /// Submit a batch of RDMA operations over TCP.
868    ///
869    /// Each operation's remote buffer is resolved to its TCP backend
870    /// context, then executed directly — sending chunked write/read
871    /// messages to the remote [`TcpManagerActor`].
872    async fn submit(
873        &mut self,
874        cx: &(impl context::Actor + Send + Sync),
875        ops: Vec<RdmaOp>,
876        timeout: Duration,
877    ) -> Result<()> {
878        let chunk_size =
879            hyperactor_config::global::get(crate::config::RDMA_MAX_CHUNK_SIZE_MB) * 1024 * 1024;
880        let parallelism =
881            hyperactor_config::global::get(crate::config::RDMA_TCP_FALLBACK_PARALLELISM);
882        let deadline = Instant::now() + timeout;
883
884        for op in ops {
885            let remaining = deadline.saturating_duration_since(Instant::now());
886            if remaining.is_zero() {
887                anyhow::bail!("tcp submit timed out");
888            }
889
890            let (remote_tcp_mgr, remote_buf_id) = op.remote.resolve_tcp()?;
891            let tcp_op = TcpOp {
892                op_type: op.op_type.clone(),
893                local_memory: op.local,
894                remote_tcp_manager: remote_tcp_mgr,
895                remote_buf_id,
896            };
897
898            if parallelism > 1 {
899                match tcp_op.op_type {
900                    RdmaOpType::WriteFromLocal => {
901                        self.execute_parallel_write(cx, &tcp_op, chunk_size, deadline)
902                            .await?;
903                    }
904                    RdmaOpType::ReadIntoLocal => {
905                        self.execute_parallel_read(cx, &tcp_op, chunk_size, deadline)
906                            .await?;
907                    }
908                }
909            } else {
910                match tcp_op.op_type {
911                    RdmaOpType::WriteFromLocal => {
912                        self.execute_write(cx, &tcp_op, chunk_size, deadline)
913                            .await?;
914                    }
915                    RdmaOpType::ReadIntoLocal => {
916                        self.execute_read(cx, &tcp_op, chunk_size, deadline).await?;
917                    }
918                }
919            }
920        }
921
922        Ok(())
923    }
924
925    fn transport_level(&self) -> RdmaTransportLevel {
926        RdmaTransportLevel::Tcp
927    }
928
929    fn transport_info(&self) -> Option<Self::TransportInfo> {
930        None
931    }
932}
933
934#[cfg(test)]
935mod tests {
936    use std::sync::Arc;
937    use std::sync::atomic::AtomicUsize;
938    use std::sync::atomic::Ordering;
939    use std::time::Duration;
940
941    use hyperactor::ActorHandle;
942    use hyperactor::Proc;
943    use hyperactor::RemoteSpawn;
944    use hyperactor::channel::ChannelAddr;
945    use hyperactor_config::Flattrs;
946
947    use super::TcpBackend;
948    use super::TcpManagerActor;
949    use crate::RdmaManagerMessageClient;
950    use crate::RdmaOp;
951    use crate::RdmaOpType;
952    use crate::backend::RdmaBackend;
953    use crate::local_memory::Keepalive;
954    use crate::local_memory::KeepaliveLocalMemory;
955    use crate::local_memory::RdmaLocalMemory;
956    use crate::rdma_manager_actor::GetTcpActorRefClient;
957    use crate::rdma_manager_actor::RdmaManagerActor;
958
959    impl Keepalive for Box<[u8]> {}
960
961    static COUNTER: AtomicUsize = AtomicUsize::new(0);
962
963    struct TcpTestProcEnv {
964        proc: Proc,
965        rdma_handle: ActorHandle<RdmaManagerActor>,
966        instance: hyperactor::Instance<()>,
967        tcp_backend: TcpBackend,
968        rdma_remote_buf: crate::RdmaRemoteBuffer,
969        local_memory: Arc<dyn RdmaLocalMemory>,
970    }
971
972    impl Drop for TcpTestProcEnv {
973        fn drop(&mut self) {
974            use crate::rdma_manager_actor::ReleaseBufferClient;
975            // Release the buffer so the actor drops its local_memory
976            // clone while the CUDA runtime is still alive.
977            tokio::task::block_in_place(|| {
978                tokio::runtime::Handle::current()
979                    .block_on(
980                        self.rdma_remote_buf
981                            .owner
982                            .release_buffer(&self.instance, self.rdma_remote_buf.id),
983                    )
984                    .expect("failed to release buffer in TcpTestProcEnv drop");
985            });
986        }
987    }
988
989    impl TcpTestProcEnv {
990        /// Create a standalone test environment with its own proc and rdma manager.
991        async fn new(buffer_size: usize) -> anyhow::Result<Self> {
992            let id = COUNTER.fetch_add(1, Ordering::Relaxed);
993            let proc = Proc::direct(
994                ChannelAddr::any(hyperactor::channel::ChannelTransport::Unix),
995                format!("tcp_test_{id}"),
996            )?;
997            let (instance, _) = proc.instance("client")?;
998
999            let rdma_actor = RdmaManagerActor::new(None, Flattrs::default()).await?;
1000            let rdma_handle = proc.spawn("rdma_manager", rdma_actor)?;
1001
1002            let tcp_ref = rdma_handle.get_tcp_actor_ref(&instance).await?;
1003            let tcp_backend = TcpBackend(
1004                tcp_ref
1005                    .downcast_handle(&instance)
1006                    .ok_or_else(|| anyhow::anyhow!("tcp actor not local"))?,
1007            );
1008
1009            let (local_memory, rdma_remote_buf) =
1010                Self::alloc_cpu_buffer(&instance, &rdma_handle, buffer_size).await?;
1011
1012            Ok(Self {
1013                proc,
1014                rdma_handle,
1015                instance,
1016                tcp_backend,
1017                rdma_remote_buf,
1018                local_memory,
1019            })
1020        }
1021
1022        /// Create a buffer on an existing proc's rdma manager.
1023        async fn on_proc(
1024            proc: &Proc,
1025            rdma_handle: &ActorHandle<RdmaManagerActor>,
1026            tcp_backend: TcpBackend,
1027            buffer_size: usize,
1028        ) -> anyhow::Result<Self> {
1029            let id = COUNTER.fetch_add(1, Ordering::Relaxed);
1030            let (instance, _) = proc.instance(&format!("client_{id}"))?;
1031
1032            let (local_memory, rdma_remote_buf) =
1033                Self::alloc_cpu_buffer(&instance, rdma_handle, buffer_size).await?;
1034
1035            Ok(Self {
1036                proc: proc.clone(),
1037                rdma_handle: rdma_handle.clone(),
1038                instance,
1039                tcp_backend,
1040                rdma_remote_buf,
1041                local_memory,
1042            })
1043        }
1044
1045        async fn alloc_cpu_buffer(
1046            instance: &hyperactor::Instance<()>,
1047            rdma_handle: &ActorHandle<RdmaManagerActor>,
1048            buffer_size: usize,
1049        ) -> anyhow::Result<(Arc<dyn RdmaLocalMemory>, crate::RdmaRemoteBuffer)> {
1050            let cpu_buf = vec![0u8; buffer_size].into_boxed_slice();
1051            let ptr = cpu_buf.as_ptr() as usize;
1052            let local_memory: Arc<dyn RdmaLocalMemory> = Arc::new(KeepaliveLocalMemory::new(
1053                ptr,
1054                buffer_size,
1055                Arc::new(cpu_buf),
1056            ));
1057            let rdma_remote_buf = rdma_handle
1058                .request_buffer(instance, local_memory.clone())
1059                .await?;
1060            Ok((local_memory, rdma_remote_buf))
1061        }
1062    }
1063
1064    /// Two separate procs, one buffer each.
1065    async fn setup_tcp_env(buf_size: usize) -> anyhow::Result<Vec<TcpTestProcEnv>> {
1066        Ok(vec![
1067            TcpTestProcEnv::new(buf_size).await?,
1068            TcpTestProcEnv::new(buf_size).await?,
1069        ])
1070    }
1071
1072    /// Single proc, two buffers.
1073    async fn setup_same_proc_tcp_env(buf_size: usize) -> anyhow::Result<Vec<TcpTestProcEnv>> {
1074        let first = TcpTestProcEnv::new(buf_size).await?;
1075        let second = TcpTestProcEnv::on_proc(
1076            &first.proc,
1077            &first.rdma_handle,
1078            first.tcp_backend.clone(),
1079            buf_size,
1080        )
1081        .await?;
1082        Ok(vec![first, second])
1083    }
1084
1085    /// Two procs, two buffers each (4 total). For concurrent tests that
1086    /// need independent source/dest pairs.
1087    async fn setup_tcp_env_pairs(buf_size: usize) -> anyhow::Result<Vec<TcpTestProcEnv>> {
1088        let e0 = TcpTestProcEnv::new(buf_size).await?;
1089        let e1 = TcpTestProcEnv::new(buf_size).await?;
1090        let e2 =
1091            TcpTestProcEnv::on_proc(&e0.proc, &e0.rdma_handle, e0.tcp_backend.clone(), buf_size)
1092                .await?;
1093        let e3 =
1094            TcpTestProcEnv::on_proc(&e1.proc, &e1.rdma_handle, e1.tcp_backend.clone(), buf_size)
1095                .await?;
1096        Ok(vec![e0, e1, e2, e3])
1097    }
1098
1099    // --- Shared test helpers ---
1100
1101    /// Fill envs[0], write to envs[1], verify.
1102    async fn do_write_test(
1103        envs: &mut [TcpTestProcEnv],
1104        buf_size: usize,
1105        timeout: Duration,
1106    ) -> anyhow::Result<()> {
1107        let mut src = vec![0u8; buf_size];
1108        for (i, byte) in src.iter_mut().enumerate() {
1109            *byte = (i % 256) as u8;
1110        }
1111        envs[0].local_memory.write_at(0, &src)?;
1112
1113        let remote = envs[1].rdma_remote_buf.clone();
1114        let env = &mut envs[0];
1115        env.tcp_backend
1116            .submit(
1117                &env.instance,
1118                vec![RdmaOp {
1119                    op_type: RdmaOpType::WriteFromLocal,
1120                    local: env.local_memory.clone(),
1121                    remote,
1122                }],
1123                timeout,
1124            )
1125            .await?;
1126
1127        let mut dst = vec![0u8; buf_size];
1128        envs[1].local_memory.read_at(0, &mut dst)?;
1129        for (i, byte) in dst.iter().enumerate() {
1130            assert_eq!(*byte, (i % 256) as u8, "mismatch at offset {i} after write");
1131        }
1132        Ok(())
1133    }
1134
1135    /// Fill envs[1], read into envs[0], verify.
1136    async fn do_read_test(
1137        envs: &mut [TcpTestProcEnv],
1138        buf_size: usize,
1139        timeout: Duration,
1140    ) -> anyhow::Result<()> {
1141        let mut src = vec![0u8; buf_size];
1142        for (i, byte) in src.iter_mut().enumerate() {
1143            *byte = ((i * 7 + 3) % 256) as u8;
1144        }
1145        envs[1].local_memory.write_at(0, &src)?;
1146
1147        let remote = envs[1].rdma_remote_buf.clone();
1148        let env = &mut envs[0];
1149        env.tcp_backend
1150            .submit(
1151                &env.instance,
1152                vec![RdmaOp {
1153                    op_type: RdmaOpType::ReadIntoLocal,
1154                    local: env.local_memory.clone(),
1155                    remote,
1156                }],
1157                timeout,
1158            )
1159            .await?;
1160
1161        let mut dst = vec![0u8; buf_size];
1162        envs[0].local_memory.read_at(0, &mut dst)?;
1163        for (i, byte) in dst.iter().enumerate() {
1164            assert_eq!(
1165                *byte,
1166                ((i * 7 + 3) % 256) as u8,
1167                "mismatch at offset {i} after read"
1168            );
1169        }
1170        Ok(())
1171    }
1172
1173    /// Write, clear, read-back, verify round-trip.
1174    async fn do_round_trip_test(
1175        envs: &mut [TcpTestProcEnv],
1176        buf_size: usize,
1177        timeout: Duration,
1178    ) -> anyhow::Result<()> {
1179        let mut src = vec![0u8; buf_size];
1180        for (i, byte) in src.iter_mut().enumerate() {
1181            *byte = ((i * 13 + 5) % 256) as u8;
1182        }
1183        envs[0].local_memory.write_at(0, &src)?;
1184
1185        let remote = envs[1].rdma_remote_buf.clone();
1186        let env = &mut envs[0];
1187        env.tcp_backend
1188            .submit(
1189                &env.instance,
1190                vec![RdmaOp {
1191                    op_type: RdmaOpType::WriteFromLocal,
1192                    local: env.local_memory.clone(),
1193                    remote: remote.clone(),
1194                }],
1195                timeout,
1196            )
1197            .await?;
1198
1199        envs[0].local_memory.write_at(0, &vec![0u8; buf_size])?;
1200
1201        let env = &mut envs[0];
1202        env.tcp_backend
1203            .submit(
1204                &env.instance,
1205                vec![RdmaOp {
1206                    op_type: RdmaOpType::ReadIntoLocal,
1207                    local: env.local_memory.clone(),
1208                    remote,
1209                }],
1210                timeout,
1211            )
1212            .await?;
1213
1214        let mut dst = vec![0u8; buf_size];
1215        envs[0].local_memory.read_at(0, &mut dst)?;
1216        for (i, byte) in dst.iter().enumerate() {
1217            assert_eq!(
1218                *byte,
1219                ((i * 13 + 5) % 256) as u8,
1220                "mismatch at offset {i} after round-trip"
1221            );
1222        }
1223        Ok(())
1224    }
1225
1226    // --- Non-parallel two-proc tests ---
1227
1228    /// Write from local buffer 0 into remote buffer 1.
1229    #[timed_test::async_timed_test(timeout_secs = 30)]
1230    async fn test_tcp_write_from_local() -> anyhow::Result<()> {
1231        let config = hyperactor_config::global::lock();
1232        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1233
1234        let mut envs = setup_tcp_env(4096).await?;
1235        do_write_test(&mut envs, 4096, Duration::from_secs(30)).await
1236    }
1237
1238    /// Read from remote buffer 1 into local buffer 0.
1239    #[timed_test::async_timed_test(timeout_secs = 30)]
1240    async fn test_tcp_read_into_local() -> anyhow::Result<()> {
1241        let config = hyperactor_config::global::lock();
1242        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1243
1244        let mut envs = setup_tcp_env(2048).await?;
1245        do_read_test(&mut envs, 2048, Duration::from_secs(30)).await
1246    }
1247
1248    /// Write, clear, read-back, verify round-trip.
1249    #[timed_test::async_timed_test(timeout_secs = 30)]
1250    async fn test_tcp_write_then_read_back() -> anyhow::Result<()> {
1251        let config = hyperactor_config::global::lock();
1252        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1253
1254        let mut envs = setup_tcp_env(4096).await?;
1255        do_round_trip_test(&mut envs, 4096, Duration::from_secs(30)).await
1256    }
1257
1258    /// Multi-chunk write (1 MiB chunks, 1.5 MiB buffer).
1259    #[timed_test::async_timed_test(timeout_secs = 30)]
1260    async fn test_tcp_multi_chunk_write() -> anyhow::Result<()> {
1261        let config = hyperactor_config::global::lock();
1262        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1263        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1264
1265        let buf_size = 3 * 1024 * 512;
1266        let mut envs = setup_tcp_env(buf_size).await?;
1267
1268        let mut src = vec![0u8; buf_size];
1269        for (i, byte) in src.iter_mut().enumerate() {
1270            *byte = (i % 251) as u8;
1271        }
1272        envs[0].local_memory.write_at(0, &src)?;
1273
1274        let remote = envs[1].rdma_remote_buf.clone();
1275        let env = &mut envs[0];
1276        env.tcp_backend
1277            .submit(
1278                &env.instance,
1279                vec![RdmaOp {
1280                    op_type: RdmaOpType::WriteFromLocal,
1281                    local: env.local_memory.clone(),
1282                    remote,
1283                }],
1284                Duration::from_secs(30),
1285            )
1286            .await?;
1287
1288        let mut dst = vec![0u8; buf_size];
1289        envs[1].local_memory.read_at(0, &mut dst)?;
1290        for (i, byte) in dst.iter().enumerate() {
1291            assert_eq!(*byte, (i % 251) as u8, "mismatch at offset {i}");
1292        }
1293
1294        Ok(())
1295    }
1296
1297    /// Multi-chunk read.
1298    #[timed_test::async_timed_test(timeout_secs = 30)]
1299    async fn test_tcp_multi_chunk_read() -> anyhow::Result<()> {
1300        let config = hyperactor_config::global::lock();
1301        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1302        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1303
1304        let buf_size = 3 * 1024 * 512;
1305        let mut envs = setup_tcp_env(buf_size).await?;
1306
1307        let mut src = vec![0u8; buf_size];
1308        for (i, byte) in src.iter_mut().enumerate() {
1309            *byte = ((i * 3 + 17) % 256) as u8;
1310        }
1311        envs[1].local_memory.write_at(0, &src)?;
1312
1313        let remote = envs[1].rdma_remote_buf.clone();
1314        let env = &mut envs[0];
1315        env.tcp_backend
1316            .submit(
1317                &env.instance,
1318                vec![RdmaOp {
1319                    op_type: RdmaOpType::ReadIntoLocal,
1320                    local: env.local_memory.clone(),
1321                    remote,
1322                }],
1323                Duration::from_secs(30),
1324            )
1325            .await?;
1326
1327        let mut dst = vec![0u8; buf_size];
1328        envs[0].local_memory.read_at(0, &mut dst)?;
1329        for (i, byte) in dst.iter().enumerate() {
1330            assert_eq!(*byte, ((i * 3 + 17) % 256) as u8, "mismatch at offset {i}");
1331        }
1332
1333        Ok(())
1334    }
1335
1336    /// Multi-chunk write-then-read round-trip.
1337    #[timed_test::async_timed_test(timeout_secs = 30)]
1338    async fn test_tcp_multi_chunk_round_trip() -> anyhow::Result<()> {
1339        let config = hyperactor_config::global::lock();
1340        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1341        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1342
1343        let buf_size = 5 * 1024 * 512;
1344        let mut envs = setup_tcp_env(buf_size).await?;
1345
1346        let mut src = vec![0u8; buf_size];
1347        for (i, byte) in src.iter_mut().enumerate() {
1348            *byte = ((i * 41 + 7) % 256) as u8;
1349        }
1350        envs[0].local_memory.write_at(0, &src)?;
1351
1352        let remote = envs[1].rdma_remote_buf.clone();
1353        let env = &mut envs[0];
1354        env.tcp_backend
1355            .submit(
1356                &env.instance,
1357                vec![RdmaOp {
1358                    op_type: RdmaOpType::WriteFromLocal,
1359                    local: env.local_memory.clone(),
1360                    remote: remote.clone(),
1361                }],
1362                Duration::from_secs(30),
1363            )
1364            .await?;
1365
1366        envs[0].local_memory.write_at(0, &vec![0u8; buf_size])?;
1367
1368        let env = &mut envs[0];
1369        env.tcp_backend
1370            .submit(
1371                &env.instance,
1372                vec![RdmaOp {
1373                    op_type: RdmaOpType::ReadIntoLocal,
1374                    local: env.local_memory.clone(),
1375                    remote,
1376                }],
1377                Duration::from_secs(30),
1378            )
1379            .await?;
1380
1381        let mut dst = vec![0u8; buf_size];
1382        envs[0].local_memory.read_at(0, &mut dst)?;
1383        for (i, byte) in dst.iter().enumerate() {
1384            assert_eq!(
1385                *byte,
1386                ((i * 41 + 7) % 256) as u8,
1387                "mismatch at offset {i} after multi-chunk round-trip"
1388            );
1389        }
1390
1391        Ok(())
1392    }
1393
1394    /// resolve_tcp finds the Tcp backend context in a buffer.
1395    #[timed_test::async_timed_test(timeout_secs = 30)]
1396    async fn test_tcp_resolve_tcp() -> anyhow::Result<()> {
1397        let config = hyperactor_config::global::lock();
1398        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1399
1400        let envs = setup_tcp_env(64).await?;
1401
1402        for (i, env) in envs.iter().enumerate() {
1403            let (tcp_ref, id) = env.rdma_remote_buf.resolve_tcp()?;
1404            assert_eq!(id, env.rdma_remote_buf.id, "buf id mismatch for env {i}");
1405            let expected: hyperactor::ActorRef<TcpManagerActor> = env.tcp_backend.bind();
1406            assert_eq!(tcp_ref.actor_id(), expected.actor_id());
1407        }
1408
1409        Ok(())
1410    }
1411
1412    /// Write to a released buffer returns an error without crashing.
1413    #[timed_test::async_timed_test(timeout_secs = 30)]
1414    async fn test_tcp_write_to_released_buffer() -> anyhow::Result<()> {
1415        let config = hyperactor_config::global::lock();
1416        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1417
1418        let buf_size = 64;
1419        let mut envs = setup_tcp_env(buf_size).await?;
1420
1421        let mut src = vec![0u8; buf_size];
1422        for (i, byte) in src.iter_mut().enumerate() {
1423            *byte = (i % 256) as u8;
1424        }
1425        envs[0].local_memory.write_at(0, &src)?;
1426
1427        // Normal write should succeed.
1428        let remote = envs[1].rdma_remote_buf.clone();
1429        let env = &mut envs[0];
1430        env.tcp_backend
1431            .submit(
1432                &env.instance,
1433                vec![RdmaOp {
1434                    op_type: RdmaOpType::WriteFromLocal,
1435                    local: env.local_memory.clone(),
1436                    remote: remote.clone(),
1437                }],
1438                Duration::from_secs(10),
1439            )
1440            .await?;
1441
1442        // Release the remote buffer.
1443        use crate::rdma_manager_actor::ReleaseBufferClient;
1444        let owner_ref = envs[1].rdma_remote_buf.owner.clone();
1445        owner_ref
1446            .release_buffer(&envs[0].instance, envs[1].rdma_remote_buf.id)
1447            .await?;
1448
1449        // Writing to the released buffer should fail.
1450        let env = &mut envs[0];
1451        let result = env
1452            .tcp_backend
1453            .submit(
1454                &env.instance,
1455                vec![RdmaOp {
1456                    op_type: RdmaOpType::WriteFromLocal,
1457                    local: env.local_memory.clone(),
1458                    remote: remote.clone(),
1459                }],
1460                Duration::from_secs(10),
1461            )
1462            .await;
1463        assert!(result.is_err(), "expected error writing to released buffer");
1464
1465        Ok(())
1466    }
1467
1468    /// Read from a released buffer returns an error without crashing.
1469    #[timed_test::async_timed_test(timeout_secs = 30)]
1470    async fn test_tcp_read_from_released_buffer() -> anyhow::Result<()> {
1471        let config = hyperactor_config::global::lock();
1472        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1473
1474        let buf_size = 64;
1475        let mut envs = setup_tcp_env(buf_size).await?;
1476
1477        // Release the remote buffer.
1478        use crate::rdma_manager_actor::ReleaseBufferClient;
1479        let owner_ref = envs[1].rdma_remote_buf.owner.clone();
1480        owner_ref
1481            .release_buffer(&envs[0].instance, envs[1].rdma_remote_buf.id)
1482            .await?;
1483
1484        // Reading from the released buffer should fail.
1485        let remote = envs[1].rdma_remote_buf.clone();
1486        let env = &mut envs[0];
1487        let result = env
1488            .tcp_backend
1489            .submit(
1490                &env.instance,
1491                vec![RdmaOp {
1492                    op_type: RdmaOpType::ReadIntoLocal,
1493                    local: env.local_memory.clone(),
1494                    remote,
1495                }],
1496                Duration::from_secs(10),
1497            )
1498            .await;
1499        assert!(
1500            result.is_err(),
1501            "expected error reading from released buffer"
1502        );
1503
1504        Ok(())
1505    }
1506
1507    // --- Non-parallel same-proc tests ---
1508
1509    /// Same-process write.
1510    #[timed_test::async_timed_test(timeout_secs = 30)]
1511    async fn test_tcp_same_process_write() -> anyhow::Result<()> {
1512        let config = hyperactor_config::global::lock();
1513        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1514
1515        let mut envs = setup_same_proc_tcp_env(4096).await?;
1516        do_write_test(&mut envs, 4096, Duration::from_secs(10)).await
1517    }
1518
1519    /// Same-process read.
1520    #[timed_test::async_timed_test(timeout_secs = 30)]
1521    async fn test_tcp_same_process_read() -> anyhow::Result<()> {
1522        let config = hyperactor_config::global::lock();
1523        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1524
1525        let mut envs = setup_same_proc_tcp_env(2048).await?;
1526        do_read_test(&mut envs, 2048, Duration::from_secs(10)).await
1527    }
1528
1529    /// Same-process write-then-read round-trip.
1530    #[timed_test::async_timed_test(timeout_secs = 30)]
1531    async fn test_tcp_same_process_round_trip() -> anyhow::Result<()> {
1532        let config = hyperactor_config::global::lock();
1533        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1534
1535        let mut envs = setup_same_proc_tcp_env(4096).await?;
1536        do_round_trip_test(&mut envs, 4096, Duration::from_secs(10)).await
1537    }
1538
1539    /// When TCP fallback is disabled and ibverbs is unavailable,
1540    /// RdmaManagerActor::new returns an error.
1541    #[timed_test::async_timed_test(timeout_secs = 30)]
1542    async fn test_tcp_fallback_disabled_fails() -> anyhow::Result<()> {
1543        let config = hyperactor_config::global::lock();
1544        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, false);
1545
1546        let result = RdmaManagerActor::new(None, Flattrs::default()).await;
1547        if crate::ibverbs_supported() {
1548            assert!(result.is_ok());
1549        } else {
1550            assert!(result.is_err());
1551        }
1552
1553        Ok(())
1554    }
1555
1556    // --- Multi-GPU TCP fallback tests ---
1557
1558    use crate::backend::cuda_test_utils::CudaAllocator;
1559    use crate::backend::cuda_test_utils::cuda_device_count;
1560
1561    impl TcpTestProcEnv {
1562        /// Create a test environment backed by CUDA device memory.
1563        async fn new_gpu(device: i32, buffer_size: usize) -> anyhow::Result<Self> {
1564            let id = COUNTER.fetch_add(1, Ordering::Relaxed);
1565            let proc = Proc::direct(
1566                ChannelAddr::any(hyperactor::channel::ChannelTransport::Unix),
1567                format!("tcp_gpu_test_{id}"),
1568            )?;
1569            let (instance, _) = proc.instance("client")?;
1570
1571            let rdma_actor = RdmaManagerActor::new(None, Flattrs::default()).await?;
1572            let rdma_handle = proc.spawn("rdma_manager", rdma_actor)?;
1573
1574            let tcp_ref = rdma_handle.get_tcp_actor_ref(&instance).await?;
1575            let tcp_backend = TcpBackend(
1576                tcp_ref
1577                    .downcast_handle(&instance)
1578                    .ok_or_else(|| anyhow::anyhow!("tcp actor not local"))?,
1579            );
1580
1581            let alloc = CudaAllocator::get().allocate(device, buffer_size);
1582            let local_memory: Arc<dyn RdmaLocalMemory> = Arc::new(KeepaliveLocalMemory::new(
1583                alloc.ptr(),
1584                buffer_size,
1585                Arc::new(alloc),
1586            ));
1587            let rdma_remote_buf = rdma_handle
1588                .request_buffer(&instance, local_memory.clone())
1589                .await?;
1590
1591            Ok(Self {
1592                proc,
1593                rdma_handle,
1594                instance,
1595                tcp_backend,
1596                rdma_remote_buf,
1597                local_memory,
1598            })
1599        }
1600    }
1601
1602    /// TCP write from GPU on cuda:0 to GPU on cuda:1.
1603    #[timed_test::async_timed_test(timeout_secs = 60)]
1604    async fn test_tcp_write_multi_gpu() -> anyhow::Result<()> {
1605        if cuda_device_count() < 2 {
1606            println!("Skipping: need at least 2 CUDA devices");
1607            return Ok(());
1608        }
1609
1610        let config = hyperactor_config::global::lock();
1611        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1612
1613        let buf_size = 2 * 1024 * 1024;
1614        let mut envs = vec![
1615            TcpTestProcEnv::new_gpu(0, buf_size).await?,
1616            TcpTestProcEnv::new_gpu(1, buf_size).await?,
1617        ];
1618        do_write_test(&mut envs, buf_size, Duration::from_secs(30)).await
1619    }
1620
1621    /// TCP read from GPU on cuda:1 into GPU on cuda:0.
1622    #[timed_test::async_timed_test(timeout_secs = 60)]
1623    async fn test_tcp_read_multi_gpu() -> anyhow::Result<()> {
1624        if cuda_device_count() < 2 {
1625            println!("Skipping: need at least 2 CUDA devices");
1626            return Ok(());
1627        }
1628
1629        let config = hyperactor_config::global::lock();
1630        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1631
1632        let buf_size = 2 * 1024 * 1024;
1633        let mut envs = vec![
1634            TcpTestProcEnv::new_gpu(0, buf_size).await?,
1635            TcpTestProcEnv::new_gpu(1, buf_size).await?,
1636        ];
1637        do_read_test(&mut envs, buf_size, Duration::from_secs(30)).await
1638    }
1639
1640    /// TCP write-then-read round-trip between cuda:0 and cuda:1.
1641    #[timed_test::async_timed_test(timeout_secs = 60)]
1642    async fn test_tcp_round_trip_multi_gpu() -> anyhow::Result<()> {
1643        if cuda_device_count() < 2 {
1644            println!("Skipping: need at least 2 CUDA devices");
1645            return Ok(());
1646        }
1647
1648        let config = hyperactor_config::global::lock();
1649        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1650
1651        let buf_size = 2 * 1024 * 1024;
1652        let mut envs = vec![
1653            TcpTestProcEnv::new_gpu(0, buf_size).await?,
1654            TcpTestProcEnv::new_gpu(1, buf_size).await?,
1655        ];
1656        do_round_trip_test(&mut envs, buf_size, Duration::from_secs(30)).await
1657    }
1658
1659    /// Stopping the RdmaManagerActor with parallelism enabled cleanly
1660    /// shuts down the TcpManagerActor's receive loop without hanging.
1661    #[timed_test::async_timed_test(timeout_secs = 30)]
1662    async fn test_tcp_parallel_clean_shutdown() -> anyhow::Result<()> {
1663        let config = hyperactor_config::global::lock();
1664        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1665        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1666        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1667
1668        let buf_size = 3 * 1024 * 1024;
1669        let mut envs = setup_tcp_env(buf_size).await?;
1670
1671        // Do a transfer so the receive loop and outbound connections are live.
1672        let mut src = vec![0u8; buf_size];
1673        for (i, byte) in src.iter_mut().enumerate() {
1674            *byte = (i % 256) as u8;
1675        }
1676        envs[0].local_memory.write_at(0, &src)?;
1677        let remote = envs[1].rdma_remote_buf.clone();
1678        let env = &mut envs[0];
1679        env.tcp_backend
1680            .submit(
1681                &env.instance,
1682                vec![RdmaOp {
1683                    op_type: RdmaOpType::WriteFromLocal,
1684                    local: env.local_memory.clone(),
1685                    remote: remote.clone(),
1686                }],
1687                Duration::from_secs(30),
1688            )
1689            .await?;
1690
1691        // Stop the RdmaManagerActor, which cascades to TcpManagerActor.
1692        // The test timeout ensures we detect hangs in the cleanup path.
1693        envs[0].rdma_handle.drain_and_stop("test")?;
1694        envs[0].rdma_handle.clone().await;
1695        envs[1].rdma_handle.drain_and_stop("test")?;
1696        envs[1].rdma_handle.clone().await;
1697
1698        Ok(())
1699    }
1700
1701    // --- Parallel transfer tests ---
1702
1703    /// Parallel write via direct channels.
1704    #[timed_test::async_timed_test(timeout_secs = 30)]
1705    async fn test_tcp_parallel_write() -> anyhow::Result<()> {
1706        let config = hyperactor_config::global::lock();
1707        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1708        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1709        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1710
1711        // 3 MiB, 3 chunks spread across 2 workers.
1712        let buf_size = 3 * 1024 * 1024;
1713        let mut envs = setup_tcp_env(buf_size).await?;
1714        do_write_test(&mut envs, buf_size, Duration::from_secs(30)).await
1715    }
1716
1717    /// Parallel read via direct channels.
1718    #[timed_test::async_timed_test(timeout_secs = 30)]
1719    async fn test_tcp_parallel_read() -> anyhow::Result<()> {
1720        let config = hyperactor_config::global::lock();
1721        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1722        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1723        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1724
1725        let buf_size = 3 * 1024 * 1024;
1726        let mut envs = setup_tcp_env(buf_size).await?;
1727        do_read_test(&mut envs, buf_size, Duration::from_secs(30)).await
1728    }
1729
1730    /// Parallel write-then-read round-trip.
1731    #[timed_test::async_timed_test(timeout_secs = 30)]
1732    async fn test_tcp_parallel_round_trip() -> anyhow::Result<()> {
1733        let config = hyperactor_config::global::lock();
1734        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1735        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1736        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1737
1738        let buf_size = 3 * 1024 * 1024;
1739        let mut envs = setup_tcp_env(buf_size).await?;
1740        do_round_trip_test(&mut envs, buf_size, Duration::from_secs(30)).await
1741    }
1742
1743    /// Same-process parallel write.
1744    #[timed_test::async_timed_test(timeout_secs = 30)]
1745    async fn test_tcp_parallel_same_process_write() -> anyhow::Result<()> {
1746        let config = hyperactor_config::global::lock();
1747        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1748        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1749        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1750
1751        let buf_size = 3 * 1024 * 1024;
1752        let mut envs = setup_same_proc_tcp_env(buf_size).await?;
1753        do_write_test(&mut envs, buf_size, Duration::from_secs(10)).await
1754    }
1755
1756    /// Same-process parallel read.
1757    #[timed_test::async_timed_test(timeout_secs = 30)]
1758    async fn test_tcp_parallel_same_process_read() -> anyhow::Result<()> {
1759        let config = hyperactor_config::global::lock();
1760        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1761        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1762        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1763
1764        let buf_size = 3 * 1024 * 1024;
1765        let mut envs = setup_same_proc_tcp_env(buf_size).await?;
1766        do_read_test(&mut envs, buf_size, Duration::from_secs(10)).await
1767    }
1768
1769    // --- Concurrent parallel tests (4 envs, 2 independent pairs) ---
1770
1771    /// Two concurrent parallel writes to independent buffer pairs.
1772    #[timed_test::async_timed_test(timeout_secs = 30)]
1773    async fn test_tcp_parallel_concurrent_writes() -> anyhow::Result<()> {
1774        let config = hyperactor_config::global::lock();
1775        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1776        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1777        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1778
1779        let buf_size = 3 * 1024 * 1024;
1780        let envs = setup_tcp_env_pairs(buf_size).await?;
1781
1782        // Fill source buffers with distinct patterns.
1783        let mut src0 = vec![0u8; buf_size];
1784        for (i, byte) in src0.iter_mut().enumerate() {
1785            *byte = (i % 256) as u8;
1786        }
1787        envs[0].local_memory.write_at(0, &src0)?;
1788        let mut src2 = vec![0u8; buf_size];
1789        for (i, byte) in src2.iter_mut().enumerate() {
1790            *byte = ((i * 3 + 7) % 256) as u8;
1791        }
1792        envs[2].local_memory.write_at(0, &src2)?;
1793
1794        // Pair 1: envs[0] -> envs[1], Pair 2: envs[2] -> envs[3].
1795        let remote_1 = envs[1].rdma_remote_buf.clone();
1796        let remote_3 = envs[3].rdma_remote_buf.clone();
1797        let mut h0 = envs[0].tcp_backend.clone();
1798        let mut h2 = envs[2].tcp_backend.clone();
1799        let inst_0 = &envs[0].instance;
1800        let inst_2 = &envs[2].instance;
1801        let mem_0 = envs[0].local_memory.clone();
1802        let mem_2 = envs[2].local_memory.clone();
1803        let (r1, r2) = tokio::join!(
1804            h0.submit(
1805                inst_0,
1806                vec![RdmaOp {
1807                    op_type: RdmaOpType::WriteFromLocal,
1808                    local: mem_0,
1809                    remote: remote_1,
1810                }],
1811                Duration::from_secs(30),
1812            ),
1813            h2.submit(
1814                inst_2,
1815                vec![RdmaOp {
1816                    op_type: RdmaOpType::WriteFromLocal,
1817                    local: mem_2,
1818                    remote: remote_3,
1819                }],
1820                Duration::from_secs(30),
1821            ),
1822        );
1823        r1?;
1824        r2?;
1825
1826        let mut dst1 = vec![0u8; buf_size];
1827        envs[1].local_memory.read_at(0, &mut dst1)?;
1828        for (i, byte) in dst1.iter().enumerate() {
1829            assert_eq!(*byte, (i % 256) as u8, "pair 1 mismatch at offset {i}");
1830        }
1831        let mut dst3 = vec![0u8; buf_size];
1832        envs[3].local_memory.read_at(0, &mut dst3)?;
1833        for (i, byte) in dst3.iter().enumerate() {
1834            assert_eq!(
1835                *byte,
1836                ((i * 3 + 7) % 256) as u8,
1837                "pair 2 mismatch at offset {i}"
1838            );
1839        }
1840
1841        Ok(())
1842    }
1843
1844    /// Two concurrent parallel reads from independent buffer pairs.
1845    #[timed_test::async_timed_test(timeout_secs = 30)]
1846    async fn test_tcp_parallel_concurrent_reads() -> anyhow::Result<()> {
1847        let config = hyperactor_config::global::lock();
1848        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1849        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1850        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1851
1852        let buf_size = 3 * 1024 * 1024;
1853        let envs = setup_tcp_env_pairs(buf_size).await?;
1854
1855        // Fill remote buffers with distinct patterns.
1856        let mut src1 = vec![0u8; buf_size];
1857        for (i, byte) in src1.iter_mut().enumerate() {
1858            *byte = ((i * 11 + 3) % 256) as u8;
1859        }
1860        envs[1].local_memory.write_at(0, &src1)?;
1861        let mut src3 = vec![0u8; buf_size];
1862        for (i, byte) in src3.iter_mut().enumerate() {
1863            *byte = ((i * 5 + 13) % 256) as u8;
1864        }
1865        envs[3].local_memory.write_at(0, &src3)?;
1866
1867        // Pair 1: envs[0] <- envs[1], Pair 2: envs[2] <- envs[3].
1868        let remote_1 = envs[1].rdma_remote_buf.clone();
1869        let remote_3 = envs[3].rdma_remote_buf.clone();
1870        let mut h0 = envs[0].tcp_backend.clone();
1871        let mut h2 = envs[2].tcp_backend.clone();
1872        let inst_0 = &envs[0].instance;
1873        let inst_2 = &envs[2].instance;
1874        let mem_0 = envs[0].local_memory.clone();
1875        let mem_2 = envs[2].local_memory.clone();
1876        let (r1, r2) = tokio::join!(
1877            h0.submit(
1878                inst_0,
1879                vec![RdmaOp {
1880                    op_type: RdmaOpType::ReadIntoLocal,
1881                    local: mem_0,
1882                    remote: remote_1,
1883                }],
1884                Duration::from_secs(30),
1885            ),
1886            h2.submit(
1887                inst_2,
1888                vec![RdmaOp {
1889                    op_type: RdmaOpType::ReadIntoLocal,
1890                    local: mem_2,
1891                    remote: remote_3,
1892                }],
1893                Duration::from_secs(30),
1894            ),
1895        );
1896        r1?;
1897        r2?;
1898
1899        let mut dst0 = vec![0u8; buf_size];
1900        envs[0].local_memory.read_at(0, &mut dst0)?;
1901        for (i, byte) in dst0.iter().enumerate() {
1902            assert_eq!(
1903                *byte,
1904                ((i * 11 + 3) % 256) as u8,
1905                "pair 1 mismatch at offset {i}"
1906            );
1907        }
1908        let mut dst2 = vec![0u8; buf_size];
1909        envs[2].local_memory.read_at(0, &mut dst2)?;
1910        for (i, byte) in dst2.iter().enumerate() {
1911            assert_eq!(
1912                *byte,
1913                ((i * 5 + 13) % 256) as u8,
1914                "pair 2 mismatch at offset {i}"
1915            );
1916        }
1917
1918        Ok(())
1919    }
1920
1921    /// Concurrent parallel write and read on independent buffer pairs.
1922    #[timed_test::async_timed_test(timeout_secs = 30)]
1923    async fn test_tcp_parallel_concurrent_write_and_read() -> anyhow::Result<()> {
1924        let config = hyperactor_config::global::lock();
1925        let _guard = config.override_key(crate::config::RDMA_ALLOW_TCP_FALLBACK, true);
1926        let _par_guard = config.override_key(crate::config::RDMA_TCP_FALLBACK_PARALLELISM, 2);
1927        let _chunk_guard = config.override_key(crate::config::RDMA_MAX_CHUNK_SIZE_MB, 1);
1928
1929        let buf_size = 3 * 1024 * 1024;
1930        let envs = setup_tcp_env_pairs(buf_size).await?;
1931
1932        // Fill source buffers.
1933        let mut src0 = vec![0u8; buf_size];
1934        for (i, byte) in src0.iter_mut().enumerate() {
1935            *byte = (i % 256) as u8;
1936        }
1937        envs[0].local_memory.write_at(0, &src0)?;
1938        let mut src3 = vec![0u8; buf_size];
1939        for (i, byte) in src3.iter_mut().enumerate() {
1940            *byte = ((i * 7 + 13) % 256) as u8;
1941        }
1942        envs[3].local_memory.write_at(0, &src3)?;
1943
1944        // Write envs[0] -> envs[1], read envs[2] <- envs[3] concurrently.
1945        let remote_1 = envs[1].rdma_remote_buf.clone();
1946        let remote_3 = envs[3].rdma_remote_buf.clone();
1947        let mut h0 = envs[0].tcp_backend.clone();
1948        let mut h2 = envs[2].tcp_backend.clone();
1949        let inst_0 = &envs[0].instance;
1950        let inst_2 = &envs[2].instance;
1951        let mem_0 = envs[0].local_memory.clone();
1952        let mem_2 = envs[2].local_memory.clone();
1953        let (write_result, read_result) = tokio::join!(
1954            h0.submit(
1955                inst_0,
1956                vec![RdmaOp {
1957                    op_type: RdmaOpType::WriteFromLocal,
1958                    local: mem_0,
1959                    remote: remote_1,
1960                }],
1961                Duration::from_secs(30),
1962            ),
1963            h2.submit(
1964                inst_2,
1965                vec![RdmaOp {
1966                    op_type: RdmaOpType::ReadIntoLocal,
1967                    local: mem_2,
1968                    remote: remote_3,
1969                }],
1970                Duration::from_secs(30),
1971            ),
1972        );
1973        write_result?;
1974        read_result?;
1975
1976        let mut dst1 = vec![0u8; buf_size];
1977        envs[1].local_memory.read_at(0, &mut dst1)?;
1978        for (i, byte) in dst1.iter().enumerate() {
1979            assert_eq!(*byte, (i % 256) as u8, "write mismatch at offset {i}");
1980        }
1981        let mut dst2 = vec![0u8; buf_size];
1982        envs[2].local_memory.read_at(0, &mut dst2)?;
1983        for (i, byte) in dst2.iter().enumerate() {
1984            assert_eq!(
1985                *byte,
1986                ((i * 7 + 13) % 256) as u8,
1987                "read mismatch at offset {i}"
1988            );
1989        }
1990
1991        Ok(())
1992    }
1993}