1use std::sync::Once;
10use std::sync::atomic::AtomicBool;
11use std::sync::atomic::Ordering;
12
13static CUDA_AVAILABLE: AtomicBool = AtomicBool::new(false);
15static INIT: Once = Once::new();
16
17pub fn is_cuda_available() -> bool {
38 INIT.call_once(|| {
39 let available = check_cuda_available();
40 CUDA_AVAILABLE.store(available, Ordering::SeqCst);
41 });
42 CUDA_AVAILABLE.load(Ordering::SeqCst)
43}
44
45fn check_cuda_available() -> bool {
47 unsafe {
48 let result = cuda_sys::cuInit(0);
50
51 if result != cuda_sys::CUresult::CUDA_SUCCESS {
52 return false;
53 }
54
55 let mut device_count: i32 = 0;
57 let count_result = cuda_sys::cuDeviceGetCount(&mut device_count);
58
59 if count_result != cuda_sys::CUresult::CUDA_SUCCESS || device_count <= 0 {
60 return false;
61 }
62
63 let mut device: cuda_sys::CUdevice = std::mem::zeroed();
65 let device_result = cuda_sys::cuDeviceGet(&mut device, 0);
66
67 if device_result != cuda_sys::CUresult::CUDA_SUCCESS {
68 return false;
69 }
70
71 true
72 }
73}
74
75#[cfg(test)]
76pub mod test_utils {
77 use std::time::Duration;
78 use std::time::Instant;
79
80 use hyperactor::ActorRef;
81 use hyperactor::Instance;
82 use hyperactor::channel::ChannelTransport;
83 use hyperactor::clock::Clock;
84 use hyperactor::clock::RealClock;
85 use hyperactor_mesh::Mesh;
86 use hyperactor_mesh::ProcMesh;
87 use hyperactor_mesh::RootActorMesh;
88 use hyperactor_mesh::alloc::AllocSpec;
89 use hyperactor_mesh::alloc::Allocator;
90 use hyperactor_mesh::alloc::LocalAllocator;
91 use ndslice::extent;
92
93 use crate::IbverbsConfig;
94 use crate::RdmaBuffer;
95 use crate::cu_check;
96 use crate::ibverbs_primitives::get_all_devices;
97 use crate::rdma_components::PollTarget;
98 use crate::rdma_components::RdmaQueuePair;
99 use crate::rdma_manager_actor::RdmaManagerActor;
100 use crate::rdma_manager_actor::RdmaManagerMessageClient;
101 use crate::validate_execution_context;
102 pub async fn wait_for_completion(
110 qp: &mut RdmaQueuePair,
111 poll_target: PollTarget,
112 timeout_secs: u64,
113 ) -> Result<bool, anyhow::Error> {
114 let timeout = Duration::from_secs(timeout_secs);
115 let start_time = Instant::now();
116 while start_time.elapsed() < timeout {
117 match qp.poll_completion_target(poll_target) {
118 Ok(Some(_wc)) => {
119 return Ok(true);
120 }
121 Ok(None) => {
122 RealClock.sleep(Duration::from_millis(1)).await;
123 }
124 Err(e) => {
125 return Err(anyhow::anyhow!(e));
126 }
127 }
128 }
129 Err(anyhow::Error::msg("Timeout while waiting for completion"))
130 }
131
132 pub async fn send_wqe_gpu(
134 qp: &mut RdmaQueuePair,
135 lhandle: &RdmaBuffer,
136 rhandle: &RdmaBuffer,
137 op_type: u32,
138 ) -> Result<(), anyhow::Error> {
139 unsafe {
140 let ibv_qp = qp.qp as *mut rdmaxcel_sys::ibv_qp;
141 let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
142 let params = rdmaxcel_sys::wqe_params_t {
143 laddr: lhandle.addr,
144 length: lhandle.size,
145 lkey: lhandle.lkey,
146 wr_id: qp.send_wqe_idx,
147 signaled: true,
148 op_type,
149 raddr: rhandle.addr,
150 rkey: rhandle.rkey,
151 qp_num: (*ibv_qp).qp_num,
152 buf: (*dv_qp).sq.buf as *mut u8,
153 wqe_cnt: (*dv_qp).sq.wqe_cnt,
154 dbrec: (*dv_qp).dbrec,
155 ..Default::default()
156 };
157 rdmaxcel_sys::launch_send_wqe(params);
158 qp.send_wqe_idx += 1;
159 }
160 Ok(())
161 }
162
163 pub async fn recv_wqe_gpu(
165 qp: &mut RdmaQueuePair,
166 lhandle: &RdmaBuffer,
167 _rhandle: &RdmaBuffer,
168 op_type: u32,
169 ) -> Result<(), anyhow::Error> {
170 unsafe {
172 let ibv_qp = qp.qp as *mut rdmaxcel_sys::ibv_qp;
173 let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
174 let params = rdmaxcel_sys::wqe_params_t {
175 laddr: lhandle.addr,
176 length: lhandle.size,
177 lkey: lhandle.lkey,
178 wr_id: qp.recv_wqe_idx,
179 op_type,
180 signaled: true,
181 qp_num: (*ibv_qp).qp_num,
182 buf: (*dv_qp).rq.buf as *mut u8,
183 wqe_cnt: (*dv_qp).rq.wqe_cnt,
184 dbrec: (*dv_qp).dbrec,
185 ..Default::default()
186 };
187 rdmaxcel_sys::launch_recv_wqe(params);
188 qp.recv_wqe_idx += 1;
189 qp.recv_db_idx += 1;
190 }
191 Ok(())
192 }
193
194 pub async fn ring_db_gpu(qp: &mut RdmaQueuePair) -> Result<(), anyhow::Error> {
195 unsafe {
196 let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
197 let base_ptr = (*dv_qp).sq.buf as *mut u8;
198 let wqe_cnt = (*dv_qp).sq.wqe_cnt;
199 let stride = (*dv_qp).sq.stride;
200 if (wqe_cnt as u64) < (qp.send_wqe_idx - qp.send_db_idx) {
201 return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
202 }
203 while qp.send_db_idx < qp.send_wqe_idx {
204 let offset = (qp.send_db_idx % wqe_cnt as u64) * stride as u64;
205 let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
206 rdmaxcel_sys::launch_db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
207 qp.send_db_idx += 1;
208 }
209 }
210 Ok(())
211 }
212
213 pub async fn wait_for_completion_gpu(
215 qp: &mut RdmaQueuePair,
216 poll_target: PollTarget,
217 timeout_secs: u64,
218 ) -> Result<bool, anyhow::Error> {
219 let timeout = Duration::from_secs(timeout_secs);
220 let start_time = Instant::now();
221
222 while start_time.elapsed() < timeout {
223 let (cq, idx, cq_type_str) = match poll_target {
225 PollTarget::Send => (
226 qp.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq,
227 qp.send_cq_idx,
228 "send",
229 ),
230 PollTarget::Recv => (
231 qp.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq,
232 qp.recv_cq_idx,
233 "receive",
234 ),
235 };
236
237 let result =
239 unsafe { rdmaxcel_sys::launch_cqe_poll(cq as *mut std::ffi::c_void, idx as i32) };
240
241 match result {
242 rdmaxcel_sys::CQE_POLL_TRUE => {
243 match poll_target {
245 PollTarget::Send => qp.send_cq_idx += 1,
246 PollTarget::Recv => qp.recv_cq_idx += 1,
247 }
248 return Ok(true);
249 }
250 rdmaxcel_sys::CQE_POLL_ERROR => {
251 return Err(anyhow::anyhow!("Error polling {} completion", cq_type_str));
252 }
253 _ => {
254 RealClock.sleep(Duration::from_millis(1)).await;
256 }
257 }
258 }
259
260 Err(anyhow::Error::msg("Timeout while waiting for completion"))
261 }
262
263 pub struct RdmaManagerTestEnv<'a> {
264 buffer_1: Buffer,
265 buffer_2: Buffer,
266 pub client_1: &'a Instance<()>,
267 pub client_2: &'a Instance<()>,
268 pub actor_1: ActorRef<RdmaManagerActor>,
269 pub actor_2: ActorRef<RdmaManagerActor>,
270 pub rdma_handle_1: RdmaBuffer,
271 pub rdma_handle_2: RdmaBuffer,
272 cuda_context_1: Option<cuda_sys::CUcontext>,
273 cuda_context_2: Option<cuda_sys::CUcontext>,
274 }
275
276 #[derive(Debug, Clone)]
277 pub struct Buffer {
278 ptr: u64,
279 len: usize,
280 #[allow(dead_code)]
281 cpu_ref: Option<Box<[u8]>>,
282 }
283 impl RdmaManagerTestEnv<'_> {
284 pub async fn setup(
297 buffer_size: usize,
298 nics: (&str, &str),
299 accels: (&str, &str),
300 ) -> Result<Self, anyhow::Error> {
301 let all_nic_devices = get_all_devices();
302 let mut config1 = IbverbsConfig::default();
303 let mut config2 = IbverbsConfig::default();
304
305 for nic_device in all_nic_devices.iter() {
306 if nic_device.name == nics.0 {
307 config1.device = nic_device.clone();
308 }
309 if nic_device.name == nics.1 {
310 config2.device = nic_device.clone();
311 }
312 }
313 let accel1;
314 let accel2;
315
316 if let Some((backend, idx)) = accels.0.split_once(':') {
317 assert!(backend == "cuda");
318 let parsed_idx = idx
319 .parse::<usize>()
320 .expect("Device index is not a valid integer");
321 accel1 = (backend.to_string(), parsed_idx);
322 config1.use_gpu_direct = validate_execution_context().await.is_ok();
323 } else {
324 assert!(accels.0 == "cpu");
325 accel1 = ("cpu".to_string(), 0);
326 }
327
328 if let Some((backend, idx)) = accels.1.split_once(':') {
329 assert!(backend == "cuda");
330 let parsed_idx = idx
331 .parse::<usize>()
332 .expect("Device index is not a valid integer");
333 accel2 = (backend.to_string(), parsed_idx);
334 config2.use_gpu_direct = validate_execution_context().await.is_ok();
335 } else {
336 assert!(accels.1 == "cpu");
337 accel2 = ("cpu".to_string(), 0);
338 }
339
340 let alloc_1 = LocalAllocator
341 .allocate(AllocSpec {
342 extent: extent! { proc = 1 },
343 constraints: Default::default(),
344 proc_name: None,
345 transport: ChannelTransport::Local,
346 })
347 .await
348 .unwrap();
349
350 let proc_mesh_1 = Box::leak(Box::new(ProcMesh::allocate(alloc_1).await.unwrap()));
351 let actor_mesh_1: RootActorMesh<'_, RdmaManagerActor> = proc_mesh_1
352 .spawn("rdma_manager", &Some(config1))
353 .await
354 .unwrap();
355
356 let alloc_2 = LocalAllocator
357 .allocate(AllocSpec {
358 extent: extent! { proc = 1 },
359 constraints: Default::default(),
360 proc_name: None,
361 transport: ChannelTransport::Local,
362 })
363 .await
364 .unwrap();
365
366 let proc_mesh_2 = Box::leak(Box::new(ProcMesh::allocate(alloc_2).await.unwrap()));
367 let actor_mesh_2: RootActorMesh<'_, RdmaManagerActor> = proc_mesh_2
368 .spawn("rdma_manager", &Some(config2))
369 .await
370 .unwrap();
371
372 let mut buf_vec = Vec::new();
373 let mut cuda_contexts = Vec::new();
374
375 for accel in [accel1.clone(), accel2.clone()] {
376 if accel.0 == "cpu" {
377 let mut buffer = vec![0u8; buffer_size].into_boxed_slice();
378 buf_vec.push(Buffer {
379 ptr: buffer.as_mut_ptr() as u64,
380 len: buffer.len(),
381 cpu_ref: Some(buffer),
382 });
383 cuda_contexts.push(None);
384 continue;
385 }
386 unsafe {
388 cu_check!(cuda_sys::cuInit(0));
389
390 let mut dptr: cuda_sys::CUdeviceptr = std::mem::zeroed();
391 let mut handle: cuda_sys::CUmemGenericAllocationHandle = std::mem::zeroed();
392
393 let mut device: cuda_sys::CUdevice = std::mem::zeroed();
394 cu_check!(cuda_sys::cuDeviceGet(&mut device, accel.1 as i32));
395
396 let mut context: cuda_sys::CUcontext = std::mem::zeroed();
397 cu_check!(cuda_sys::cuCtxCreate_v2(&mut context, 0, accel.1 as i32));
398 cu_check!(cuda_sys::cuCtxSetCurrent(context));
399
400 let mut granularity: usize = 0;
401 let mut prop: cuda_sys::CUmemAllocationProp = std::mem::zeroed();
402 prop.type_ = cuda_sys::CUmemAllocationType::CU_MEM_ALLOCATION_TYPE_PINNED;
403 prop.location.type_ = cuda_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE;
404 prop.location.id = device;
405 prop.allocFlags.gpuDirectRDMACapable = 1;
406 prop.requestedHandleTypes =
407 cuda_sys::CUmemAllocationHandleType::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
408
409 cu_check!(cuda_sys::cuMemGetAllocationGranularity(
410 &mut granularity as *mut usize,
411 &prop,
412 cuda_sys::CUmemAllocationGranularity_flags::CU_MEM_ALLOC_GRANULARITY_MINIMUM,
413 ));
414
415 let padded_size: usize = ((buffer_size - 1) / granularity + 1) * granularity;
417 assert!(padded_size == buffer_size);
418
419 cu_check!(cuda_sys::cuMemCreate(
420 &mut handle as *mut cuda_sys::CUmemGenericAllocationHandle,
421 padded_size,
422 &prop,
423 0
424 ));
425 cu_check!(cuda_sys::cuMemAddressReserve(
427 &mut dptr as *mut cuda_sys::CUdeviceptr,
428 padded_size,
429 0,
430 0,
431 0,
432 ));
433 assert!(dptr as usize % granularity == 0);
434 assert!(padded_size % granularity == 0);
435
436 let err = cuda_sys::cuMemMap(
438 dptr as cuda_sys::CUdeviceptr,
439 padded_size,
440 0,
441 handle as cuda_sys::CUmemGenericAllocationHandle,
442 0,
443 );
444 if err != cuda_sys::CUresult::CUDA_SUCCESS {
445 panic!("failed reserving and mapping memory {:?}", err);
446 }
447
448 let mut access_desc: cuda_sys::CUmemAccessDesc = std::mem::zeroed();
450 access_desc.location.type_ =
451 cuda_sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE;
452 access_desc.location.id = device;
453 access_desc.flags =
454 cuda_sys::CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
455 cu_check!(cuda_sys::cuMemSetAccess(dptr, padded_size, &access_desc, 1));
456 buf_vec.push(Buffer {
457 ptr: dptr,
458 len: padded_size,
459 cpu_ref: None,
460 });
461 cuda_contexts.push(Some(context));
462 }
463 }
464
465 if accel1.0 == "cuda" {
467 let mut temp_buffer = vec![0u8; buffer_size].into_boxed_slice();
468 for (i, val) in temp_buffer.iter_mut().enumerate() {
469 *val = (i % 256) as u8;
470 }
471 unsafe {
472 cu_check!(cuda_sys::cuCtxSetCurrent(
474 cuda_contexts[0].expect("No CUDA context found")
475 ));
476
477 cu_check!(cuda_sys::cuMemcpyHtoD_v2(
478 buf_vec[0].ptr,
479 temp_buffer.as_ptr() as *const std::ffi::c_void,
480 temp_buffer.len()
481 ));
482 }
483 } else {
484 unsafe {
485 let ptr = buf_vec[0].ptr as *mut u8; for i in 0..buf_vec[0].len {
487 *ptr.add(i) = (i % 256) as u8;
488 }
489 }
490 }
491 let actor_1 = actor_mesh_1.get(0).unwrap();
492 let actor_2 = actor_mesh_2.get(0).unwrap();
493
494 let rdma_handle_1 = actor_1
495 .request_buffer(proc_mesh_1.client(), buf_vec[0].ptr as usize, buffer_size)
496 .await?;
497 let rdma_handle_2 = actor_2
498 .request_buffer(proc_mesh_2.client(), buf_vec[1].ptr as usize, buffer_size)
499 .await?;
500 let buffer_2 = buf_vec.remove(1);
503 let buffer_1 = buf_vec.remove(0);
504 Ok(Self {
505 buffer_1,
506 buffer_2,
507 client_1: proc_mesh_1.client(),
508 client_2: proc_mesh_2.client(),
509 actor_1,
510 actor_2,
511 rdma_handle_1,
512 rdma_handle_2,
513 cuda_context_1: cuda_contexts.first().cloned().flatten(),
514 cuda_context_2: cuda_contexts.get(1).cloned().flatten(),
515 })
516 }
517
518 pub async fn cleanup(self) -> Result<(), anyhow::Error> {
519 self.actor_1
520 .release_buffer(self.client_1, self.rdma_handle_1.clone())
521 .await?;
522 self.actor_2
523 .release_buffer(self.client_2, self.rdma_handle_2.clone())
524 .await?;
525 if self.cuda_context_1.is_some() {
526 unsafe {
527 cu_check!(cuda_sys::cuCtxSetCurrent(
528 self.cuda_context_1.expect("No CUDA context found")
529 ));
530 cu_check!(cuda_sys::cuMemUnmap(
531 self.buffer_1.ptr as cuda_sys::CUdeviceptr,
532 self.buffer_1.len
533 ));
534 cu_check!(cuda_sys::cuMemAddressFree(
535 self.buffer_1.ptr as cuda_sys::CUdeviceptr,
536 self.buffer_1.len
537 ));
538 }
539 }
540 if self.cuda_context_2.is_some() {
541 unsafe {
542 cu_check!(cuda_sys::cuCtxSetCurrent(
543 self.cuda_context_2.expect("No CUDA context found")
544 ));
545 cu_check!(cuda_sys::cuMemUnmap(
546 self.buffer_2.ptr as cuda_sys::CUdeviceptr,
547 self.buffer_2.len
548 ));
549 cu_check!(cuda_sys::cuMemAddressFree(
550 self.buffer_2.ptr as cuda_sys::CUdeviceptr,
551 self.buffer_2.len
552 ));
553 }
554 }
555 Ok(())
556 }
557
558 pub async fn verify_buffers(&self, size: usize) -> Result<(), anyhow::Error> {
559 let mut buf_vec = Vec::new();
560 for (virtual_addr, cuda_context) in [
561 (self.buffer_1.ptr, self.cuda_context_1),
562 (self.buffer_2.ptr, self.cuda_context_2),
563 ] {
564 if cuda_context.is_some() {
565 let mut temp_buffer = vec![0u8; size].into_boxed_slice();
566 unsafe {
568 cu_check!(cuda_sys::cuCtxSetCurrent(
569 cuda_context.expect("No CUDA context found")
570 ));
571 cu_check!(cuda_sys::cuMemcpyDtoH_v2(
572 temp_buffer.as_mut_ptr() as *mut std::ffi::c_void,
573 virtual_addr as cuda_sys::CUdeviceptr,
574 size
575 ));
576 }
577 buf_vec.push(Buffer {
578 ptr: temp_buffer.as_mut_ptr() as u64,
579 len: size,
580 cpu_ref: Some(temp_buffer),
581 });
582 } else {
583 buf_vec.push(Buffer {
584 ptr: virtual_addr,
585 len: size,
586 cpu_ref: None,
587 });
588 }
589 }
590 unsafe {
592 let ptr1 = buf_vec[0].ptr as *mut u8;
593 let ptr2: *mut u8 = buf_vec[1].ptr as *mut u8;
594 for i in 0..buf_vec[0].len {
595 if *ptr1.add(i) != *ptr2.add(i) {
596 return Err(anyhow::anyhow!("Buffers are not equal at index {}", i));
597 }
598 }
599 }
600 Ok(())
601 }
602 }
603}