1use std::sync::Once;
10use std::sync::atomic::AtomicBool;
11use std::sync::atomic::Ordering;
12
13static CUDA_AVAILABLE: AtomicBool = AtomicBool::new(false);
15static INIT: Once = Once::new();
16
17pub fn is_cuda_available() -> bool {
38 INIT.call_once(|| {
39 let available = check_cuda_available();
40 CUDA_AVAILABLE.store(available, Ordering::SeqCst);
41 });
42 CUDA_AVAILABLE.load(Ordering::SeqCst)
43}
44
45fn check_cuda_available() -> bool {
47 unsafe {
48 let result = cuda_sys::cuInit(0);
50
51 if result != cuda_sys::CUresult::CUDA_SUCCESS {
52 return false;
53 }
54
55 let mut device_count: i32 = 0;
57 let count_result = cuda_sys::cuDeviceGetCount(&mut device_count);
58
59 if count_result != cuda_sys::CUresult::CUDA_SUCCESS || device_count <= 0 {
60 return false;
61 }
62
63 let mut device: cuda_sys::CUdevice = std::mem::zeroed();
65 let device_result = cuda_sys::cuDeviceGet(&mut device, 0);
66
67 if device_result != cuda_sys::CUresult::CUDA_SUCCESS {
68 return false;
69 }
70
71 true
72 }
73}
74
75#[cfg(test)]
76pub mod test_utils {
77 use std::time::Duration;
78 use std::time::Instant;
79
80 use hyperactor::ActorRef;
81 use hyperactor::Mailbox;
82 use hyperactor::clock::Clock;
83 use hyperactor::clock::RealClock;
84 use hyperactor_mesh::Mesh;
85 use hyperactor_mesh::ProcMesh;
86 use hyperactor_mesh::RootActorMesh;
87 use hyperactor_mesh::alloc::AllocSpec;
88 use hyperactor_mesh::alloc::Allocator;
89 use hyperactor_mesh::alloc::LocalAllocator;
90 use ndslice::extent;
91
92 use crate::IbverbsConfig;
93 use crate::PollTarget;
94 use crate::RdmaBuffer;
95 use crate::cu_check;
96 use crate::ibverbs_primitives::get_all_devices;
97 use crate::rdma_components::RdmaQueuePair;
98 use crate::rdma_manager_actor::RdmaManagerActor;
99 use crate::rdma_manager_actor::RdmaManagerMessageClient;
100 pub async fn wait_for_completion(
108 qp: &mut RdmaQueuePair,
109 poll_target: PollTarget,
110 timeout_secs: u64,
111 ) -> Result<bool, anyhow::Error> {
112 let timeout = Duration::from_secs(timeout_secs);
113 let start_time = Instant::now();
114 while start_time.elapsed() < timeout {
115 match qp.poll_completion_target(poll_target) {
116 Ok(Some(_wc)) => {
117 return Ok(true);
118 }
119 Ok(None) => {
120 RealClock.sleep(Duration::from_millis(1)).await;
121 }
122 Err(e) => {
123 return Err(anyhow::anyhow!(e));
124 }
125 }
126 }
127 Err(anyhow::Error::msg("Timeout while waiting for completion"))
128 }
129
130 pub async fn send_wqe_gpu(
132 qp: &mut RdmaQueuePair,
133 lhandle: &RdmaBuffer,
134 rhandle: &RdmaBuffer,
135 op_type: u32,
136 ) -> Result<(), anyhow::Error> {
137 unsafe {
138 let ibv_qp = qp.qp as *mut rdmaxcel_sys::ibv_qp;
139 let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
140 let params = rdmaxcel_sys::wqe_params_t {
141 laddr: lhandle.addr,
142 length: lhandle.size,
143 lkey: lhandle.lkey,
144 wr_id: u32::from_be(*(*dv_qp).dbrec.wrapping_add(1)) as u64,
145 signaled: true,
146 op_type,
147 raddr: rhandle.addr,
148 rkey: rhandle.rkey,
149 qp_num: (*ibv_qp).qp_num,
150 buf: (*dv_qp).sq.buf as *mut u8,
151 wqe_cnt: (*dv_qp).sq.wqe_cnt,
152 dbrec: (*dv_qp).dbrec,
153 ..Default::default()
154 };
155 rdmaxcel_sys::launch_send_wqe(params);
156 qp.send_wqe_idx += 1;
157 }
158 Ok(())
159 }
160
161 pub async fn recv_wqe_gpu(
163 qp: &mut RdmaQueuePair,
164 lhandle: &RdmaBuffer,
165 _rhandle: &RdmaBuffer,
166 op_type: u32,
167 ) -> Result<(), anyhow::Error> {
168 unsafe {
170 let ibv_qp = qp.qp as *mut rdmaxcel_sys::ibv_qp;
171 let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
172 let params = rdmaxcel_sys::wqe_params_t {
173 laddr: lhandle.addr,
174 length: lhandle.size,
175 lkey: lhandle.lkey,
176 wr_id: u32::from_be(*(*dv_qp).dbrec) as u64,
177 op_type,
178 signaled: true,
179 qp_num: (*ibv_qp).qp_num,
180 buf: (*dv_qp).rq.buf as *mut u8,
181 wqe_cnt: (*dv_qp).rq.wqe_cnt,
182 dbrec: (*dv_qp).dbrec,
183 ..Default::default()
184 };
185 rdmaxcel_sys::launch_recv_wqe(params);
186 qp.recv_wqe_idx += 1;
187 qp.recv_db_idx += 1;
188 }
189 Ok(())
190 }
191
192 pub async fn ring_db_gpu(qp: &mut RdmaQueuePair) -> Result<(), anyhow::Error> {
193 unsafe {
194 let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
195 let base_ptr = (*dv_qp).sq.buf as *mut u8;
196 let wqe_cnt = (*dv_qp).sq.wqe_cnt;
197 let stride = (*dv_qp).sq.stride;
198 if wqe_cnt < (qp.send_wqe_idx - qp.send_db_idx) {
199 return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
200 }
201 while qp.send_db_idx < qp.send_wqe_idx {
202 let offset = (qp.send_db_idx % wqe_cnt) * stride;
203 let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
204 rdmaxcel_sys::launch_db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
205 qp.send_db_idx += 1;
206 }
207 }
208 Ok(())
209 }
210
211 pub async fn wait_for_completion_gpu(
213 qp: &mut RdmaQueuePair,
214 poll_target: PollTarget,
215 timeout_secs: u64,
216 ) -> Result<bool, anyhow::Error> {
217 let timeout = Duration::from_secs(timeout_secs);
218 let start_time = Instant::now();
219
220 while start_time.elapsed() < timeout {
221 let (cq, idx, cq_type_str) = match poll_target {
223 PollTarget::Send => (
224 qp.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq,
225 qp.send_cq_idx as i32,
226 "send",
227 ),
228 PollTarget::Recv => (
229 qp.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq,
230 qp.recv_cq_idx as i32,
231 "receive",
232 ),
233 };
234
235 let result = unsafe { rdmaxcel_sys::launch_cqe_poll(cq as *mut std::ffi::c_void, idx) };
237
238 match result {
239 rdmaxcel_sys::CQE_POLL_TRUE => {
240 match poll_target {
242 PollTarget::Send => qp.send_cq_idx += 1,
243 PollTarget::Recv => qp.recv_cq_idx += 1,
244 }
245 return Ok(true);
246 }
247 rdmaxcel_sys::CQE_POLL_ERROR => {
248 return Err(anyhow::anyhow!("Error polling {} completion", cq_type_str));
249 }
250 _ => {
251 RealClock.sleep(Duration::from_millis(1)).await;
253 }
254 }
255 }
256
257 Err(anyhow::Error::msg("Timeout while waiting for completion"))
258 }
259
260 pub struct RdmaManagerTestEnv<'a> {
261 buffer_1: Buffer,
262 buffer_2: Buffer,
263 pub client_1: &'a Mailbox,
264 pub client_2: &'a Mailbox,
265 pub actor_1: ActorRef<RdmaManagerActor>,
266 pub actor_2: ActorRef<RdmaManagerActor>,
267 pub rdma_handle_1: RdmaBuffer,
268 pub rdma_handle_2: RdmaBuffer,
269 handle_1_cuda: bool,
270 handle_2_cuda: bool,
271 }
272
273 #[derive(Debug, Clone)]
274 pub struct Buffer {
275 ptr: u64,
276 len: usize,
277 #[allow(dead_code)]
278 cpu_ref: Option<Box<[u8]>>,
279 }
280 impl RdmaManagerTestEnv<'_> {
281 pub async fn setup(
293 buffer_size: usize,
294 nics: (&str, &str),
295 devices: (&str, &str),
296 ) -> Result<Self, anyhow::Error> {
297 let all_devices = get_all_devices();
298 let mut config1 = IbverbsConfig::default();
299 let mut config2 = IbverbsConfig::default();
300
301 for device in all_devices.iter() {
302 if device.name == nics.0 {
303 config1.device = device.clone();
304 }
305 if device.name == nics.1 {
306 config2.device = device.clone();
307 }
308 }
309
310 let device_str1 = (String::new(), 0);
311 let device_str2 = (String::new(), 0);
312
313 if let Some((backend, idx)) = devices.0.split_once(':') {
314 assert!(backend == "cuda");
315 let _parsed_idx = idx
316 .parse::<usize>()
317 .expect("Device index is not a valid integer");
318 } else {
319 assert!(devices.0 == "cpu");
320 config1.use_gpu_direct = false;
321 }
322
323 if let Some((backend, idx)) = devices.1.split_once(':') {
324 assert!(backend == "cuda");
325 let _parsed_idx = idx
326 .parse::<usize>()
327 .expect("Device index is not a valid integer");
328 } else {
329 assert!(devices.1 == "cpu");
330 config2.use_gpu_direct = false;
331 }
332
333 let alloc_1 = LocalAllocator
334 .allocate(AllocSpec {
335 extent: extent! { proc = 1 },
336 constraints: Default::default(),
337 })
338 .await
339 .unwrap();
340
341 let proc_mesh_1 = Box::leak(Box::new(ProcMesh::allocate(alloc_1).await.unwrap()));
342 let actor_mesh_1: RootActorMesh<'_, RdmaManagerActor> =
343 proc_mesh_1.spawn("rdma_manager", &(config1)).await.unwrap();
344
345 let alloc_2 = LocalAllocator
346 .allocate(AllocSpec {
347 extent: extent! { proc = 1 },
348 constraints: Default::default(),
349 })
350 .await
351 .unwrap();
352
353 let proc_mesh_2 = Box::leak(Box::new(ProcMesh::allocate(alloc_2).await.unwrap()));
354 let actor_mesh_2: RootActorMesh<'_, RdmaManagerActor> =
355 proc_mesh_2.spawn("rdma_manager", &(config2)).await.unwrap();
356
357 let mut buf_vec = Vec::new();
358
359 for device_str in [device_str1.clone(), device_str2.clone()] {
360 if device_str.0 != "cpu" {
361 let mut buffer = vec![0u8; buffer_size].into_boxed_slice();
362 buf_vec.push(Buffer {
363 ptr: buffer.as_mut_ptr() as u64,
364 len: buffer.len(),
365 cpu_ref: Some(buffer),
366 });
367 continue;
368 }
369 unsafe {
371 cu_check!(cuda_sys::cuInit(0));
372
373 let mut dptr: cuda_sys::CUdeviceptr = std::mem::zeroed();
374 let mut handle: cuda_sys::CUmemGenericAllocationHandle = std::mem::zeroed();
375
376 let mut device: cuda_sys::CUdevice = std::mem::zeroed();
377 cu_check!(cuda_sys::cuDeviceGet(&mut device, device_str.1));
378
379 let mut context: cuda_sys::CUcontext = std::mem::zeroed();
380 cu_check!(cuda_sys::cuCtxCreate_v2(&mut context, 0, device_str.1));
381 cu_check!(cuda_sys::cuCtxSetCurrent(context));
382
383 let mut granularity: usize = 0;
384 let mut prop: cuda_sys::CUmemAllocationProp = std::mem::zeroed();
385 prop.type_ = cuda_sys::CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_PINNED;
386 prop.location.type_ = cuda_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE;
387 prop.location.id = device;
388 prop.allocFlags.gpuDirectRDMACapable = 1;
389 prop.requestedHandleTypes =
390 cuda_sys::CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
391
392 cu_check!(cuda_sys::cuMemGetAllocationGranularity(
393 &mut granularity as *mut usize,
394 &prop,
395 cuda_sys::CUmemAllocationGranularity_flags::CU_MEM_ALLOC_GRANULARITY_MINIMUM,
396 ));
397
398 let padded_size: usize = ((buffer_size - 1) / granularity + 1) * granularity;
400 assert!(padded_size == buffer_size);
401
402 cu_check!(cuda_sys::cuMemCreate(
403 &mut handle as *mut cuda_sys::CUmemGenericAllocationHandle,
404 padded_size,
405 &prop,
406 0
407 ));
408 cu_check!(cuda_sys::cuMemAddressReserve(
410 &mut dptr as *mut cuda_sys::CUdeviceptr,
411 padded_size,
412 0,
413 0,
414 0,
415 ));
416 assert!(dptr as usize % granularity == 0);
417 assert!(padded_size % granularity == 0);
418
419 let err = cuda_sys::cuMemMap(
421 dptr as cuda_sys::CUdeviceptr,
422 padded_size,
423 0,
424 handle as cuda_sys::CUmemGenericAllocationHandle,
425 0,
426 );
427 if err != cuda_sys::CUresult::CUDA_SUCCESS {
428 panic!("failed reserving and mapping memory {:?}", err);
429 }
430
431 let mut access_desc: cuda_sys::CUmemAccessDesc = std::mem::zeroed();
433 access_desc.location.type_ =
434 cuda_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE;
435 access_desc.location.id = device;
436 access_desc.flags =
437 cuda_sys::CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
438 cu_check!(cuda_sys::cuMemSetAccess(dptr, padded_size, &access_desc, 1));
439 buf_vec.push(Buffer {
440 ptr: dptr,
441 len: padded_size,
442 cpu_ref: None,
443 });
444 }
445 }
446
447 if device_str1.0 == "cuda" {
449 let mut temp_buffer = vec![0u8; buffer_size].into_boxed_slice();
450 for (i, val) in temp_buffer.iter_mut().enumerate() {
451 *val = (i % 256) as u8;
452 }
453 unsafe {
454 cu_check!(cuda_sys::cuMemcpyHtoD_v2(
455 buf_vec[0].ptr,
456 temp_buffer.as_ptr() as *const std::ffi::c_void,
457 temp_buffer.len()
458 ));
459 }
460 } else {
461 unsafe {
462 let ptr = buf_vec[0].ptr as *mut u8; for i in 0..buf_vec[0].len {
464 *ptr.add(i) = (i % 256) as u8;
465 }
466 }
467 }
468 let actor_1 = actor_mesh_1.get(0).unwrap();
469 let actor_2 = actor_mesh_2.get(0).unwrap();
470
471 let rdma_handle_1 = actor_1
472 .request_buffer(proc_mesh_1.client(), buf_vec[0].ptr as usize, buffer_size)
473 .await?;
474 let rdma_handle_2 = actor_2
475 .request_buffer(proc_mesh_2.client(), buf_vec[1].ptr as usize, buffer_size)
476 .await?;
477 let buffer_2 = buf_vec.remove(1);
480 let buffer_1 = buf_vec.remove(0);
481 Ok(Self {
482 buffer_1,
483 buffer_2,
484 client_1: proc_mesh_1.client(),
485 client_2: proc_mesh_2.client(),
486 actor_1,
487 actor_2,
488 rdma_handle_1,
489 rdma_handle_2,
490 handle_1_cuda: device_str1.0 == "cuda",
491 handle_2_cuda: device_str2.0 == "cuda",
492 })
493 }
494
495 pub async fn cleanup(self) -> Result<(), anyhow::Error> {
496 self.actor_1
497 .release_buffer(self.client_1, self.rdma_handle_1.clone())
498 .await?;
499 self.actor_2
500 .release_buffer(self.client_2, self.rdma_handle_2.clone())
501 .await?;
502 if self.handle_1_cuda {
503 unsafe {
504 cu_check!(cuda_sys::cuMemUnmap(
505 self.buffer_1.ptr as cuda_sys::CUdeviceptr,
506 self.buffer_1.len
507 ));
508 cu_check!(cuda_sys::cuMemAddressFree(
509 self.buffer_1.ptr as cuda_sys::CUdeviceptr,
510 self.buffer_1.len
511 ));
512 }
513 }
514 if self.handle_2_cuda {
515 unsafe {
516 cu_check!(cuda_sys::cuMemUnmap(
517 self.buffer_2.ptr as cuda_sys::CUdeviceptr,
518 self.buffer_2.len
519 ));
520 cu_check!(cuda_sys::cuMemAddressFree(
521 self.buffer_2.ptr as cuda_sys::CUdeviceptr,
522 self.buffer_2.len
523 ));
524 }
525 }
526 Ok(())
527 }
528
529 pub async fn verify_buffers(&self, size: usize) -> Result<(), anyhow::Error> {
530 let mut buf_vec = Vec::new();
531 for (handle, is_cuda) in [
532 (self.rdma_handle_1.clone(), self.handle_1_cuda),
533 (self.rdma_handle_2.clone(), self.handle_2_cuda),
534 ] {
535 if is_cuda {
536 let mut temp_buffer = vec![0u8; size].into_boxed_slice();
537 unsafe {
539 cu_check!(cuda_sys::cuMemcpyDtoH_v2(
540 temp_buffer.as_mut_ptr() as *mut std::ffi::c_void,
541 handle.addr as cuda_sys::CUdeviceptr,
542 size
543 ));
544 }
545 buf_vec.push(Buffer {
546 ptr: temp_buffer.as_mut_ptr() as u64,
547 len: size,
548 cpu_ref: Some(temp_buffer),
549 });
550 } else {
551 buf_vec.push(Buffer {
552 ptr: handle.addr as u64,
553 len: size,
554 cpu_ref: None,
555 });
556 }
557 }
558 unsafe {
560 let ptr1 = buf_vec[0].ptr as *mut u8;
561 let ptr2: *mut u8 = buf_vec[1].ptr as *mut u8;
562 for i in 0..buf_vec[0].len {
563 if *ptr1.add(i) != *ptr2.add(i) {
564 return Err(anyhow::anyhow!("Buffers are not equal at index {}", i));
565 }
566 }
567 }
568 Ok(())
569 }
570 }
571}