monarch_rdma/
test_utils.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::sync::Once;
10use std::sync::atomic::AtomicBool;
11use std::sync::atomic::Ordering;
12
13/// Cached result of CUDA availability check
14static CUDA_AVAILABLE: AtomicBool = AtomicBool::new(false);
15static INIT: Once = Once::new();
16
17/// Safely checks if CUDA is available on the system.
18///
19/// This function attempts to initialize CUDA and determine if it's available.
20/// The result is cached after the first call, so subsequent calls are very fast.
21///
22/// # Returns
23///
24/// `true` if CUDA is available and can be initialized, `false` otherwise.
25///
26/// # Examples
27///
28/// ```
29/// use monarch_rdma::is_cuda_available;
30///
31/// if is_cuda_available() {
32///     println!("CUDA is available, can use GPU features");
33/// } else {
34///     println!("CUDA is not available, falling back to CPU-only mode");
35/// }
36/// ```
37pub fn is_cuda_available() -> bool {
38    INIT.call_once(|| {
39        let available = check_cuda_available();
40        CUDA_AVAILABLE.store(available, Ordering::SeqCst);
41    });
42    CUDA_AVAILABLE.load(Ordering::SeqCst)
43}
44
45/// Internal function that performs the actual CUDA availability check
46fn check_cuda_available() -> bool {
47    unsafe {
48        // Try to initialize CUDA
49        let result = rdmaxcel_sys::rdmaxcel_cuInit(0);
50
51        if result != rdmaxcel_sys::CUDA_SUCCESS {
52            return false;
53        }
54
55        // Check if there are any CUDA devices
56        let mut device_count: i32 = 0;
57        let count_result = rdmaxcel_sys::rdmaxcel_cuDeviceGetCount(&mut device_count);
58
59        if count_result != rdmaxcel_sys::CUDA_SUCCESS || device_count <= 0 {
60            return false;
61        }
62
63        // Try to get the first device to verify it's actually accessible
64        let mut device: rdmaxcel_sys::CUdevice = std::mem::zeroed();
65        let device_result = rdmaxcel_sys::rdmaxcel_cuDeviceGet(&mut device, 0);
66
67        if device_result != rdmaxcel_sys::CUDA_SUCCESS {
68            return false;
69        }
70
71        true
72    }
73}
74
75#[cfg(test)]
76pub mod test_utils {
77    use std::time::Duration;
78    use std::time::Instant;
79
80    use hyperactor::Actor;
81    use hyperactor::ActorRef;
82    use hyperactor::Context;
83    use hyperactor::HandleClient;
84    use hyperactor::Handler;
85    use hyperactor::Instance;
86    use hyperactor::Proc;
87    use hyperactor::RefClient;
88    use hyperactor::channel::ChannelTransport;
89    use hyperactor::clock::Clock;
90    use hyperactor::clock::RealClock;
91    use hyperactor_mesh::Mesh;
92    use hyperactor_mesh::ProcMesh;
93    use hyperactor_mesh::RootActorMesh;
94    use hyperactor_mesh::alloc::AllocSpec;
95    use hyperactor_mesh::alloc::Allocator;
96    use hyperactor_mesh::alloc::LocalAllocator;
97    use ndslice::extent;
98
99    use crate::IbverbsConfig;
100    use crate::RdmaBuffer;
101    use crate::cu_check;
102    use crate::rdma_components::PollTarget;
103    use crate::rdma_components::RdmaQueuePair;
104    use crate::rdma_manager_actor::RdmaManagerActor;
105    use crate::rdma_manager_actor::RdmaManagerMessageClient;
106    use crate::validate_execution_context;
107
108    #[derive(Debug)]
109    struct SendSyncCudaContext(rdmaxcel_sys::CUcontext);
110    unsafe impl Send for SendSyncCudaContext {}
111    unsafe impl Sync for SendSyncCudaContext {}
112
113    /// Actor responsible for CUDA initialization and buffer management within its own process context.  
114    /// This is important because you preform CUDA operations within the same process as the RDMA operations.  
115    #[hyperactor::export(
116        spawn = true,
117        handlers = [
118            CudaActorMessage,
119        ],
120    )]
121    #[derive(Debug)]
122    pub struct CudaActor {
123        device: Option<i32>,
124        context: SendSyncCudaContext,
125    }
126
127    #[async_trait::async_trait]
128    impl Actor for CudaActor {
129        type Params = i32;
130
131        async fn new(device_id: i32) -> Result<Self, anyhow::Error> {
132            unsafe {
133                cu_check!(rdmaxcel_sys::rdmaxcel_cuInit(0));
134                let mut device: rdmaxcel_sys::CUdevice = std::mem::zeroed();
135                cu_check!(rdmaxcel_sys::rdmaxcel_cuDeviceGet(&mut device, device_id));
136                let mut context: rdmaxcel_sys::CUcontext = std::mem::zeroed();
137                cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxCreate_v2(
138                    &mut context,
139                    0,
140                    device_id
141                ));
142
143                Ok(Self {
144                    device: Some(device),
145                    context: SendSyncCudaContext(context),
146                })
147            }
148        }
149    }
150
151    #[derive(
152        Handler,
153        HandleClient,
154        RefClient,
155        hyperactor::Named,
156        serde::Serialize,
157        serde::Deserialize,
158        Debug
159    )]
160    pub enum CudaActorMessage {
161        CreateBuffer {
162            size: usize,
163            rdma_actor: ActorRef<RdmaManagerActor>,
164            #[reply]
165            reply: hyperactor::OncePortRef<(RdmaBuffer, usize)>,
166        },
167        FillBuffer {
168            device_ptr: usize,
169            size: usize,
170            value: u8,
171            #[reply]
172            reply: hyperactor::OncePortRef<()>,
173        },
174        VerifyBuffer {
175            cpu_buffer_ptr: usize,
176            device_ptr: usize,
177            size: usize,
178            #[reply]
179            reply: hyperactor::OncePortRef<()>,
180        },
181    }
182
183    #[async_trait::async_trait]
184    impl Handler<CudaActorMessage> for CudaActor {
185        async fn handle(
186            &mut self,
187            cx: &Context<Self>,
188            msg: CudaActorMessage,
189        ) -> Result<(), anyhow::Error> {
190            match msg {
191                CudaActorMessage::CreateBuffer {
192                    size,
193                    rdma_actor,
194                    reply,
195                } => {
196                    let device = self
197                        .device
198                        .ok_or_else(|| anyhow::anyhow!("Device not initialized"))?;
199
200                    let (dptr, padded_size) = unsafe {
201                        cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.context.0));
202
203                        let mut dptr: rdmaxcel_sys::CUdeviceptr = std::mem::zeroed();
204                        let mut handle: rdmaxcel_sys::CUmemGenericAllocationHandle =
205                            std::mem::zeroed();
206
207                        let mut granularity: usize = 0;
208                        let mut prop: rdmaxcel_sys::CUmemAllocationProp = std::mem::zeroed();
209                        prop.type_ = rdmaxcel_sys::CU_MEM_ALLOCATION_TYPE_PINNED;
210                        prop.location.type_ = rdmaxcel_sys::CU_MEM_LOCATION_TYPE_DEVICE;
211                        prop.location.id = device;
212                        prop.allocFlags.gpuDirectRDMACapable = 1;
213                        prop.requestedHandleTypes =
214                            rdmaxcel_sys::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
215
216                        cu_check!(rdmaxcel_sys::rdmaxcel_cuMemGetAllocationGranularity(
217                            &mut granularity as *mut usize,
218                            &prop,
219                            rdmaxcel_sys::CU_MEM_ALLOC_GRANULARITY_MINIMUM,
220                        ));
221
222                        let padded_size: usize = ((size - 1) / granularity + 1) * granularity;
223
224                        cu_check!(rdmaxcel_sys::rdmaxcel_cuMemCreate(
225                            &mut handle,
226                            padded_size,
227                            &prop,
228                            0
229                        ));
230
231                        cu_check!(rdmaxcel_sys::rdmaxcel_cuMemAddressReserve(
232                            &mut dptr,
233                            padded_size,
234                            0,
235                            0,
236                            0,
237                        ));
238
239                        assert!((dptr as usize).is_multiple_of(granularity));
240                        assert!(padded_size.is_multiple_of(granularity));
241
242                        let err = rdmaxcel_sys::rdmaxcel_cuMemMap(dptr, padded_size, 0, handle, 0);
243                        if err != rdmaxcel_sys::CUDA_SUCCESS {
244                            return Err(anyhow::anyhow!("Failed to map CUDA memory: {:?}", err));
245                        }
246
247                        let mut access_desc: rdmaxcel_sys::CUmemAccessDesc = std::mem::zeroed();
248                        access_desc.location.type_ = rdmaxcel_sys::CU_MEM_LOCATION_TYPE_DEVICE;
249                        access_desc.location.id = device;
250                        access_desc.flags = rdmaxcel_sys::CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
251                        cu_check!(rdmaxcel_sys::rdmaxcel_cuMemSetAccess(
252                            dptr,
253                            padded_size,
254                            &access_desc,
255                            1
256                        ));
257
258                        (dptr, padded_size)
259                    };
260
261                    let rdma_handle = rdma_actor
262                        .request_buffer(cx, dptr as usize, padded_size)
263                        .await?;
264
265                    reply.send(cx, (rdma_handle, dptr as usize))?;
266                    Ok(())
267                }
268                CudaActorMessage::FillBuffer {
269                    device_ptr,
270                    size,
271                    value,
272                    reply,
273                } => {
274                    unsafe {
275                        cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.context.0));
276
277                        cu_check!(rdmaxcel_sys::rdmaxcel_cuMemsetD8_v2(
278                            device_ptr as rdmaxcel_sys::CUdeviceptr,
279                            value,
280                            size
281                        ));
282                    }
283
284                    reply.send(cx, ())?;
285                    Ok(())
286                }
287                CudaActorMessage::VerifyBuffer {
288                    cpu_buffer_ptr,
289                    device_ptr,
290                    size,
291                    reply,
292                } => {
293                    unsafe {
294                        cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.context.0));
295
296                        cu_check!(rdmaxcel_sys::rdmaxcel_cuMemcpyDtoH_v2(
297                            cpu_buffer_ptr as *mut std::ffi::c_void,
298                            device_ptr as rdmaxcel_sys::CUdeviceptr,
299                            size
300                        ));
301                    }
302
303                    reply.send(cx, ())?;
304                    Ok(())
305                }
306            }
307        }
308    }
309
310    /// Waits for the completion of RDMA operations.
311    ///
312    /// This function polls for the completion of RDMA operations by repeatedly
313    /// checking the completion queue until all expected work requests complete
314    /// or the specified timeout is reached.
315    ///
316    /// # Arguments
317    /// * `qp` - The RDMA Queue Pair to poll for completion
318    /// * `poll_target` - Which CQ to poll (Send or Recv)
319    /// * `expected_wr_ids` - Slice of work request IDs to wait for
320    /// * `timeout_secs` - Timeout in seconds
321    ///
322    /// # Returns
323    /// `Ok(true)` if all operations complete successfully within the timeout,
324    /// or an error if the timeout is reached
325    pub async fn wait_for_completion(
326        qp: &mut RdmaQueuePair,
327        poll_target: PollTarget,
328        expected_wr_ids: &[u64],
329        timeout_secs: u64,
330    ) -> Result<bool, anyhow::Error> {
331        let timeout = Duration::from_secs(timeout_secs);
332        let start_time = Instant::now();
333
334        let mut remaining: std::collections::HashSet<u64> =
335            expected_wr_ids.iter().copied().collect();
336
337        while start_time.elapsed() < timeout {
338            if remaining.is_empty() {
339                return Ok(true);
340            }
341
342            let wr_ids_to_poll: Vec<u64> = remaining.iter().copied().collect();
343            match qp.poll_completion(poll_target, &wr_ids_to_poll) {
344                Ok(completions) => {
345                    for (wr_id, _wc) in completions {
346                        remaining.remove(&wr_id);
347                    }
348                    if remaining.is_empty() {
349                        return Ok(true);
350                    }
351                    RealClock.sleep(Duration::from_millis(1)).await;
352                }
353                Err(e) => {
354                    return Err(anyhow::anyhow!(e));
355                }
356            }
357        }
358        Err(anyhow::Error::msg(format!(
359            "Timeout while waiting for completion of wr_ids: {:?}",
360            remaining
361        )))
362    }
363
364    /// Posts a work request to the send queue of the given RDMA queue pair.
365    pub async fn send_wqe_gpu(
366        qp: &mut RdmaQueuePair,
367        lhandle: &RdmaBuffer,
368        rhandle: &RdmaBuffer,
369        op_type: u32,
370    ) -> Result<(), anyhow::Error> {
371        unsafe {
372            let ibv_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
373            let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
374            let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(ibv_qp);
375            let params = rdmaxcel_sys::wqe_params_t {
376                laddr: lhandle.addr,
377                length: lhandle.size,
378                lkey: lhandle.lkey,
379                wr_id: send_wqe_idx,
380                signaled: true,
381                op_type,
382                raddr: rhandle.addr,
383                rkey: rhandle.rkey,
384                qp_num: (*(*ibv_qp).ibv_qp).qp_num,
385                buf: (*dv_qp).sq.buf as *mut u8,
386                wqe_cnt: (*dv_qp).sq.wqe_cnt,
387                dbrec: (*dv_qp).dbrec,
388                ..Default::default()
389            };
390            rdmaxcel_sys::launch_send_wqe(params);
391            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(ibv_qp);
392        }
393        Ok(())
394    }
395
396    /// Posts a work request to the receive queue of the given RDMA queue pair.
397    pub async fn recv_wqe_gpu(
398        qp: &mut RdmaQueuePair,
399        lhandle: &RdmaBuffer,
400        _rhandle: &RdmaBuffer,
401        op_type: u32,
402    ) -> Result<(), anyhow::Error> {
403        // Populate params using lhandle and rhandle
404        unsafe {
405            let rdmaxcel_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
406            let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
407            let recv_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_recv_wqe_idx(rdmaxcel_qp);
408            let params = rdmaxcel_sys::wqe_params_t {
409                laddr: lhandle.addr,
410                length: lhandle.size,
411                lkey: lhandle.lkey,
412                wr_id: recv_wqe_idx,
413                op_type,
414                signaled: true,
415                qp_num: (*(*rdmaxcel_qp).ibv_qp).qp_num,
416                buf: (*dv_qp).rq.buf as *mut u8,
417                wqe_cnt: (*dv_qp).rq.wqe_cnt,
418                dbrec: (*dv_qp).dbrec,
419                ..Default::default()
420            };
421            rdmaxcel_sys::launch_recv_wqe(params);
422            rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(rdmaxcel_qp);
423            rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(rdmaxcel_qp);
424        }
425        Ok(())
426    }
427
428    pub async fn ring_db_gpu(qp: &RdmaQueuePair) -> Result<(), anyhow::Error> {
429        RealClock.sleep(Duration::from_millis(2)).await;
430        unsafe {
431            let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
432            let base_ptr = (*dv_qp).sq.buf as *mut u8;
433            let wqe_cnt = (*dv_qp).sq.wqe_cnt;
434            let stride = (*dv_qp).sq.stride;
435            let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(
436                qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
437            );
438            let mut send_db_idx =
439                rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp);
440            if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) {
441                return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
442            }
443            while send_db_idx < send_wqe_idx {
444                let offset = (send_db_idx % wqe_cnt as u64) * stride as u64;
445                let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
446                rdmaxcel_sys::launch_db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
447                send_db_idx += 1;
448                rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx(
449                    qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
450                    send_db_idx,
451                );
452            }
453        }
454        Ok(())
455    }
456
457    /// Wait for completion on a specific completion queue
458    pub async fn wait_for_completion_gpu(
459        qp: &mut RdmaQueuePair,
460        poll_target: PollTarget,
461        timeout_secs: u64,
462    ) -> Result<bool, anyhow::Error> {
463        unsafe {
464            let start_time = Instant::now();
465            let timeout = Duration::from_secs(timeout_secs);
466            let ibv_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
467
468            while start_time.elapsed() < timeout {
469                // Get the appropriate completion queue and index based on the poll target
470                let (cq, idx, cq_type_str) = match poll_target {
471                    PollTarget::Send => (
472                        qp.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq,
473                        rdmaxcel_sys::rdmaxcel_qp_load_send_cq_idx(ibv_qp),
474                        "send",
475                    ),
476                    PollTarget::Recv => (
477                        qp.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq,
478                        rdmaxcel_sys::rdmaxcel_qp_load_recv_cq_idx(ibv_qp),
479                        "receive",
480                    ),
481                };
482
483                // Poll the completion queue
484                let result = rdmaxcel_sys::launch_cqe_poll(cq as *mut std::ffi::c_void, idx as i32);
485
486                match result {
487                    rdmaxcel_sys::CQE_POLL_TRUE => {
488                        // Update the appropriate index based on the poll target
489                        match poll_target {
490                            PollTarget::Send => {
491                                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_cq_idx(ibv_qp);
492                            }
493                            PollTarget::Recv => {
494                                rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_cq_idx(ibv_qp);
495                            }
496                        }
497                        return Ok(true);
498                    }
499                    rdmaxcel_sys::CQE_POLL_ERROR => {
500                        return Err(anyhow::anyhow!("Error polling {} completion", cq_type_str));
501                    }
502                    _ => {
503                        // No completion yet, sleep and try again
504                        RealClock.sleep(Duration::from_millis(1)).await;
505                    }
506                }
507            }
508
509            Err(anyhow::Error::msg("Timeout while waiting for completion"))
510        }
511    }
512
513    pub struct RdmaManagerTestEnv<'a> {
514        buffer_1: Buffer,
515        buffer_2: Buffer,
516        pub client_1: &'a Instance<()>,
517        pub client_2: &'a Instance<()>,
518        pub actor_1: ActorRef<RdmaManagerActor>,
519        pub actor_2: ActorRef<RdmaManagerActor>,
520        pub rdma_handle_1: RdmaBuffer,
521        pub rdma_handle_2: RdmaBuffer,
522        cuda_actor_1: Option<ActorRef<CudaActor>>,
523        cuda_actor_2: Option<ActorRef<CudaActor>>,
524        device_ptr_1: Option<usize>,
525        device_ptr_2: Option<usize>,
526    }
527
528    #[derive(Debug, Clone)]
529    pub struct Buffer {
530        ptr: u64,
531        len: usize,
532        #[allow(dead_code)]
533        cpu_ref: Option<Box<[u8]>>,
534    }
535    /// Helper function to parse accelerator strings
536    async fn parse_accel(accel: &str, config: &mut IbverbsConfig) -> (String, usize) {
537        let (backend, idx) = accel.split_once(':').unwrap();
538        let parsed_idx = idx.parse::<usize>().unwrap();
539
540        if backend == "cuda" {
541            config.use_gpu_direct = validate_execution_context().await.is_ok();
542            eprintln!("Using GPU Direct: {}", config.use_gpu_direct);
543        }
544
545        (backend.to_string(), parsed_idx)
546    }
547
548    impl RdmaManagerTestEnv<'_> {
549        /// Sets up the RDMA test environment with a specified QP type.
550        ///
551        /// This function initializes the RDMA test environment by setting up two actor meshes
552        /// with their respective RDMA configurations. It also prepares two buffers for testing
553        /// RDMA operations and fills the first buffer with test data.
554        ///
555        /// # Arguments
556        ///
557        /// * `buffer_size` - The size of the buffers to be used in the test.
558        /// * `accel1` - Accelerator for first actor (e.g., "cpu:0", "cuda:0")
559        /// * `accel2` - Accelerator for second actor (e.g., "cpu:0", "cuda:1")
560        /// * `qp_type` - The queue pair type to use (Auto, Standard, or Mlx5dv)
561        pub async fn setup_with_qp_type(
562            buffer_size: usize,
563            accel1: &str,
564            accel2: &str,
565            qp_type: crate::ibverbs_primitives::RdmaQpType,
566        ) -> Result<Self, anyhow::Error> {
567            // Use device selection logic to find optimal RDMA devices
568            let mut config1 = IbverbsConfig::targeting(accel1);
569            let mut config2 = IbverbsConfig::targeting(accel2);
570
571            // Set the QP type
572            config1.qp_type = qp_type;
573            config2.qp_type = qp_type;
574
575            let parsed_accel1 = parse_accel(accel1, &mut config1).await;
576            let parsed_accel2 = parse_accel(accel2, &mut config2).await;
577
578            let alloc_1 = LocalAllocator
579                .allocate(AllocSpec {
580                    extent: extent! { proc = 1 },
581                    constraints: Default::default(),
582                    proc_name: None,
583                    transport: ChannelTransport::Local,
584                    proc_allocation_mode: Default::default(),
585                })
586                .await
587                .unwrap();
588
589            let (instance, _) = Proc::local().instance("test").unwrap();
590
591            let proc_mesh_1 = Box::leak(Box::new(ProcMesh::allocate(alloc_1).await.unwrap()));
592            let actor_mesh_1: RootActorMesh<'_, RdmaManagerActor> = proc_mesh_1
593                .spawn(&instance, "rdma_manager", &Some(config1))
594                .await
595                .unwrap();
596
597            let alloc_2 = LocalAllocator
598                .allocate(AllocSpec {
599                    extent: extent! { proc = 1 },
600                    constraints: Default::default(),
601                    proc_name: None,
602                    transport: ChannelTransport::Local,
603                    proc_allocation_mode: Default::default(),
604                })
605                .await
606                .unwrap();
607
608            let proc_mesh_2 = Box::leak(Box::new(ProcMesh::allocate(alloc_2).await.unwrap()));
609            let actor_mesh_2: RootActorMesh<'_, RdmaManagerActor> = proc_mesh_2
610                .spawn(&instance, "rdma_manager", &Some(config2))
611                .await
612                .unwrap();
613
614            let actor_1 = actor_mesh_1.get(0).unwrap();
615            let actor_2 = actor_mesh_2.get(0).unwrap();
616
617            let mut buf_vec = Vec::new();
618            let mut cuda_actor_1 = None;
619            let mut cuda_actor_2 = None;
620            let mut device_ptr_1: Option<usize> = None;
621            let mut device_ptr_2: Option<usize> = None;
622
623            let rdma_handle_1;
624            let rdma_handle_2;
625
626            // Process first accelerator
627            if parsed_accel1.0 == "cpu" {
628                let mut buffer = vec![0u8; buffer_size].into_boxed_slice();
629                buf_vec.push(Buffer {
630                    ptr: buffer.as_mut_ptr() as u64,
631                    len: buffer.len(),
632                    cpu_ref: Some(buffer),
633                });
634                rdma_handle_1 = actor_1
635                    .request_buffer(proc_mesh_1.client(), buf_vec[0].ptr as usize, buffer_size)
636                    .await?;
637            } else {
638                // CUDA case - spawn CudaActor in the same process mesh
639                let cuda_actor_mesh_1: RootActorMesh<'_, CudaActor> = proc_mesh_1
640                    .spawn(&instance, "cuda_init", &(parsed_accel1.1 as i32))
641                    .await?;
642                let cuda_actor_ref_1 = cuda_actor_mesh_1.get(0).unwrap();
643
644                let (rdma_buf, dev_ptr) = cuda_actor_ref_1
645                    .create_buffer(proc_mesh_1.client(), buffer_size, actor_1.clone())
646                    .await?;
647                rdma_handle_1 = rdma_buf;
648                device_ptr_1 = Some(dev_ptr);
649
650                buf_vec.push(Buffer {
651                    ptr: rdma_handle_1.addr as u64,
652                    len: buffer_size,
653                    cpu_ref: None,
654                });
655                cuda_actor_1 = Some(cuda_actor_ref_1);
656            }
657
658            // Process second accelerator
659            if parsed_accel2.0 == "cpu" {
660                let mut buffer = vec![0u8; buffer_size].into_boxed_slice();
661                buf_vec.push(Buffer {
662                    ptr: buffer.as_mut_ptr() as u64,
663                    len: buffer.len(),
664                    cpu_ref: Some(buffer),
665                });
666                rdma_handle_2 = actor_2
667                    .request_buffer(proc_mesh_2.client(), buf_vec[1].ptr as usize, buffer_size)
668                    .await?;
669            } else {
670                // CUDA case - spawn CudaActor in the same process mesh
671                let cuda_actor_mesh_2: RootActorMesh<'_, CudaActor> = proc_mesh_2
672                    .spawn(&instance, "cuda_init", &(parsed_accel2.1 as i32))
673                    .await?;
674                let cuda_actor_ref_2 = cuda_actor_mesh_2.get(0).unwrap();
675
676                let (rdma_buf, dev_ptr) = cuda_actor_ref_2
677                    .create_buffer(proc_mesh_2.client(), buffer_size, actor_2.clone())
678                    .await?;
679                rdma_handle_2 = rdma_buf;
680                device_ptr_2 = Some(dev_ptr);
681
682                buf_vec.push(Buffer {
683                    ptr: rdma_handle_2.addr as u64,
684                    len: buffer_size,
685                    cpu_ref: None,
686                });
687                cuda_actor_2 = Some(cuda_actor_ref_2);
688            }
689
690            // Fill buffer1 with test data
691            if parsed_accel1.0 == "cuda" {
692                cuda_actor_1
693                    .clone()
694                    .unwrap()
695                    .fill_buffer(proc_mesh_1.client(), device_ptr_1.unwrap(), buffer_size, 42)
696                    .await?;
697            } else {
698                unsafe {
699                    let ptr = buf_vec[0].ptr as *mut u8;
700                    for i in 0..buf_vec[0].len {
701                        *ptr.add(i) = 42_u8;
702                    }
703                }
704            }
705
706            let buffer_2 = buf_vec.remove(1);
707            let buffer_1 = buf_vec.remove(0);
708
709            Ok(Self {
710                buffer_1,
711                buffer_2,
712                client_1: proc_mesh_1.client(),
713                client_2: proc_mesh_2.client(),
714                actor_1,
715                actor_2,
716                rdma_handle_1,
717                rdma_handle_2,
718                cuda_actor_1,
719                cuda_actor_2,
720                device_ptr_1,
721                device_ptr_2,
722            })
723        }
724
725        pub async fn cleanup(self) -> Result<(), anyhow::Error> {
726            // Just release buffers from RDMA manager - CUDA cleanup happens automatically
727            self.actor_1
728                .release_buffer(self.client_1, self.rdma_handle_1.clone())
729                .await?;
730
731            self.actor_2
732                .release_buffer(self.client_2, self.rdma_handle_2.clone())
733                .await?;
734            Ok(())
735        }
736
737        /// Sets up the RDMA test environment with auto-detected QP type.
738        ///
739        /// This is a convenience wrapper around `setup_with_qp_type` that uses
740        /// `RdmaQpType::Auto` to automatically select the appropriate QP type.
741        ///
742        /// # Arguments
743        ///
744        /// * `buffer_size` - The size of the buffers to be used in the test.
745        /// * `accel1` - Accelerator for first actor (e.g., "cpu:0", "cuda:0")
746        /// * `accel2` - Accelerator for second actor (e.g., "cpu:0", "cuda:1")
747        pub async fn setup(
748            buffer_size: usize,
749            accel1: &str,
750            accel2: &str,
751        ) -> Result<Self, anyhow::Error> {
752            Self::setup_with_qp_type(
753                buffer_size,
754                accel1,
755                accel2,
756                crate::ibverbs_primitives::RdmaQpType::Auto,
757            )
758            .await
759        }
760
761        pub async fn verify_buffers(
762            &self,
763            size: usize,
764            offset: usize,
765        ) -> Result<(), anyhow::Error> {
766            let mut temp_buffer_1 = vec![0u8; size];
767            let mut temp_buffer_2 = vec![0u8; size];
768
769            // Read buffer 1
770            if let Some(cuda_actor) = &self.cuda_actor_1 {
771                cuda_actor
772                    .verify_buffer(
773                        self.client_1,
774                        temp_buffer_1.as_mut_ptr() as usize,
775                        self.device_ptr_1.unwrap() + offset,
776                        size,
777                    )
778                    .await?;
779            } else {
780                unsafe {
781                    std::ptr::copy_nonoverlapping(
782                        (self.buffer_1.ptr + offset as u64) as *const u8,
783                        temp_buffer_1.as_mut_ptr(),
784                        size,
785                    );
786                }
787            }
788
789            // Read buffer 2
790            if let Some(cuda_actor) = &self.cuda_actor_2 {
791                cuda_actor
792                    .verify_buffer(
793                        self.client_2,
794                        temp_buffer_2.as_mut_ptr() as usize,
795                        self.device_ptr_2.unwrap() + offset,
796                        size,
797                    )
798                    .await?;
799            } else {
800                unsafe {
801                    std::ptr::copy_nonoverlapping(
802                        (self.buffer_2.ptr + offset as u64) as *const u8,
803                        temp_buffer_2.as_mut_ptr(),
804                        size,
805                    );
806                }
807            }
808
809            // Compare buffers
810            for i in 0..size {
811                if temp_buffer_1[i] != temp_buffer_2[i] {
812                    return Err(anyhow::anyhow!(
813                        "Buffers are not equal at index {}",
814                        offset + i
815                    ));
816                }
817            }
818            Ok(())
819        }
820    }
821}