monarch_rdma/
test_utils.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::sync::Once;
10use std::sync::atomic::AtomicBool;
11use std::sync::atomic::Ordering;
12
13/// Cached result of CUDA availability check
14static CUDA_AVAILABLE: AtomicBool = AtomicBool::new(false);
15static INIT: Once = Once::new();
16
17/// Safely checks if CUDA is available on the system.
18///
19/// This function attempts to initialize CUDA and determine if it's available.
20/// The result is cached after the first call, so subsequent calls are very fast.
21///
22/// # Returns
23///
24/// `true` if CUDA is available and can be initialized, `false` otherwise.
25///
26/// # Examples
27///
28/// ```
29/// use monarch_rdma::is_cuda_available;
30///
31/// if is_cuda_available() {
32///     println!("CUDA is available, can use GPU features");
33/// } else {
34///     println!("CUDA is not available, falling back to CPU-only mode");
35/// }
36/// ```
37pub fn is_cuda_available() -> bool {
38    INIT.call_once(|| {
39        let available = check_cuda_available();
40        CUDA_AVAILABLE.store(available, Ordering::SeqCst);
41    });
42    CUDA_AVAILABLE.load(Ordering::SeqCst)
43}
44
45/// Internal function that performs the actual CUDA availability check
46fn check_cuda_available() -> bool {
47    unsafe {
48        // Try to initialize CUDA
49        let result = cuda_sys::cuInit(0);
50
51        if result != cuda_sys::CUresult::CUDA_SUCCESS {
52            return false;
53        }
54
55        // Check if there are any CUDA devices
56        let mut device_count: i32 = 0;
57        let count_result = cuda_sys::cuDeviceGetCount(&mut device_count);
58
59        if count_result != cuda_sys::CUresult::CUDA_SUCCESS || device_count <= 0 {
60            return false;
61        }
62
63        // Try to get the first device to verify it's actually accessible
64        let mut device: cuda_sys::CUdevice = std::mem::zeroed();
65        let device_result = cuda_sys::cuDeviceGet(&mut device, 0);
66
67        if device_result != cuda_sys::CUresult::CUDA_SUCCESS {
68            return false;
69        }
70
71        true
72    }
73}
74
75#[cfg(test)]
76pub mod test_utils {
77    use std::time::Duration;
78    use std::time::Instant;
79
80    use hyperactor::ActorRef;
81    use hyperactor::Instance;
82    use hyperactor::channel::ChannelTransport;
83    use hyperactor::clock::Clock;
84    use hyperactor::clock::RealClock;
85    use hyperactor_mesh::Mesh;
86    use hyperactor_mesh::ProcMesh;
87    use hyperactor_mesh::RootActorMesh;
88    use hyperactor_mesh::alloc::AllocSpec;
89    use hyperactor_mesh::alloc::Allocator;
90    use hyperactor_mesh::alloc::LocalAllocator;
91    use ndslice::extent;
92
93    use crate::IbverbsConfig;
94    use crate::RdmaBuffer;
95    use crate::cu_check;
96    use crate::ibverbs_primitives::get_all_devices;
97    use crate::rdma_components::PollTarget;
98    use crate::rdma_components::RdmaQueuePair;
99    use crate::rdma_manager_actor::RdmaManagerActor;
100    use crate::rdma_manager_actor::RdmaManagerMessageClient;
101    use crate::validate_execution_context;
102    // Waits for the completion of an RDMA operation.
103
104    // This function polls for the completion of an RDMA operation by repeatedly
105    // sending a `PollCompletion` message to the specified actor mesh and checking
106    // the returned work completion status. It continues polling until the operation
107    // completes or the specified timeout is reached.
108
109    pub async fn wait_for_completion(
110        qp: &mut RdmaQueuePair,
111        poll_target: PollTarget,
112        timeout_secs: u64,
113    ) -> Result<bool, anyhow::Error> {
114        let timeout = Duration::from_secs(timeout_secs);
115        let start_time = Instant::now();
116        while start_time.elapsed() < timeout {
117            match qp.poll_completion_target(poll_target) {
118                Ok(Some(_wc)) => {
119                    return Ok(true);
120                }
121                Ok(None) => {
122                    RealClock.sleep(Duration::from_millis(1)).await;
123                }
124                Err(e) => {
125                    return Err(anyhow::anyhow!(e));
126                }
127            }
128        }
129        Err(anyhow::Error::msg("Timeout while waiting for completion"))
130    }
131
132    /// Posts a work request to the send queue of the given RDMA queue pair.
133    pub async fn send_wqe_gpu(
134        qp: &mut RdmaQueuePair,
135        lhandle: &RdmaBuffer,
136        rhandle: &RdmaBuffer,
137        op_type: u32,
138    ) -> Result<(), anyhow::Error> {
139        unsafe {
140            let ibv_qp = qp.qp as *mut rdmaxcel_sys::ibv_qp;
141            let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
142            let params = rdmaxcel_sys::wqe_params_t {
143                laddr: lhandle.addr,
144                length: lhandle.size,
145                lkey: lhandle.lkey,
146                wr_id: qp.send_wqe_idx,
147                signaled: true,
148                op_type,
149                raddr: rhandle.addr,
150                rkey: rhandle.rkey,
151                qp_num: (*ibv_qp).qp_num,
152                buf: (*dv_qp).sq.buf as *mut u8,
153                wqe_cnt: (*dv_qp).sq.wqe_cnt,
154                dbrec: (*dv_qp).dbrec,
155                ..Default::default()
156            };
157            rdmaxcel_sys::launch_send_wqe(params);
158            qp.send_wqe_idx += 1;
159        }
160        Ok(())
161    }
162
163    /// Posts a work request to the receive queue of the given RDMA queue pair.
164    pub async fn recv_wqe_gpu(
165        qp: &mut RdmaQueuePair,
166        lhandle: &RdmaBuffer,
167        _rhandle: &RdmaBuffer,
168        op_type: u32,
169    ) -> Result<(), anyhow::Error> {
170        // Populate params using lhandle and rhandle
171        unsafe {
172            let ibv_qp = qp.qp as *mut rdmaxcel_sys::ibv_qp;
173            let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
174            let params = rdmaxcel_sys::wqe_params_t {
175                laddr: lhandle.addr,
176                length: lhandle.size,
177                lkey: lhandle.lkey,
178                wr_id: qp.recv_wqe_idx,
179                op_type,
180                signaled: true,
181                qp_num: (*ibv_qp).qp_num,
182                buf: (*dv_qp).rq.buf as *mut u8,
183                wqe_cnt: (*dv_qp).rq.wqe_cnt,
184                dbrec: (*dv_qp).dbrec,
185                ..Default::default()
186            };
187            rdmaxcel_sys::launch_recv_wqe(params);
188            qp.recv_wqe_idx += 1;
189            qp.recv_db_idx += 1;
190        }
191        Ok(())
192    }
193
194    pub async fn ring_db_gpu(qp: &mut RdmaQueuePair) -> Result<(), anyhow::Error> {
195        unsafe {
196            let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
197            let base_ptr = (*dv_qp).sq.buf as *mut u8;
198            let wqe_cnt = (*dv_qp).sq.wqe_cnt;
199            let stride = (*dv_qp).sq.stride;
200            if (wqe_cnt as u64) < (qp.send_wqe_idx - qp.send_db_idx) {
201                return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
202            }
203            while qp.send_db_idx < qp.send_wqe_idx {
204                let offset = (qp.send_db_idx % wqe_cnt as u64) * stride as u64;
205                let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
206                rdmaxcel_sys::launch_db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
207                qp.send_db_idx += 1;
208            }
209        }
210        Ok(())
211    }
212
213    /// Wait for completion on a specific completion queue
214    pub async fn wait_for_completion_gpu(
215        qp: &mut RdmaQueuePair,
216        poll_target: PollTarget,
217        timeout_secs: u64,
218    ) -> Result<bool, anyhow::Error> {
219        let timeout = Duration::from_secs(timeout_secs);
220        let start_time = Instant::now();
221
222        while start_time.elapsed() < timeout {
223            // Get the appropriate completion queue and index based on the poll target
224            let (cq, idx, cq_type_str) = match poll_target {
225                PollTarget::Send => (
226                    qp.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq,
227                    qp.send_cq_idx,
228                    "send",
229                ),
230                PollTarget::Recv => (
231                    qp.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq,
232                    qp.recv_cq_idx,
233                    "receive",
234                ),
235            };
236
237            // Poll the completion queue
238            let result =
239                unsafe { rdmaxcel_sys::launch_cqe_poll(cq as *mut std::ffi::c_void, idx as i32) };
240
241            match result {
242                rdmaxcel_sys::CQE_POLL_TRUE => {
243                    // Update the appropriate index based on the poll target
244                    match poll_target {
245                        PollTarget::Send => qp.send_cq_idx += 1,
246                        PollTarget::Recv => qp.recv_cq_idx += 1,
247                    }
248                    return Ok(true);
249                }
250                rdmaxcel_sys::CQE_POLL_ERROR => {
251                    return Err(anyhow::anyhow!("Error polling {} completion", cq_type_str));
252                }
253                _ => {
254                    // No completion yet, sleep and try again
255                    RealClock.sleep(Duration::from_millis(1)).await;
256                }
257            }
258        }
259
260        Err(anyhow::Error::msg("Timeout while waiting for completion"))
261    }
262
263    pub struct RdmaManagerTestEnv<'a> {
264        buffer_1: Buffer,
265        buffer_2: Buffer,
266        pub client_1: &'a Instance<()>,
267        pub client_2: &'a Instance<()>,
268        pub actor_1: ActorRef<RdmaManagerActor>,
269        pub actor_2: ActorRef<RdmaManagerActor>,
270        pub rdma_handle_1: RdmaBuffer,
271        pub rdma_handle_2: RdmaBuffer,
272        cuda_context_1: Option<cuda_sys::CUcontext>,
273        cuda_context_2: Option<cuda_sys::CUcontext>,
274    }
275
276    #[derive(Debug, Clone)]
277    pub struct Buffer {
278        ptr: u64,
279        len: usize,
280        #[allow(dead_code)]
281        cpu_ref: Option<Box<[u8]>>,
282    }
283    impl RdmaManagerTestEnv<'_> {
284        /// Sets up the RDMA test environment.
285        ///
286        /// This function initializes the RDMA test environment by setting up two actor meshes
287        /// with their respective RDMA configurations. It also prepares two buffers for testing
288        /// RDMA operations and fills the first buffer with test data.
289        ///
290        /// # Arguments
291        ///
292        /// * `buffer_size` - The size of the buffers to be used in the test.
293        /// * `nics` - Tuple specifying the indices of RDMA devices to use
294        /// * `accels` - Tuple specifying the indices of accelerators to use (or cpu)
295        ///   both RDMAManagerActors will default to the first indexed RDMA device.
296        pub async fn setup(
297            buffer_size: usize,
298            nics: (&str, &str),
299            accels: (&str, &str),
300        ) -> Result<Self, anyhow::Error> {
301            let all_nic_devices = get_all_devices();
302            let mut config1 = IbverbsConfig::default();
303            let mut config2 = IbverbsConfig::default();
304
305            for nic_device in all_nic_devices.iter() {
306                if nic_device.name == nics.0 {
307                    config1.device = nic_device.clone();
308                }
309                if nic_device.name == nics.1 {
310                    config2.device = nic_device.clone();
311                }
312            }
313            let accel1;
314            let accel2;
315
316            if let Some((backend, idx)) = accels.0.split_once(':') {
317                assert!(backend == "cuda");
318                let parsed_idx = idx
319                    .parse::<usize>()
320                    .expect("Device index is not a valid integer");
321                accel1 = (backend.to_string(), parsed_idx);
322                config1.use_gpu_direct = validate_execution_context().await.is_ok();
323            } else {
324                assert!(accels.0 == "cpu");
325                accel1 = ("cpu".to_string(), 0);
326            }
327
328            if let Some((backend, idx)) = accels.1.split_once(':') {
329                assert!(backend == "cuda");
330                let parsed_idx = idx
331                    .parse::<usize>()
332                    .expect("Device index is not a valid integer");
333                accel2 = (backend.to_string(), parsed_idx);
334                config2.use_gpu_direct = validate_execution_context().await.is_ok();
335            } else {
336                assert!(accels.1 == "cpu");
337                accel2 = ("cpu".to_string(), 0);
338            }
339
340            let alloc_1 = LocalAllocator
341                .allocate(AllocSpec {
342                    extent: extent! { proc = 1 },
343                    constraints: Default::default(),
344                    proc_name: None,
345                    transport: ChannelTransport::Local,
346                })
347                .await
348                .unwrap();
349
350            let proc_mesh_1 = Box::leak(Box::new(ProcMesh::allocate(alloc_1).await.unwrap()));
351            let actor_mesh_1: RootActorMesh<'_, RdmaManagerActor> = proc_mesh_1
352                .spawn("rdma_manager", &Some(config1))
353                .await
354                .unwrap();
355
356            let alloc_2 = LocalAllocator
357                .allocate(AllocSpec {
358                    extent: extent! { proc = 1 },
359                    constraints: Default::default(),
360                    proc_name: None,
361                    transport: ChannelTransport::Local,
362                })
363                .await
364                .unwrap();
365
366            let proc_mesh_2 = Box::leak(Box::new(ProcMesh::allocate(alloc_2).await.unwrap()));
367            let actor_mesh_2: RootActorMesh<'_, RdmaManagerActor> = proc_mesh_2
368                .spawn("rdma_manager", &Some(config2))
369                .await
370                .unwrap();
371
372            let mut buf_vec = Vec::new();
373            let mut cuda_contexts = Vec::new();
374
375            for accel in [accel1.clone(), accel2.clone()] {
376                if accel.0 == "cpu" {
377                    let mut buffer = vec![0u8; buffer_size].into_boxed_slice();
378                    buf_vec.push(Buffer {
379                        ptr: buffer.as_mut_ptr() as u64,
380                        len: buffer.len(),
381                        cpu_ref: Some(buffer),
382                    });
383                    cuda_contexts.push(None);
384                    continue;
385                }
386                // CUDA case
387                unsafe {
388                    cu_check!(cuda_sys::cuInit(0));
389
390                    let mut dptr: cuda_sys::CUdeviceptr = std::mem::zeroed();
391                    let mut handle: cuda_sys::CUmemGenericAllocationHandle = std::mem::zeroed();
392
393                    let mut device: cuda_sys::CUdevice = std::mem::zeroed();
394                    cu_check!(cuda_sys::cuDeviceGet(&mut device, accel.1 as i32));
395
396                    let mut context: cuda_sys::CUcontext = std::mem::zeroed();
397                    cu_check!(cuda_sys::cuCtxCreate_v2(&mut context, 0, accel.1 as i32));
398                    cu_check!(cuda_sys::cuCtxSetCurrent(context));
399
400                    let mut granularity: usize = 0;
401                    let mut prop: cuda_sys::CUmemAllocationProp = std::mem::zeroed();
402                    prop.type_ = cuda_sys::CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_PINNED;
403                    prop.location.type_ = cuda_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE;
404                    prop.location.id = device;
405                    prop.allocFlags.gpuDirectRDMACapable = 1;
406                    prop.requestedHandleTypes =
407                        cuda_sys::CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
408
409                    cu_check!(cuda_sys::cuMemGetAllocationGranularity(
410                        &mut granularity as *mut usize,
411                        &prop,
412                        cuda_sys::CUmemAllocationGranularity_flags::CU_MEM_ALLOC_GRANULARITY_MINIMUM,
413                    ));
414
415                    // ensure our size is aligned
416                    let /*mut*/ padded_size: usize = ((buffer_size - 1) / granularity + 1) * granularity;
417                    assert!(padded_size == buffer_size);
418
419                    cu_check!(cuda_sys::cuMemCreate(
420                        &mut handle as *mut cuda_sys::CUmemGenericAllocationHandle,
421                        padded_size,
422                        &prop,
423                        0
424                    ));
425                    // reserve and map the memory
426                    cu_check!(cuda_sys::cuMemAddressReserve(
427                        &mut dptr as *mut cuda_sys::CUdeviceptr,
428                        padded_size,
429                        0,
430                        0,
431                        0,
432                    ));
433                    assert!(dptr as usize % granularity == 0);
434                    assert!(padded_size % granularity == 0);
435
436                    // fails if a add cu_check macro; but passes if we don't
437                    let err = cuda_sys::cuMemMap(
438                        dptr as cuda_sys::CUdeviceptr,
439                        padded_size,
440                        0,
441                        handle as cuda_sys::CUmemGenericAllocationHandle,
442                        0,
443                    );
444                    if err != cuda_sys::CUresult::CUDA_SUCCESS {
445                        panic!("failed reserving and mapping memory {:?}", err);
446                    }
447
448                    // set access
449                    let mut access_desc: cuda_sys::CUmemAccessDesc = std::mem::zeroed();
450                    access_desc.location.type_ =
451                        cuda_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE;
452                    access_desc.location.id = device;
453                    access_desc.flags =
454                        cuda_sys::CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
455                    cu_check!(cuda_sys::cuMemSetAccess(dptr, padded_size, &access_desc, 1));
456                    buf_vec.push(Buffer {
457                        ptr: dptr,
458                        len: padded_size,
459                        cpu_ref: None,
460                    });
461                    cuda_contexts.push(Some(context));
462                }
463            }
464
465            // Fill buffer1 with test data
466            if accel1.0 == "cuda" {
467                let mut temp_buffer = vec![0u8; buffer_size].into_boxed_slice();
468                for (i, val) in temp_buffer.iter_mut().enumerate() {
469                    *val = (i % 256) as u8;
470                }
471                unsafe {
472                    // Use the CUDA context that was created for the first buffer
473                    cu_check!(cuda_sys::cuCtxSetCurrent(
474                        cuda_contexts[0].expect("No CUDA context found")
475                    ));
476
477                    cu_check!(cuda_sys::cuMemcpyHtoD_v2(
478                        buf_vec[0].ptr,
479                        temp_buffer.as_ptr() as *const std::ffi::c_void,
480                        temp_buffer.len()
481                    ));
482                }
483            } else {
484                unsafe {
485                    let ptr = buf_vec[0].ptr as *mut u8; // or *const u8
486                    for i in 0..buf_vec[0].len {
487                        *ptr.add(i) = (i % 256) as u8;
488                    }
489                }
490            }
491            let actor_1 = actor_mesh_1.get(0).unwrap();
492            let actor_2 = actor_mesh_2.get(0).unwrap();
493
494            let rdma_handle_1 = actor_1
495                .request_buffer(proc_mesh_1.client(), buf_vec[0].ptr as usize, buffer_size)
496                .await?;
497            let rdma_handle_2 = actor_2
498                .request_buffer(proc_mesh_2.client(), buf_vec[1].ptr as usize, buffer_size)
499                .await?;
500            // Get keys from both actors.
501
502            let buffer_2 = buf_vec.remove(1);
503            let buffer_1 = buf_vec.remove(0);
504            Ok(Self {
505                buffer_1,
506                buffer_2,
507                client_1: proc_mesh_1.client(),
508                client_2: proc_mesh_2.client(),
509                actor_1,
510                actor_2,
511                rdma_handle_1,
512                rdma_handle_2,
513                cuda_context_1: cuda_contexts.first().cloned().flatten(),
514                cuda_context_2: cuda_contexts.get(1).cloned().flatten(),
515            })
516        }
517
518        pub async fn cleanup(self) -> Result<(), anyhow::Error> {
519            self.actor_1
520                .release_buffer(self.client_1, self.rdma_handle_1.clone())
521                .await?;
522            self.actor_2
523                .release_buffer(self.client_2, self.rdma_handle_2.clone())
524                .await?;
525            if self.cuda_context_1.is_some() {
526                unsafe {
527                    cu_check!(cuda_sys::cuCtxSetCurrent(
528                        self.cuda_context_1.expect("No CUDA context found")
529                    ));
530                    cu_check!(cuda_sys::cuMemUnmap(
531                        self.buffer_1.ptr as cuda_sys::CUdeviceptr,
532                        self.buffer_1.len
533                    ));
534                    cu_check!(cuda_sys::cuMemAddressFree(
535                        self.buffer_1.ptr as cuda_sys::CUdeviceptr,
536                        self.buffer_1.len
537                    ));
538                }
539            }
540            if self.cuda_context_2.is_some() {
541                unsafe {
542                    cu_check!(cuda_sys::cuCtxSetCurrent(
543                        self.cuda_context_2.expect("No CUDA context found")
544                    ));
545                    cu_check!(cuda_sys::cuMemUnmap(
546                        self.buffer_2.ptr as cuda_sys::CUdeviceptr,
547                        self.buffer_2.len
548                    ));
549                    cu_check!(cuda_sys::cuMemAddressFree(
550                        self.buffer_2.ptr as cuda_sys::CUdeviceptr,
551                        self.buffer_2.len
552                    ));
553                }
554            }
555            Ok(())
556        }
557
558        pub async fn verify_buffers(&self, size: usize) -> Result<(), anyhow::Error> {
559            let mut buf_vec = Vec::new();
560            for (virtual_addr, cuda_context) in [
561                (self.buffer_1.ptr, self.cuda_context_1),
562                (self.buffer_2.ptr, self.cuda_context_2),
563            ] {
564                if cuda_context.is_some() {
565                    let mut temp_buffer = vec![0u8; size].into_boxed_slice();
566                    // SAFETY: The buffer is allocated with the correct size and the pointer is valid.
567                    unsafe {
568                        cu_check!(cuda_sys::cuCtxSetCurrent(
569                            cuda_context.expect("No CUDA context found")
570                        ));
571                        cu_check!(cuda_sys::cuMemcpyDtoH_v2(
572                            temp_buffer.as_mut_ptr() as *mut std::ffi::c_void,
573                            virtual_addr as cuda_sys::CUdeviceptr,
574                            size
575                        ));
576                    }
577                    buf_vec.push(Buffer {
578                        ptr: temp_buffer.as_mut_ptr() as u64,
579                        len: size,
580                        cpu_ref: Some(temp_buffer),
581                    });
582                } else {
583                    buf_vec.push(Buffer {
584                        ptr: virtual_addr,
585                        len: size,
586                        cpu_ref: None,
587                    });
588                }
589            }
590            // SAFETY: The pointers are valid and the buffers have the same length.
591            unsafe {
592                let ptr1 = buf_vec[0].ptr as *mut u8;
593                let ptr2: *mut u8 = buf_vec[1].ptr as *mut u8;
594                for i in 0..buf_vec[0].len {
595                    if *ptr1.add(i) != *ptr2.add(i) {
596                        return Err(anyhow::anyhow!("Buffers are not equal at index {}", i));
597                    }
598                }
599            }
600            Ok(())
601        }
602    }
603}