1use std::sync::Once;
10use std::sync::atomic::AtomicBool;
11use std::sync::atomic::Ordering;
12
13static CUDA_AVAILABLE: AtomicBool = AtomicBool::new(false);
15static INIT: Once = Once::new();
16
17pub fn is_cuda_available() -> bool {
38 INIT.call_once(|| {
39 let available = check_cuda_available();
40 CUDA_AVAILABLE.store(available, Ordering::SeqCst);
41 });
42 CUDA_AVAILABLE.load(Ordering::SeqCst)
43}
44
45fn check_cuda_available() -> bool {
47 unsafe {
48 let result = rdmaxcel_sys::rdmaxcel_cuInit(0);
50
51 if result != rdmaxcel_sys::CUDA_SUCCESS {
52 return false;
53 }
54
55 let mut device_count: i32 = 0;
57 let count_result = rdmaxcel_sys::rdmaxcel_cuDeviceGetCount(&mut device_count);
58
59 if count_result != rdmaxcel_sys::CUDA_SUCCESS || device_count <= 0 {
60 return false;
61 }
62
63 let mut device: rdmaxcel_sys::CUdevice = std::mem::zeroed();
65 let device_result = rdmaxcel_sys::rdmaxcel_cuDeviceGet(&mut device, 0);
66
67 if device_result != rdmaxcel_sys::CUDA_SUCCESS {
68 return false;
69 }
70
71 true
72 }
73}
74
75#[cfg(test)]
76pub mod test_utils {
77 use std::time::Duration;
78 use std::time::Instant;
79
80 use hyperactor::Actor;
81 use hyperactor::ActorRef;
82 use hyperactor::Context;
83 use hyperactor::HandleClient;
84 use hyperactor::Handler;
85 use hyperactor::Instance;
86 use hyperactor::RefClient;
87 use hyperactor::RemoteSpawn;
88 use hyperactor::channel::ChannelTransport;
89 use hyperactor::clock::Clock;
90 use hyperactor::clock::RealClock;
91 use hyperactor_mesh::Mesh;
92 use hyperactor_mesh::ProcMesh;
93 use hyperactor_mesh::RootActorMesh;
94 use hyperactor_mesh::alloc::AllocSpec;
95 use hyperactor_mesh::alloc::Allocator;
96 use hyperactor_mesh::alloc::LocalAllocator;
97 use hyperactor_mesh::proc_mesh::global_root_client;
98 use ndslice::extent;
99
100 use crate::IbverbsConfig;
101 use crate::RdmaBuffer;
102 use crate::cu_check;
103 use crate::rdma_components::PollTarget;
104 use crate::rdma_components::RdmaQueuePair;
105 use crate::rdma_manager_actor::RdmaManagerActor;
106 use crate::rdma_manager_actor::RdmaManagerMessageClient;
107 use crate::validate_execution_context;
108
109 #[derive(Debug)]
110 struct SendSyncCudaContext(rdmaxcel_sys::CUcontext);
111 unsafe impl Send for SendSyncCudaContext {}
112 unsafe impl Sync for SendSyncCudaContext {}
113
114 #[hyperactor::export(
117 spawn = true,
118 handlers = [
119 CudaActorMessage,
120 ],
121 )]
122 #[derive(Debug)]
123 pub struct CudaActor {
124 device: Option<i32>,
125 context: SendSyncCudaContext,
126 }
127
128 impl Actor for CudaActor {}
129
130 #[async_trait::async_trait]
131 impl RemoteSpawn for CudaActor {
132 type Params = i32;
133
134 async fn new(device_id: i32) -> Result<Self, anyhow::Error> {
135 unsafe {
136 cu_check!(rdmaxcel_sys::rdmaxcel_cuInit(0));
137 let mut device: rdmaxcel_sys::CUdevice = std::mem::zeroed();
138 cu_check!(rdmaxcel_sys::rdmaxcel_cuDeviceGet(&mut device, device_id));
139 let mut context: rdmaxcel_sys::CUcontext = std::mem::zeroed();
140 cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxCreate_v2(
141 &mut context,
142 0,
143 device_id
144 ));
145
146 Ok(Self {
147 device: Some(device),
148 context: SendSyncCudaContext(context),
149 })
150 }
151 }
152 }
153
154 #[derive(
155 Handler,
156 HandleClient,
157 RefClient,
158 typeuri::Named,
159 serde::Serialize,
160 serde::Deserialize,
161 Debug
162 )]
163 pub enum CudaActorMessage {
164 CreateBuffer {
165 size: usize,
166 rdma_actor: ActorRef<RdmaManagerActor>,
167 #[reply]
168 reply: hyperactor::OncePortRef<(RdmaBuffer, usize)>,
169 },
170 FillBuffer {
171 device_ptr: usize,
172 size: usize,
173 value: u8,
174 #[reply]
175 reply: hyperactor::OncePortRef<()>,
176 },
177 VerifyBuffer {
178 cpu_buffer_ptr: usize,
179 device_ptr: usize,
180 size: usize,
181 #[reply]
182 reply: hyperactor::OncePortRef<()>,
183 },
184 }
185
186 #[async_trait::async_trait]
187 impl Handler<CudaActorMessage> for CudaActor {
188 async fn handle(
189 &mut self,
190 cx: &Context<Self>,
191 msg: CudaActorMessage,
192 ) -> Result<(), anyhow::Error> {
193 match msg {
194 CudaActorMessage::CreateBuffer {
195 size,
196 rdma_actor,
197 reply,
198 } => {
199 let device = self
200 .device
201 .ok_or_else(|| anyhow::anyhow!("Device not initialized"))?;
202
203 let (dptr, padded_size) = unsafe {
204 cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.context.0));
205
206 let mut dptr: rdmaxcel_sys::CUdeviceptr = std::mem::zeroed();
207 let mut handle: rdmaxcel_sys::CUmemGenericAllocationHandle =
208 std::mem::zeroed();
209
210 let mut granularity: usize = 0;
211 let mut prop: rdmaxcel_sys::CUmemAllocationProp = std::mem::zeroed();
212 prop.type_ = rdmaxcel_sys::CU_MEM_ALLOCATION_TYPE_PINNED;
213 prop.location.type_ = rdmaxcel_sys::CU_MEM_LOCATION_TYPE_DEVICE;
214 prop.location.id = device;
215 prop.allocFlags.gpuDirectRDMACapable = 1;
216 prop.requestedHandleTypes =
217 rdmaxcel_sys::CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
218
219 cu_check!(rdmaxcel_sys::rdmaxcel_cuMemGetAllocationGranularity(
220 &mut granularity as *mut usize,
221 &prop,
222 rdmaxcel_sys::CU_MEM_ALLOC_GRANULARITY_MINIMUM,
223 ));
224
225 let padded_size: usize = ((size - 1) / granularity + 1) * granularity;
226
227 cu_check!(rdmaxcel_sys::rdmaxcel_cuMemCreate(
228 &mut handle,
229 padded_size,
230 &prop,
231 0
232 ));
233
234 cu_check!(rdmaxcel_sys::rdmaxcel_cuMemAddressReserve(
235 &mut dptr,
236 padded_size,
237 0,
238 0,
239 0,
240 ));
241
242 assert!((dptr as usize).is_multiple_of(granularity));
243 assert!(padded_size.is_multiple_of(granularity));
244
245 let err = rdmaxcel_sys::rdmaxcel_cuMemMap(dptr, padded_size, 0, handle, 0);
246 if err != rdmaxcel_sys::CUDA_SUCCESS {
247 return Err(anyhow::anyhow!("Failed to map CUDA memory: {:?}", err));
248 }
249
250 let mut access_desc: rdmaxcel_sys::CUmemAccessDesc = std::mem::zeroed();
251 access_desc.location.type_ = rdmaxcel_sys::CU_MEM_LOCATION_TYPE_DEVICE;
252 access_desc.location.id = device;
253 access_desc.flags = rdmaxcel_sys::CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
254 cu_check!(rdmaxcel_sys::rdmaxcel_cuMemSetAccess(
255 dptr,
256 padded_size,
257 &access_desc,
258 1
259 ));
260
261 (dptr, padded_size)
262 };
263
264 let rdma_handle = rdma_actor
265 .request_buffer(cx, dptr as usize, padded_size)
266 .await?;
267
268 reply.send(cx, (rdma_handle, dptr as usize))?;
269 Ok(())
270 }
271 CudaActorMessage::FillBuffer {
272 device_ptr,
273 size,
274 value,
275 reply,
276 } => {
277 unsafe {
278 cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.context.0));
279
280 cu_check!(rdmaxcel_sys::rdmaxcel_cuMemsetD8_v2(
281 device_ptr as rdmaxcel_sys::CUdeviceptr,
282 value,
283 size
284 ));
285 }
286
287 reply.send(cx, ())?;
288 Ok(())
289 }
290 CudaActorMessage::VerifyBuffer {
291 cpu_buffer_ptr,
292 device_ptr,
293 size,
294 reply,
295 } => {
296 unsafe {
297 cu_check!(rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.context.0));
298
299 cu_check!(rdmaxcel_sys::rdmaxcel_cuMemcpyDtoH_v2(
300 cpu_buffer_ptr as *mut std::ffi::c_void,
301 device_ptr as rdmaxcel_sys::CUdeviceptr,
302 size
303 ));
304 }
305
306 reply.send(cx, ())?;
307 Ok(())
308 }
309 }
310 }
311 }
312
313 pub async fn wait_for_completion(
329 qp: &mut RdmaQueuePair,
330 poll_target: PollTarget,
331 expected_wr_ids: &[u64],
332 timeout_secs: u64,
333 ) -> Result<bool, anyhow::Error> {
334 let timeout = Duration::from_secs(timeout_secs);
335 let start_time = Instant::now();
336
337 let mut remaining: std::collections::HashSet<u64> =
338 expected_wr_ids.iter().copied().collect();
339
340 while start_time.elapsed() < timeout {
341 if remaining.is_empty() {
342 return Ok(true);
343 }
344
345 let wr_ids_to_poll: Vec<u64> = remaining.iter().copied().collect();
346 match qp.poll_completion(poll_target, &wr_ids_to_poll) {
347 Ok(completions) => {
348 for (wr_id, _wc) in completions {
349 remaining.remove(&wr_id);
350 }
351 if remaining.is_empty() {
352 return Ok(true);
353 }
354 RealClock.sleep(Duration::from_millis(1)).await;
355 }
356 Err(e) => {
357 return Err(anyhow::anyhow!(e));
358 }
359 }
360 }
361 Err(anyhow::Error::msg(format!(
362 "Timeout while waiting for completion of wr_ids: {:?}",
363 remaining
364 )))
365 }
366
367 pub async fn send_wqe_gpu(
369 qp: &mut RdmaQueuePair,
370 lhandle: &RdmaBuffer,
371 rhandle: &RdmaBuffer,
372 op_type: u32,
373 ) -> Result<(), anyhow::Error> {
374 unsafe {
375 let ibv_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
376 let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
377 let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(ibv_qp);
378 let params = rdmaxcel_sys::wqe_params_t {
379 laddr: lhandle.addr,
380 length: lhandle.size,
381 lkey: lhandle.lkey,
382 wr_id: send_wqe_idx,
383 signaled: true,
384 op_type,
385 raddr: rhandle.addr,
386 rkey: rhandle.rkey,
387 qp_num: (*(*ibv_qp).ibv_qp).qp_num,
388 buf: (*dv_qp).sq.buf as *mut u8,
389 wqe_cnt: (*dv_qp).sq.wqe_cnt,
390 dbrec: (*dv_qp).dbrec,
391 ..Default::default()
392 };
393 rdmaxcel_sys::launch_send_wqe(params);
394 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(ibv_qp);
395 }
396 Ok(())
397 }
398
399 pub async fn recv_wqe_gpu(
401 qp: &mut RdmaQueuePair,
402 lhandle: &RdmaBuffer,
403 _rhandle: &RdmaBuffer,
404 op_type: u32,
405 ) -> Result<(), anyhow::Error> {
406 unsafe {
408 let rdmaxcel_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
409 let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
410 let recv_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_recv_wqe_idx(rdmaxcel_qp);
411 let params = rdmaxcel_sys::wqe_params_t {
412 laddr: lhandle.addr,
413 length: lhandle.size,
414 lkey: lhandle.lkey,
415 wr_id: recv_wqe_idx,
416 op_type,
417 signaled: true,
418 qp_num: (*(*rdmaxcel_qp).ibv_qp).qp_num,
419 buf: (*dv_qp).rq.buf as *mut u8,
420 wqe_cnt: (*dv_qp).rq.wqe_cnt,
421 dbrec: (*dv_qp).dbrec,
422 ..Default::default()
423 };
424 rdmaxcel_sys::launch_recv_wqe(params);
425 rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(rdmaxcel_qp);
426 rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(rdmaxcel_qp);
427 }
428 Ok(())
429 }
430
431 pub async fn ring_db_gpu(qp: &RdmaQueuePair) -> Result<(), anyhow::Error> {
432 RealClock.sleep(Duration::from_millis(2)).await;
433 unsafe {
434 let dv_qp = qp.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
435 let base_ptr = (*dv_qp).sq.buf as *mut u8;
436 let wqe_cnt = (*dv_qp).sq.wqe_cnt;
437 let stride = (*dv_qp).sq.stride;
438 let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(
439 qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
440 );
441 let mut send_db_idx =
442 rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp);
443 if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) {
444 return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
445 }
446 while send_db_idx < send_wqe_idx {
447 let offset = (send_db_idx % wqe_cnt as u64) * stride as u64;
448 let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
449 rdmaxcel_sys::launch_db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
450 send_db_idx += 1;
451 rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx(
452 qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
453 send_db_idx,
454 );
455 }
456 }
457 Ok(())
458 }
459
460 pub async fn wait_for_completion_gpu(
462 qp: &mut RdmaQueuePair,
463 poll_target: PollTarget,
464 timeout_secs: u64,
465 ) -> Result<bool, anyhow::Error> {
466 unsafe {
467 let start_time = Instant::now();
468 let timeout = Duration::from_secs(timeout_secs);
469 let ibv_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
470
471 while start_time.elapsed() < timeout {
472 let (cq, idx, cq_type_str) = match poll_target {
474 PollTarget::Send => (
475 qp.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq,
476 rdmaxcel_sys::rdmaxcel_qp_load_send_cq_idx(ibv_qp),
477 "send",
478 ),
479 PollTarget::Recv => (
480 qp.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq,
481 rdmaxcel_sys::rdmaxcel_qp_load_recv_cq_idx(ibv_qp),
482 "receive",
483 ),
484 };
485
486 let result = rdmaxcel_sys::launch_cqe_poll(cq as *mut std::ffi::c_void, idx as i32);
488
489 match result {
490 rdmaxcel_sys::CQE_POLL_TRUE => {
491 match poll_target {
493 PollTarget::Send => {
494 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_cq_idx(ibv_qp);
495 }
496 PollTarget::Recv => {
497 rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_cq_idx(ibv_qp);
498 }
499 }
500 return Ok(true);
501 }
502 rdmaxcel_sys::CQE_POLL_ERROR => {
503 return Err(anyhow::anyhow!("Error polling {} completion", cq_type_str));
504 }
505 _ => {
506 RealClock.sleep(Duration::from_millis(1)).await;
508 }
509 }
510 }
511
512 Err(anyhow::Error::msg("Timeout while waiting for completion"))
513 }
514 }
515
516 pub struct RdmaManagerTestEnv<'a> {
517 buffer_1: Buffer,
518 buffer_2: Buffer,
519 pub client_1: &'a Instance<()>,
520 pub client_2: &'a Instance<()>,
521 pub actor_1: ActorRef<RdmaManagerActor>,
522 pub actor_2: ActorRef<RdmaManagerActor>,
523 pub rdma_handle_1: RdmaBuffer,
524 pub rdma_handle_2: RdmaBuffer,
525 cuda_actor_1: Option<ActorRef<CudaActor>>,
526 cuda_actor_2: Option<ActorRef<CudaActor>>,
527 device_ptr_1: Option<usize>,
528 device_ptr_2: Option<usize>,
529 }
530
531 #[derive(Debug, Clone)]
532 pub struct Buffer {
533 ptr: u64,
534 len: usize,
535 #[allow(dead_code)]
536 cpu_ref: Option<Box<[u8]>>,
537 }
538 async fn parse_accel(accel: &str, config: &mut IbverbsConfig) -> (String, usize) {
540 let (backend, idx) = accel.split_once(':').unwrap();
541 let parsed_idx = idx.parse::<usize>().unwrap();
542
543 if backend == "cuda" {
544 config.use_gpu_direct = validate_execution_context().await.is_ok();
545 eprintln!("Using GPU Direct: {}", config.use_gpu_direct);
546 }
547
548 (backend.to_string(), parsed_idx)
549 }
550
551 impl RdmaManagerTestEnv<'_> {
552 pub async fn setup_with_qp_type(
565 buffer_size: usize,
566 accel1: &str,
567 accel2: &str,
568 qp_type: crate::ibverbs_primitives::RdmaQpType,
569 ) -> Result<Self, anyhow::Error> {
570 let mut config1 = IbverbsConfig::targeting(accel1);
572 let mut config2 = IbverbsConfig::targeting(accel2);
573
574 config1.qp_type = qp_type;
576 config2.qp_type = qp_type;
577
578 let parsed_accel1 = parse_accel(accel1, &mut config1).await;
579 let parsed_accel2 = parse_accel(accel2, &mut config2).await;
580
581 let alloc_1 = LocalAllocator
582 .allocate(AllocSpec {
583 extent: extent! { proc = 1 },
584 constraints: Default::default(),
585 proc_name: None,
586 transport: ChannelTransport::Local,
587 proc_allocation_mode: Default::default(),
588 })
589 .await
590 .unwrap();
591
592 let instance = global_root_client();
593
594 let proc_mesh_1 = Box::leak(Box::new(ProcMesh::allocate(alloc_1).await.unwrap()));
595 let actor_mesh_1: RootActorMesh<'_, RdmaManagerActor> = proc_mesh_1
596 .spawn(&instance, "rdma_manager", &Some(config1))
597 .await
598 .unwrap();
599
600 let alloc_2 = LocalAllocator
601 .allocate(AllocSpec {
602 extent: extent! { proc = 1 },
603 constraints: Default::default(),
604 proc_name: None,
605 transport: ChannelTransport::Local,
606 proc_allocation_mode: Default::default(),
607 })
608 .await
609 .unwrap();
610
611 let proc_mesh_2 = Box::leak(Box::new(ProcMesh::allocate(alloc_2).await.unwrap()));
612 let actor_mesh_2: RootActorMesh<'_, RdmaManagerActor> = proc_mesh_2
613 .spawn(&instance, "rdma_manager", &Some(config2))
614 .await
615 .unwrap();
616
617 let actor_1 = actor_mesh_1.get(0).unwrap();
618 let actor_2 = actor_mesh_2.get(0).unwrap();
619
620 let mut buf_vec = Vec::new();
621 let mut cuda_actor_1 = None;
622 let mut cuda_actor_2 = None;
623 let mut device_ptr_1: Option<usize> = None;
624 let mut device_ptr_2: Option<usize> = None;
625
626 let rdma_handle_1;
627 let rdma_handle_2;
628
629 if parsed_accel1.0 == "cpu" {
631 let mut buffer = vec![0u8; buffer_size].into_boxed_slice();
632 buf_vec.push(Buffer {
633 ptr: buffer.as_mut_ptr() as u64,
634 len: buffer.len(),
635 cpu_ref: Some(buffer),
636 });
637 rdma_handle_1 = actor_1
638 .request_buffer(proc_mesh_1.client(), buf_vec[0].ptr as usize, buffer_size)
639 .await?;
640 } else {
641 let cuda_actor_mesh_1: RootActorMesh<'_, CudaActor> = proc_mesh_1
643 .spawn(&instance, "cuda_init", &(parsed_accel1.1 as i32))
644 .await?;
645 let cuda_actor_ref_1 = cuda_actor_mesh_1.get(0).unwrap();
646
647 let (rdma_buf, dev_ptr) = cuda_actor_ref_1
648 .create_buffer(proc_mesh_1.client(), buffer_size, actor_1.clone())
649 .await?;
650 rdma_handle_1 = rdma_buf;
651 device_ptr_1 = Some(dev_ptr);
652
653 buf_vec.push(Buffer {
654 ptr: rdma_handle_1.addr as u64,
655 len: buffer_size,
656 cpu_ref: None,
657 });
658 cuda_actor_1 = Some(cuda_actor_ref_1);
659 }
660
661 if parsed_accel2.0 == "cpu" {
663 let mut buffer = vec![0u8; buffer_size].into_boxed_slice();
664 buf_vec.push(Buffer {
665 ptr: buffer.as_mut_ptr() as u64,
666 len: buffer.len(),
667 cpu_ref: Some(buffer),
668 });
669 rdma_handle_2 = actor_2
670 .request_buffer(proc_mesh_2.client(), buf_vec[1].ptr as usize, buffer_size)
671 .await?;
672 } else {
673 let cuda_actor_mesh_2: RootActorMesh<'_, CudaActor> = proc_mesh_2
675 .spawn(&instance, "cuda_init", &(parsed_accel2.1 as i32))
676 .await?;
677 let cuda_actor_ref_2 = cuda_actor_mesh_2.get(0).unwrap();
678
679 let (rdma_buf, dev_ptr) = cuda_actor_ref_2
680 .create_buffer(proc_mesh_2.client(), buffer_size, actor_2.clone())
681 .await?;
682 rdma_handle_2 = rdma_buf;
683 device_ptr_2 = Some(dev_ptr);
684
685 buf_vec.push(Buffer {
686 ptr: rdma_handle_2.addr as u64,
687 len: buffer_size,
688 cpu_ref: None,
689 });
690 cuda_actor_2 = Some(cuda_actor_ref_2);
691 }
692
693 if parsed_accel1.0 == "cuda" {
695 cuda_actor_1
696 .clone()
697 .unwrap()
698 .fill_buffer(proc_mesh_1.client(), device_ptr_1.unwrap(), buffer_size, 42)
699 .await?;
700 } else {
701 unsafe {
702 let ptr = buf_vec[0].ptr as *mut u8;
703 for i in 0..buf_vec[0].len {
704 *ptr.add(i) = 42_u8;
705 }
706 }
707 }
708
709 let buffer_2 = buf_vec.remove(1);
710 let buffer_1 = buf_vec.remove(0);
711
712 Ok(Self {
713 buffer_1,
714 buffer_2,
715 client_1: proc_mesh_1.client(),
716 client_2: proc_mesh_2.client(),
717 actor_1,
718 actor_2,
719 rdma_handle_1,
720 rdma_handle_2,
721 cuda_actor_1,
722 cuda_actor_2,
723 device_ptr_1,
724 device_ptr_2,
725 })
726 }
727
728 pub async fn cleanup(self) -> Result<(), anyhow::Error> {
729 self.actor_1
731 .release_buffer(self.client_1, self.rdma_handle_1.clone())
732 .await?;
733
734 self.actor_2
735 .release_buffer(self.client_2, self.rdma_handle_2.clone())
736 .await?;
737 Ok(())
738 }
739
740 pub async fn setup(
751 buffer_size: usize,
752 accel1: &str,
753 accel2: &str,
754 ) -> Result<Self, anyhow::Error> {
755 Self::setup_with_qp_type(
756 buffer_size,
757 accel1,
758 accel2,
759 crate::ibverbs_primitives::RdmaQpType::Auto,
760 )
761 .await
762 }
763
764 pub async fn verify_buffers(
765 &self,
766 size: usize,
767 offset: usize,
768 ) -> Result<(), anyhow::Error> {
769 let mut temp_buffer_1 = vec![0u8; size];
770 let mut temp_buffer_2 = vec![0u8; size];
771
772 if let Some(cuda_actor) = &self.cuda_actor_1 {
774 cuda_actor
775 .verify_buffer(
776 self.client_1,
777 temp_buffer_1.as_mut_ptr() as usize,
778 self.device_ptr_1.unwrap() + offset,
779 size,
780 )
781 .await?;
782 } else {
783 unsafe {
784 std::ptr::copy_nonoverlapping(
785 (self.buffer_1.ptr + offset as u64) as *const u8,
786 temp_buffer_1.as_mut_ptr(),
787 size,
788 );
789 }
790 }
791
792 if let Some(cuda_actor) = &self.cuda_actor_2 {
794 cuda_actor
795 .verify_buffer(
796 self.client_2,
797 temp_buffer_2.as_mut_ptr() as usize,
798 self.device_ptr_2.unwrap() + offset,
799 size,
800 )
801 .await?;
802 } else {
803 unsafe {
804 std::ptr::copy_nonoverlapping(
805 (self.buffer_2.ptr + offset as u64) as *const u8,
806 temp_buffer_2.as_mut_ptr(),
807 size,
808 );
809 }
810 }
811
812 for i in 0..size {
814 if temp_buffer_1[i] != temp_buffer_2[i] {
815 return Err(anyhow::anyhow!(
816 "Buffers are not equal at index {}",
817 offset + i
818 ));
819 }
820 }
821 Ok(())
822 }
823 }
824}