1use std::collections::HashMap;
44use std::ffi::CStr;
45use std::fs;
46use std::io::Error;
47use std::result::Result;
48use std::time::Duration;
49
50use hyperactor::ActorRef;
51use hyperactor::Mailbox;
52use hyperactor::Named;
53use hyperactor::clock::Clock;
54use hyperactor::clock::RealClock;
55use serde::Deserialize;
56use serde::Serialize;
57
58use crate::RdmaDevice;
59use crate::RdmaManagerActor;
60use crate::RdmaManagerMessageClient;
61use crate::ibverbs_primitives::Gid;
62use crate::ibverbs_primitives::IbvWc;
63use crate::ibverbs_primitives::IbverbsConfig;
64use crate::ibverbs_primitives::RdmaMemoryRegionView;
65use crate::ibverbs_primitives::RdmaOperation;
66use crate::ibverbs_primitives::RdmaQpInfo;
67
68#[derive(Debug, Named, Clone, Serialize, Deserialize)]
69pub struct DoorBell {
70 pub src_ptr: usize,
71 pub dst_ptr: usize,
72 pub size: usize,
73}
74
75impl DoorBell {
76 pub fn ring(&self) -> Result<(), anyhow::Error> {
85 unsafe {
86 let src_ptr = self.src_ptr as *mut std::ffi::c_void;
87 let dst_ptr = self.dst_ptr as *mut std::ffi::c_void;
88 rdmaxcel_sys::db_ring(dst_ptr, src_ptr);
89 Ok(())
90 }
91 }
92}
93
94#[derive(Debug, Serialize, Deserialize, Named, Clone)]
95pub struct RdmaBuffer {
96 pub owner: ActorRef<RdmaManagerActor>,
97 pub mr_id: u32,
98 pub lkey: u32,
99 pub rkey: u32,
100 pub addr: usize,
101 pub size: usize,
102}
103
104impl RdmaBuffer {
105 pub async fn read_into(
118 &self,
119 client: &Mailbox,
120 remote: RdmaBuffer,
121 timeout: u64,
122 ) -> Result<bool, anyhow::Error> {
123 tracing::debug!(
124 "[buffer] reading from {:?} into remote ({:?}) at {:?}",
125 self,
126 remote.owner.actor_id(),
127 remote,
128 );
129 let mut qp = self
130 .owner
131 .request_queue_pair(client, remote.owner.clone())
132 .await?;
133
134 qp.put(self.clone(), remote)?;
135 self.wait_for_completion(&mut qp, PollTarget::Send, timeout)
136 .await
137 }
138
139 pub async fn write_from(
153 &self,
154 client: &Mailbox,
155 remote: RdmaBuffer,
156 timeout: u64,
157 ) -> Result<bool, anyhow::Error> {
158 tracing::debug!(
159 "[buffer] writing into {:?} from remote ({:?}) at {:?}",
160 self,
161 remote.owner.actor_id(),
162 remote,
163 );
164 let mut qp = self
165 .owner
166 .request_queue_pair(client, remote.owner.clone())
167 .await?;
168 qp.get(self.clone(), remote)?;
169 self.wait_for_completion(&mut qp, PollTarget::Send, timeout)
170 .await
171 }
172 async fn wait_for_completion(
185 &self,
186 qp: &mut RdmaQueuePair,
187 poll_target: PollTarget,
188 timeout: u64,
189 ) -> Result<bool, anyhow::Error> {
190 let timeout = Duration::from_secs(timeout);
191 let start_time = std::time::Instant::now();
192
193 while start_time.elapsed() < timeout {
194 match qp.poll_completion_target(poll_target) {
195 Ok(Some(_wc)) => {
196 tracing::debug!("work completed");
197 return Ok(true);
198 }
199 Ok(None) => {
200 RealClock.sleep(Duration::from_millis(1)).await;
201 }
202 Err(e) => {
203 tracing::error!("polling completion failed: {}", e);
204 return Err(anyhow::anyhow!(e));
205 }
206 }
207 }
208 tracing::error!("timed out while waiting on request completion");
209 Err(anyhow::anyhow!(
210 "[buffer({:?})] rdma operation did not complete in time",
211 self
212 ))
213 }
214}
215
216pub struct RdmaDomain {
230 pub context: *mut rdmaxcel_sys::ibv_context,
231 pub pd: *mut rdmaxcel_sys::ibv_pd,
232 mr_map: HashMap<u32, *mut rdmaxcel_sys::ibv_mr>,
233 counter: u32,
234}
235
236impl std::fmt::Debug for RdmaDomain {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 f.debug_struct("RdmaDomain")
239 .field("context", &format!("{:p}", self.context))
240 .field("pd", &format!("{:p}", self.pd))
241 .field("mr", &format!("{:?}", self.mr_map))
242 .field("counter", &self.counter)
243 .finish()
244 }
245}
246
247unsafe impl Send for RdmaDomain {}
253
254unsafe impl Sync for RdmaDomain {}
258
259impl Drop for RdmaDomain {
260 fn drop(&mut self) {
261 unsafe {
262 rdmaxcel_sys::ibv_dealloc_pd(self.pd);
263 }
264 }
265}
266
267impl RdmaDomain {
268 pub fn new(device: RdmaDevice) -> Result<Self, anyhow::Error> {
295 tracing::debug!("creating RdmaDomain for device {}", device.name());
296 unsafe {
303 let device_name = device.name();
305 let mut num_devices = 0i32;
306 let devices = rdmaxcel_sys::ibv_get_device_list(&mut num_devices as *mut _);
307
308 if devices.is_null() || num_devices == 0 {
309 return Err(anyhow::anyhow!("no RDMA devices found"));
310 }
311
312 let mut device_ptr = std::ptr::null_mut();
314 for i in 0..num_devices {
315 let dev = *devices.offset(i as isize);
316 let dev_name =
317 CStr::from_ptr(rdmaxcel_sys::ibv_get_device_name(dev)).to_string_lossy();
318
319 if dev_name == *device_name {
320 device_ptr = dev;
321 break;
322 }
323 }
324
325 if device_ptr.is_null() {
327 rdmaxcel_sys::ibv_free_device_list(devices);
328 return Err(anyhow::anyhow!("device '{}' not found", device_name));
329 }
330 tracing::info!("using RDMA device: {}", device_name);
331
332 let context = rdmaxcel_sys::ibv_open_device(device_ptr);
334 if context.is_null() {
335 rdmaxcel_sys::ibv_free_device_list(devices);
336 let os_error = Error::last_os_error();
337 return Err(anyhow::anyhow!("failed to create context: {}", os_error));
338 }
339
340 let pd = rdmaxcel_sys::ibv_alloc_pd(context);
342 if pd.is_null() {
343 rdmaxcel_sys::ibv_close_device(context);
344 rdmaxcel_sys::ibv_free_device_list(devices);
345 let os_error = Error::last_os_error();
346 return Err(anyhow::anyhow!(
347 "failed to create protection domain (PD): {}",
348 os_error
349 ));
350 }
351
352 rdmaxcel_sys::ibv_free_device_list(devices);
354
355 Ok(RdmaDomain {
356 context,
357 pd,
358 mr_map: HashMap::new(),
359 counter: 0,
360 })
361 }
362 }
363
364 fn register_mr(
365 &mut self,
366 addr: usize,
367 size: usize,
368 ) -> Result<RdmaMemoryRegionView, anyhow::Error> {
369 unsafe {
370 let mut mem_type: i32 = 0;
371 let ptr = addr as cuda_sys::CUdeviceptr;
372 let err = cuda_sys::cuPointerGetAttribute(
373 &mut mem_type as *mut _ as *mut std::ffi::c_void,
374 cuda_sys::CUpointer_attribute_enum::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
375 ptr,
376 );
377 let is_cuda = err == cuda_sys::CUresult::CUDA_SUCCESS;
378
379 let access = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
380 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
381 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
382 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
383
384 let mr;
385 if is_cuda {
386 let mut fd: i32 = -1;
387 cuda_sys::cuMemGetHandleForAddressRange(
388 &mut fd as *mut i32 as *mut std::ffi::c_void,
389 addr as cuda_sys::CUdeviceptr,
390 size,
391 cuda_sys::CUmemRangeHandleType::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD,
392 0,
393 );
394 mr = rdmaxcel_sys::ibv_reg_dmabuf_mr(self.pd, 0, size, 0, fd, access.0 as i32);
395 } else {
396 mr = rdmaxcel_sys::ibv_reg_mr(
397 self.pd,
398 addr as *mut std::ffi::c_void,
399 size,
400 access.0 as i32,
401 );
402 }
403
404 if mr.is_null() {
405 return Err(anyhow::anyhow!("failed to register memory region (MR)"));
406 }
407 let id = self.counter;
408 self.mr_map.insert(id, mr);
409 self.counter += 1;
410
411 Ok(RdmaMemoryRegionView {
412 id,
413 addr: (*mr).addr as usize,
414 size: (*mr).length,
415 lkey: (*mr).lkey,
416 rkey: (*mr).rkey,
417 })
418 }
419 }
420
421 fn deregister_mr(&mut self, id: u32) -> Result<(), anyhow::Error> {
422 let mr = self.mr_map.remove(&id);
423 if mr.is_some() {
424 unsafe {
425 rdmaxcel_sys::ibv_dereg_mr(mr.expect("mr is required"));
426 }
427 }
428 Ok(())
429 }
430
431 pub fn register_buffer(
432 &mut self,
433 addr: usize,
434 size: usize,
435 ) -> Result<RdmaMemoryRegionView, anyhow::Error> {
436 let region_view = self.register_mr(addr, size)?;
437 Ok(region_view)
438 }
439
440 pub fn deregister_buffer(&mut self, buffer: RdmaBuffer) -> Result<(), anyhow::Error> {
443 self.deregister_mr(buffer.mr_id)?;
444 Ok(())
445 }
446}
447#[derive(Debug, Clone, Copy, PartialEq)]
449pub enum PollTarget {
450 Send,
451 Recv,
452}
453
454#[derive(Debug, Serialize, Deserialize, Named, Clone)]
481pub struct RdmaQueuePair {
482 pub send_cq: usize, pub recv_cq: usize, pub qp: usize, pub dv_qp: usize, pub dv_send_cq: usize, pub dv_recv_cq: usize, context: usize, config: IbverbsConfig,
490 pub send_wqe_idx: u32,
491 pub send_db_idx: u32,
492 pub send_cq_idx: u32,
493 pub recv_wqe_idx: u32,
494 pub recv_db_idx: u32,
495 pub recv_cq_idx: u32,
496}
497
498impl RdmaQueuePair {
499 pub fn new(
520 context: *mut rdmaxcel_sys::ibv_context,
521 pd: *mut rdmaxcel_sys::ibv_pd,
522 config: IbverbsConfig,
523 ) -> Result<Self, anyhow::Error> {
524 tracing::debug!("creating an RdmaQueuePair from config {}", config);
525 unsafe {
526 let qp = rdmaxcel_sys::create_qp(
528 context,
529 pd,
530 config.cq_entries,
531 config.max_send_wr.try_into().unwrap(),
532 config.max_recv_wr.try_into().unwrap(),
533 config.max_send_sge.try_into().unwrap(),
534 config.max_recv_sge.try_into().unwrap(),
535 );
536
537 if qp.is_null() {
538 let os_error = Error::last_os_error();
539 return Err(anyhow::anyhow!(
540 "failed to create queue pair (QP): {}",
541 os_error
542 ));
543 }
544
545 let send_cq = (*qp).send_cq;
546 let recv_cq = (*qp).recv_cq;
547
548 let dv_qp = rdmaxcel_sys::create_mlx5dv_qp(qp);
550 let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq(qp);
551 let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq(qp);
552
553 if dv_qp.is_null() || dv_send_cq.is_null() || dv_recv_cq.is_null() {
554 rdmaxcel_sys::ibv_destroy_cq((*qp).recv_cq);
555 rdmaxcel_sys::ibv_destroy_cq((*qp).send_cq);
556 rdmaxcel_sys::ibv_destroy_qp(qp);
557 return Err(anyhow::anyhow!(
558 "failed to init mlx5dv_qp or completion queues"
559 ));
560 }
561
562 if config.use_gpu_direct {
564 let ret = rdmaxcel_sys::register_cuda_memory(dv_qp, dv_recv_cq, dv_send_cq);
565 if ret != 0 {
566 rdmaxcel_sys::ibv_destroy_cq((*qp).recv_cq);
567 rdmaxcel_sys::ibv_destroy_cq((*qp).send_cq);
568 rdmaxcel_sys::ibv_destroy_qp(qp);
569 return Err(anyhow::anyhow!(
570 "failed to register GPU Direct RDMA memory: {:?}",
571 ret
572 ));
573 }
574 }
575
576 Ok(RdmaQueuePair {
577 send_cq: send_cq as usize,
578 recv_cq: recv_cq as usize,
579 qp: qp as usize,
580 dv_qp: dv_qp as usize,
581 dv_send_cq: dv_send_cq as usize,
582 dv_recv_cq: dv_recv_cq as usize,
583 context: context as usize,
584 config,
585 recv_db_idx: 0,
586 recv_wqe_idx: 0,
587 recv_cq_idx: 0,
588 send_db_idx: 0,
589 send_wqe_idx: 0,
590 send_cq_idx: 0,
591 })
592 }
593 }
594
595 pub fn get_qp_info(&mut self) -> Result<RdmaQpInfo, anyhow::Error> {
611 unsafe {
618 let context = self.context as *mut rdmaxcel_sys::ibv_context;
619 let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
620 let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
621 let errno = rdmaxcel_sys::ibv_query_port(
622 context,
623 self.config.port_num,
624 &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
625 );
626 if errno != 0 {
627 let os_error = Error::last_os_error();
628 return Err(anyhow::anyhow!(
629 "Failed to query port attributes: {}",
630 os_error
631 ));
632 }
633
634 let mut gid = Gid::default();
635 let ret = rdmaxcel_sys::ibv_query_gid(
636 context,
637 self.config.port_num,
638 i32::from(self.config.gid_index),
639 gid.as_mut(),
640 );
641 if ret != 0 {
642 return Err(anyhow::anyhow!("Failed to query GID"));
643 }
644
645 Ok(RdmaQpInfo {
646 qp_num: (*qp).qp_num,
647 lid: port_attr.lid,
648 gid: Some(gid),
649 psn: self.config.psn,
650 })
651 }
652 }
653
654 pub fn state(&mut self) -> Result<u32, anyhow::Error> {
655 unsafe {
657 let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
658 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
659 ..Default::default()
660 };
661 let mut qp_init_attr = rdmaxcel_sys::ibv_qp_init_attr {
662 ..Default::default()
663 };
664 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE;
665 let errno =
666 rdmaxcel_sys::ibv_query_qp(qp, &mut qp_attr, mask.0 as i32, &mut qp_init_attr);
667 if errno != 0 {
668 let os_error = Error::last_os_error();
669 return Err(anyhow::anyhow!("failed to query QP state: {}", os_error));
670 }
671 Ok(qp_attr.qp_state)
672 }
673 }
674 pub fn connect(&mut self, connection_info: &RdmaQpInfo) -> Result<(), anyhow::Error> {
682 unsafe {
690 let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
692
693 let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
694 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
695 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
696 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
697
698 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
699 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_INIT,
700 qp_access_flags: qp_access_flags.0,
701 pkey_index: self.config.pkey_index,
702 port_num: self.config.port_num,
703 ..Default::default()
704 };
705
706 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
707 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX
708 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT
709 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS;
710
711 let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32);
712 if errno != 0 {
713 let os_error = Error::last_os_error();
714 return Err(anyhow::anyhow!(
715 "failed to transition QP to INIT: {}",
716 os_error
717 ));
718 }
719
720 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
722 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTR,
723 path_mtu: self.config.path_mtu,
724 dest_qp_num: connection_info.qp_num,
725 rq_psn: connection_info.psn,
726 max_dest_rd_atomic: self.config.max_dest_rd_atomic,
727 min_rnr_timer: self.config.min_rnr_timer,
728 ah_attr: rdmaxcel_sys::ibv_ah_attr {
729 dlid: connection_info.lid,
730 sl: 0,
731 src_path_bits: 0,
732 port_num: self.config.port_num,
733 grh: Default::default(),
734 ..Default::default()
735 },
736 ..Default::default()
737 };
738
739 if let Some(gid) = connection_info.gid {
742 qp_attr.ah_attr.is_global = 1;
743 qp_attr.ah_attr.grh.dgid = gid.into();
744 qp_attr.ah_attr.grh.hop_limit = 0xff;
745 qp_attr.ah_attr.grh.sgid_index = self.config.gid_index;
746 } else {
747 qp_attr.ah_attr.is_global = 0;
749 }
750
751 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
752 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_AV
753 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PATH_MTU
754 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_DEST_QPN
755 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RQ_PSN
756 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC
757 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER;
758
759 let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32);
760 if errno != 0 {
761 let os_error = Error::last_os_error();
762 return Err(anyhow::anyhow!(
763 "failed to transition QP to RTR: {}",
764 os_error
765 ));
766 }
767
768 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
770 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS,
771 sq_psn: self.config.psn,
772 max_rd_atomic: self.config.max_rd_atomic,
773 retry_cnt: self.config.retry_cnt,
774 rnr_retry: self.config.rnr_retry,
775 timeout: self.config.qp_timeout,
776 ..Default::default()
777 };
778
779 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
780 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_TIMEOUT
781 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RETRY_CNT
782 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_SQ_PSN
783 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY
784 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC;
785
786 let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32);
787 if errno != 0 {
788 let os_error = Error::last_os_error();
789 return Err(anyhow::anyhow!(
790 "failed to transition QP to RTS: {}",
791 os_error
792 ));
793 }
794 tracing::debug!(
795 "connection sequence has successfully completed (qp: {:?})",
796 qp
797 );
798
799 Ok(())
800 }
801 }
802
803 pub fn recv(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> {
804 let idx = self.recv_wqe_idx;
805 self.recv_wqe_idx += 1;
806 self.send_wqe(
807 0,
808 lhandle.lkey,
809 0,
810 idx,
811 true,
812 RdmaOperation::Recv,
813 0,
814 rhandle.rkey,
815 )
816 .unwrap();
817 Ok(())
818 }
819
820 pub fn put_with_recv(
821 &mut self,
822 lhandle: RdmaBuffer,
823 rhandle: RdmaBuffer,
824 ) -> Result<(), anyhow::Error> {
825 let idx = self.send_wqe_idx;
826 self.send_wqe_idx += 1;
827 self.post_op(
828 lhandle.addr,
829 lhandle.lkey,
830 lhandle.size,
831 idx,
832 true,
833 RdmaOperation::WriteWithImm,
834 rhandle.addr,
835 rhandle.rkey,
836 )
837 .unwrap();
838 self.send_db_idx += 1;
839 Ok(())
840 }
841
842 pub fn put(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> {
843 let idx = self.send_wqe_idx;
844 self.send_wqe_idx += 1;
845 self.post_op(
846 lhandle.addr,
847 lhandle.lkey,
848 lhandle.size,
849 idx,
850 true,
851 RdmaOperation::Write,
852 rhandle.addr,
853 rhandle.rkey,
854 )
855 .unwrap();
856 self.send_db_idx += 1;
857 Ok(())
858 }
859
860 pub fn ring_doorbell(&mut self) -> Result<(), anyhow::Error> {
869 unsafe {
870 let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
871 let base_ptr = (*dv_qp).sq.buf as *mut u8;
872 let wqe_cnt = (*dv_qp).sq.wqe_cnt;
873 let stride = (*dv_qp).sq.stride;
874 if wqe_cnt < (self.send_wqe_idx - self.send_db_idx) {
875 return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
876 }
877 while self.send_db_idx < self.send_wqe_idx {
878 let offset = (self.send_db_idx % wqe_cnt) * stride;
879 let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
880 rdmaxcel_sys::db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
881 self.send_db_idx += 1;
882 }
883 Ok(())
884 }
885 }
886
887 pub fn enqueue_put(
901 &mut self,
902 lhandle: RdmaBuffer,
903 rhandle: RdmaBuffer,
904 ) -> Result<(), anyhow::Error> {
905 let idx = self.send_wqe_idx;
906 self.send_wqe_idx += 1;
907 self.send_wqe(
908 lhandle.addr,
909 lhandle.lkey,
910 lhandle.size,
911 idx,
912 true,
913 RdmaOperation::Write,
914 rhandle.addr,
915 rhandle.rkey,
916 )?;
917 Ok(())
918 }
919
920 pub fn enqueue_put_with_recv(
934 &mut self,
935 lhandle: RdmaBuffer,
936 rhandle: RdmaBuffer,
937 ) -> Result<(), anyhow::Error> {
938 let idx = self.send_wqe_idx;
939 self.send_wqe_idx += 1;
940 self.send_wqe(
941 lhandle.addr,
942 lhandle.lkey,
943 lhandle.size,
944 idx,
945 true,
946 RdmaOperation::WriteWithImm,
947 rhandle.addr,
948 rhandle.rkey,
949 )?;
950 Ok(())
951 }
952
953 pub fn enqueue_get(
967 &mut self,
968 lhandle: RdmaBuffer,
969 rhandle: RdmaBuffer,
970 ) -> Result<(), anyhow::Error> {
971 let idx = self.send_wqe_idx;
972 self.send_wqe_idx += 1;
973 self.send_wqe(
974 lhandle.addr,
975 lhandle.lkey,
976 lhandle.size,
977 idx,
978 true,
979 RdmaOperation::Read,
980 rhandle.addr,
981 rhandle.rkey,
982 )?;
983 Ok(())
984 }
985
986 pub fn get(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> {
987 let idx = self.send_wqe_idx;
988 self.send_wqe_idx += 1;
989 self.post_op(
990 lhandle.addr,
991 lhandle.lkey,
992 lhandle.size,
993 idx,
994 true,
995 RdmaOperation::Read,
996 rhandle.addr,
997 rhandle.rkey,
998 )
999 .unwrap();
1000 self.send_db_idx += 1;
1001 Ok(())
1002 }
1003
1004 fn post_op(
1016 &mut self,
1017 laddr: usize,
1018 lkey: u32,
1019 length: usize,
1020 wr_id: u32,
1021 signaled: bool,
1022 op_type: RdmaOperation,
1023 raddr: usize,
1024 rkey: u32,
1025 ) -> Result<(), anyhow::Error> {
1026 unsafe {
1034 let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
1035 let context = self.context as *mut rdmaxcel_sys::ibv_context;
1036 let ops = &mut (*context).ops;
1037 let errno;
1038 if op_type == RdmaOperation::Recv {
1039 let mut sge = rdmaxcel_sys::ibv_sge {
1040 addr: laddr as u64,
1041 length: length as u32,
1042 lkey,
1043 };
1044 let mut wr = rdmaxcel_sys::ibv_recv_wr {
1045 wr_id: wr_id.try_into().unwrap(),
1046 sg_list: &mut sge as *mut _,
1047 num_sge: 1,
1048 ..Default::default()
1049 };
1050 let mut bad_wr: *mut rdmaxcel_sys::ibv_recv_wr = std::ptr::null_mut();
1051 errno = ops.post_recv.as_mut().unwrap()(qp, &mut wr as *mut _, &mut bad_wr);
1052 } else if op_type == RdmaOperation::Write
1053 || op_type == RdmaOperation::Read
1054 || op_type == RdmaOperation::WriteWithImm
1055 {
1056 let send_flags = if signaled {
1057 rdmaxcel_sys::ibv_send_flags::IBV_SEND_SIGNALED.0
1058 } else {
1059 0
1060 };
1061 let mut sge = rdmaxcel_sys::ibv_sge {
1062 addr: laddr as u64,
1063 length: length as u32,
1064 lkey,
1065 };
1066 let mut wr = rdmaxcel_sys::ibv_send_wr {
1067 wr_id: wr_id.try_into().unwrap(),
1068 next: std::ptr::null_mut(),
1069 sg_list: &mut sge as *mut _,
1070 num_sge: 1,
1071 opcode: op_type.into(),
1072 send_flags,
1073 wr: Default::default(),
1074 qp_type: Default::default(),
1075 __bindgen_anon_1: Default::default(),
1076 __bindgen_anon_2: Default::default(),
1077 };
1078
1079 wr.wr.rdma.remote_addr = raddr as u64;
1080 wr.wr.rdma.rkey = rkey;
1081
1082 let mut bad_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut();
1083
1084 errno = ops.post_send.as_mut().unwrap()(qp, &mut wr as *mut _, &mut bad_wr);
1085 } else {
1086 panic!("Not Implemented");
1087 }
1088
1089 if errno != 0 {
1090 let os_error = Error::last_os_error();
1091 return Err(anyhow::anyhow!("Failed to post send request: {}", os_error));
1092 }
1093 tracing::debug!(
1094 "completed sending {:?} request (lkey: {}, addr: 0x{:x}, length {}) to (raddr 0x{:x}, rkey {})",
1095 op_type,
1096 lkey,
1097 laddr,
1098 length,
1099 raddr,
1100 rkey,
1101 );
1102
1103 Ok(())
1104 }
1105 }
1106
1107 fn send_wqe(
1108 &mut self,
1109 laddr: usize,
1110 lkey: u32,
1111 length: usize,
1112 wr_id: u32,
1113 signaled: bool,
1114 op_type: RdmaOperation,
1115 raddr: usize,
1116 rkey: u32,
1117 ) -> Result<DoorBell, anyhow::Error> {
1118 unsafe {
1119 let op_type_val = match op_type {
1120 RdmaOperation::Write => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE,
1121 RdmaOperation::WriteWithImm => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE_IMM,
1122 RdmaOperation::Read => rdmaxcel_sys::MLX5_OPCODE_RDMA_READ,
1123 RdmaOperation::Recv => 0,
1124 };
1125
1126 let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
1127 let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
1128 let _dv_cq = if op_type == RdmaOperation::Recv {
1129 self.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq
1130 } else {
1131 self.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq
1132 };
1133
1134 let buf = if op_type == RdmaOperation::Recv {
1137 (*dv_qp).rq.buf as *mut u8
1138 } else {
1139 (*dv_qp).sq.buf as *mut u8
1140 };
1141
1142 let params = rdmaxcel_sys::wqe_params_t {
1143 laddr,
1144 lkey,
1145 length,
1146 wr_id: wr_id.try_into().unwrap(),
1147 signaled,
1148 op_type: op_type_val,
1149 raddr,
1150 rkey,
1151 qp_num: (*qp).qp_num,
1152 buf,
1153 dbrec: (*dv_qp).dbrec,
1154 wqe_cnt: (*dv_qp).sq.wqe_cnt,
1155 };
1156
1157 if op_type == RdmaOperation::Recv {
1159 rdmaxcel_sys::recv_wqe(params);
1160 std::ptr::write_volatile((*dv_qp).dbrec, 1_u32.to_be());
1161 } else {
1162 rdmaxcel_sys::send_wqe(params);
1163 };
1164
1165 Ok(DoorBell {
1167 dst_ptr: (*dv_qp).bf.reg as usize,
1168 src_ptr: (*dv_qp).sq.buf as usize,
1169 size: 8,
1170 })
1171 }
1172 }
1173
1174 pub fn poll_completion_target(
1186 &mut self,
1187 target: PollTarget,
1188 ) -> Result<Option<IbvWc>, anyhow::Error> {
1189 unsafe {
1190 let context = self.context as *mut rdmaxcel_sys::ibv_context;
1191 let _outstanding_wqe =
1192 self.send_db_idx + self.recv_db_idx - self.send_cq_idx - self.recv_cq_idx;
1193
1194 if (target == PollTarget::Send) && self.send_db_idx > self.send_cq_idx {
1196 let send_cq = self.send_cq as *mut rdmaxcel_sys::ibv_cq;
1197 let ops = &mut (*context).ops;
1198 let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
1199 let ret = ops.poll_cq.as_mut().unwrap()(send_cq, 1, &mut wc);
1200
1201 if ret < 0 {
1202 return Err(anyhow::anyhow!(
1203 "Failed to poll send CQ: {}",
1204 Error::last_os_error()
1205 ));
1206 }
1207
1208 if ret > 0 {
1209 if !wc.is_valid() {
1210 if let Some((status, vendor_err)) = wc.error() {
1211 return Err(anyhow::anyhow!(
1212 "Send work completion failed with status: {:?}, vendor error: {}",
1213 status,
1214 vendor_err
1215 ));
1216 }
1217 }
1218
1219 self.send_cq_idx += 1;
1221
1222 return Ok(Some(IbvWc::from(wc)));
1223 }
1224 }
1225
1226 if (target == PollTarget::Recv) && self.recv_db_idx > self.recv_cq_idx {
1228 let recv_cq = self.recv_cq as *mut rdmaxcel_sys::ibv_cq;
1229 let ops = &mut (*context).ops;
1230 let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
1231 let ret = ops.poll_cq.as_mut().unwrap()(recv_cq, 1, &mut wc);
1232
1233 if ret < 0 {
1234 return Err(anyhow::anyhow!(
1235 "Failed to poll receive CQ: {}",
1236 Error::last_os_error()
1237 ));
1238 }
1239
1240 if ret > 0 {
1241 if !wc.is_valid() {
1242 if let Some((status, vendor_err)) = wc.error() {
1243 return Err(anyhow::anyhow!(
1244 "Receive work completion failed with status: {:?}, vendor error: {}",
1245 status,
1246 vendor_err
1247 ));
1248 }
1249 }
1250
1251 self.recv_cq_idx += 1;
1253
1254 return Ok(Some(IbvWc::from(wc)));
1255 }
1256 }
1257
1258 Ok(None)
1260 }
1261 }
1262
1263 pub fn poll_send_completion(&mut self) -> Result<Option<IbvWc>, anyhow::Error> {
1264 self.poll_completion_target(PollTarget::Send)
1265 }
1266
1267 pub fn poll_recv_completion(&mut self) -> Result<Option<IbvWc>, anyhow::Error> {
1268 self.poll_completion_target(PollTarget::Recv)
1269 }
1270}
1271
1272pub async fn validate_execution_context() -> Result<(), anyhow::Error> {
1284 match fs::read_to_string("/proc/modules") {
1286 Ok(contents) => {
1287 if !contents.contains("nvidia_peermem") {
1288 return Err(anyhow::anyhow!(
1289 "nvidia_peermem module not found in /proc/modules"
1290 ));
1291 }
1292 }
1293 Err(e) => {
1294 return Err(anyhow::anyhow!(e));
1295 }
1296 }
1297
1298 match fs::read_to_string("/proc/driver/nvidia/params") {
1300 Ok(contents) => {
1301 if !contents.contains("PeerMappingOverride=1") {
1302 return Err(anyhow::anyhow!(
1303 "PeerMappingOverride=1 not found in /proc/driver/nvidia/params"
1304 ));
1305 }
1306 }
1307 Err(e) => {
1308 return Err(anyhow::anyhow!(e));
1309 }
1310 }
1311 Ok(())
1312}
1313
1314#[cfg(test)]
1315mod tests {
1316 use super::*;
1317
1318 #[test]
1319 fn test_create_connection() {
1320 if crate::ibverbs_primitives::get_all_devices().len() < 1 {
1322 println!("Skipping test: RDMA devices not available");
1323 return;
1324 }
1325
1326 let config = IbverbsConfig {
1327 use_gpu_direct: false,
1328 ..Default::default()
1329 };
1330 let domain = RdmaDomain::new(config.device.clone());
1331 assert!(domain.is_ok());
1332
1333 let domain = domain.unwrap();
1334 let queue_pair = RdmaQueuePair::new(domain.context, domain.pd, config.clone());
1335 assert!(queue_pair.is_ok());
1336 }
1337
1338 #[test]
1339 fn test_loopback_connection() {
1340 if crate::ibverbs_primitives::get_all_devices().len() < 1 {
1342 println!("Skipping test: RDMA devices not available");
1343 return;
1344 }
1345
1346 let server_config = IbverbsConfig {
1347 use_gpu_direct: false,
1348 ..Default::default()
1349 };
1350 let client_config = IbverbsConfig {
1351 use_gpu_direct: false,
1352 ..Default::default()
1353 };
1354
1355 let server_domain = RdmaDomain::new(server_config.device.clone()).unwrap();
1356 let client_domain = RdmaDomain::new(client_config.device.clone()).unwrap();
1357
1358 let mut server_qp = RdmaQueuePair::new(
1359 server_domain.context,
1360 server_domain.pd,
1361 server_config.clone(),
1362 )
1363 .unwrap();
1364 let mut client_qp = RdmaQueuePair::new(
1365 client_domain.context,
1366 client_domain.pd,
1367 client_config.clone(),
1368 )
1369 .unwrap();
1370
1371 let server_connection_info = server_qp.get_qp_info().unwrap();
1372 let client_connection_info = client_qp.get_qp_info().unwrap();
1373
1374 assert!(server_qp.connect(&client_connection_info).is_ok());
1375 assert!(client_qp.connect(&server_connection_info).is_ok());
1376 }
1377}