1const MAX_RDMA_MSG_SIZE: usize = 1024 * 1024 * 1024;
17
18use std::io::Error;
19use std::result::Result;
20use std::time::Duration;
21
22use serde::Deserialize;
23use serde::Serialize;
24use typeuri::Named;
25
26use super::IbvBuffer;
27use super::primitives::Gid;
28use super::primitives::IbvConfig;
29use super::primitives::IbvOperation;
30use super::primitives::IbvQpInfo;
31use super::primitives::IbvWc;
32use super::primitives::resolve_qp_type;
33
34#[derive(Debug)]
39pub struct PollCompletionError {
40 pub status: Option<rdmaxcel_sys::ibv_wc_status::Type>,
41 pub vendor_err: Option<u32>,
42 message: String,
43}
44
45impl std::fmt::Display for PollCompletionError {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 f.write_str(&self.message)
48 }
49}
50
51impl std::error::Error for PollCompletionError {}
52
53impl PollCompletionError {
54 pub fn is_wr_flush_err(&self) -> bool {
58 self.status == Some(rdmaxcel_sys::ibv_wc_status::IBV_WC_WR_FLUSH_ERR)
59 }
60}
61
62#[derive(Debug, Named, Clone, Serialize, Deserialize)]
66pub struct DoorBell {
67 pub src_ptr: usize,
68 pub dst_ptr: usize,
69 pub size: usize,
70}
71wirevalue::register_type!(DoorBell);
72
73#[derive(Debug, Clone, Copy, PartialEq)]
75pub enum PollTarget {
76 Send,
77 Recv,
78}
79
80#[derive(Debug, Serialize, Deserialize, Named, Clone)]
100pub struct IbvQueuePair {
101 pub send_cq: usize, pub recv_cq: usize, pub qp: usize, pub dv_qp: usize, pub dv_send_cq: usize, pub dv_recv_cq: usize, context: usize, config: IbvConfig,
109 is_efa: bool,
110}
111wirevalue::register_type!(IbvQueuePair);
112
113impl IbvQueuePair {
114 fn is_efa(&self) -> bool {
115 self.is_efa
116 }
117
118 fn apply_first_op_delay(&self, wr_id: u64) {
120 unsafe {
121 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
122 if wr_id == 0 {
123 let rts_timestamp = rdmaxcel_sys::rdmaxcel_qp_load_rts_timestamp(qp);
124 assert!(
125 rts_timestamp != u64::MAX,
126 "First operation attempted before queue pair reached RTS state! Call connect() first."
127 );
128 let current_nanos = std::time::SystemTime::now()
129 .duration_since(std::time::UNIX_EPOCH)
130 .unwrap()
131 .as_nanos() as u64;
132 let elapsed_nanos = current_nanos - rts_timestamp;
133 let elapsed = Duration::from_nanos(elapsed_nanos);
134 let init_delay = Duration::from_millis(self.config.hw_init_delay_ms);
135 if elapsed < init_delay {
136 let remaining_delay = init_delay - elapsed;
137 std::thread::sleep(remaining_delay);
141 }
142 }
143 }
144 }
145
146 pub fn new(
156 context: *mut rdmaxcel_sys::ibv_context,
157 pd: *mut rdmaxcel_sys::ibv_pd,
158 config: IbvConfig,
159 ) -> Result<Self, anyhow::Error> {
160 tracing::debug!("creating an IbvQueuePair from config {}", config);
161 unsafe {
162 let resolved_qp_type = resolve_qp_type(config.qp_type);
164 let is_efa = resolved_qp_type == rdmaxcel_sys::RDMA_QP_TYPE_EFA;
165 let qp = rdmaxcel_sys::rdmaxcel_qp_create(
166 context,
167 pd,
168 config.cq_entries,
169 config.max_send_wr.try_into().unwrap(),
170 config.max_recv_wr.try_into().unwrap(),
171 config.max_send_sge.try_into().unwrap(),
172 config.max_recv_sge.try_into().unwrap(),
173 resolved_qp_type,
174 );
175
176 if qp.is_null() {
177 let os_error = Error::last_os_error();
178 return Err(anyhow::anyhow!(
179 "failed to create queue pair (QP): {}",
180 os_error
181 ));
182 }
183
184 let send_cq = (*(*qp).ibv_qp).send_cq;
185 let recv_cq = (*(*qp).ibv_qp).recv_cq;
186
187 if is_efa {
189 return Ok(IbvQueuePair {
190 send_cq: send_cq as usize,
191 recv_cq: recv_cq as usize,
192 qp: qp as usize,
193 dv_qp: 0,
194 dv_send_cq: 0,
195 dv_recv_cq: 0,
196 context: context as usize,
197 config,
198 is_efa: true,
199 });
200 }
201
202 let dv_qp = rdmaxcel_sys::create_mlx5dv_qp((*qp).ibv_qp);
203 let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq((*qp).ibv_qp);
204 let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq((*qp).ibv_qp);
205
206 if dv_qp.is_null() || dv_send_cq.is_null() || dv_recv_cq.is_null() {
207 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
208 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
209 rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
210 return Err(anyhow::anyhow!(
211 "failed to init mlx5dv_qp or completion queues"
212 ));
213 }
214
215 if config.use_gpu_direct {
216 let ret = rdmaxcel_sys::register_cuda_memory(dv_qp, dv_recv_cq, dv_send_cq);
217 if ret != 0 {
218 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
219 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
220 rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
221 return Err(anyhow::anyhow!(
222 "failed to register GPU Direct RDMA memory: {:?}",
223 ret
224 ));
225 }
226 }
227 Ok(IbvQueuePair {
228 send_cq: send_cq as usize,
229 recv_cq: recv_cq as usize,
230 qp: qp as usize,
231 dv_qp: dv_qp as usize,
232 dv_send_cq: dv_send_cq as usize,
233 dv_recv_cq: dv_recv_cq as usize,
234 context: context as usize,
235 config,
236 is_efa: false,
237 })
238 }
239 }
240
241 pub fn get_qp_info(&mut self) -> Result<IbvQpInfo, anyhow::Error> {
243 unsafe {
244 let context = self.context as *mut rdmaxcel_sys::ibv_context;
245 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
246 let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
247 let errno = rdmaxcel_sys::ibv_query_port(
248 context,
249 self.config.port_num,
250 &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
251 );
252 if errno != 0 {
253 let os_error = Error::last_os_error();
254 return Err(anyhow::anyhow!(
255 "Failed to query port attributes: {}",
256 os_error
257 ));
258 }
259
260 let mut gid = Gid::default();
261 let ret = rdmaxcel_sys::ibv_query_gid(
262 context,
263 self.config.port_num,
264 i32::from(self.config.gid_index),
265 gid.as_mut(),
266 );
267 if ret != 0 {
268 return Err(anyhow::anyhow!("Failed to query GID"));
269 }
270
271 Ok(IbvQpInfo {
272 qp_num: (*(*qp).ibv_qp).qp_num,
273 lid: port_attr.lid,
274 gid: Some(gid),
275 psn: self.config.psn,
276 })
277 }
278 }
279
280 pub fn state(&mut self) -> Result<u32, anyhow::Error> {
282 unsafe {
283 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
284 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
285 ..Default::default()
286 };
287 let mut qp_init_attr = rdmaxcel_sys::ibv_qp_init_attr {
288 ..Default::default()
289 };
290 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE;
291 let errno = rdmaxcel_sys::ibv_query_qp(
292 (*qp).ibv_qp,
293 &mut qp_attr,
294 mask.0 as i32,
295 &mut qp_init_attr,
296 );
297 if errno != 0 {
298 let os_error = Error::last_os_error();
299 return Err(anyhow::anyhow!("failed to query QP state: {}", os_error));
300 }
301 Ok(qp_attr.qp_state)
302 }
303 }
304
305 pub fn connect(&mut self, connection_info: &IbvQpInfo) -> Result<(), anyhow::Error> {
311 if self.is_efa() {
313 return self.efa_connect(connection_info);
314 }
315
316 unsafe {
317 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
318
319 let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
320 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
321 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
322 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
323
324 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
326 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_INIT,
327 qp_access_flags: qp_access_flags.0,
328 pkey_index: self.config.pkey_index,
329 port_num: self.config.port_num,
330 ..Default::default()
331 };
332
333 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
334 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX
335 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT
336 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS;
337
338 let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
339 if errno != 0 {
340 let os_error = Error::last_os_error();
341 return Err(anyhow::anyhow!(
342 "failed to transition QP to INIT: {}",
343 os_error
344 ));
345 }
346
347 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
349 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTR,
350 path_mtu: self.config.path_mtu,
351 dest_qp_num: connection_info.qp_num,
352 rq_psn: connection_info.psn,
353 max_dest_rd_atomic: self.config.max_dest_rd_atomic,
354 min_rnr_timer: self.config.min_rnr_timer,
355 ah_attr: rdmaxcel_sys::ibv_ah_attr {
356 dlid: connection_info.lid,
357 sl: 0,
358 src_path_bits: 0,
359 port_num: self.config.port_num,
360 grh: Default::default(),
361 ..Default::default()
362 },
363 ..Default::default()
364 };
365
366 if let Some(gid) = connection_info.gid {
367 qp_attr.ah_attr.is_global = 1;
368 qp_attr.ah_attr.grh.dgid = rdmaxcel_sys::ibv_gid::from(gid);
369 qp_attr.ah_attr.grh.hop_limit = 0xff;
370 qp_attr.ah_attr.grh.sgid_index = self.config.gid_index;
371 } else {
372 qp_attr.ah_attr.is_global = 0;
373 }
374
375 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
376 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_AV
377 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PATH_MTU
378 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_DEST_QPN
379 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RQ_PSN
380 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC
381 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER;
382
383 let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
384 if errno != 0 {
385 let os_error = Error::last_os_error();
386 return Err(anyhow::anyhow!(
387 "failed to transition QP to RTR: {}",
388 os_error
389 ));
390 }
391
392 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
394 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS,
395 sq_psn: self.config.psn,
396 max_rd_atomic: self.config.max_rd_atomic,
397 retry_cnt: self.config.retry_cnt,
398 rnr_retry: self.config.rnr_retry,
399 timeout: self.config.qp_timeout,
400 ..Default::default()
401 };
402
403 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
404 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_TIMEOUT
405 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RETRY_CNT
406 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_SQ_PSN
407 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY
408 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC;
409
410 let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
411 if errno != 0 {
412 let os_error = Error::last_os_error();
413 return Err(anyhow::anyhow!(
414 "failed to transition QP to RTS: {}",
415 os_error
416 ));
417 }
418 tracing::debug!(
419 "connection sequence has successfully completed (qp: {:?})",
420 qp
421 );
422
423 let rts_timestamp_nanos = std::time::SystemTime::now()
424 .duration_since(std::time::UNIX_EPOCH)
425 .unwrap()
426 .as_nanos() as u64;
427 rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos);
428
429 Ok(())
430 }
431 }
432
433 fn efa_connect(&mut self, connection_info: &IbvQpInfo) -> Result<(), anyhow::Error> {
435 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
436
437 let gid_ptr = connection_info.gid.as_ref().map_or(std::ptr::null(), |g| {
438 let ibv_gid: &rdmaxcel_sys::ibv_gid = g.as_ref();
439 unsafe { ibv_gid.raw.as_ptr() }
440 });
441
442 unsafe {
443 let ret = rdmaxcel_sys::rdmaxcel_efa_connect(
444 qp,
445 self.config.port_num,
446 self.config.pkey_index,
447 0x4242, self.config.psn,
449 self.config.gid_index,
450 gid_ptr,
451 connection_info.qp_num,
452 );
453 if ret != 0 {
454 let msg = std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
455 .to_str()
456 .unwrap_or("unknown");
457 return Err(anyhow::anyhow!("EFA connect failed: {}", msg));
458 }
459 }
460
461 let rts_timestamp_nanos = std::time::SystemTime::now()
463 .duration_since(std::time::UNIX_EPOCH)
464 .unwrap()
465 .as_nanos() as u64;
466 unsafe {
467 rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos);
468 }
469
470 Ok(())
471 }
472
473 pub fn recv(&mut self, lhandle: IbvBuffer, rhandle: IbvBuffer) -> Result<u64, anyhow::Error> {
474 unsafe {
475 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
476 let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(qp);
477 self.post_op(
478 0,
479 lhandle.lkey,
480 0,
481 idx,
482 true,
483 IbvOperation::Recv,
484 0,
485 rhandle.rkey,
486 )
487 .unwrap();
488 rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(qp);
489 Ok(idx)
490 }
491 }
492
493 pub fn put_with_recv(
494 &mut self,
495 lhandle: IbvBuffer,
496 rhandle: IbvBuffer,
497 ) -> Result<Vec<u64>, anyhow::Error> {
498 unsafe {
499 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
500 let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(qp);
501 self.post_op(
502 lhandle.addr,
503 lhandle.lkey,
504 lhandle.size,
505 idx,
506 true,
507 IbvOperation::WriteWithImm,
508 rhandle.addr,
509 rhandle.rkey,
510 )
511 .unwrap();
512 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(qp);
513 Ok(vec![idx])
514 }
515 }
516
517 pub fn put(
518 &mut self,
519 lhandle: IbvBuffer,
520 rhandle: IbvBuffer,
521 ) -> Result<Vec<u64>, anyhow::Error> {
522 let total_size = lhandle.size;
523 if rhandle.size < total_size {
524 return Err(anyhow::anyhow!(
525 "Remote buffer size ({}) is smaller than local buffer size ({})",
526 rhandle.size,
527 total_size
528 ));
529 }
530
531 let mut remaining = total_size;
532 let mut offset = 0;
533 let mut wr_ids = Vec::new();
534 while remaining > 0 {
535 let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
536 let idx = unsafe {
537 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
538 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
539 )
540 };
541 wr_ids.push(idx);
542 self.post_op(
543 lhandle.addr + offset,
544 lhandle.lkey,
545 chunk_size,
546 idx,
547 true,
548 IbvOperation::Write,
549 rhandle.addr + offset,
550 rhandle.rkey,
551 )?;
552 unsafe {
553 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
554 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
555 );
556 }
557
558 remaining -= chunk_size;
559 offset += chunk_size;
560 }
561
562 Ok(wr_ids)
563 }
564
565 pub fn ring_doorbell(&mut self) -> Result<(), anyhow::Error> {
567 if self.is_efa() {
569 return Ok(());
570 }
571
572 unsafe {
573 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
574 let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
575 let base_ptr = (*dv_qp).sq.buf as *mut u8;
576 let wqe_cnt = (*dv_qp).sq.wqe_cnt;
577 let stride = (*dv_qp).sq.stride;
578 let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(qp);
579 let mut send_db_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp);
580 if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) {
581 return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
582 }
583 self.apply_first_op_delay(send_db_idx);
584 while send_db_idx < send_wqe_idx {
585 let offset = (send_db_idx % wqe_cnt as u64) * stride as u64;
586 let src_ptr = base_ptr.wrapping_add(offset as usize);
587 rdmaxcel_sys::db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
588 send_db_idx += 1;
589 rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx(qp, send_db_idx);
590 }
591 Ok(())
592 }
593 }
594
595 pub fn enqueue_put(
597 &mut self,
598 lhandle: IbvBuffer,
599 rhandle: IbvBuffer,
600 ) -> Result<Vec<u64>, anyhow::Error> {
601 let idx = unsafe {
602 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
603 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
604 )
605 };
606
607 self.send_wqe(
608 lhandle.addr,
609 lhandle.lkey,
610 lhandle.size,
611 idx,
612 true,
613 IbvOperation::Write,
614 rhandle.addr,
615 rhandle.rkey,
616 )?;
617 Ok(vec![idx])
618 }
619
620 pub fn enqueue_put_with_recv(
622 &mut self,
623 lhandle: IbvBuffer,
624 rhandle: IbvBuffer,
625 ) -> Result<Vec<u64>, anyhow::Error> {
626 let idx = unsafe {
627 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
628 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
629 )
630 };
631
632 self.send_wqe(
633 lhandle.addr,
634 lhandle.lkey,
635 lhandle.size,
636 idx,
637 true,
638 IbvOperation::WriteWithImm,
639 rhandle.addr,
640 rhandle.rkey,
641 )?;
642 Ok(vec![idx])
643 }
644
645 pub fn enqueue_get(
647 &mut self,
648 lhandle: IbvBuffer,
649 rhandle: IbvBuffer,
650 ) -> Result<Vec<u64>, anyhow::Error> {
651 let idx = unsafe {
652 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
653 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
654 )
655 };
656
657 self.send_wqe(
658 lhandle.addr,
659 lhandle.lkey,
660 lhandle.size,
661 idx,
662 true,
663 IbvOperation::Read,
664 rhandle.addr,
665 rhandle.rkey,
666 )?;
667 Ok(vec![idx])
668 }
669
670 pub fn get(
671 &mut self,
672 lhandle: IbvBuffer,
673 rhandle: IbvBuffer,
674 ) -> Result<Vec<u64>, anyhow::Error> {
675 let total_size = lhandle.size;
676 if rhandle.size < total_size {
677 return Err(anyhow::anyhow!(
678 "Remote buffer size ({}) is smaller than local buffer size ({})",
679 rhandle.size,
680 total_size
681 ));
682 }
683
684 let mut remaining = total_size;
685 let mut offset = 0;
686 let mut wr_ids = Vec::new();
687
688 while remaining > 0 {
689 let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
690 let idx = unsafe {
691 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
692 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
693 )
694 };
695 wr_ids.push(idx);
696
697 self.post_op(
698 lhandle.addr + offset,
699 lhandle.lkey,
700 chunk_size,
701 idx,
702 true,
703 IbvOperation::Read,
704 rhandle.addr + offset,
705 rhandle.rkey,
706 )?;
707 unsafe {
708 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
709 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
710 );
711 }
712
713 remaining -= chunk_size;
714 offset += chunk_size;
715 }
716
717 Ok(wr_ids)
718 }
719
720 fn post_op(
722 &mut self,
723 laddr: usize,
724 lkey: u32,
725 length: usize,
726 wr_id: u64,
727 signaled: bool,
728 op_type: IbvOperation,
729 raddr: usize,
730 rkey: u32,
731 ) -> Result<(), anyhow::Error> {
732 if self.is_efa() {
734 return self.post_op_efa(laddr, lkey, length, wr_id, signaled, op_type, raddr, rkey);
735 }
736
737 unsafe {
738 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
739 let context = self.context as *mut rdmaxcel_sys::ibv_context;
740 let ops = &mut (*context).ops;
741 let errno;
742 if op_type == IbvOperation::Recv {
743 let mut sge = rdmaxcel_sys::ibv_sge {
744 addr: laddr as u64,
745 length: length as u32,
746 lkey,
747 };
748 let mut wr = rdmaxcel_sys::ibv_recv_wr {
749 wr_id,
750 sg_list: &mut sge as *mut _,
751 num_sge: 1,
752 ..Default::default()
753 };
754 let mut bad_wr: *mut rdmaxcel_sys::ibv_recv_wr = std::ptr::null_mut();
755 errno =
756 ops.post_recv.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
757 } else if op_type == IbvOperation::Write
758 || op_type == IbvOperation::Read
759 || op_type == IbvOperation::WriteWithImm
760 {
761 self.apply_first_op_delay(wr_id);
762 let send_flags = if signaled {
763 rdmaxcel_sys::ibv_send_flags::IBV_SEND_SIGNALED.0
764 } else {
765 0
766 };
767 let mut sge = rdmaxcel_sys::ibv_sge {
768 addr: laddr as u64,
769 length: length as u32,
770 lkey,
771 };
772 let mut wr = rdmaxcel_sys::ibv_send_wr {
773 wr_id,
774 next: std::ptr::null_mut(),
775 sg_list: &mut sge as *mut _,
776 num_sge: 1,
777 opcode: op_type.into(),
778 send_flags,
779 wr: Default::default(),
780 qp_type: Default::default(),
781 __bindgen_anon_1: Default::default(),
782 __bindgen_anon_2: Default::default(),
783 };
784
785 wr.wr.rdma.remote_addr = raddr as u64;
786 wr.wr.rdma.rkey = rkey;
787 let mut bad_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut();
788
789 errno =
790 ops.post_send.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
791 } else {
792 panic!("Not Implemented");
793 }
794
795 if errno != 0 {
796 let os_error = Error::last_os_error();
797 return Err(anyhow::anyhow!("Failed to post send request: {}", os_error));
798 }
799 tracing::debug!(
800 "completed sending {:?} request (lkey: {}, addr: 0x{:x}, length {}) to (raddr 0x{:x}, rkey {})",
801 op_type,
802 lkey,
803 laddr,
804 length,
805 raddr,
806 rkey,
807 );
808
809 Ok(())
810 }
811 }
812
813 fn post_op_efa(
815 &mut self,
816 laddr: usize,
817 lkey: u32,
818 length: usize,
819 wr_id: u64,
820 signaled: bool,
821 op_type: IbvOperation,
822 raddr: usize,
823 rkey: u32,
824 ) -> Result<(), anyhow::Error> {
825 let c_op = match op_type {
826 IbvOperation::Write => 0,
827 IbvOperation::Read => 1,
828 IbvOperation::Recv => 2,
829 IbvOperation::WriteWithImm => 3,
830 };
831
832 unsafe {
833 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
834 let ret = rdmaxcel_sys::rdmaxcel_qp_post_op(
835 qp,
836 laddr as *mut std::ffi::c_void,
837 lkey,
838 length,
839 raddr as *mut std::ffi::c_void,
840 rkey,
841 wr_id,
842 signaled as i32,
843 c_op,
844 );
845 if ret != 0 {
846 let msg = std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
847 .to_str()
848 .unwrap_or("unknown");
849 return Err(anyhow::anyhow!("EFA post_op failed: {}", msg));
850 }
851 }
852 Ok(())
853 }
854
855 fn send_wqe(
856 &mut self,
857 laddr: usize,
858 lkey: u32,
859 length: usize,
860 wr_id: u64,
861 signaled: bool,
862 op_type: IbvOperation,
863 raddr: usize,
864 rkey: u32,
865 ) -> Result<DoorBell, anyhow::Error> {
866 if self.is_efa() {
868 self.post_op(laddr, lkey, length, wr_id, signaled, op_type, raddr, rkey)?;
869 return Ok(DoorBell {
870 dst_ptr: 0,
871 src_ptr: 0,
872 size: 0,
873 });
874 }
875
876 unsafe {
877 let op_type_val = match op_type {
878 IbvOperation::Write => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE,
879 IbvOperation::WriteWithImm => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE_IMM,
880 IbvOperation::Read => rdmaxcel_sys::MLX5_OPCODE_RDMA_READ,
881 IbvOperation::Recv => 0,
882 };
883
884 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
885 let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
886 let _dv_cq = if op_type == IbvOperation::Recv {
887 self.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq
888 } else {
889 self.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq
890 };
891
892 let buf = if op_type == IbvOperation::Recv {
893 (*dv_qp).rq.buf as *mut u8
894 } else {
895 (*dv_qp).sq.buf as *mut u8
896 };
897
898 let params = rdmaxcel_sys::wqe_params_t {
899 laddr,
900 lkey,
901 length,
902 wr_id,
903 signaled,
904 op_type: op_type_val,
905 raddr,
906 rkey,
907 qp_num: (*(*qp).ibv_qp).qp_num,
908 buf,
909 dbrec: (*dv_qp).dbrec,
910 wqe_cnt: (*dv_qp).sq.wqe_cnt,
911 };
912
913 if op_type == IbvOperation::Recv {
914 rdmaxcel_sys::recv_wqe(params);
915 std::ptr::write_volatile((*dv_qp).dbrec, 1_u32.to_be());
916 } else {
917 rdmaxcel_sys::send_wqe(params);
918 };
919
920 Ok(DoorBell {
921 dst_ptr: (*dv_qp).bf.reg as usize,
922 src_ptr: (*dv_qp).sq.buf as usize,
923 size: 8,
924 })
925 }
926 }
927
928 pub fn poll_completion(
940 &mut self,
941 target: PollTarget,
942 expected_wr_ids: &[u64],
943 ) -> Result<Vec<(u64, IbvWc)>, PollCompletionError> {
944 if expected_wr_ids.is_empty() {
945 return Ok(Vec::new());
946 }
947
948 unsafe {
949 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
950 let qp_num = (*(*qp).ibv_qp).qp_num;
951
952 let (cq, cache, cq_type) = match target {
953 PollTarget::Send => (
954 self.send_cq as *mut rdmaxcel_sys::ibv_cq,
955 rdmaxcel_sys::rdmaxcel_qp_get_send_cache(qp),
956 "send",
957 ),
958 PollTarget::Recv => (
959 self.recv_cq as *mut rdmaxcel_sys::ibv_cq,
960 rdmaxcel_sys::rdmaxcel_qp_get_recv_cache(qp),
961 "recv",
962 ),
963 };
964
965 let mut results = Vec::new();
966
967 for &expected_wr_id in expected_wr_ids {
968 let mut poll_ctx = rdmaxcel_sys::poll_context_t {
969 expected_wr_id,
970 expected_qp_num: qp_num,
971 cache,
972 cq,
973 };
974
975 let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
976 let ret = rdmaxcel_sys::poll_cq_with_cache(&mut poll_ctx, &mut wc);
977
978 match ret {
979 1 => {
980 if !wc.is_valid() {
981 if let Some((status, vendor_err)) = wc.error() {
982 return Err(PollCompletionError {
983 status: Some(status),
984 vendor_err: Some(vendor_err),
985 message: format!(
986 "{} completion failed for wr_id={}: status={:?}, vendor_err={}",
987 cq_type, expected_wr_id, status, vendor_err,
988 ),
989 });
990 }
991 }
992 results.push((expected_wr_id, IbvWc::from(wc)));
993 }
994 0 => {
995 }
997 -17 => {
998 let error_msg =
999 std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1000 .to_str()
1001 .unwrap_or("Unknown error");
1002 if let Some((status, vendor_err)) = wc.error() {
1003 return Err(PollCompletionError {
1004 status: Some(status),
1005 vendor_err: Some(vendor_err),
1006 message: format!(
1007 "Failed to poll {} CQ for wr_id={}: {} [status={:?}, vendor_err={}, qp_num={}, byte_len={}]",
1008 cq_type,
1009 expected_wr_id,
1010 error_msg,
1011 status,
1012 vendor_err,
1013 wc.qp_num,
1014 wc.len(),
1015 ),
1016 });
1017 } else {
1018 return Err(PollCompletionError {
1019 status: None,
1020 vendor_err: None,
1021 message: format!(
1022 "Failed to poll {} CQ for wr_id={}: {} [qp_num={}, byte_len={}]",
1023 cq_type,
1024 expected_wr_id,
1025 error_msg,
1026 wc.qp_num,
1027 wc.len(),
1028 ),
1029 });
1030 }
1031 }
1032 _ => {
1033 let error_msg =
1034 std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1035 .to_str()
1036 .unwrap_or("Unknown error");
1037 return Err(PollCompletionError {
1038 status: None,
1039 vendor_err: None,
1040 message: format!(
1041 "Failed to poll {} CQ for wr_id={}: {}",
1042 cq_type, expected_wr_id, error_msg,
1043 ),
1044 });
1045 }
1046 }
1047 }
1048
1049 Ok(results)
1050 }
1051 }
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056 use super::*;
1057 use crate::backend::ibverbs::domain::IbvDomain;
1058 use crate::backend::ibverbs::primitives::IbvConfig;
1059 use crate::backend::ibverbs::primitives::get_all_devices;
1060
1061 #[test]
1062 fn test_create_connection() {
1063 if get_all_devices().is_empty() {
1064 println!("Skipping test: RDMA devices not available");
1065 return;
1066 }
1067
1068 let config = IbvConfig {
1069 use_gpu_direct: false,
1070 ..Default::default()
1071 };
1072 let domain = IbvDomain::new(config.device.clone());
1073 assert!(domain.is_ok());
1074
1075 let domain = domain.unwrap();
1076 let queue_pair = IbvQueuePair::new(domain.context, domain.pd, config.clone());
1077 assert!(queue_pair.is_ok());
1078 }
1079
1080 #[test]
1081 fn test_loopback_connection() {
1082 if get_all_devices().is_empty() {
1083 println!("Skipping test: RDMA devices not available");
1084 return;
1085 }
1086
1087 let server_config = IbvConfig {
1088 use_gpu_direct: false,
1089 ..Default::default()
1090 };
1091 let client_config = IbvConfig {
1092 use_gpu_direct: false,
1093 ..Default::default()
1094 };
1095
1096 let server_domain = IbvDomain::new(server_config.device.clone()).unwrap();
1097 let client_domain = IbvDomain::new(client_config.device.clone()).unwrap();
1098
1099 let mut server_qp = IbvQueuePair::new(
1100 server_domain.context,
1101 server_domain.pd,
1102 server_config.clone(),
1103 )
1104 .unwrap();
1105 let mut client_qp = IbvQueuePair::new(
1106 client_domain.context,
1107 client_domain.pd,
1108 client_config.clone(),
1109 )
1110 .unwrap();
1111
1112 let server_connection_info = server_qp.get_qp_info().unwrap();
1113 let client_connection_info = client_qp.get_qp_info().unwrap();
1114
1115 assert!(server_qp.connect(&client_connection_info).is_ok());
1116 assert!(client_qp.connect(&server_connection_info).is_ok());
1117 }
1118}