1use std::sync::Once;
10use std::sync::atomic::AtomicBool;
11use std::sync::atomic::Ordering;
12
13static CUDA_AVAILABLE: AtomicBool = AtomicBool::new(false);
15static INIT: Once = Once::new();
16
17pub fn is_cuda_available() -> bool {
38 INIT.call_once(|| {
39 let available = check_cuda_available();
40 CUDA_AVAILABLE.store(available, Ordering::SeqCst);
41 });
42 CUDA_AVAILABLE.load(Ordering::SeqCst)
43}
44
45fn check_cuda_available() -> bool {
47 unsafe {
48 let result = rdmaxcel_sys::rdmaxcel_cuInit(0);
50
51 if result != rdmaxcel_sys::CUDA_SUCCESS {
52 return false;
53 }
54
55 let mut device_count: i32 = 0;
57 let count_result = rdmaxcel_sys::rdmaxcel_cuDeviceGetCount(&mut device_count);
58
59 if count_result != rdmaxcel_sys::CUDA_SUCCESS || device_count <= 0 {
60 return false;
61 }
62
63 let mut device: rdmaxcel_sys::CUdevice = std::mem::zeroed();
65 let device_result = rdmaxcel_sys::rdmaxcel_cuDeviceGet(&mut device, 0);
66
67 if device_result != rdmaxcel_sys::CUDA_SUCCESS {
68 return false;
69 }
70
71 true
72 }
73}
74
75#[cfg(test)]
76pub mod test_utils {
77 use std::time::Duration;
78 use std::time::Instant;
79
80 use hyperactor::Actor;
81 use hyperactor::ActorRef;
82 use hyperactor::Context;
83 use hyperactor::HandleClient;
84 use hyperactor::Handler;
85 use hyperactor::Instance;
86 use hyperactor::Proc;
87 use hyperactor::RefClient;
88 use hyperactor::channel::ChannelTransport;
89 use hyperactor::clock::Clock;
90 use hyperactor::clock::RealClock;
91 use hyperactor_mesh::Mesh;
92 use hyperactor_mesh::ProcMesh;
93 use hyperactor_mesh::RootActorMesh;
94 use hyperactor_mesh::alloc::AllocSpec;
95 use hyperactor_mesh::alloc::Allocator;
96 use hyperactor_mesh::alloc::LocalAllocator;
97 use ndslice::extent;
98
99 use crate::IbverbsConfig;
100 use crate::RdmaBuffer;
101 use crate::cu_check;
102 use crate::rdma_components::PollTarget;
103 use crate::rdma_components::RdmaQueuePair;
104 use crate::rdma_manager_actor::RdmaManagerActor;
105 use crate::rdma_manager_actor::RdmaManagerMessageClient;
106 use crate::validate_execution_context;
107
108 #[derive(Debug)]
109 struct SendSyncCudaContext(rdmaxcel_sys::CUcontext);
110 unsafe impl Send for SendSyncCudaContext {}
111 unsafe impl Sync for SendSyncCudaContext {}
112
113 #[hyperactor::export(
116 spawn = true,
117 handlers = [
118 CudaActorMessage,
119 ],
120 )]
121 #[derive(Debug)]
122 pub struct CudaActor {
123 device: Option<i32>,
124 context: SendSyncCudaContext,
125 }
126
127 #[async_trait::async_trait]
128 impl Actor for CudaActor {
129 type Params = i32;
130
131 async fn new(device_id: i32) -> Result<Self, anyhow::Error> {
132 unsafe {
133 cu_check!(rdmaxcel_sys::rdmaxcel_cuInit(0));
134 let mut device: rdmaxcel_sys::CUdevice = std::mem::zeroed();
135 cu_check!(rdmaxcel_sys::rdmaxcel_cuDeviceGet(&mut device, device_id));
136 let mut context: rdmaxcel_sys::CUcontext = std::mem::zeroed();
137 cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxCreate_v2(
138 &mut context,
139 0,
140 device_id
141 ));
142
143 Ok(Self {
144 device: Some(device),
145 context: SendSyncCudaContext(context),
146 })
147 }
148 }
149 }
150
151 #[derive(
152 Handler,
153 HandleClient,
154 RefClient,
155 hyperactor::Named,
156 serde::Serialize,
157 serde::Deserialize,
158 Debug
159 )]
160 pub enum CudaActorMessage {
161 CreateBuffer {
162 size: usize,
163 rdma_actor: ActorRef<RdmaManagerActor>,
164 #[reply]
165 reply: hyperactor::OncePortRef<(RdmaBuffer, usize)>,
166 },
167 FillBuffer {
168 device_ptr: usize,
169 size: usize,
170 value: u8,
171 #[reply]
172 reply: hyperactor::OncePortRef<()>,
173 },
174 VerifyBuffer {
175 cpu_buffer_ptr: usize,
176 device_ptr: usize,
177 size: usize,
178 #[reply]
179 reply: hyperactor::OncePortRef<()>,
180 },
181 }
182
183 #[async_trait::async_trait]
184 impl Handler<CudaActorMessage> for CudaActor {
185 async fn handle(
186 &mut self,
187 cx: &Context<Self>,
188 msg: CudaActorMessage,
189 ) -> Result<(), anyhow::Error> {
190 match msg {
191 CudaActorMessage::CreateBuffer {
192 size,
193 rdma_actor,
194 reply,
195 } => {
196 let device = self
197 .device
198 .ok_or_else(|| anyhow::anyhow!("Device not initialized"))?;
199
200 let (dptr, padded_size) = unsafe {
201 cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.context.0));
202
203 let mut dptr: rdmaxcel_sys::CUdeviceptr = std::mem::zeroed();
204 let mut handle: rdmaxcel_sys::CUmemGenericAllocationHandle =
205 std::mem::zeroed();
206
207 let mut granularity: usize = 0;
208 let mut prop: rdmaxcel_sys::CUmemAllocationProp = std::mem::zeroed();
209 prop.type_ = rdmaxcel_sys::CU_MEM_ALLOCATION_TYPE_PINNED;
210 prop.location.type_ = rdmaxcel_sys::CU_MEM_LOCATION_TYPE_DEVICE;
211 prop.location.id = device;
212 prop.allocFlags.gpuDirectRDMACapable = 1;
213 prop.requestedHandleTypes =
214 rdmaxcel_sys::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
215
216 cu_check!(rdmaxcel_sys::rdmaxcel_cuMemGetAllocationGranularity(
217 &mut granularity as *mut usize,
218 &prop,
219 rdmaxcel_sys::CU_MEM_ALLOC_GRANULARITY_MINIMUM,
220 ));
221
222 let padded_size: usize = ((size - 1) / granularity + 1) * granularity;
223
224 cu_check!(rdmaxcel_sys::rdmaxcel_cuMemCreate(
225 &mut handle,
226 padded_size,
227 &prop,
228 0
229 ));
230
231 cu_check!(rdmaxcel_sys::rdmaxcel_cuMemAddressReserve(
232 &mut dptr,
233 padded_size,
234 0,
235 0,
236 0,
237 ));
238
239 assert!((dptr as usize).is_multiple_of(granularity));
240 assert!(padded_size.is_multiple_of(granularity));
241
242 let err = rdmaxcel_sys::rdmaxcel_cuMemMap(dptr, padded_size, 0, handle, 0);
243 if err != rdmaxcel_sys::CUDA_SUCCESS {
244 return Err(anyhow::anyhow!("Failed to map CUDA memory: {:?}", err));
245 }
246
247 let mut access_desc: rdmaxcel_sys::CUmemAccessDesc = std::mem::zeroed();
248 access_desc.location.type_ = rdmaxcel_sys::CU_MEM_LOCATION_TYPE_DEVICE;
249 access_desc.location.id = device;
250 access_desc.flags = rdmaxcel_sys::CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
251 cu_check!(rdmaxcel_sys::rdmaxcel_cuMemSetAccess(
252 dptr,
253 padded_size,
254 &access_desc,
255 1
256 ));
257
258 (dptr, padded_size)
259 };
260
261 let rdma_handle = rdma_actor
262 .request_buffer(cx, dptr as usize, padded_size)
263 .await?;
264
265 reply.send(cx, (rdma_handle, dptr as usize))?;
266 Ok(())
267 }
268 CudaActorMessage::FillBuffer {
269 device_ptr,
270 size,
271 value,
272 reply,
273 } => {
274 unsafe {
275 cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.context.0));
276
277 cu_check!(rdmaxcel_sys::rdmaxcel_cuMemsetD8_v2(
278 device_ptr as rdmaxcel_sys::CUdeviceptr,
279 value,
280 size
281 ));
282 }
283
284 reply.send(cx, ())?;
285 Ok(())
286 }
287 CudaActorMessage::VerifyBuffer {
288 cpu_buffer_ptr,
289 device_ptr,
290 size,
291 reply,
292 } => {
293 unsafe {
294 cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.context.0));
295
296 cu_check!(rdmaxcel_sys::rdmaxcel_cuMemcpyDtoH_v2(
297 cpu_buffer_ptr as *mut std::ffi::c_void,
298 device_ptr as rdmaxcel_sys::CUdeviceptr,
299 size
300 ));
301 }
302
303 reply.send(cx, ())?;
304 Ok(())
305 }
306 }
307 }
308 }
309
310 pub async fn wait_for_completion(
326 qp: &mut RdmaQueuePair,
327 poll_target: PollTarget,
328 expected_wr_ids: &[u64],
329 timeout_secs: u64,
330 ) -> Result<bool, anyhow::Error> {
331 let timeout = Duration::from_secs(timeout_secs);
332 let start_time = Instant::now();
333
334 let mut remaining: std::collections::HashSet<u64> =
335 expected_wr_ids.iter().copied().collect();
336
337 while start_time.elapsed() < timeout {
338 if remaining.is_empty() {
339 return Ok(true);
340 }
341
342 let wr_ids_to_poll: Vec<u64> = remaining.iter().copied().collect();
343 match qp.poll_completion(poll_target, &wr_ids_to_poll) {
344 Ok(completions) => {
345 for (wr_id, _wc) in completions {
346 remaining.remove(&wr_id);
347 }
348 if remaining.is_empty() {
349 return Ok(true);
350 }
351 RealClock.sleep(Duration::from_millis(1)).await;
352 }
353 Err(e) => {
354 return Err(anyhow::anyhow!(e));
355 }
356 }
357 }
358 Err(anyhow::Error::msg(format!(
359 "Timeout while waiting for completion of wr_ids: {:?}",
360 remaining
361 )))
362 }
363
364 pub async fn send_wqe_gpu(
366 qp: &mut RdmaQueuePair,
367 lhandle: &RdmaBuffer,
368 rhandle: &RdmaBuffer,
369 op_type: u32,
370 ) -> Result<(), anyhow::Error> {
371 unsafe {
372 let ibv_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
373 let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
374 let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(ibv_qp);
375 let params = rdmaxcel_sys::wqe_params_t {
376 laddr: lhandle.addr,
377 length: lhandle.size,
378 lkey: lhandle.lkey,
379 wr_id: send_wqe_idx,
380 signaled: true,
381 op_type,
382 raddr: rhandle.addr,
383 rkey: rhandle.rkey,
384 qp_num: (*(*ibv_qp).ibv_qp).qp_num,
385 buf: (*dv_qp).sq.buf as *mut u8,
386 wqe_cnt: (*dv_qp).sq.wqe_cnt,
387 dbrec: (*dv_qp).dbrec,
388 ..Default::default()
389 };
390 rdmaxcel_sys::launch_send_wqe(params);
391 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(ibv_qp);
392 }
393 Ok(())
394 }
395
396 pub async fn recv_wqe_gpu(
398 qp: &mut RdmaQueuePair,
399 lhandle: &RdmaBuffer,
400 _rhandle: &RdmaBuffer,
401 op_type: u32,
402 ) -> Result<(), anyhow::Error> {
403 unsafe {
405 let rdmaxcel_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
406 let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
407 let recv_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_recv_wqe_idx(rdmaxcel_qp);
408 let params = rdmaxcel_sys::wqe_params_t {
409 laddr: lhandle.addr,
410 length: lhandle.size,
411 lkey: lhandle.lkey,
412 wr_id: recv_wqe_idx,
413 op_type,
414 signaled: true,
415 qp_num: (*(*rdmaxcel_qp).ibv_qp).qp_num,
416 buf: (*dv_qp).rq.buf as *mut u8,
417 wqe_cnt: (*dv_qp).rq.wqe_cnt,
418 dbrec: (*dv_qp).dbrec,
419 ..Default::default()
420 };
421 rdmaxcel_sys::launch_recv_wqe(params);
422 rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(rdmaxcel_qp);
423 rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(rdmaxcel_qp);
424 }
425 Ok(())
426 }
427
428 pub async fn ring_db_gpu(qp: &RdmaQueuePair) -> Result<(), anyhow::Error> {
429 RealClock.sleep(Duration::from_millis(2)).await;
430 unsafe {
431 let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
432 let base_ptr = (*dv_qp).sq.buf as *mut u8;
433 let wqe_cnt = (*dv_qp).sq.wqe_cnt;
434 let stride = (*dv_qp).sq.stride;
435 let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(
436 qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
437 );
438 let mut send_db_idx =
439 rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp);
440 if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) {
441 return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
442 }
443 while send_db_idx < send_wqe_idx {
444 let offset = (send_db_idx % wqe_cnt as u64) * stride as u64;
445 let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
446 rdmaxcel_sys::launch_db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
447 send_db_idx += 1;
448 rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx(
449 qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
450 send_db_idx,
451 );
452 }
453 }
454 Ok(())
455 }
456
457 pub async fn wait_for_completion_gpu(
459 qp: &mut RdmaQueuePair,
460 poll_target: PollTarget,
461 timeout_secs: u64,
462 ) -> Result<bool, anyhow::Error> {
463 unsafe {
464 let start_time = Instant::now();
465 let timeout = Duration::from_secs(timeout_secs);
466 let ibv_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
467
468 while start_time.elapsed() < timeout {
469 let (cq, idx, cq_type_str) = match poll_target {
471 PollTarget::Send => (
472 qp.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq,
473 rdmaxcel_sys::rdmaxcel_qp_load_send_cq_idx(ibv_qp),
474 "send",
475 ),
476 PollTarget::Recv => (
477 qp.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq,
478 rdmaxcel_sys::rdmaxcel_qp_load_recv_cq_idx(ibv_qp),
479 "receive",
480 ),
481 };
482
483 let result = rdmaxcel_sys::launch_cqe_poll(cq as *mut std::ffi::c_void, idx as i32);
485
486 match result {
487 rdmaxcel_sys::CQE_POLL_TRUE => {
488 match poll_target {
490 PollTarget::Send => {
491 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_cq_idx(ibv_qp);
492 }
493 PollTarget::Recv => {
494 rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_cq_idx(ibv_qp);
495 }
496 }
497 return Ok(true);
498 }
499 rdmaxcel_sys::CQE_POLL_ERROR => {
500 return Err(anyhow::anyhow!("Error polling {} completion", cq_type_str));
501 }
502 _ => {
503 RealClock.sleep(Duration::from_millis(1)).await;
505 }
506 }
507 }
508
509 Err(anyhow::Error::msg("Timeout while waiting for completion"))
510 }
511 }
512
513 pub struct RdmaManagerTestEnv<'a> {
514 buffer_1: Buffer,
515 buffer_2: Buffer,
516 pub client_1: &'a Instance<()>,
517 pub client_2: &'a Instance<()>,
518 pub actor_1: ActorRef<RdmaManagerActor>,
519 pub actor_2: ActorRef<RdmaManagerActor>,
520 pub rdma_handle_1: RdmaBuffer,
521 pub rdma_handle_2: RdmaBuffer,
522 cuda_actor_1: Option<ActorRef<CudaActor>>,
523 cuda_actor_2: Option<ActorRef<CudaActor>>,
524 device_ptr_1: Option<usize>,
525 device_ptr_2: Option<usize>,
526 }
527
528 #[derive(Debug, Clone)]
529 pub struct Buffer {
530 ptr: u64,
531 len: usize,
532 #[allow(dead_code)]
533 cpu_ref: Option<Box<[u8]>>,
534 }
535 async fn parse_accel(accel: &str, config: &mut IbverbsConfig) -> (String, usize) {
537 let (backend, idx) = accel.split_once(':').unwrap();
538 let parsed_idx = idx.parse::<usize>().unwrap();
539
540 if backend == "cuda" {
541 config.use_gpu_direct = validate_execution_context().await.is_ok();
542 eprintln!("Using GPU Direct: {}", config.use_gpu_direct);
543 }
544
545 (backend.to_string(), parsed_idx)
546 }
547
548 impl RdmaManagerTestEnv<'_> {
549 pub async fn setup_with_qp_type(
562 buffer_size: usize,
563 accel1: &str,
564 accel2: &str,
565 qp_type: crate::ibverbs_primitives::RdmaQpType,
566 ) -> Result<Self, anyhow::Error> {
567 let mut config1 = IbverbsConfig::targeting(accel1);
569 let mut config2 = IbverbsConfig::targeting(accel2);
570
571 config1.qp_type = qp_type;
573 config2.qp_type = qp_type;
574
575 let parsed_accel1 = parse_accel(accel1, &mut config1).await;
576 let parsed_accel2 = parse_accel(accel2, &mut config2).await;
577
578 let alloc_1 = LocalAllocator
579 .allocate(AllocSpec {
580 extent: extent! { proc = 1 },
581 constraints: Default::default(),
582 proc_name: None,
583 transport: ChannelTransport::Local,
584 proc_allocation_mode: Default::default(),
585 })
586 .await
587 .unwrap();
588
589 let (instance, _) = Proc::local().instance("test").unwrap();
590
591 let proc_mesh_1 = Box::leak(Box::new(ProcMesh::allocate(alloc_1).await.unwrap()));
592 let actor_mesh_1: RootActorMesh<'_, RdmaManagerActor> = proc_mesh_1
593 .spawn(&instance, "rdma_manager", &Some(config1))
594 .await
595 .unwrap();
596
597 let alloc_2 = LocalAllocator
598 .allocate(AllocSpec {
599 extent: extent! { proc = 1 },
600 constraints: Default::default(),
601 proc_name: None,
602 transport: ChannelTransport::Local,
603 proc_allocation_mode: Default::default(),
604 })
605 .await
606 .unwrap();
607
608 let proc_mesh_2 = Box::leak(Box::new(ProcMesh::allocate(alloc_2).await.unwrap()));
609 let actor_mesh_2: RootActorMesh<'_, RdmaManagerActor> = proc_mesh_2
610 .spawn(&instance, "rdma_manager", &Some(config2))
611 .await
612 .unwrap();
613
614 let actor_1 = actor_mesh_1.get(0).unwrap();
615 let actor_2 = actor_mesh_2.get(0).unwrap();
616
617 let mut buf_vec = Vec::new();
618 let mut cuda_actor_1 = None;
619 let mut cuda_actor_2 = None;
620 let mut device_ptr_1: Option<usize> = None;
621 let mut device_ptr_2: Option<usize> = None;
622
623 let rdma_handle_1;
624 let rdma_handle_2;
625
626 if parsed_accel1.0 == "cpu" {
628 let mut buffer = vec![0u8; buffer_size].into_boxed_slice();
629 buf_vec.push(Buffer {
630 ptr: buffer.as_mut_ptr() as u64,
631 len: buffer.len(),
632 cpu_ref: Some(buffer),
633 });
634 rdma_handle_1 = actor_1
635 .request_buffer(proc_mesh_1.client(), buf_vec[0].ptr as usize, buffer_size)
636 .await?;
637 } else {
638 let cuda_actor_mesh_1: RootActorMesh<'_, CudaActor> = proc_mesh_1
640 .spawn(&instance, "cuda_init", &(parsed_accel1.1 as i32))
641 .await?;
642 let cuda_actor_ref_1 = cuda_actor_mesh_1.get(0).unwrap();
643
644 let (rdma_buf, dev_ptr) = cuda_actor_ref_1
645 .create_buffer(proc_mesh_1.client(), buffer_size, actor_1.clone())
646 .await?;
647 rdma_handle_1 = rdma_buf;
648 device_ptr_1 = Some(dev_ptr);
649
650 buf_vec.push(Buffer {
651 ptr: rdma_handle_1.addr as u64,
652 len: buffer_size,
653 cpu_ref: None,
654 });
655 cuda_actor_1 = Some(cuda_actor_ref_1);
656 }
657
658 if parsed_accel2.0 == "cpu" {
660 let mut buffer = vec![0u8; buffer_size].into_boxed_slice();
661 buf_vec.push(Buffer {
662 ptr: buffer.as_mut_ptr() as u64,
663 len: buffer.len(),
664 cpu_ref: Some(buffer),
665 });
666 rdma_handle_2 = actor_2
667 .request_buffer(proc_mesh_2.client(), buf_vec[1].ptr as usize, buffer_size)
668 .await?;
669 } else {
670 let cuda_actor_mesh_2: RootActorMesh<'_, CudaActor> = proc_mesh_2
672 .spawn(&instance, "cuda_init", &(parsed_accel2.1 as i32))
673 .await?;
674 let cuda_actor_ref_2 = cuda_actor_mesh_2.get(0).unwrap();
675
676 let (rdma_buf, dev_ptr) = cuda_actor_ref_2
677 .create_buffer(proc_mesh_2.client(), buffer_size, actor_2.clone())
678 .await?;
679 rdma_handle_2 = rdma_buf;
680 device_ptr_2 = Some(dev_ptr);
681
682 buf_vec.push(Buffer {
683 ptr: rdma_handle_2.addr as u64,
684 len: buffer_size,
685 cpu_ref: None,
686 });
687 cuda_actor_2 = Some(cuda_actor_ref_2);
688 }
689
690 if parsed_accel1.0 == "cuda" {
692 cuda_actor_1
693 .clone()
694 .unwrap()
695 .fill_buffer(proc_mesh_1.client(), device_ptr_1.unwrap(), buffer_size, 42)
696 .await?;
697 } else {
698 unsafe {
699 let ptr = buf_vec[0].ptr as *mut u8;
700 for i in 0..buf_vec[0].len {
701 *ptr.add(i) = 42_u8;
702 }
703 }
704 }
705
706 let buffer_2 = buf_vec.remove(1);
707 let buffer_1 = buf_vec.remove(0);
708
709 Ok(Self {
710 buffer_1,
711 buffer_2,
712 client_1: proc_mesh_1.client(),
713 client_2: proc_mesh_2.client(),
714 actor_1,
715 actor_2,
716 rdma_handle_1,
717 rdma_handle_2,
718 cuda_actor_1,
719 cuda_actor_2,
720 device_ptr_1,
721 device_ptr_2,
722 })
723 }
724
725 pub async fn cleanup(self) -> Result<(), anyhow::Error> {
726 self.actor_1
728 .release_buffer(self.client_1, self.rdma_handle_1.clone())
729 .await?;
730
731 self.actor_2
732 .release_buffer(self.client_2, self.rdma_handle_2.clone())
733 .await?;
734 Ok(())
735 }
736
737 pub async fn setup(
748 buffer_size: usize,
749 accel1: &str,
750 accel2: &str,
751 ) -> Result<Self, anyhow::Error> {
752 Self::setup_with_qp_type(
753 buffer_size,
754 accel1,
755 accel2,
756 crate::ibverbs_primitives::RdmaQpType::Auto,
757 )
758 .await
759 }
760
761 pub async fn verify_buffers(
762 &self,
763 size: usize,
764 offset: usize,
765 ) -> Result<(), anyhow::Error> {
766 let mut temp_buffer_1 = vec![0u8; size];
767 let mut temp_buffer_2 = vec![0u8; size];
768
769 if let Some(cuda_actor) = &self.cuda_actor_1 {
771 cuda_actor
772 .verify_buffer(
773 self.client_1,
774 temp_buffer_1.as_mut_ptr() as usize,
775 self.device_ptr_1.unwrap() + offset,
776 size,
777 )
778 .await?;
779 } else {
780 unsafe {
781 std::ptr::copy_nonoverlapping(
782 (self.buffer_1.ptr + offset as u64) as *const u8,
783 temp_buffer_1.as_mut_ptr(),
784 size,
785 );
786 }
787 }
788
789 if let Some(cuda_actor) = &self.cuda_actor_2 {
791 cuda_actor
792 .verify_buffer(
793 self.client_2,
794 temp_buffer_2.as_mut_ptr() as usize,
795 self.device_ptr_2.unwrap() + offset,
796 size,
797 )
798 .await?;
799 } else {
800 unsafe {
801 std::ptr::copy_nonoverlapping(
802 (self.buffer_2.ptr + offset as u64) as *const u8,
803 temp_buffer_2.as_mut_ptr(),
804 size,
805 );
806 }
807 }
808
809 for i in 0..size {
811 if temp_buffer_1[i] != temp_buffer_2[i] {
812 return Err(anyhow::anyhow!(
813 "Buffers are not equal at index {}",
814 offset + i
815 ));
816 }
817 }
818 Ok(())
819 }
820 }
821}