monarch_rdma/
test_utils.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::sync::Once;
10use std::sync::atomic::AtomicBool;
11use std::sync::atomic::Ordering;
12
13/// Cached result of CUDA availability check
14static CUDA_AVAILABLE: AtomicBool = AtomicBool::new(false);
15static INIT: Once = Once::new();
16
17/// Safely checks if CUDA is available on the system.
18///
19/// This function attempts to initialize CUDA and determine if it's available.
20/// The result is cached after the first call, so subsequent calls are very fast.
21///
22/// # Returns
23///
24/// `true` if CUDA is available and can be initialized, `false` otherwise.
25///
26/// # Examples
27///
28/// ```
29/// use monarch_rdma::is_cuda_available;
30///
31/// if is_cuda_available() {
32///     println!("CUDA is available, can use GPU features");
33/// } else {
34///     println!("CUDA is not available, falling back to CPU-only mode");
35/// }
36/// ```
37pub fn is_cuda_available() -> bool {
38    INIT.call_once(|| {
39        let available = check_cuda_available();
40        CUDA_AVAILABLE.store(available, Ordering::SeqCst);
41    });
42    CUDA_AVAILABLE.load(Ordering::SeqCst)
43}
44
45/// Internal function that performs the actual CUDA availability check
46fn check_cuda_available() -> bool {
47    unsafe {
48        // Try to initialize CUDA
49        let result = rdmaxcel_sys::rdmaxcel_cuInit(0);
50
51        if result != rdmaxcel_sys::CUDA_SUCCESS {
52            return false;
53        }
54
55        // Check if there are any CUDA devices
56        let mut device_count: i32 = 0;
57        let count_result = rdmaxcel_sys::rdmaxcel_cuDeviceGetCount(&mut device_count);
58
59        if count_result != rdmaxcel_sys::CUDA_SUCCESS || device_count <= 0 {
60            return false;
61        }
62
63        // Try to get the first device to verify it's actually accessible
64        let mut device: rdmaxcel_sys::CUdevice = std::mem::zeroed();
65        let device_result = rdmaxcel_sys::rdmaxcel_cuDeviceGet(&mut device, 0);
66
67        if device_result != rdmaxcel_sys::CUDA_SUCCESS {
68            return false;
69        }
70
71        true
72    }
73}
74
75#[cfg(test)]
76pub mod test_utils {
77    use std::time::Duration;
78    use std::time::Instant;
79
80    use hyperactor::Actor;
81    use hyperactor::ActorRef;
82    use hyperactor::Context;
83    use hyperactor::HandleClient;
84    use hyperactor::Handler;
85    use hyperactor::Instance;
86    use hyperactor::RefClient;
87    use hyperactor::RemoteSpawn;
88    use hyperactor::channel::ChannelTransport;
89    use hyperactor::clock::Clock;
90    use hyperactor::clock::RealClock;
91    use hyperactor_mesh::Mesh;
92    use hyperactor_mesh::ProcMesh;
93    use hyperactor_mesh::RootActorMesh;
94    use hyperactor_mesh::alloc::AllocSpec;
95    use hyperactor_mesh::alloc::Allocator;
96    use hyperactor_mesh::alloc::LocalAllocator;
97    use hyperactor_mesh::proc_mesh::global_root_client;
98    use ndslice::extent;
99
100    use crate::IbverbsConfig;
101    use crate::RdmaBuffer;
102    use crate::cu_check;
103    use crate::rdma_components::PollTarget;
104    use crate::rdma_components::RdmaQueuePair;
105    use crate::rdma_manager_actor::RdmaManagerActor;
106    use crate::rdma_manager_actor::RdmaManagerMessageClient;
107    use crate::validate_execution_context;
108
109    #[derive(Debug)]
110    struct SendSyncCudaContext(rdmaxcel_sys::CUcontext);
111    unsafe impl Send for SendSyncCudaContext {}
112    unsafe impl Sync for SendSyncCudaContext {}
113
114    /// Actor responsible for CUDA initialization and buffer management within its own process context.  
115    /// This is important because you preform CUDA operations within the same process as the RDMA operations.  
116    #[hyperactor::export(
117        spawn = true,
118        handlers = [
119            CudaActorMessage,
120        ],
121    )]
122    #[derive(Debug)]
123    pub struct CudaActor {
124        device: Option<i32>,
125        context: SendSyncCudaContext,
126    }
127
128    impl Actor for CudaActor {}
129
130    #[async_trait::async_trait]
131    impl RemoteSpawn for CudaActor {
132        type Params = i32;
133
134        async fn new(device_id: i32) -> Result<Self, anyhow::Error> {
135            unsafe {
136                cu_check!(rdmaxcel_sys::rdmaxcel_cuInit(0));
137                let mut device: rdmaxcel_sys::CUdevice = std::mem::zeroed();
138                cu_check!(rdmaxcel_sys::rdmaxcel_cuDeviceGet(&mut device, device_id));
139                let mut context: rdmaxcel_sys::CUcontext = std::mem::zeroed();
140                cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxCreate_v2(
141                    &mut context,
142                    0,
143                    device_id
144                ));
145
146                Ok(Self {
147                    device: Some(device),
148                    context: SendSyncCudaContext(context),
149                })
150            }
151        }
152    }
153
154    #[derive(
155        Handler,
156        HandleClient,
157        RefClient,
158        typeuri::Named,
159        serde::Serialize,
160        serde::Deserialize,
161        Debug
162    )]
163    pub enum CudaActorMessage {
164        CreateBuffer {
165            size: usize,
166            rdma_actor: ActorRef<RdmaManagerActor>,
167            #[reply]
168            reply: hyperactor::OncePortRef<(RdmaBuffer, usize)>,
169        },
170        FillBuffer {
171            device_ptr: usize,
172            size: usize,
173            value: u8,
174            #[reply]
175            reply: hyperactor::OncePortRef<()>,
176        },
177        VerifyBuffer {
178            cpu_buffer_ptr: usize,
179            device_ptr: usize,
180            size: usize,
181            #[reply]
182            reply: hyperactor::OncePortRef<()>,
183        },
184    }
185
186    #[async_trait::async_trait]
187    impl Handler<CudaActorMessage> for CudaActor {
188        async fn handle(
189            &mut self,
190            cx: &Context<Self>,
191            msg: CudaActorMessage,
192        ) -> Result<(), anyhow::Error> {
193            match msg {
194                CudaActorMessage::CreateBuffer {
195                    size,
196                    rdma_actor,
197                    reply,
198                } => {
199                    let device = self
200                        .device
201                        .ok_or_else(|| anyhow::anyhow!("Device not initialized"))?;
202
203                    let (dptr, padded_size) = unsafe {
204                        cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.context.0));
205
206                        let mut dptr: rdmaxcel_sys::CUdeviceptr = std::mem::zeroed();
207                        let mut handle: rdmaxcel_sys::CUmemGenericAllocationHandle =
208                            std::mem::zeroed();
209
210                        let mut granularity: usize = 0;
211                        let mut prop: rdmaxcel_sys::CUmemAllocationProp = std::mem::zeroed();
212                        prop.type_ = rdmaxcel_sys::CU_MEM_ALLOCATION_TYPE_PINNED;
213                        prop.location.type_ = rdmaxcel_sys::CU_MEM_LOCATION_TYPE_DEVICE;
214                        prop.location.id = device;
215                        prop.allocFlags.gpuDirectRDMACapable = 1;
216                        prop.requestedHandleTypes =
217                            rdmaxcel_sys::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
218
219                        cu_check!(rdmaxcel_sys::rdmaxcel_cuMemGetAllocationGranularity(
220                            &mut granularity as *mut usize,
221                            &prop,
222                            rdmaxcel_sys::CU_MEM_ALLOC_GRANULARITY_MINIMUM,
223                        ));
224
225                        let padded_size: usize = ((size - 1) / granularity + 1) * granularity;
226
227                        cu_check!(rdmaxcel_sys::rdmaxcel_cuMemCreate(
228                            &mut handle,
229                            padded_size,
230                            &prop,
231                            0
232                        ));
233
234                        cu_check!(rdmaxcel_sys::rdmaxcel_cuMemAddressReserve(
235                            &mut dptr,
236                            padded_size,
237                            0,
238                            0,
239                            0,
240                        ));
241
242                        assert!((dptr as usize).is_multiple_of(granularity));
243                        assert!(padded_size.is_multiple_of(granularity));
244
245                        let err = rdmaxcel_sys::rdmaxcel_cuMemMap(dptr, padded_size, 0, handle, 0);
246                        if err != rdmaxcel_sys::CUDA_SUCCESS {
247                            return Err(anyhow::anyhow!("Failed to map CUDA memory: {:?}", err));
248                        }
249
250                        let mut access_desc: rdmaxcel_sys::CUmemAccessDesc = std::mem::zeroed();
251                        access_desc.location.type_ = rdmaxcel_sys::CU_MEM_LOCATION_TYPE_DEVICE;
252                        access_desc.location.id = device;
253                        access_desc.flags = rdmaxcel_sys::CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
254                        cu_check!(rdmaxcel_sys::rdmaxcel_cuMemSetAccess(
255                            dptr,
256                            padded_size,
257                            &access_desc,
258                            1
259                        ));
260
261                        (dptr, padded_size)
262                    };
263
264                    let rdma_handle = rdma_actor
265                        .request_buffer(cx, dptr as usize, padded_size)
266                        .await?;
267
268                    reply.send(cx, (rdma_handle, dptr as usize))?;
269                    Ok(())
270                }
271                CudaActorMessage::FillBuffer {
272                    device_ptr,
273                    size,
274                    value,
275                    reply,
276                } => {
277                    unsafe {
278                        cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.context.0));
279
280                        cu_check!(rdmaxcel_sys::rdmaxcel_cuMemsetD8_v2(
281                            device_ptr as rdmaxcel_sys::CUdeviceptr,
282                            value,
283                            size
284                        ));
285                    }
286
287                    reply.send(cx, ())?;
288                    Ok(())
289                }
290                CudaActorMessage::VerifyBuffer {
291                    cpu_buffer_ptr,
292                    device_ptr,
293                    size,
294                    reply,
295                } => {
296                    unsafe {
297                        cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.context.0));
298
299                        cu_check!(rdmaxcel_sys::rdmaxcel_cuMemcpyDtoH_v2(
300                            cpu_buffer_ptr as *mut std::ffi::c_void,
301                            device_ptr as rdmaxcel_sys::CUdeviceptr,
302                            size
303                        ));
304                    }
305
306                    reply.send(cx, ())?;
307                    Ok(())
308                }
309            }
310        }
311    }
312
313    /// Waits for the completion of RDMA operations.
314    ///
315    /// This function polls for the completion of RDMA operations by repeatedly
316    /// checking the completion queue until all expected work requests complete
317    /// or the specified timeout is reached.
318    ///
319    /// # Arguments
320    /// * `qp` - The RDMA Queue Pair to poll for completion
321    /// * `poll_target` - Which CQ to poll (Send or Recv)
322    /// * `expected_wr_ids` - Slice of work request IDs to wait for
323    /// * `timeout_secs` - Timeout in seconds
324    ///
325    /// # Returns
326    /// `Ok(true)` if all operations complete successfully within the timeout,
327    /// or an error if the timeout is reached
328    pub async fn wait_for_completion(
329        qp: &mut RdmaQueuePair,
330        poll_target: PollTarget,
331        expected_wr_ids: &[u64],
332        timeout_secs: u64,
333    ) -> Result<bool, anyhow::Error> {
334        let timeout = Duration::from_secs(timeout_secs);
335        let start_time = Instant::now();
336
337        let mut remaining: std::collections::HashSet<u64> =
338            expected_wr_ids.iter().copied().collect();
339
340        while start_time.elapsed() < timeout {
341            if remaining.is_empty() {
342                return Ok(true);
343            }
344
345            let wr_ids_to_poll: Vec<u64> = remaining.iter().copied().collect();
346            match qp.poll_completion(poll_target, &wr_ids_to_poll) {
347                Ok(completions) => {
348                    for (wr_id, _wc) in completions {
349                        remaining.remove(&wr_id);
350                    }
351                    if remaining.is_empty() {
352                        return Ok(true);
353                    }
354                    RealClock.sleep(Duration::from_millis(1)).await;
355                }
356                Err(e) => {
357                    return Err(anyhow::anyhow!(e));
358                }
359            }
360        }
361        Err(anyhow::Error::msg(format!(
362            "Timeout while waiting for completion of wr_ids: {:?}",
363            remaining
364        )))
365    }
366
367    /// Posts a work request to the send queue of the given RDMA queue pair.
368    pub async fn send_wqe_gpu(
369        qp: &mut RdmaQueuePair,
370        lhandle: &RdmaBuffer,
371        rhandle: &RdmaBuffer,
372        op_type: u32,
373    ) -> Result<(), anyhow::Error> {
374        unsafe {
375            let ibv_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
376            let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
377            let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(ibv_qp);
378            let params = rdmaxcel_sys::wqe_params_t {
379                laddr: lhandle.addr,
380                length: lhandle.size,
381                lkey: lhandle.lkey,
382                wr_id: send_wqe_idx,
383                signaled: true,
384                op_type,
385                raddr: rhandle.addr,
386                rkey: rhandle.rkey,
387                qp_num: (*(*ibv_qp).ibv_qp).qp_num,
388                buf: (*dv_qp).sq.buf as *mut u8,
389                wqe_cnt: (*dv_qp).sq.wqe_cnt,
390                dbrec: (*dv_qp).dbrec,
391                ..Default::default()
392            };
393            rdmaxcel_sys::launch_send_wqe(params);
394            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(ibv_qp);
395        }
396        Ok(())
397    }
398
399    /// Posts a work request to the receive queue of the given RDMA queue pair.
400    pub async fn recv_wqe_gpu(
401        qp: &mut RdmaQueuePair,
402        lhandle: &RdmaBuffer,
403        _rhandle: &RdmaBuffer,
404        op_type: u32,
405    ) -> Result<(), anyhow::Error> {
406        // Populate params using lhandle and rhandle
407        unsafe {
408            let rdmaxcel_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
409            let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
410            let recv_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_recv_wqe_idx(rdmaxcel_qp);
411            let params = rdmaxcel_sys::wqe_params_t {
412                laddr: lhandle.addr,
413                length: lhandle.size,
414                lkey: lhandle.lkey,
415                wr_id: recv_wqe_idx,
416                op_type,
417                signaled: true,
418                qp_num: (*(*rdmaxcel_qp).ibv_qp).qp_num,
419                buf: (*dv_qp).rq.buf as *mut u8,
420                wqe_cnt: (*dv_qp).rq.wqe_cnt,
421                dbrec: (*dv_qp).dbrec,
422                ..Default::default()
423            };
424            rdmaxcel_sys::launch_recv_wqe(params);
425            rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(rdmaxcel_qp);
426            rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(rdmaxcel_qp);
427        }
428        Ok(())
429    }
430
431    pub async fn ring_db_gpu(qp: &RdmaQueuePair) -> Result<(), anyhow::Error> {
432        RealClock.sleep(Duration::from_millis(2)).await;
433        unsafe {
434            let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
435            let base_ptr = (*dv_qp).sq.buf as *mut u8;
436            let wqe_cnt = (*dv_qp).sq.wqe_cnt;
437            let stride = (*dv_qp).sq.stride;
438            let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(
439                qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
440            );
441            let mut send_db_idx =
442                rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp);
443            if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) {
444                return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
445            }
446            while send_db_idx < send_wqe_idx {
447                let offset = (send_db_idx % wqe_cnt as u64) * stride as u64;
448                let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
449                rdmaxcel_sys::launch_db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
450                send_db_idx += 1;
451                rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx(
452                    qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
453                    send_db_idx,
454                );
455            }
456        }
457        Ok(())
458    }
459
460    /// Wait for completion on a specific completion queue
461    pub async fn wait_for_completion_gpu(
462        qp: &mut RdmaQueuePair,
463        poll_target: PollTarget,
464        timeout_secs: u64,
465    ) -> Result<bool, anyhow::Error> {
466        unsafe {
467            let start_time = Instant::now();
468            let timeout = Duration::from_secs(timeout_secs);
469            let ibv_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
470
471            while start_time.elapsed() < timeout {
472                // Get the appropriate completion queue and index based on the poll target
473                let (cq, idx, cq_type_str) = match poll_target {
474                    PollTarget::Send => (
475                        qp.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq,
476                        rdmaxcel_sys::rdmaxcel_qp_load_send_cq_idx(ibv_qp),
477                        "send",
478                    ),
479                    PollTarget::Recv => (
480                        qp.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq,
481                        rdmaxcel_sys::rdmaxcel_qp_load_recv_cq_idx(ibv_qp),
482                        "receive",
483                    ),
484                };
485
486                // Poll the completion queue
487                let result = rdmaxcel_sys::launch_cqe_poll(cq as *mut std::ffi::c_void, idx as i32);
488
489                match result {
490                    rdmaxcel_sys::CQE_POLL_TRUE => {
491                        // Update the appropriate index based on the poll target
492                        match poll_target {
493                            PollTarget::Send => {
494                                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_cq_idx(ibv_qp);
495                            }
496                            PollTarget::Recv => {
497                                rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_cq_idx(ibv_qp);
498                            }
499                        }
500                        return Ok(true);
501                    }
502                    rdmaxcel_sys::CQE_POLL_ERROR => {
503                        return Err(anyhow::anyhow!("Error polling {} completion", cq_type_str));
504                    }
505                    _ => {
506                        // No completion yet, sleep and try again
507                        RealClock.sleep(Duration::from_millis(1)).await;
508                    }
509                }
510            }
511
512            Err(anyhow::Error::msg("Timeout while waiting for completion"))
513        }
514    }
515
516    pub struct RdmaManagerTestEnv<'a> {
517        buffer_1: Buffer,
518        buffer_2: Buffer,
519        pub client_1: &'a Instance<()>,
520        pub client_2: &'a Instance<()>,
521        pub actor_1: ActorRef<RdmaManagerActor>,
522        pub actor_2: ActorRef<RdmaManagerActor>,
523        pub rdma_handle_1: RdmaBuffer,
524        pub rdma_handle_2: RdmaBuffer,
525        cuda_actor_1: Option<ActorRef<CudaActor>>,
526        cuda_actor_2: Option<ActorRef<CudaActor>>,
527        device_ptr_1: Option<usize>,
528        device_ptr_2: Option<usize>,
529    }
530
531    #[derive(Debug, Clone)]
532    pub struct Buffer {
533        ptr: u64,
534        len: usize,
535        #[allow(dead_code)]
536        cpu_ref: Option<Box<[u8]>>,
537    }
538    /// Helper function to parse accelerator strings
539    async fn parse_accel(accel: &str, config: &mut IbverbsConfig) -> (String, usize) {
540        let (backend, idx) = accel.split_once(':').unwrap();
541        let parsed_idx = idx.parse::<usize>().unwrap();
542
543        if backend == "cuda" {
544            config.use_gpu_direct = validate_execution_context().await.is_ok();
545            eprintln!("Using GPU Direct: {}", config.use_gpu_direct);
546        }
547
548        (backend.to_string(), parsed_idx)
549    }
550
551    impl RdmaManagerTestEnv<'_> {
552        /// Sets up the RDMA test environment with a specified QP type.
553        ///
554        /// This function initializes the RDMA test environment by setting up two actor meshes
555        /// with their respective RDMA configurations. It also prepares two buffers for testing
556        /// RDMA operations and fills the first buffer with test data.
557        ///
558        /// # Arguments
559        ///
560        /// * `buffer_size` - The size of the buffers to be used in the test.
561        /// * `accel1` - Accelerator for first actor (e.g., "cpu:0", "cuda:0")
562        /// * `accel2` - Accelerator for second actor (e.g., "cpu:0", "cuda:1")
563        /// * `qp_type` - The queue pair type to use (Auto, Standard, or Mlx5dv)
564        pub async fn setup_with_qp_type(
565            buffer_size: usize,
566            accel1: &str,
567            accel2: &str,
568            qp_type: crate::ibverbs_primitives::RdmaQpType,
569        ) -> Result<Self, anyhow::Error> {
570            // Use device selection logic to find optimal RDMA devices
571            let mut config1 = IbverbsConfig::targeting(accel1);
572            let mut config2 = IbverbsConfig::targeting(accel2);
573
574            // Set the QP type
575            config1.qp_type = qp_type;
576            config2.qp_type = qp_type;
577
578            let parsed_accel1 = parse_accel(accel1, &mut config1).await;
579            let parsed_accel2 = parse_accel(accel2, &mut config2).await;
580
581            let alloc_1 = LocalAllocator
582                .allocate(AllocSpec {
583                    extent: extent! { proc = 1 },
584                    constraints: Default::default(),
585                    proc_name: None,
586                    transport: ChannelTransport::Local,
587                    proc_allocation_mode: Default::default(),
588                })
589                .await
590                .unwrap();
591
592            let instance = global_root_client();
593
594            let proc_mesh_1 = Box::leak(Box::new(ProcMesh::allocate(alloc_1).await.unwrap()));
595            let actor_mesh_1: RootActorMesh<'_, RdmaManagerActor> = proc_mesh_1
596                .spawn(&instance, "rdma_manager", &Some(config1))
597                .await
598                .unwrap();
599
600            let alloc_2 = LocalAllocator
601                .allocate(AllocSpec {
602                    extent: extent! { proc = 1 },
603                    constraints: Default::default(),
604                    proc_name: None,
605                    transport: ChannelTransport::Local,
606                    proc_allocation_mode: Default::default(),
607                })
608                .await
609                .unwrap();
610
611            let proc_mesh_2 = Box::leak(Box::new(ProcMesh::allocate(alloc_2).await.unwrap()));
612            let actor_mesh_2: RootActorMesh<'_, RdmaManagerActor> = proc_mesh_2
613                .spawn(&instance, "rdma_manager", &Some(config2))
614                .await
615                .unwrap();
616
617            let actor_1 = actor_mesh_1.get(0).unwrap();
618            let actor_2 = actor_mesh_2.get(0).unwrap();
619
620            let mut buf_vec = Vec::new();
621            let mut cuda_actor_1 = None;
622            let mut cuda_actor_2 = None;
623            let mut device_ptr_1: Option<usize> = None;
624            let mut device_ptr_2: Option<usize> = None;
625
626            let rdma_handle_1;
627            let rdma_handle_2;
628
629            // Process first accelerator
630            if parsed_accel1.0 == "cpu" {
631                let mut buffer = vec![0u8; buffer_size].into_boxed_slice();
632                buf_vec.push(Buffer {
633                    ptr: buffer.as_mut_ptr() as u64,
634                    len: buffer.len(),
635                    cpu_ref: Some(buffer),
636                });
637                rdma_handle_1 = actor_1
638                    .request_buffer(proc_mesh_1.client(), buf_vec[0].ptr as usize, buffer_size)
639                    .await?;
640            } else {
641                // CUDA case - spawn CudaActor in the same process mesh
642                let cuda_actor_mesh_1: RootActorMesh<'_, CudaActor> = proc_mesh_1
643                    .spawn(&instance, "cuda_init", &(parsed_accel1.1 as i32))
644                    .await?;
645                let cuda_actor_ref_1 = cuda_actor_mesh_1.get(0).unwrap();
646
647                let (rdma_buf, dev_ptr) = cuda_actor_ref_1
648                    .create_buffer(proc_mesh_1.client(), buffer_size, actor_1.clone())
649                    .await?;
650                rdma_handle_1 = rdma_buf;
651                device_ptr_1 = Some(dev_ptr);
652
653                buf_vec.push(Buffer {
654                    ptr: rdma_handle_1.addr as u64,
655                    len: buffer_size,
656                    cpu_ref: None,
657                });
658                cuda_actor_1 = Some(cuda_actor_ref_1);
659            }
660
661            // Process second accelerator
662            if parsed_accel2.0 == "cpu" {
663                let mut buffer = vec![0u8; buffer_size].into_boxed_slice();
664                buf_vec.push(Buffer {
665                    ptr: buffer.as_mut_ptr() as u64,
666                    len: buffer.len(),
667                    cpu_ref: Some(buffer),
668                });
669                rdma_handle_2 = actor_2
670                    .request_buffer(proc_mesh_2.client(), buf_vec[1].ptr as usize, buffer_size)
671                    .await?;
672            } else {
673                // CUDA case - spawn CudaActor in the same process mesh
674                let cuda_actor_mesh_2: RootActorMesh<'_, CudaActor> = proc_mesh_2
675                    .spawn(&instance, "cuda_init", &(parsed_accel2.1 as i32))
676                    .await?;
677                let cuda_actor_ref_2 = cuda_actor_mesh_2.get(0).unwrap();
678
679                let (rdma_buf, dev_ptr) = cuda_actor_ref_2
680                    .create_buffer(proc_mesh_2.client(), buffer_size, actor_2.clone())
681                    .await?;
682                rdma_handle_2 = rdma_buf;
683                device_ptr_2 = Some(dev_ptr);
684
685                buf_vec.push(Buffer {
686                    ptr: rdma_handle_2.addr as u64,
687                    len: buffer_size,
688                    cpu_ref: None,
689                });
690                cuda_actor_2 = Some(cuda_actor_ref_2);
691            }
692
693            // Fill buffer1 with test data
694            if parsed_accel1.0 == "cuda" {
695                cuda_actor_1
696                    .clone()
697                    .unwrap()
698                    .fill_buffer(proc_mesh_1.client(), device_ptr_1.unwrap(), buffer_size, 42)
699                    .await?;
700            } else {
701                unsafe {
702                    let ptr = buf_vec[0].ptr as *mut u8;
703                    for i in 0..buf_vec[0].len {
704                        *ptr.add(i) = 42_u8;
705                    }
706                }
707            }
708
709            let buffer_2 = buf_vec.remove(1);
710            let buffer_1 = buf_vec.remove(0);
711
712            Ok(Self {
713                buffer_1,
714                buffer_2,
715                client_1: proc_mesh_1.client(),
716                client_2: proc_mesh_2.client(),
717                actor_1,
718                actor_2,
719                rdma_handle_1,
720                rdma_handle_2,
721                cuda_actor_1,
722                cuda_actor_2,
723                device_ptr_1,
724                device_ptr_2,
725            })
726        }
727
728        pub async fn cleanup(self) -> Result<(), anyhow::Error> {
729            // Just release buffers from RDMA manager - CUDA cleanup happens automatically
730            self.actor_1
731                .release_buffer(self.client_1, self.rdma_handle_1.clone())
732                .await?;
733
734            self.actor_2
735                .release_buffer(self.client_2, self.rdma_handle_2.clone())
736                .await?;
737            Ok(())
738        }
739
740        /// Sets up the RDMA test environment with auto-detected QP type.
741        ///
742        /// This is a convenience wrapper around `setup_with_qp_type` that uses
743        /// `RdmaQpType::Auto` to automatically select the appropriate QP type.
744        ///
745        /// # Arguments
746        ///
747        /// * `buffer_size` - The size of the buffers to be used in the test.
748        /// * `accel1` - Accelerator for first actor (e.g., "cpu:0", "cuda:0")
749        /// * `accel2` - Accelerator for second actor (e.g., "cpu:0", "cuda:1")
750        pub async fn setup(
751            buffer_size: usize,
752            accel1: &str,
753            accel2: &str,
754        ) -> Result<Self, anyhow::Error> {
755            Self::setup_with_qp_type(
756                buffer_size,
757                accel1,
758                accel2,
759                crate::ibverbs_primitives::RdmaQpType::Auto,
760            )
761            .await
762        }
763
764        pub async fn verify_buffers(
765            &self,
766            size: usize,
767            offset: usize,
768        ) -> Result<(), anyhow::Error> {
769            let mut temp_buffer_1 = vec![0u8; size];
770            let mut temp_buffer_2 = vec![0u8; size];
771
772            // Read buffer 1
773            if let Some(cuda_actor) = &self.cuda_actor_1 {
774                cuda_actor
775                    .verify_buffer(
776                        self.client_1,
777                        temp_buffer_1.as_mut_ptr() as usize,
778                        self.device_ptr_1.unwrap() + offset,
779                        size,
780                    )
781                    .await?;
782            } else {
783                unsafe {
784                    std::ptr::copy_nonoverlapping(
785                        (self.buffer_1.ptr + offset as u64) as *const u8,
786                        temp_buffer_1.as_mut_ptr(),
787                        size,
788                    );
789                }
790            }
791
792            // Read buffer 2
793            if let Some(cuda_actor) = &self.cuda_actor_2 {
794                cuda_actor
795                    .verify_buffer(
796                        self.client_2,
797                        temp_buffer_2.as_mut_ptr() as usize,
798                        self.device_ptr_2.unwrap() + offset,
799                        size,
800                    )
801                    .await?;
802            } else {
803                unsafe {
804                    std::ptr::copy_nonoverlapping(
805                        (self.buffer_2.ptr + offset as u64) as *const u8,
806                        temp_buffer_2.as_mut_ptr(),
807                        size,
808                    );
809                }
810            }
811
812            // Compare buffers
813            for i in 0..size {
814                if temp_buffer_1[i] != temp_buffer_2[i] {
815                    return Err(anyhow::anyhow!(
816                        "Buffers are not equal at index {}",
817                        offset + i
818                    ));
819                }
820            }
821            Ok(())
822        }
823    }
824}