monarch_rdma/
test_utils.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::sync::Once;
10use std::sync::atomic::AtomicBool;
11use std::sync::atomic::Ordering;
12
13/// Cached result of CUDA availability check
14static CUDA_AVAILABLE: AtomicBool = AtomicBool::new(false);
15static INIT: Once = Once::new();
16
17/// Safely checks if CUDA is available on the system.
18///
19/// This function attempts to initialize CUDA and determine if it's available.
20/// The result is cached after the first call, so subsequent calls are very fast.
21///
22/// # Returns
23///
24/// `true` if CUDA is available and can be initialized, `false` otherwise.
25///
26/// # Examples
27///
28/// ```
29/// use monarch_rdma::is_cuda_available;
30///
31/// if is_cuda_available() {
32///     println!("CUDA is available, can use GPU features");
33/// } else {
34///     println!("CUDA is not available, falling back to CPU-only mode");
35/// }
36/// ```
37pub fn is_cuda_available() -> bool {
38    INIT.call_once(|| {
39        let available = check_cuda_available();
40        CUDA_AVAILABLE.store(available, Ordering::SeqCst);
41    });
42    CUDA_AVAILABLE.load(Ordering::SeqCst)
43}
44
45/// Internal function that performs the actual CUDA availability check
46fn check_cuda_available() -> bool {
47    unsafe {
48        // Try to initialize CUDA
49        let result = cuda_sys::cuInit(0);
50
51        if result != cuda_sys::CUresult::CUDA_SUCCESS {
52            return false;
53        }
54
55        // Check if there are any CUDA devices
56        let mut device_count: i32 = 0;
57        let count_result = cuda_sys::cuDeviceGetCount(&mut device_count);
58
59        if count_result != cuda_sys::CUresult::CUDA_SUCCESS || device_count <= 0 {
60            return false;
61        }
62
63        // Try to get the first device to verify it's actually accessible
64        let mut device: cuda_sys::CUdevice = std::mem::zeroed();
65        let device_result = cuda_sys::cuDeviceGet(&mut device, 0);
66
67        if device_result != cuda_sys::CUresult::CUDA_SUCCESS {
68            return false;
69        }
70
71        true
72    }
73}
74
75#[cfg(test)]
76pub mod test_utils {
77    use std::time::Duration;
78    use std::time::Instant;
79
80    use hyperactor::ActorRef;
81    use hyperactor::Mailbox;
82    use hyperactor::clock::Clock;
83    use hyperactor::clock::RealClock;
84    use hyperactor_mesh::Mesh;
85    use hyperactor_mesh::ProcMesh;
86    use hyperactor_mesh::RootActorMesh;
87    use hyperactor_mesh::alloc::AllocSpec;
88    use hyperactor_mesh::alloc::Allocator;
89    use hyperactor_mesh::alloc::LocalAllocator;
90    use ndslice::extent;
91
92    use crate::IbverbsConfig;
93    use crate::PollTarget;
94    use crate::RdmaBuffer;
95    use crate::cu_check;
96    use crate::ibverbs_primitives::get_all_devices;
97    use crate::rdma_components::RdmaQueuePair;
98    use crate::rdma_manager_actor::RdmaManagerActor;
99    use crate::rdma_manager_actor::RdmaManagerMessageClient;
100    // Waits for the completion of an RDMA operation.
101
102    // This function polls for the completion of an RDMA operation by repeatedly
103    // sending a `PollCompletion` message to the specified actor mesh and checking
104    // the returned work completion status. It continues polling until the operation
105    // completes or the specified timeout is reached.
106
107    pub async fn wait_for_completion(
108        qp: &mut RdmaQueuePair,
109        poll_target: PollTarget,
110        timeout_secs: u64,
111    ) -> Result<bool, anyhow::Error> {
112        let timeout = Duration::from_secs(timeout_secs);
113        let start_time = Instant::now();
114        while start_time.elapsed() < timeout {
115            match qp.poll_completion_target(poll_target) {
116                Ok(Some(_wc)) => {
117                    return Ok(true);
118                }
119                Ok(None) => {
120                    RealClock.sleep(Duration::from_millis(1)).await;
121                }
122                Err(e) => {
123                    return Err(anyhow::anyhow!(e));
124                }
125            }
126        }
127        Err(anyhow::Error::msg("Timeout while waiting for completion"))
128    }
129
130    /// Posts a work request to the send queue of the given RDMA queue pair.
131    pub async fn send_wqe_gpu(
132        qp: &mut RdmaQueuePair,
133        lhandle: &RdmaBuffer,
134        rhandle: &RdmaBuffer,
135        op_type: u32,
136    ) -> Result<(), anyhow::Error> {
137        unsafe {
138            let ibv_qp = qp.qp as *mut rdmaxcel_sys::ibv_qp;
139            let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
140            let params = rdmaxcel_sys::wqe_params_t {
141                laddr: lhandle.addr,
142                length: lhandle.size,
143                lkey: lhandle.lkey,
144                wr_id: u32::from_be(*(*dv_qp).dbrec.wrapping_add(1)) as u64,
145                signaled: true,
146                op_type,
147                raddr: rhandle.addr,
148                rkey: rhandle.rkey,
149                qp_num: (*ibv_qp).qp_num,
150                buf: (*dv_qp).sq.buf as *mut u8,
151                wqe_cnt: (*dv_qp).sq.wqe_cnt,
152                dbrec: (*dv_qp).dbrec,
153                ..Default::default()
154            };
155            rdmaxcel_sys::launch_send_wqe(params);
156            qp.send_wqe_idx += 1;
157        }
158        Ok(())
159    }
160
161    /// Posts a work request to the receive queue of the given RDMA queue pair.
162    pub async fn recv_wqe_gpu(
163        qp: &mut RdmaQueuePair,
164        lhandle: &RdmaBuffer,
165        _rhandle: &RdmaBuffer,
166        op_type: u32,
167    ) -> Result<(), anyhow::Error> {
168        // Populate params using lhandle and rhandle
169        unsafe {
170            let ibv_qp = qp.qp as *mut rdmaxcel_sys::ibv_qp;
171            let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
172            let params = rdmaxcel_sys::wqe_params_t {
173                laddr: lhandle.addr,
174                length: lhandle.size,
175                lkey: lhandle.lkey,
176                wr_id: u32::from_be(*(*dv_qp).dbrec) as u64,
177                op_type,
178                signaled: true,
179                qp_num: (*ibv_qp).qp_num,
180                buf: (*dv_qp).rq.buf as *mut u8,
181                wqe_cnt: (*dv_qp).rq.wqe_cnt,
182                dbrec: (*dv_qp).dbrec,
183                ..Default::default()
184            };
185            rdmaxcel_sys::launch_recv_wqe(params);
186            qp.recv_wqe_idx += 1;
187            qp.recv_db_idx += 1;
188        }
189        Ok(())
190    }
191
192    pub async fn ring_db_gpu(qp: &mut RdmaQueuePair) -> Result<(), anyhow::Error> {
193        unsafe {
194            let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
195            let base_ptr = (*dv_qp).sq.buf as *mut u8;
196            let wqe_cnt = (*dv_qp).sq.wqe_cnt;
197            let stride = (*dv_qp).sq.stride;
198            if wqe_cnt < (qp.send_wqe_idx - qp.send_db_idx) {
199                return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
200            }
201            while qp.send_db_idx < qp.send_wqe_idx {
202                let offset = (qp.send_db_idx % wqe_cnt) * stride;
203                let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
204                rdmaxcel_sys::launch_db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
205                qp.send_db_idx += 1;
206            }
207        }
208        Ok(())
209    }
210
211    /// Wait for completion on a specific completion queue
212    pub async fn wait_for_completion_gpu(
213        qp: &mut RdmaQueuePair,
214        poll_target: PollTarget,
215        timeout_secs: u64,
216    ) -> Result<bool, anyhow::Error> {
217        let timeout = Duration::from_secs(timeout_secs);
218        let start_time = Instant::now();
219
220        while start_time.elapsed() < timeout {
221            // Get the appropriate completion queue and index based on the poll target
222            let (cq, idx, cq_type_str) = match poll_target {
223                PollTarget::Send => (
224                    qp.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq,
225                    qp.send_cq_idx as i32,
226                    "send",
227                ),
228                PollTarget::Recv => (
229                    qp.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq,
230                    qp.recv_cq_idx as i32,
231                    "receive",
232                ),
233            };
234
235            // Poll the completion queue
236            let result = unsafe { rdmaxcel_sys::launch_cqe_poll(cq as *mut std::ffi::c_void, idx) };
237
238            match result {
239                rdmaxcel_sys::CQE_POLL_TRUE => {
240                    // Update the appropriate index based on the poll target
241                    match poll_target {
242                        PollTarget::Send => qp.send_cq_idx += 1,
243                        PollTarget::Recv => qp.recv_cq_idx += 1,
244                    }
245                    return Ok(true);
246                }
247                rdmaxcel_sys::CQE_POLL_ERROR => {
248                    return Err(anyhow::anyhow!("Error polling {} completion", cq_type_str));
249                }
250                _ => {
251                    // No completion yet, sleep and try again
252                    RealClock.sleep(Duration::from_millis(1)).await;
253                }
254            }
255        }
256
257        Err(anyhow::Error::msg("Timeout while waiting for completion"))
258    }
259
260    pub struct RdmaManagerTestEnv<'a> {
261        buffer_1: Buffer,
262        buffer_2: Buffer,
263        pub client_1: &'a Mailbox,
264        pub client_2: &'a Mailbox,
265        pub actor_1: ActorRef<RdmaManagerActor>,
266        pub actor_2: ActorRef<RdmaManagerActor>,
267        pub rdma_handle_1: RdmaBuffer,
268        pub rdma_handle_2: RdmaBuffer,
269        handle_1_cuda: bool,
270        handle_2_cuda: bool,
271    }
272
273    #[derive(Debug, Clone)]
274    pub struct Buffer {
275        ptr: u64,
276        len: usize,
277        #[allow(dead_code)]
278        cpu_ref: Option<Box<[u8]>>,
279    }
280    impl RdmaManagerTestEnv<'_> {
281        /// Sets up the RDMA test environment.
282        ///
283        /// This function initializes the RDMA test environment by setting up two actor meshes
284        /// with their respective RDMA configurations. It also prepares two buffers for testing
285        /// RDMA operations and fills the first buffer with test data.
286        ///
287        /// # Arguments
288        ///
289        /// * `buffer_size` - The size of the buffers to be used in the test.
290        /// * `devices` - Optional tuple specifying the indices of RDMA devices to use. If not provided, then
291        ///   both RDMAManagerActors will default to the first indexed RDMA device.
292        pub async fn setup(
293            buffer_size: usize,
294            nics: (&str, &str),
295            devices: (&str, &str),
296        ) -> Result<Self, anyhow::Error> {
297            let all_devices = get_all_devices();
298            let mut config1 = IbverbsConfig::default();
299            let mut config2 = IbverbsConfig::default();
300
301            for device in all_devices.iter() {
302                if device.name == nics.0 {
303                    config1.device = device.clone();
304                }
305                if device.name == nics.1 {
306                    config2.device = device.clone();
307                }
308            }
309
310            let device_str1 = (String::new(), 0);
311            let device_str2 = (String::new(), 0);
312
313            if let Some((backend, idx)) = devices.0.split_once(':') {
314                assert!(backend == "cuda");
315                let _parsed_idx = idx
316                    .parse::<usize>()
317                    .expect("Device index is not a valid integer");
318            } else {
319                assert!(devices.0 == "cpu");
320                config1.use_gpu_direct = false;
321            }
322
323            if let Some((backend, idx)) = devices.1.split_once(':') {
324                assert!(backend == "cuda");
325                let _parsed_idx = idx
326                    .parse::<usize>()
327                    .expect("Device index is not a valid integer");
328            } else {
329                assert!(devices.1 == "cpu");
330                config2.use_gpu_direct = false;
331            }
332
333            let alloc_1 = LocalAllocator
334                .allocate(AllocSpec {
335                    extent: extent! { proc = 1 },
336                    constraints: Default::default(),
337                })
338                .await
339                .unwrap();
340
341            let proc_mesh_1 = Box::leak(Box::new(ProcMesh::allocate(alloc_1).await.unwrap()));
342            let actor_mesh_1: RootActorMesh<'_, RdmaManagerActor> =
343                proc_mesh_1.spawn("rdma_manager", &(config1)).await.unwrap();
344
345            let alloc_2 = LocalAllocator
346                .allocate(AllocSpec {
347                    extent: extent! { proc = 1 },
348                    constraints: Default::default(),
349                })
350                .await
351                .unwrap();
352
353            let proc_mesh_2 = Box::leak(Box::new(ProcMesh::allocate(alloc_2).await.unwrap()));
354            let actor_mesh_2: RootActorMesh<'_, RdmaManagerActor> =
355                proc_mesh_2.spawn("rdma_manager", &(config2)).await.unwrap();
356
357            let mut buf_vec = Vec::new();
358
359            for device_str in [device_str1.clone(), device_str2.clone()] {
360                if device_str.0 != "cpu" {
361                    let mut buffer = vec![0u8; buffer_size].into_boxed_slice();
362                    buf_vec.push(Buffer {
363                        ptr: buffer.as_mut_ptr() as u64,
364                        len: buffer.len(),
365                        cpu_ref: Some(buffer),
366                    });
367                    continue;
368                }
369                // CUDA case
370                unsafe {
371                    cu_check!(cuda_sys::cuInit(0));
372
373                    let mut dptr: cuda_sys::CUdeviceptr = std::mem::zeroed();
374                    let mut handle: cuda_sys::CUmemGenericAllocationHandle = std::mem::zeroed();
375
376                    let mut device: cuda_sys::CUdevice = std::mem::zeroed();
377                    cu_check!(cuda_sys::cuDeviceGet(&mut device, device_str.1));
378
379                    let mut context: cuda_sys::CUcontext = std::mem::zeroed();
380                    cu_check!(cuda_sys::cuCtxCreate_v2(&mut context, 0, device_str.1));
381                    cu_check!(cuda_sys::cuCtxSetCurrent(context));
382
383                    let mut granularity: usize = 0;
384                    let mut prop: cuda_sys::CUmemAllocationProp = std::mem::zeroed();
385                    prop.type_ = cuda_sys::CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_PINNED;
386                    prop.location.type_ = cuda_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE;
387                    prop.location.id = device;
388                    prop.allocFlags.gpuDirectRDMACapable = 1;
389                    prop.requestedHandleTypes =
390                        cuda_sys::CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
391
392                    cu_check!(cuda_sys::cuMemGetAllocationGranularity(
393                        &mut granularity as *mut usize,
394                        &prop,
395                        cuda_sys::CUmemAllocationGranularity_flags::CU_MEM_ALLOC_GRANULARITY_MINIMUM,
396                    ));
397
398                    // ensure our size is aligned
399                    let /*mut*/ padded_size: usize = ((buffer_size - 1) / granularity + 1) * granularity;
400                    assert!(padded_size == buffer_size);
401
402                    cu_check!(cuda_sys::cuMemCreate(
403                        &mut handle as *mut cuda_sys::CUmemGenericAllocationHandle,
404                        padded_size,
405                        &prop,
406                        0
407                    ));
408                    // reserve and map the memory
409                    cu_check!(cuda_sys::cuMemAddressReserve(
410                        &mut dptr as *mut cuda_sys::CUdeviceptr,
411                        padded_size,
412                        0,
413                        0,
414                        0,
415                    ));
416                    assert!(dptr as usize % granularity == 0);
417                    assert!(padded_size % granularity == 0);
418
419                    // fails if a add cu_check macro; but passes if we don't
420                    let err = cuda_sys::cuMemMap(
421                        dptr as cuda_sys::CUdeviceptr,
422                        padded_size,
423                        0,
424                        handle as cuda_sys::CUmemGenericAllocationHandle,
425                        0,
426                    );
427                    if err != cuda_sys::CUresult::CUDA_SUCCESS {
428                        panic!("failed reserving and mapping memory {:?}", err);
429                    }
430
431                    // set access
432                    let mut access_desc: cuda_sys::CUmemAccessDesc = std::mem::zeroed();
433                    access_desc.location.type_ =
434                        cuda_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE;
435                    access_desc.location.id = device;
436                    access_desc.flags =
437                        cuda_sys::CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
438                    cu_check!(cuda_sys::cuMemSetAccess(dptr, padded_size, &access_desc, 1));
439                    buf_vec.push(Buffer {
440                        ptr: dptr,
441                        len: padded_size,
442                        cpu_ref: None,
443                    });
444                }
445            }
446
447            // Fill buffer1 with test data
448            if device_str1.0 == "cuda" {
449                let mut temp_buffer = vec![0u8; buffer_size].into_boxed_slice();
450                for (i, val) in temp_buffer.iter_mut().enumerate() {
451                    *val = (i % 256) as u8;
452                }
453                unsafe {
454                    cu_check!(cuda_sys::cuMemcpyHtoD_v2(
455                        buf_vec[0].ptr,
456                        temp_buffer.as_ptr() as *const std::ffi::c_void,
457                        temp_buffer.len()
458                    ));
459                }
460            } else {
461                unsafe {
462                    let ptr = buf_vec[0].ptr as *mut u8; // or *const u8
463                    for i in 0..buf_vec[0].len {
464                        *ptr.add(i) = (i % 256) as u8;
465                    }
466                }
467            }
468            let actor_1 = actor_mesh_1.get(0).unwrap();
469            let actor_2 = actor_mesh_2.get(0).unwrap();
470
471            let rdma_handle_1 = actor_1
472                .request_buffer(proc_mesh_1.client(), buf_vec[0].ptr as usize, buffer_size)
473                .await?;
474            let rdma_handle_2 = actor_2
475                .request_buffer(proc_mesh_2.client(), buf_vec[1].ptr as usize, buffer_size)
476                .await?;
477            // Get keys from both actors.
478
479            let buffer_2 = buf_vec.remove(1);
480            let buffer_1 = buf_vec.remove(0);
481            Ok(Self {
482                buffer_1,
483                buffer_2,
484                client_1: proc_mesh_1.client(),
485                client_2: proc_mesh_2.client(),
486                actor_1,
487                actor_2,
488                rdma_handle_1,
489                rdma_handle_2,
490                handle_1_cuda: device_str1.0 == "cuda",
491                handle_2_cuda: device_str2.0 == "cuda",
492            })
493        }
494
495        pub async fn cleanup(self) -> Result<(), anyhow::Error> {
496            self.actor_1
497                .release_buffer(self.client_1, self.rdma_handle_1.clone())
498                .await?;
499            self.actor_2
500                .release_buffer(self.client_2, self.rdma_handle_2.clone())
501                .await?;
502            if self.handle_1_cuda {
503                unsafe {
504                    cu_check!(cuda_sys::cuMemUnmap(
505                        self.buffer_1.ptr as cuda_sys::CUdeviceptr,
506                        self.buffer_1.len
507                    ));
508                    cu_check!(cuda_sys::cuMemAddressFree(
509                        self.buffer_1.ptr as cuda_sys::CUdeviceptr,
510                        self.buffer_1.len
511                    ));
512                }
513            }
514            if self.handle_2_cuda {
515                unsafe {
516                    cu_check!(cuda_sys::cuMemUnmap(
517                        self.buffer_2.ptr as cuda_sys::CUdeviceptr,
518                        self.buffer_2.len
519                    ));
520                    cu_check!(cuda_sys::cuMemAddressFree(
521                        self.buffer_2.ptr as cuda_sys::CUdeviceptr,
522                        self.buffer_2.len
523                    ));
524                }
525            }
526            Ok(())
527        }
528
529        pub async fn verify_buffers(&self, size: usize) -> Result<(), anyhow::Error> {
530            let mut buf_vec = Vec::new();
531            for (handle, is_cuda) in [
532                (self.rdma_handle_1.clone(), self.handle_1_cuda),
533                (self.rdma_handle_2.clone(), self.handle_2_cuda),
534            ] {
535                if is_cuda {
536                    let mut temp_buffer = vec![0u8; size].into_boxed_slice();
537                    // SAFETY: The buffer is allocated with the correct size and the pointer is valid.
538                    unsafe {
539                        cu_check!(cuda_sys::cuMemcpyDtoH_v2(
540                            temp_buffer.as_mut_ptr() as *mut std::ffi::c_void,
541                            handle.addr as cuda_sys::CUdeviceptr,
542                            size
543                        ));
544                    }
545                    buf_vec.push(Buffer {
546                        ptr: temp_buffer.as_mut_ptr() as u64,
547                        len: size,
548                        cpu_ref: Some(temp_buffer),
549                    });
550                } else {
551                    buf_vec.push(Buffer {
552                        ptr: handle.addr as u64,
553                        len: size,
554                        cpu_ref: None,
555                    });
556                }
557            }
558            // SAFETY: The pointers are valid and the buffers have the same length.
559            unsafe {
560                let ptr1 = buf_vec[0].ptr as *mut u8;
561                let ptr2: *mut u8 = buf_vec[1].ptr as *mut u8;
562                for i in 0..buf_vec[0].len {
563                    if *ptr1.add(i) != *ptr2.add(i) {
564                        return Err(anyhow::anyhow!("Buffers are not equal at index {}", i));
565                    }
566                }
567            }
568            Ok(())
569        }
570    }
571}