monarch_rdma/backend/ibverbs/
queue_pair.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! ibverbs queue pair, doorbell, and completion polling.
10//!
11//! An [`IbvQueuePair`] encapsulates the send and receive queues, completion
12//! queues, and other resources needed for RDMA communication. It provides
13//! methods for establishing connections and performing RDMA operations.
14
15/// Maximum size for a single RDMA operation in bytes (1 GiB).
16const MAX_RDMA_MSG_SIZE: usize = 1024 * 1024 * 1024;
17
18use std::io::Error;
19use std::result::Result;
20use std::time::Duration;
21
22use serde::Deserialize;
23use serde::Serialize;
24use typeuri::Named;
25
26use super::IbvBuffer;
27use super::primitives::Gid;
28use super::primitives::IbvConfig;
29use super::primitives::IbvOperation;
30use super::primitives::IbvQpInfo;
31use super::primitives::IbvWc;
32use super::primitives::resolve_qp_type;
33
34/// A structured error from [`IbvQueuePair::poll_completion`].
35///
36/// Carries the `ibv_wc_status` and vendor error code (when available) so
37/// callers can match on specific completion statuses without string parsing.
38#[derive(Debug)]
39pub struct PollCompletionError {
40    pub status: Option<rdmaxcel_sys::ibv_wc_status::Type>,
41    pub vendor_err: Option<u32>,
42    message: String,
43}
44
45impl std::fmt::Display for PollCompletionError {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        f.write_str(&self.message)
48    }
49}
50
51impl std::error::Error for PollCompletionError {}
52
53impl PollCompletionError {
54    /// Returns `true` when the completion status is `IBV_WC_WR_FLUSH_ERR`,
55    /// which typically indicates a secondary failure after the QP entered
56    /// error state due to a different work request's failure.
57    pub fn is_wr_flush_err(&self) -> bool {
58        self.status == Some(rdmaxcel_sys::ibv_wc_status::IBV_WC_WR_FLUSH_ERR)
59    }
60}
61
62/// A doorbell trigger for batched RDMA operations.
63///
64/// Rings the hardware doorbell to execute previously enqueued work requests.
65#[derive(Debug, Named, Clone, Serialize, Deserialize)]
66pub struct DoorBell {
67    pub src_ptr: usize,
68    pub dst_ptr: usize,
69    pub size: usize,
70}
71wirevalue::register_type!(DoorBell);
72
73/// Specifies which completion queue to poll.
74#[derive(Debug, Clone, Copy, PartialEq)]
75pub enum PollTarget {
76    Send,
77    Recv,
78}
79
80/// An RDMA Queue Pair (QP) for communication between two endpoints.
81///
82/// Encapsulates the send/receive queues, completion queues, and mlx5dv
83/// device-specific structures needed for RDMA communication.
84///
85/// # Connection Lifecycle
86///
87/// 1. Create with `new()` from context and protection domain pointers
88/// 2. Get connection info with `get_qp_info()`
89/// 3. Exchange connection info with remote peer
90/// 4. Connect to remote endpoint with `connect()`
91/// 5. Perform RDMA operations with `put()` or `get()`
92/// 6. Poll for completions with `poll_completion()`
93///
94/// # Notes
95/// - The `qp` field stores a pointer to `rdmaxcel_qp_t` (not `ibv_qp`)
96/// - `rdmaxcel_qp_t` contains atomic counters and completion caches internally
97/// - This makes IbvQueuePair trivially Clone and Serialize
98/// - Multiple clones share the same underlying rdmaxcel_qp_t via the pointer
99#[derive(Debug, Serialize, Deserialize, Named, Clone)]
100pub struct IbvQueuePair {
101    pub send_cq: usize,    // *mut rdmaxcel_sys::ibv_cq,
102    pub recv_cq: usize,    // *mut rdmaxcel_sys::ibv_cq,
103    pub qp: usize,         // *mut rdmaxcel_sys::rdmaxcel_qp_t
104    pub dv_qp: usize,      // *mut rdmaxcel_sys::mlx5dv_qp,
105    pub dv_send_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
106    pub dv_recv_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
107    context: usize,        // *mut rdmaxcel_sys::ibv_context,
108    config: IbvConfig,
109    is_efa: bool,
110}
111wirevalue::register_type!(IbvQueuePair);
112
113impl IbvQueuePair {
114    fn is_efa(&self) -> bool {
115        self.is_efa
116    }
117
118    /// Applies hardware initialization delay if this is the first operation since RTS.
119    fn apply_first_op_delay(&self, wr_id: u64) {
120        unsafe {
121            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
122            if wr_id == 0 {
123                let rts_timestamp = rdmaxcel_sys::rdmaxcel_qp_load_rts_timestamp(qp);
124                assert!(
125                    rts_timestamp != u64::MAX,
126                    "First operation attempted before queue pair reached RTS state! Call connect() first."
127                );
128                let current_nanos = std::time::SystemTime::now()
129                    .duration_since(std::time::UNIX_EPOCH)
130                    .unwrap()
131                    .as_nanos() as u64;
132                let elapsed_nanos = current_nanos - rts_timestamp;
133                let elapsed = Duration::from_nanos(elapsed_nanos);
134                let init_delay = Duration::from_millis(self.config.hw_init_delay_ms);
135                if elapsed < init_delay {
136                    let remaining_delay = init_delay - elapsed;
137                    // Sync context within unsafe block; tokio::time::sleep is async
138                    // and converting would require propagating async through the
139                    // entire post_op / ring_doorbell call chain.
140                    std::thread::sleep(remaining_delay);
141                }
142            }
143        }
144    }
145
146    /// Creates a new IbvQueuePair.
147    ///
148    /// Initializes a new Queue Pair (QP) and associated Completion Queues (CQ)
149    /// using the provided context and protection domain. The QP is created in
150    /// the RESET state and must be transitioned via `connect()` before use.
151    ///
152    /// # Errors
153    ///
154    /// Returns errors if CQ or QP creation fails.
155    pub fn new(
156        context: *mut rdmaxcel_sys::ibv_context,
157        pd: *mut rdmaxcel_sys::ibv_pd,
158        config: IbvConfig,
159    ) -> Result<Self, anyhow::Error> {
160        tracing::debug!("creating an IbvQueuePair from config {}", config);
161        unsafe {
162            // Resolve Auto to a concrete QP type based on device capabilities
163            let resolved_qp_type = resolve_qp_type(config.qp_type);
164            let is_efa = resolved_qp_type == rdmaxcel_sys::RDMA_QP_TYPE_EFA;
165            let qp = rdmaxcel_sys::rdmaxcel_qp_create(
166                context,
167                pd,
168                config.cq_entries,
169                config.max_send_wr.try_into().unwrap(),
170                config.max_recv_wr.try_into().unwrap(),
171                config.max_send_sge.try_into().unwrap(),
172                config.max_recv_sge.try_into().unwrap(),
173                resolved_qp_type,
174            );
175
176            if qp.is_null() {
177                let os_error = Error::last_os_error();
178                return Err(anyhow::anyhow!(
179                    "failed to create queue pair (QP): {}",
180                    os_error
181                ));
182            }
183
184            let send_cq = (*(*qp).ibv_qp).send_cq;
185            let recv_cq = (*(*qp).ibv_qp).recv_cq;
186
187            // EFA uses standard ibverbs (not mlx5dv), so skip dv setup
188            if is_efa {
189                return Ok(IbvQueuePair {
190                    send_cq: send_cq as usize,
191                    recv_cq: recv_cq as usize,
192                    qp: qp as usize,
193                    dv_qp: 0,
194                    dv_send_cq: 0,
195                    dv_recv_cq: 0,
196                    context: context as usize,
197                    config,
198                    is_efa: true,
199                });
200            }
201
202            let dv_qp = rdmaxcel_sys::create_mlx5dv_qp((*qp).ibv_qp);
203            let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq((*qp).ibv_qp);
204            let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq((*qp).ibv_qp);
205
206            if dv_qp.is_null() || dv_send_cq.is_null() || dv_recv_cq.is_null() {
207                rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
208                rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
209                rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
210                return Err(anyhow::anyhow!(
211                    "failed to init mlx5dv_qp or completion queues"
212                ));
213            }
214
215            if config.use_gpu_direct {
216                let ret = rdmaxcel_sys::register_cuda_memory(dv_qp, dv_recv_cq, dv_send_cq);
217                if ret != 0 {
218                    rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
219                    rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
220                    rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
221                    return Err(anyhow::anyhow!(
222                        "failed to register GPU Direct RDMA memory: {:?}",
223                        ret
224                    ));
225                }
226            }
227            Ok(IbvQueuePair {
228                send_cq: send_cq as usize,
229                recv_cq: recv_cq as usize,
230                qp: qp as usize,
231                dv_qp: dv_qp as usize,
232                dv_send_cq: dv_send_cq as usize,
233                dv_recv_cq: dv_recv_cq as usize,
234                context: context as usize,
235                config,
236                is_efa: false,
237            })
238        }
239    }
240
241    /// Returns the connection info needed by a remote peer to connect to this QP.
242    pub fn get_qp_info(&mut self) -> Result<IbvQpInfo, anyhow::Error> {
243        unsafe {
244            let context = self.context as *mut rdmaxcel_sys::ibv_context;
245            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
246            let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
247            let errno = rdmaxcel_sys::ibv_query_port(
248                context,
249                self.config.port_num,
250                &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
251            );
252            if errno != 0 {
253                let os_error = Error::last_os_error();
254                return Err(anyhow::anyhow!(
255                    "Failed to query port attributes: {}",
256                    os_error
257                ));
258            }
259
260            let mut gid = Gid::default();
261            let ret = rdmaxcel_sys::ibv_query_gid(
262                context,
263                self.config.port_num,
264                i32::from(self.config.gid_index),
265                gid.as_mut(),
266            );
267            if ret != 0 {
268                return Err(anyhow::anyhow!("Failed to query GID"));
269            }
270
271            Ok(IbvQpInfo {
272                qp_num: (*(*qp).ibv_qp).qp_num,
273                lid: port_attr.lid,
274                gid: Some(gid),
275                psn: self.config.psn,
276            })
277        }
278    }
279
280    /// Returns the current state of the QP.
281    pub fn state(&mut self) -> Result<u32, anyhow::Error> {
282        unsafe {
283            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
284            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
285                ..Default::default()
286            };
287            let mut qp_init_attr = rdmaxcel_sys::ibv_qp_init_attr {
288                ..Default::default()
289            };
290            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE;
291            let errno = rdmaxcel_sys::ibv_query_qp(
292                (*qp).ibv_qp,
293                &mut qp_attr,
294                mask.0 as i32,
295                &mut qp_init_attr,
296            );
297            if errno != 0 {
298                let os_error = Error::last_os_error();
299                return Err(anyhow::anyhow!("failed to query QP state: {}", os_error));
300            }
301            Ok(qp_attr.qp_state)
302        }
303    }
304
305    /// Transitions the QP through INIT -> RTR -> RTS to establish a connection.
306    ///
307    /// # Arguments
308    ///
309    /// * `connection_info` - The remote connection info to connect to
310    pub fn connect(&mut self, connection_info: &IbvQpInfo) -> Result<(), anyhow::Error> {
311        // EFA: use unified C function for QP state transitions
312        if self.is_efa() {
313            return self.efa_connect(connection_info);
314        }
315
316        unsafe {
317            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
318
319            let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
320                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
321                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
322                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
323
324            // Transition to INIT
325            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
326                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_INIT,
327                qp_access_flags: qp_access_flags.0,
328                pkey_index: self.config.pkey_index,
329                port_num: self.config.port_num,
330                ..Default::default()
331            };
332
333            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
334                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX
335                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT
336                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS;
337
338            let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
339            if errno != 0 {
340                let os_error = Error::last_os_error();
341                return Err(anyhow::anyhow!(
342                    "failed to transition QP to INIT: {}",
343                    os_error
344                ));
345            }
346
347            // Transition to RTR (Ready to Receive)
348            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
349                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTR,
350                path_mtu: self.config.path_mtu,
351                dest_qp_num: connection_info.qp_num,
352                rq_psn: connection_info.psn,
353                max_dest_rd_atomic: self.config.max_dest_rd_atomic,
354                min_rnr_timer: self.config.min_rnr_timer,
355                ah_attr: rdmaxcel_sys::ibv_ah_attr {
356                    dlid: connection_info.lid,
357                    sl: 0,
358                    src_path_bits: 0,
359                    port_num: self.config.port_num,
360                    grh: Default::default(),
361                    ..Default::default()
362                },
363                ..Default::default()
364            };
365
366            if let Some(gid) = connection_info.gid {
367                qp_attr.ah_attr.is_global = 1;
368                qp_attr.ah_attr.grh.dgid = rdmaxcel_sys::ibv_gid::from(gid);
369                qp_attr.ah_attr.grh.hop_limit = 0xff;
370                qp_attr.ah_attr.grh.sgid_index = self.config.gid_index;
371            } else {
372                qp_attr.ah_attr.is_global = 0;
373            }
374
375            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
376                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_AV
377                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PATH_MTU
378                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_DEST_QPN
379                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RQ_PSN
380                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC
381                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER;
382
383            let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
384            if errno != 0 {
385                let os_error = Error::last_os_error();
386                return Err(anyhow::anyhow!(
387                    "failed to transition QP to RTR: {}",
388                    os_error
389                ));
390            }
391
392            // Transition to RTS (Ready to Send)
393            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
394                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS,
395                sq_psn: self.config.psn,
396                max_rd_atomic: self.config.max_rd_atomic,
397                retry_cnt: self.config.retry_cnt,
398                rnr_retry: self.config.rnr_retry,
399                timeout: self.config.qp_timeout,
400                ..Default::default()
401            };
402
403            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
404                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_TIMEOUT
405                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RETRY_CNT
406                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_SQ_PSN
407                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY
408                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC;
409
410            let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
411            if errno != 0 {
412                let os_error = Error::last_os_error();
413                return Err(anyhow::anyhow!(
414                    "failed to transition QP to RTS: {}",
415                    os_error
416                ));
417            }
418            tracing::debug!(
419                "connection sequence has successfully completed (qp: {:?})",
420                qp
421            );
422
423            let rts_timestamp_nanos = std::time::SystemTime::now()
424                .duration_since(std::time::UNIX_EPOCH)
425                .unwrap()
426                .as_nanos() as u64;
427            rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos);
428
429            Ok(())
430        }
431    }
432
433    /// Connects via the EFA-specific C function for QP state transitions.
434    fn efa_connect(&mut self, connection_info: &IbvQpInfo) -> Result<(), anyhow::Error> {
435        let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
436
437        let gid_ptr = connection_info.gid.as_ref().map_or(std::ptr::null(), |g| {
438            let ibv_gid: &rdmaxcel_sys::ibv_gid = g.as_ref();
439            unsafe { ibv_gid.raw.as_ptr() }
440        });
441
442        unsafe {
443            let ret = rdmaxcel_sys::rdmaxcel_efa_connect(
444                qp,
445                self.config.port_num,
446                self.config.pkey_index,
447                0x4242, // qkey
448                self.config.psn,
449                self.config.gid_index,
450                gid_ptr,
451                connection_info.qp_num,
452            );
453            if ret != 0 {
454                let msg = std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
455                    .to_str()
456                    .unwrap_or("unknown");
457                return Err(anyhow::anyhow!("EFA connect failed: {}", msg));
458            }
459        }
460
461        // Store RTS timestamp for first-op delay
462        let rts_timestamp_nanos = std::time::SystemTime::now()
463            .duration_since(std::time::UNIX_EPOCH)
464            .unwrap()
465            .as_nanos() as u64;
466        unsafe {
467            rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos);
468        }
469
470        Ok(())
471    }
472
473    pub fn recv(&mut self, lhandle: IbvBuffer, rhandle: IbvBuffer) -> Result<u64, anyhow::Error> {
474        unsafe {
475            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
476            let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(qp);
477            self.post_op(
478                0,
479                lhandle.lkey,
480                0,
481                idx,
482                true,
483                IbvOperation::Recv,
484                0,
485                rhandle.rkey,
486            )
487            .unwrap();
488            rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(qp);
489            Ok(idx)
490        }
491    }
492
493    pub fn put_with_recv(
494        &mut self,
495        lhandle: IbvBuffer,
496        rhandle: IbvBuffer,
497    ) -> Result<Vec<u64>, anyhow::Error> {
498        unsafe {
499            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
500            let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(qp);
501            self.post_op(
502                lhandle.addr,
503                lhandle.lkey,
504                lhandle.size,
505                idx,
506                true,
507                IbvOperation::WriteWithImm,
508                rhandle.addr,
509                rhandle.rkey,
510            )
511            .unwrap();
512            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(qp);
513            Ok(vec![idx])
514        }
515    }
516
517    pub fn put(
518        &mut self,
519        lhandle: IbvBuffer,
520        rhandle: IbvBuffer,
521    ) -> Result<Vec<u64>, anyhow::Error> {
522        let total_size = lhandle.size;
523        if rhandle.size < total_size {
524            return Err(anyhow::anyhow!(
525                "Remote buffer size ({}) is smaller than local buffer size ({})",
526                rhandle.size,
527                total_size
528            ));
529        }
530
531        let mut remaining = total_size;
532        let mut offset = 0;
533        let mut wr_ids = Vec::new();
534        while remaining > 0 {
535            let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
536            let idx = unsafe {
537                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
538                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
539                )
540            };
541            wr_ids.push(idx);
542            self.post_op(
543                lhandle.addr + offset,
544                lhandle.lkey,
545                chunk_size,
546                idx,
547                true,
548                IbvOperation::Write,
549                rhandle.addr + offset,
550                rhandle.rkey,
551            )?;
552            unsafe {
553                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
554                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
555                );
556            }
557
558            remaining -= chunk_size;
559            offset += chunk_size;
560        }
561
562        Ok(wr_ids)
563    }
564
565    /// Rings the doorbell to execute all enqueued operations.
566    pub fn ring_doorbell(&mut self) -> Result<(), anyhow::Error> {
567        // EFA uses standard ibverbs (not mlx5dv), so skip doorbell ringing
568        if self.is_efa() {
569            return Ok(());
570        }
571
572        unsafe {
573            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
574            let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
575            let base_ptr = (*dv_qp).sq.buf as *mut u8;
576            let wqe_cnt = (*dv_qp).sq.wqe_cnt;
577            let stride = (*dv_qp).sq.stride;
578            let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(qp);
579            let mut send_db_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp);
580            if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) {
581                return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
582            }
583            self.apply_first_op_delay(send_db_idx);
584            while send_db_idx < send_wqe_idx {
585                let offset = (send_db_idx % wqe_cnt as u64) * stride as u64;
586                let src_ptr = base_ptr.wrapping_add(offset as usize);
587                rdmaxcel_sys::db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
588                send_db_idx += 1;
589                rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx(qp, send_db_idx);
590            }
591            Ok(())
592        }
593    }
594
595    /// Enqueues a put operation without ringing the doorbell.
596    pub fn enqueue_put(
597        &mut self,
598        lhandle: IbvBuffer,
599        rhandle: IbvBuffer,
600    ) -> Result<Vec<u64>, anyhow::Error> {
601        let idx = unsafe {
602            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
603                self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
604            )
605        };
606
607        self.send_wqe(
608            lhandle.addr,
609            lhandle.lkey,
610            lhandle.size,
611            idx,
612            true,
613            IbvOperation::Write,
614            rhandle.addr,
615            rhandle.rkey,
616        )?;
617        Ok(vec![idx])
618    }
619
620    /// Enqueues a put-with-receive operation without ringing the doorbell.
621    pub fn enqueue_put_with_recv(
622        &mut self,
623        lhandle: IbvBuffer,
624        rhandle: IbvBuffer,
625    ) -> Result<Vec<u64>, anyhow::Error> {
626        let idx = unsafe {
627            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
628                self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
629            )
630        };
631
632        self.send_wqe(
633            lhandle.addr,
634            lhandle.lkey,
635            lhandle.size,
636            idx,
637            true,
638            IbvOperation::WriteWithImm,
639            rhandle.addr,
640            rhandle.rkey,
641        )?;
642        Ok(vec![idx])
643    }
644
645    /// Enqueues a get operation without ringing the doorbell.
646    pub fn enqueue_get(
647        &mut self,
648        lhandle: IbvBuffer,
649        rhandle: IbvBuffer,
650    ) -> Result<Vec<u64>, anyhow::Error> {
651        let idx = unsafe {
652            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
653                self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
654            )
655        };
656
657        self.send_wqe(
658            lhandle.addr,
659            lhandle.lkey,
660            lhandle.size,
661            idx,
662            true,
663            IbvOperation::Read,
664            rhandle.addr,
665            rhandle.rkey,
666        )?;
667        Ok(vec![idx])
668    }
669
670    pub fn get(
671        &mut self,
672        lhandle: IbvBuffer,
673        rhandle: IbvBuffer,
674    ) -> Result<Vec<u64>, anyhow::Error> {
675        let total_size = lhandle.size;
676        if rhandle.size < total_size {
677            return Err(anyhow::anyhow!(
678                "Remote buffer size ({}) is smaller than local buffer size ({})",
679                rhandle.size,
680                total_size
681            ));
682        }
683
684        let mut remaining = total_size;
685        let mut offset = 0;
686        let mut wr_ids = Vec::new();
687
688        while remaining > 0 {
689            let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
690            let idx = unsafe {
691                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
692                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
693                )
694            };
695            wr_ids.push(idx);
696
697            self.post_op(
698                lhandle.addr + offset,
699                lhandle.lkey,
700                chunk_size,
701                idx,
702                true,
703                IbvOperation::Read,
704                rhandle.addr + offset,
705                rhandle.rkey,
706            )?;
707            unsafe {
708                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
709                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
710                );
711            }
712
713            remaining -= chunk_size;
714            offset += chunk_size;
715        }
716
717        Ok(wr_ids)
718    }
719
720    /// Posts a request to the queue pair.
721    fn post_op(
722        &mut self,
723        laddr: usize,
724        lkey: u32,
725        length: usize,
726        wr_id: u64,
727        signaled: bool,
728        op_type: IbvOperation,
729        raddr: usize,
730        rkey: u32,
731    ) -> Result<(), anyhow::Error> {
732        // EFA: use unified C function
733        if self.is_efa() {
734            return self.post_op_efa(laddr, lkey, length, wr_id, signaled, op_type, raddr, rkey);
735        }
736
737        unsafe {
738            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
739            let context = self.context as *mut rdmaxcel_sys::ibv_context;
740            let ops = &mut (*context).ops;
741            let errno;
742            if op_type == IbvOperation::Recv {
743                let mut sge = rdmaxcel_sys::ibv_sge {
744                    addr: laddr as u64,
745                    length: length as u32,
746                    lkey,
747                };
748                let mut wr = rdmaxcel_sys::ibv_recv_wr {
749                    wr_id,
750                    sg_list: &mut sge as *mut _,
751                    num_sge: 1,
752                    ..Default::default()
753                };
754                let mut bad_wr: *mut rdmaxcel_sys::ibv_recv_wr = std::ptr::null_mut();
755                errno =
756                    ops.post_recv.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
757            } else if op_type == IbvOperation::Write
758                || op_type == IbvOperation::Read
759                || op_type == IbvOperation::WriteWithImm
760            {
761                self.apply_first_op_delay(wr_id);
762                let send_flags = if signaled {
763                    rdmaxcel_sys::ibv_send_flags::IBV_SEND_SIGNALED.0
764                } else {
765                    0
766                };
767                let mut sge = rdmaxcel_sys::ibv_sge {
768                    addr: laddr as u64,
769                    length: length as u32,
770                    lkey,
771                };
772                let mut wr = rdmaxcel_sys::ibv_send_wr {
773                    wr_id,
774                    next: std::ptr::null_mut(),
775                    sg_list: &mut sge as *mut _,
776                    num_sge: 1,
777                    opcode: op_type.into(),
778                    send_flags,
779                    wr: Default::default(),
780                    qp_type: Default::default(),
781                    __bindgen_anon_1: Default::default(),
782                    __bindgen_anon_2: Default::default(),
783                };
784
785                wr.wr.rdma.remote_addr = raddr as u64;
786                wr.wr.rdma.rkey = rkey;
787                let mut bad_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut();
788
789                errno =
790                    ops.post_send.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
791            } else {
792                panic!("Not Implemented");
793            }
794
795            if errno != 0 {
796                let os_error = Error::last_os_error();
797                return Err(anyhow::anyhow!("Failed to post send request: {}", os_error));
798            }
799            tracing::debug!(
800                "completed sending {:?} request (lkey: {}, addr: 0x{:x}, length {}) to (raddr 0x{:x}, rkey {})",
801                op_type,
802                lkey,
803                laddr,
804                length,
805                raddr,
806                rkey,
807            );
808
809            Ok(())
810        }
811    }
812
813    /// Posts an RDMA operation via the EFA-specific C function.
814    fn post_op_efa(
815        &mut self,
816        laddr: usize,
817        lkey: u32,
818        length: usize,
819        wr_id: u64,
820        signaled: bool,
821        op_type: IbvOperation,
822        raddr: usize,
823        rkey: u32,
824    ) -> Result<(), anyhow::Error> {
825        let c_op = match op_type {
826            IbvOperation::Write => 0,
827            IbvOperation::Read => 1,
828            IbvOperation::Recv => 2,
829            IbvOperation::WriteWithImm => 3,
830        };
831
832        unsafe {
833            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
834            let ret = rdmaxcel_sys::rdmaxcel_qp_post_op(
835                qp,
836                laddr as *mut std::ffi::c_void,
837                lkey,
838                length,
839                raddr as *mut std::ffi::c_void,
840                rkey,
841                wr_id,
842                signaled as i32,
843                c_op,
844            );
845            if ret != 0 {
846                let msg = std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
847                    .to_str()
848                    .unwrap_or("unknown");
849                return Err(anyhow::anyhow!("EFA post_op failed: {}", msg));
850            }
851        }
852        Ok(())
853    }
854
855    fn send_wqe(
856        &mut self,
857        laddr: usize,
858        lkey: u32,
859        length: usize,
860        wr_id: u64,
861        signaled: bool,
862        op_type: IbvOperation,
863        raddr: usize,
864        rkey: u32,
865    ) -> Result<DoorBell, anyhow::Error> {
866        // Non-mlx5 devices use the unified C post_op path
867        if self.is_efa() {
868            self.post_op(laddr, lkey, length, wr_id, signaled, op_type, raddr, rkey)?;
869            return Ok(DoorBell {
870                dst_ptr: 0,
871                src_ptr: 0,
872                size: 0,
873            });
874        }
875
876        unsafe {
877            let op_type_val = match op_type {
878                IbvOperation::Write => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE,
879                IbvOperation::WriteWithImm => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE_IMM,
880                IbvOperation::Read => rdmaxcel_sys::MLX5_OPCODE_RDMA_READ,
881                IbvOperation::Recv => 0,
882            };
883
884            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
885            let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
886            let _dv_cq = if op_type == IbvOperation::Recv {
887                self.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq
888            } else {
889                self.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq
890            };
891
892            let buf = if op_type == IbvOperation::Recv {
893                (*dv_qp).rq.buf as *mut u8
894            } else {
895                (*dv_qp).sq.buf as *mut u8
896            };
897
898            let params = rdmaxcel_sys::wqe_params_t {
899                laddr,
900                lkey,
901                length,
902                wr_id,
903                signaled,
904                op_type: op_type_val,
905                raddr,
906                rkey,
907                qp_num: (*(*qp).ibv_qp).qp_num,
908                buf,
909                dbrec: (*dv_qp).dbrec,
910                wqe_cnt: (*dv_qp).sq.wqe_cnt,
911            };
912
913            if op_type == IbvOperation::Recv {
914                rdmaxcel_sys::recv_wqe(params);
915                std::ptr::write_volatile((*dv_qp).dbrec, 1_u32.to_be());
916            } else {
917                rdmaxcel_sys::send_wqe(params);
918            };
919
920            Ok(DoorBell {
921                dst_ptr: (*dv_qp).bf.reg as usize,
922                src_ptr: (*dv_qp).sq.buf as usize,
923                size: 8,
924            })
925        }
926    }
927
928    /// Polls for work completions by wr_ids.
929    ///
930    /// # Arguments
931    ///
932    /// * `target` - Which completion queue to poll (Send, Receive)
933    /// * `expected_wr_ids` - Slice of work request IDs to wait for
934    ///
935    /// # Returns
936    ///
937    /// * `Ok(Vec<(u64, IbvWc)>)` - Vector of (wr_id, completion) pairs found
938    /// * `Err(e)` - An error occurred
939    pub fn poll_completion(
940        &mut self,
941        target: PollTarget,
942        expected_wr_ids: &[u64],
943    ) -> Result<Vec<(u64, IbvWc)>, PollCompletionError> {
944        if expected_wr_ids.is_empty() {
945            return Ok(Vec::new());
946        }
947
948        unsafe {
949            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
950            let qp_num = (*(*qp).ibv_qp).qp_num;
951
952            let (cq, cache, cq_type) = match target {
953                PollTarget::Send => (
954                    self.send_cq as *mut rdmaxcel_sys::ibv_cq,
955                    rdmaxcel_sys::rdmaxcel_qp_get_send_cache(qp),
956                    "send",
957                ),
958                PollTarget::Recv => (
959                    self.recv_cq as *mut rdmaxcel_sys::ibv_cq,
960                    rdmaxcel_sys::rdmaxcel_qp_get_recv_cache(qp),
961                    "recv",
962                ),
963            };
964
965            let mut results = Vec::new();
966
967            for &expected_wr_id in expected_wr_ids {
968                let mut poll_ctx = rdmaxcel_sys::poll_context_t {
969                    expected_wr_id,
970                    expected_qp_num: qp_num,
971                    cache,
972                    cq,
973                };
974
975                let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
976                let ret = rdmaxcel_sys::poll_cq_with_cache(&mut poll_ctx, &mut wc);
977
978                match ret {
979                    1 => {
980                        if !wc.is_valid() {
981                            if let Some((status, vendor_err)) = wc.error() {
982                                return Err(PollCompletionError {
983                                    status: Some(status),
984                                    vendor_err: Some(vendor_err),
985                                    message: format!(
986                                        "{} completion failed for wr_id={}: status={:?}, vendor_err={}",
987                                        cq_type, expected_wr_id, status, vendor_err,
988                                    ),
989                                });
990                            }
991                        }
992                        results.push((expected_wr_id, IbvWc::from(wc)));
993                    }
994                    0 => {
995                        // Not found yet
996                    }
997                    -17 => {
998                        let error_msg =
999                            std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1000                                .to_str()
1001                                .unwrap_or("Unknown error");
1002                        if let Some((status, vendor_err)) = wc.error() {
1003                            return Err(PollCompletionError {
1004                                status: Some(status),
1005                                vendor_err: Some(vendor_err),
1006                                message: format!(
1007                                    "Failed to poll {} CQ for wr_id={}: {} [status={:?}, vendor_err={}, qp_num={}, byte_len={}]",
1008                                    cq_type,
1009                                    expected_wr_id,
1010                                    error_msg,
1011                                    status,
1012                                    vendor_err,
1013                                    wc.qp_num,
1014                                    wc.len(),
1015                                ),
1016                            });
1017                        } else {
1018                            return Err(PollCompletionError {
1019                                status: None,
1020                                vendor_err: None,
1021                                message: format!(
1022                                    "Failed to poll {} CQ for wr_id={}: {} [qp_num={}, byte_len={}]",
1023                                    cq_type,
1024                                    expected_wr_id,
1025                                    error_msg,
1026                                    wc.qp_num,
1027                                    wc.len(),
1028                                ),
1029                            });
1030                        }
1031                    }
1032                    _ => {
1033                        let error_msg =
1034                            std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1035                                .to_str()
1036                                .unwrap_or("Unknown error");
1037                        return Err(PollCompletionError {
1038                            status: None,
1039                            vendor_err: None,
1040                            message: format!(
1041                                "Failed to poll {} CQ for wr_id={}: {}",
1042                                cq_type, expected_wr_id, error_msg,
1043                            ),
1044                        });
1045                    }
1046                }
1047            }
1048
1049            Ok(results)
1050        }
1051    }
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056    use super::*;
1057    use crate::backend::ibverbs::domain::IbvDomain;
1058    use crate::backend::ibverbs::primitives::IbvConfig;
1059    use crate::backend::ibverbs::primitives::get_all_devices;
1060
1061    #[test]
1062    fn test_create_connection() {
1063        if get_all_devices().is_empty() {
1064            println!("Skipping test: RDMA devices not available");
1065            return;
1066        }
1067
1068        let config = IbvConfig {
1069            use_gpu_direct: false,
1070            ..Default::default()
1071        };
1072        let domain = IbvDomain::new(config.device.clone());
1073        assert!(domain.is_ok());
1074
1075        let domain = domain.unwrap();
1076        let queue_pair = IbvQueuePair::new(domain.context, domain.pd, config.clone());
1077        assert!(queue_pair.is_ok());
1078    }
1079
1080    #[test]
1081    fn test_loopback_connection() {
1082        if get_all_devices().is_empty() {
1083            println!("Skipping test: RDMA devices not available");
1084            return;
1085        }
1086
1087        let server_config = IbvConfig {
1088            use_gpu_direct: false,
1089            ..Default::default()
1090        };
1091        let client_config = IbvConfig {
1092            use_gpu_direct: false,
1093            ..Default::default()
1094        };
1095
1096        let server_domain = IbvDomain::new(server_config.device.clone()).unwrap();
1097        let client_domain = IbvDomain::new(client_config.device.clone()).unwrap();
1098
1099        let mut server_qp = IbvQueuePair::new(
1100            server_domain.context,
1101            server_domain.pd,
1102            server_config.clone(),
1103        )
1104        .unwrap();
1105        let mut client_qp = IbvQueuePair::new(
1106            client_domain.context,
1107            client_domain.pd,
1108            client_config.clone(),
1109        )
1110        .unwrap();
1111
1112        let server_connection_info = server_qp.get_qp_info().unwrap();
1113        let client_connection_info = client_qp.get_qp_info().unwrap();
1114
1115        assert!(server_qp.connect(&client_connection_info).is_ok());
1116        assert!(client_qp.connect(&server_connection_info).is_ok());
1117    }
1118}