monarch_rdma/backend/ibverbs/
queue_pair.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! ibverbs queue pair, doorbell, and completion polling.
10//!
11//! An [`IbvQueuePair`] encapsulates the send and receive queues, completion
12//! queues, and other resources needed for RDMA communication. It provides
13//! methods for establishing connections and performing RDMA operations.
14
15/// Maximum size for a single RDMA operation in bytes (1 GiB).
16const MAX_RDMA_MSG_SIZE: usize = 1024 * 1024 * 1024;
17
18use std::io::Error;
19use std::result::Result;
20use std::time::Duration;
21
22use serde::Deserialize;
23use serde::Serialize;
24use typeuri::Named;
25
26use super::IbvBuffer;
27use super::primitives::Gid;
28use super::primitives::IbvConfig;
29use super::primitives::IbvOperation;
30use super::primitives::IbvQpInfo;
31use super::primitives::IbvWc;
32use super::primitives::resolve_qp_type;
33
34/// A doorbell trigger for batched RDMA operations.
35///
36/// Rings the hardware doorbell to execute previously enqueued work requests.
37#[derive(Debug, Named, Clone, Serialize, Deserialize)]
38pub struct DoorBell {
39    pub src_ptr: usize,
40    pub dst_ptr: usize,
41    pub size: usize,
42}
43wirevalue::register_type!(DoorBell);
44
45/// Specifies which completion queue to poll.
46#[derive(Debug, Clone, Copy, PartialEq)]
47pub enum PollTarget {
48    Send,
49    Recv,
50}
51
52/// An RDMA Queue Pair (QP) for communication between two endpoints.
53///
54/// Encapsulates the send/receive queues, completion queues, and mlx5dv
55/// device-specific structures needed for RDMA communication.
56///
57/// # Connection Lifecycle
58///
59/// 1. Create with `new()` from context and protection domain pointers
60/// 2. Get connection info with `get_qp_info()`
61/// 3. Exchange connection info with remote peer
62/// 4. Connect to remote endpoint with `connect()`
63/// 5. Perform RDMA operations with `put()` or `get()`
64/// 6. Poll for completions with `poll_completion()`
65///
66/// # Notes
67/// - The `qp` field stores a pointer to `rdmaxcel_qp_t` (not `ibv_qp`)
68/// - `rdmaxcel_qp_t` contains atomic counters and completion caches internally
69/// - This makes IbvQueuePair trivially Clone and Serialize
70/// - Multiple clones share the same underlying rdmaxcel_qp_t via the pointer
71#[derive(Debug, Serialize, Deserialize, Named, Clone)]
72pub struct IbvQueuePair {
73    pub send_cq: usize,    // *mut rdmaxcel_sys::ibv_cq,
74    pub recv_cq: usize,    // *mut rdmaxcel_sys::ibv_cq,
75    pub qp: usize,         // *mut rdmaxcel_sys::rdmaxcel_qp_t
76    pub dv_qp: usize,      // *mut rdmaxcel_sys::mlx5dv_qp,
77    pub dv_send_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
78    pub dv_recv_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
79    context: usize,        // *mut rdmaxcel_sys::ibv_context,
80    config: IbvConfig,
81    is_efa: bool,
82}
83wirevalue::register_type!(IbvQueuePair);
84
85impl IbvQueuePair {
86    fn is_efa(&self) -> bool {
87        self.is_efa
88    }
89
90    /// Applies hardware initialization delay if this is the first operation since RTS.
91    fn apply_first_op_delay(&self, wr_id: u64) {
92        unsafe {
93            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
94            if wr_id == 0 {
95                let rts_timestamp = rdmaxcel_sys::rdmaxcel_qp_load_rts_timestamp(qp);
96                assert!(
97                    rts_timestamp != u64::MAX,
98                    "First operation attempted before queue pair reached RTS state! Call connect() first."
99                );
100                let current_nanos = std::time::SystemTime::now()
101                    .duration_since(std::time::UNIX_EPOCH)
102                    .unwrap()
103                    .as_nanos() as u64;
104                let elapsed_nanos = current_nanos - rts_timestamp;
105                let elapsed = Duration::from_nanos(elapsed_nanos);
106                let init_delay = Duration::from_millis(self.config.hw_init_delay_ms);
107                if elapsed < init_delay {
108                    let remaining_delay = init_delay - elapsed;
109                    // Sync context within unsafe block; tokio::time::sleep is async
110                    // and converting would require propagating async through the
111                    // entire post_op / ring_doorbell call chain.
112                    std::thread::sleep(remaining_delay);
113                }
114            }
115        }
116    }
117
118    /// Creates a new IbvQueuePair.
119    ///
120    /// Initializes a new Queue Pair (QP) and associated Completion Queues (CQ)
121    /// using the provided context and protection domain. The QP is created in
122    /// the RESET state and must be transitioned via `connect()` before use.
123    ///
124    /// # Errors
125    ///
126    /// Returns errors if CQ or QP creation fails.
127    pub fn new(
128        context: *mut rdmaxcel_sys::ibv_context,
129        pd: *mut rdmaxcel_sys::ibv_pd,
130        config: IbvConfig,
131    ) -> Result<Self, anyhow::Error> {
132        tracing::debug!("creating an IbvQueuePair from config {}", config);
133        unsafe {
134            // Resolve Auto to a concrete QP type based on device capabilities
135            let resolved_qp_type = resolve_qp_type(config.qp_type);
136            let is_efa = resolved_qp_type == rdmaxcel_sys::RDMA_QP_TYPE_EFA;
137            let qp = rdmaxcel_sys::rdmaxcel_qp_create(
138                context,
139                pd,
140                config.cq_entries,
141                config.max_send_wr.try_into().unwrap(),
142                config.max_recv_wr.try_into().unwrap(),
143                config.max_send_sge.try_into().unwrap(),
144                config.max_recv_sge.try_into().unwrap(),
145                resolved_qp_type,
146            );
147
148            if qp.is_null() {
149                let os_error = Error::last_os_error();
150                return Err(anyhow::anyhow!(
151                    "failed to create queue pair (QP): {}",
152                    os_error
153                ));
154            }
155
156            let send_cq = (*(*qp).ibv_qp).send_cq;
157            let recv_cq = (*(*qp).ibv_qp).recv_cq;
158
159            // EFA uses standard ibverbs (not mlx5dv), so skip dv setup
160            if is_efa {
161                return Ok(IbvQueuePair {
162                    send_cq: send_cq as usize,
163                    recv_cq: recv_cq as usize,
164                    qp: qp as usize,
165                    dv_qp: 0,
166                    dv_send_cq: 0,
167                    dv_recv_cq: 0,
168                    context: context as usize,
169                    config,
170                    is_efa: true,
171                });
172            }
173
174            let dv_qp = rdmaxcel_sys::create_mlx5dv_qp((*qp).ibv_qp);
175            let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq((*qp).ibv_qp);
176            let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq((*qp).ibv_qp);
177
178            if dv_qp.is_null() || dv_send_cq.is_null() || dv_recv_cq.is_null() {
179                rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
180                rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
181                rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
182                return Err(anyhow::anyhow!(
183                    "failed to init mlx5dv_qp or completion queues"
184                ));
185            }
186
187            if config.use_gpu_direct {
188                let ret = rdmaxcel_sys::register_cuda_memory(dv_qp, dv_recv_cq, dv_send_cq);
189                if ret != 0 {
190                    rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
191                    rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
192                    rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
193                    return Err(anyhow::anyhow!(
194                        "failed to register GPU Direct RDMA memory: {:?}",
195                        ret
196                    ));
197                }
198            }
199            Ok(IbvQueuePair {
200                send_cq: send_cq as usize,
201                recv_cq: recv_cq as usize,
202                qp: qp as usize,
203                dv_qp: dv_qp as usize,
204                dv_send_cq: dv_send_cq as usize,
205                dv_recv_cq: dv_recv_cq as usize,
206                context: context as usize,
207                config,
208                is_efa: false,
209            })
210        }
211    }
212
213    /// Returns the connection info needed by a remote peer to connect to this QP.
214    pub fn get_qp_info(&mut self) -> Result<IbvQpInfo, anyhow::Error> {
215        unsafe {
216            let context = self.context as *mut rdmaxcel_sys::ibv_context;
217            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
218            let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
219            let errno = rdmaxcel_sys::ibv_query_port(
220                context,
221                self.config.port_num,
222                &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
223            );
224            if errno != 0 {
225                let os_error = Error::last_os_error();
226                return Err(anyhow::anyhow!(
227                    "Failed to query port attributes: {}",
228                    os_error
229                ));
230            }
231
232            let mut gid = Gid::default();
233            let ret = rdmaxcel_sys::ibv_query_gid(
234                context,
235                self.config.port_num,
236                i32::from(self.config.gid_index),
237                gid.as_mut(),
238            );
239            if ret != 0 {
240                return Err(anyhow::anyhow!("Failed to query GID"));
241            }
242
243            Ok(IbvQpInfo {
244                qp_num: (*(*qp).ibv_qp).qp_num,
245                lid: port_attr.lid,
246                gid: Some(gid),
247                psn: self.config.psn,
248            })
249        }
250    }
251
252    /// Returns the current state of the QP.
253    pub fn state(&mut self) -> Result<u32, anyhow::Error> {
254        unsafe {
255            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
256            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
257                ..Default::default()
258            };
259            let mut qp_init_attr = rdmaxcel_sys::ibv_qp_init_attr {
260                ..Default::default()
261            };
262            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE;
263            let errno = rdmaxcel_sys::ibv_query_qp(
264                (*qp).ibv_qp,
265                &mut qp_attr,
266                mask.0 as i32,
267                &mut qp_init_attr,
268            );
269            if errno != 0 {
270                let os_error = Error::last_os_error();
271                return Err(anyhow::anyhow!("failed to query QP state: {}", os_error));
272            }
273            Ok(qp_attr.qp_state)
274        }
275    }
276
277    /// Transitions the QP through INIT -> RTR -> RTS to establish a connection.
278    ///
279    /// # Arguments
280    ///
281    /// * `connection_info` - The remote connection info to connect to
282    pub fn connect(&mut self, connection_info: &IbvQpInfo) -> Result<(), anyhow::Error> {
283        // EFA: use unified C function for QP state transitions
284        if self.is_efa() {
285            return self.efa_connect(connection_info);
286        }
287
288        unsafe {
289            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
290
291            let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
292                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
293                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
294                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
295
296            // Transition to INIT
297            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
298                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_INIT,
299                qp_access_flags: qp_access_flags.0,
300                pkey_index: self.config.pkey_index,
301                port_num: self.config.port_num,
302                ..Default::default()
303            };
304
305            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
306                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX
307                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT
308                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS;
309
310            let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
311            if errno != 0 {
312                let os_error = Error::last_os_error();
313                return Err(anyhow::anyhow!(
314                    "failed to transition QP to INIT: {}",
315                    os_error
316                ));
317            }
318
319            // Transition to RTR (Ready to Receive)
320            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
321                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTR,
322                path_mtu: self.config.path_mtu,
323                dest_qp_num: connection_info.qp_num,
324                rq_psn: connection_info.psn,
325                max_dest_rd_atomic: self.config.max_dest_rd_atomic,
326                min_rnr_timer: self.config.min_rnr_timer,
327                ah_attr: rdmaxcel_sys::ibv_ah_attr {
328                    dlid: connection_info.lid,
329                    sl: 0,
330                    src_path_bits: 0,
331                    port_num: self.config.port_num,
332                    grh: Default::default(),
333                    ..Default::default()
334                },
335                ..Default::default()
336            };
337
338            if let Some(gid) = connection_info.gid {
339                qp_attr.ah_attr.is_global = 1;
340                qp_attr.ah_attr.grh.dgid = rdmaxcel_sys::ibv_gid::from(gid);
341                qp_attr.ah_attr.grh.hop_limit = 0xff;
342                qp_attr.ah_attr.grh.sgid_index = self.config.gid_index;
343            } else {
344                qp_attr.ah_attr.is_global = 0;
345            }
346
347            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
348                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_AV
349                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PATH_MTU
350                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_DEST_QPN
351                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RQ_PSN
352                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC
353                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER;
354
355            let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
356            if errno != 0 {
357                let os_error = Error::last_os_error();
358                return Err(anyhow::anyhow!(
359                    "failed to transition QP to RTR: {}",
360                    os_error
361                ));
362            }
363
364            // Transition to RTS (Ready to Send)
365            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
366                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS,
367                sq_psn: self.config.psn,
368                max_rd_atomic: self.config.max_rd_atomic,
369                retry_cnt: self.config.retry_cnt,
370                rnr_retry: self.config.rnr_retry,
371                timeout: self.config.qp_timeout,
372                ..Default::default()
373            };
374
375            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
376                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_TIMEOUT
377                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RETRY_CNT
378                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_SQ_PSN
379                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY
380                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC;
381
382            let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
383            if errno != 0 {
384                let os_error = Error::last_os_error();
385                return Err(anyhow::anyhow!(
386                    "failed to transition QP to RTS: {}",
387                    os_error
388                ));
389            }
390            tracing::debug!(
391                "connection sequence has successfully completed (qp: {:?})",
392                qp
393            );
394
395            let rts_timestamp_nanos = std::time::SystemTime::now()
396                .duration_since(std::time::UNIX_EPOCH)
397                .unwrap()
398                .as_nanos() as u64;
399            rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos);
400
401            Ok(())
402        }
403    }
404
405    /// Connects via the EFA-specific C function for QP state transitions.
406    fn efa_connect(&mut self, connection_info: &IbvQpInfo) -> Result<(), anyhow::Error> {
407        let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
408
409        let gid_ptr = connection_info.gid.as_ref().map_or(std::ptr::null(), |g| {
410            let ibv_gid: &rdmaxcel_sys::ibv_gid = g.as_ref();
411            unsafe { ibv_gid.raw.as_ptr() }
412        });
413
414        unsafe {
415            let ret = rdmaxcel_sys::rdmaxcel_efa_connect(
416                qp,
417                self.config.port_num,
418                self.config.pkey_index,
419                0x4242, // qkey
420                self.config.psn,
421                self.config.gid_index,
422                gid_ptr,
423                connection_info.qp_num,
424            );
425            if ret != 0 {
426                let msg = std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
427                    .to_str()
428                    .unwrap_or("unknown");
429                return Err(anyhow::anyhow!("EFA connect failed: {}", msg));
430            }
431        }
432
433        // Store RTS timestamp for first-op delay
434        let rts_timestamp_nanos = std::time::SystemTime::now()
435            .duration_since(std::time::UNIX_EPOCH)
436            .unwrap()
437            .as_nanos() as u64;
438        unsafe {
439            rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos);
440        }
441
442        Ok(())
443    }
444
445    pub fn recv(&mut self, lhandle: IbvBuffer, rhandle: IbvBuffer) -> Result<u64, anyhow::Error> {
446        unsafe {
447            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
448            let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(qp);
449            self.post_op(
450                0,
451                lhandle.lkey,
452                0,
453                idx,
454                true,
455                IbvOperation::Recv,
456                0,
457                rhandle.rkey,
458            )
459            .unwrap();
460            rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(qp);
461            Ok(idx)
462        }
463    }
464
465    pub fn put_with_recv(
466        &mut self,
467        lhandle: IbvBuffer,
468        rhandle: IbvBuffer,
469    ) -> Result<Vec<u64>, anyhow::Error> {
470        unsafe {
471            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
472            let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(qp);
473            self.post_op(
474                lhandle.addr,
475                lhandle.lkey,
476                lhandle.size,
477                idx,
478                true,
479                IbvOperation::WriteWithImm,
480                rhandle.addr,
481                rhandle.rkey,
482            )
483            .unwrap();
484            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(qp);
485            Ok(vec![idx])
486        }
487    }
488
489    pub fn put(
490        &mut self,
491        lhandle: IbvBuffer,
492        rhandle: IbvBuffer,
493    ) -> Result<Vec<u64>, anyhow::Error> {
494        let total_size = lhandle.size;
495        if rhandle.size < total_size {
496            return Err(anyhow::anyhow!(
497                "Remote buffer size ({}) is smaller than local buffer size ({})",
498                rhandle.size,
499                total_size
500            ));
501        }
502
503        let mut remaining = total_size;
504        let mut offset = 0;
505        let mut wr_ids = Vec::new();
506        while remaining > 0 {
507            let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
508            let idx = unsafe {
509                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
510                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
511                )
512            };
513            wr_ids.push(idx);
514            self.post_op(
515                lhandle.addr + offset,
516                lhandle.lkey,
517                chunk_size,
518                idx,
519                true,
520                IbvOperation::Write,
521                rhandle.addr + offset,
522                rhandle.rkey,
523            )?;
524            unsafe {
525                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
526                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
527                );
528            }
529
530            remaining -= chunk_size;
531            offset += chunk_size;
532        }
533
534        Ok(wr_ids)
535    }
536
537    /// Rings the doorbell to execute all enqueued operations.
538    pub fn ring_doorbell(&mut self) -> Result<(), anyhow::Error> {
539        // EFA uses standard ibverbs (not mlx5dv), so skip doorbell ringing
540        if self.is_efa() {
541            return Ok(());
542        }
543
544        unsafe {
545            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
546            let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
547            let base_ptr = (*dv_qp).sq.buf as *mut u8;
548            let wqe_cnt = (*dv_qp).sq.wqe_cnt;
549            let stride = (*dv_qp).sq.stride;
550            let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(qp);
551            let mut send_db_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp);
552            if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) {
553                return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
554            }
555            self.apply_first_op_delay(send_db_idx);
556            while send_db_idx < send_wqe_idx {
557                let offset = (send_db_idx % wqe_cnt as u64) * stride as u64;
558                let src_ptr = base_ptr.wrapping_add(offset as usize);
559                rdmaxcel_sys::db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
560                send_db_idx += 1;
561                rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx(qp, send_db_idx);
562            }
563            Ok(())
564        }
565    }
566
567    /// Enqueues a put operation without ringing the doorbell.
568    pub fn enqueue_put(
569        &mut self,
570        lhandle: IbvBuffer,
571        rhandle: IbvBuffer,
572    ) -> Result<Vec<u64>, anyhow::Error> {
573        let idx = unsafe {
574            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
575                self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
576            )
577        };
578
579        self.send_wqe(
580            lhandle.addr,
581            lhandle.lkey,
582            lhandle.size,
583            idx,
584            true,
585            IbvOperation::Write,
586            rhandle.addr,
587            rhandle.rkey,
588        )?;
589        Ok(vec![idx])
590    }
591
592    /// Enqueues a put-with-receive operation without ringing the doorbell.
593    pub fn enqueue_put_with_recv(
594        &mut self,
595        lhandle: IbvBuffer,
596        rhandle: IbvBuffer,
597    ) -> Result<Vec<u64>, anyhow::Error> {
598        let idx = unsafe {
599            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
600                self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
601            )
602        };
603
604        self.send_wqe(
605            lhandle.addr,
606            lhandle.lkey,
607            lhandle.size,
608            idx,
609            true,
610            IbvOperation::WriteWithImm,
611            rhandle.addr,
612            rhandle.rkey,
613        )?;
614        Ok(vec![idx])
615    }
616
617    /// Enqueues a get operation without ringing the doorbell.
618    pub fn enqueue_get(
619        &mut self,
620        lhandle: IbvBuffer,
621        rhandle: IbvBuffer,
622    ) -> Result<Vec<u64>, anyhow::Error> {
623        let idx = unsafe {
624            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
625                self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
626            )
627        };
628
629        self.send_wqe(
630            lhandle.addr,
631            lhandle.lkey,
632            lhandle.size,
633            idx,
634            true,
635            IbvOperation::Read,
636            rhandle.addr,
637            rhandle.rkey,
638        )?;
639        Ok(vec![idx])
640    }
641
642    pub fn get(
643        &mut self,
644        lhandle: IbvBuffer,
645        rhandle: IbvBuffer,
646    ) -> Result<Vec<u64>, anyhow::Error> {
647        let total_size = lhandle.size;
648        if rhandle.size < total_size {
649            return Err(anyhow::anyhow!(
650                "Remote buffer size ({}) is smaller than local buffer size ({})",
651                rhandle.size,
652                total_size
653            ));
654        }
655
656        let mut remaining = total_size;
657        let mut offset = 0;
658        let mut wr_ids = Vec::new();
659
660        while remaining > 0 {
661            let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
662            let idx = unsafe {
663                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
664                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
665                )
666            };
667            wr_ids.push(idx);
668
669            self.post_op(
670                lhandle.addr + offset,
671                lhandle.lkey,
672                chunk_size,
673                idx,
674                true,
675                IbvOperation::Read,
676                rhandle.addr + offset,
677                rhandle.rkey,
678            )?;
679            unsafe {
680                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
681                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
682                );
683            }
684
685            remaining -= chunk_size;
686            offset += chunk_size;
687        }
688
689        Ok(wr_ids)
690    }
691
692    /// Posts a request to the queue pair.
693    fn post_op(
694        &mut self,
695        laddr: usize,
696        lkey: u32,
697        length: usize,
698        wr_id: u64,
699        signaled: bool,
700        op_type: IbvOperation,
701        raddr: usize,
702        rkey: u32,
703    ) -> Result<(), anyhow::Error> {
704        // EFA: use unified C function
705        if self.is_efa() {
706            return self.post_op_efa(laddr, lkey, length, wr_id, signaled, op_type, raddr, rkey);
707        }
708
709        unsafe {
710            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
711            let context = self.context as *mut rdmaxcel_sys::ibv_context;
712            let ops = &mut (*context).ops;
713            let errno;
714            if op_type == IbvOperation::Recv {
715                let mut sge = rdmaxcel_sys::ibv_sge {
716                    addr: laddr as u64,
717                    length: length as u32,
718                    lkey,
719                };
720                let mut wr = rdmaxcel_sys::ibv_recv_wr {
721                    wr_id,
722                    sg_list: &mut sge as *mut _,
723                    num_sge: 1,
724                    ..Default::default()
725                };
726                let mut bad_wr: *mut rdmaxcel_sys::ibv_recv_wr = std::ptr::null_mut();
727                errno =
728                    ops.post_recv.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
729            } else if op_type == IbvOperation::Write
730                || op_type == IbvOperation::Read
731                || op_type == IbvOperation::WriteWithImm
732            {
733                self.apply_first_op_delay(wr_id);
734                let send_flags = if signaled {
735                    rdmaxcel_sys::ibv_send_flags::IBV_SEND_SIGNALED.0
736                } else {
737                    0
738                };
739                let mut sge = rdmaxcel_sys::ibv_sge {
740                    addr: laddr as u64,
741                    length: length as u32,
742                    lkey,
743                };
744                let mut wr = rdmaxcel_sys::ibv_send_wr {
745                    wr_id,
746                    next: std::ptr::null_mut(),
747                    sg_list: &mut sge as *mut _,
748                    num_sge: 1,
749                    opcode: op_type.into(),
750                    send_flags,
751                    wr: Default::default(),
752                    qp_type: Default::default(),
753                    __bindgen_anon_1: Default::default(),
754                    __bindgen_anon_2: Default::default(),
755                };
756
757                wr.wr.rdma.remote_addr = raddr as u64;
758                wr.wr.rdma.rkey = rkey;
759                let mut bad_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut();
760
761                errno =
762                    ops.post_send.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
763            } else {
764                panic!("Not Implemented");
765            }
766
767            if errno != 0 {
768                let os_error = Error::last_os_error();
769                return Err(anyhow::anyhow!("Failed to post send request: {}", os_error));
770            }
771            tracing::debug!(
772                "completed sending {:?} request (lkey: {}, addr: 0x{:x}, length {}) to (raddr 0x{:x}, rkey {})",
773                op_type,
774                lkey,
775                laddr,
776                length,
777                raddr,
778                rkey,
779            );
780
781            Ok(())
782        }
783    }
784
785    /// Posts an RDMA operation via the EFA-specific C function.
786    fn post_op_efa(
787        &mut self,
788        laddr: usize,
789        lkey: u32,
790        length: usize,
791        wr_id: u64,
792        signaled: bool,
793        op_type: IbvOperation,
794        raddr: usize,
795        rkey: u32,
796    ) -> Result<(), anyhow::Error> {
797        let c_op = match op_type {
798            IbvOperation::Write => 0,
799            IbvOperation::Read => 1,
800            IbvOperation::Recv => 2,
801            IbvOperation::WriteWithImm => 3,
802        };
803
804        unsafe {
805            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
806            let ret = rdmaxcel_sys::rdmaxcel_qp_post_op(
807                qp,
808                laddr as *mut std::ffi::c_void,
809                lkey,
810                length,
811                raddr as *mut std::ffi::c_void,
812                rkey,
813                wr_id,
814                signaled as i32,
815                c_op,
816            );
817            if ret != 0 {
818                let msg = std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
819                    .to_str()
820                    .unwrap_or("unknown");
821                return Err(anyhow::anyhow!("EFA post_op failed: {}", msg));
822            }
823        }
824        Ok(())
825    }
826
827    fn send_wqe(
828        &mut self,
829        laddr: usize,
830        lkey: u32,
831        length: usize,
832        wr_id: u64,
833        signaled: bool,
834        op_type: IbvOperation,
835        raddr: usize,
836        rkey: u32,
837    ) -> Result<DoorBell, anyhow::Error> {
838        // Non-mlx5 devices use the unified C post_op path
839        if self.is_efa() {
840            self.post_op(laddr, lkey, length, wr_id, signaled, op_type, raddr, rkey)?;
841            return Ok(DoorBell {
842                dst_ptr: 0,
843                src_ptr: 0,
844                size: 0,
845            });
846        }
847
848        unsafe {
849            let op_type_val = match op_type {
850                IbvOperation::Write => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE,
851                IbvOperation::WriteWithImm => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE_IMM,
852                IbvOperation::Read => rdmaxcel_sys::MLX5_OPCODE_RDMA_READ,
853                IbvOperation::Recv => 0,
854            };
855
856            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
857            let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
858            let _dv_cq = if op_type == IbvOperation::Recv {
859                self.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq
860            } else {
861                self.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq
862            };
863
864            let buf = if op_type == IbvOperation::Recv {
865                (*dv_qp).rq.buf as *mut u8
866            } else {
867                (*dv_qp).sq.buf as *mut u8
868            };
869
870            let params = rdmaxcel_sys::wqe_params_t {
871                laddr,
872                lkey,
873                length,
874                wr_id,
875                signaled,
876                op_type: op_type_val,
877                raddr,
878                rkey,
879                qp_num: (*(*qp).ibv_qp).qp_num,
880                buf,
881                dbrec: (*dv_qp).dbrec,
882                wqe_cnt: (*dv_qp).sq.wqe_cnt,
883            };
884
885            if op_type == IbvOperation::Recv {
886                rdmaxcel_sys::recv_wqe(params);
887                std::ptr::write_volatile((*dv_qp).dbrec, 1_u32.to_be());
888            } else {
889                rdmaxcel_sys::send_wqe(params);
890            };
891
892            Ok(DoorBell {
893                dst_ptr: (*dv_qp).bf.reg as usize,
894                src_ptr: (*dv_qp).sq.buf as usize,
895                size: 8,
896            })
897        }
898    }
899
900    /// Polls for work completions by wr_ids.
901    ///
902    /// # Arguments
903    ///
904    /// * `target` - Which completion queue to poll (Send, Receive)
905    /// * `expected_wr_ids` - Slice of work request IDs to wait for
906    ///
907    /// # Returns
908    ///
909    /// * `Ok(Vec<(u64, IbvWc)>)` - Vector of (wr_id, completion) pairs found
910    /// * `Err(e)` - An error occurred
911    pub fn poll_completion(
912        &mut self,
913        target: PollTarget,
914        expected_wr_ids: &[u64],
915    ) -> Result<Vec<(u64, IbvWc)>, anyhow::Error> {
916        if expected_wr_ids.is_empty() {
917            return Ok(Vec::new());
918        }
919
920        unsafe {
921            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
922            let qp_num = (*(*qp).ibv_qp).qp_num;
923
924            let (cq, cache, cq_type) = match target {
925                PollTarget::Send => (
926                    self.send_cq as *mut rdmaxcel_sys::ibv_cq,
927                    rdmaxcel_sys::rdmaxcel_qp_get_send_cache(qp),
928                    "send",
929                ),
930                PollTarget::Recv => (
931                    self.recv_cq as *mut rdmaxcel_sys::ibv_cq,
932                    rdmaxcel_sys::rdmaxcel_qp_get_recv_cache(qp),
933                    "recv",
934                ),
935            };
936
937            let mut results = Vec::new();
938
939            for &expected_wr_id in expected_wr_ids {
940                let mut poll_ctx = rdmaxcel_sys::poll_context_t {
941                    expected_wr_id,
942                    expected_qp_num: qp_num,
943                    cache,
944                    cq,
945                };
946
947                let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
948                let ret = rdmaxcel_sys::poll_cq_with_cache(&mut poll_ctx, &mut wc);
949
950                match ret {
951                    1 => {
952                        if !wc.is_valid() {
953                            if let Some((status, vendor_err)) = wc.error() {
954                                return Err(anyhow::anyhow!(
955                                    "{} completion failed for wr_id={}: status={:?}, vendor_err={}",
956                                    cq_type,
957                                    expected_wr_id,
958                                    status,
959                                    vendor_err,
960                                ));
961                            }
962                        }
963                        results.push((expected_wr_id, IbvWc::from(wc)));
964                    }
965                    0 => {
966                        // Not found yet
967                    }
968                    -17 => {
969                        let error_msg =
970                            std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
971                                .to_str()
972                                .unwrap_or("Unknown error");
973                        if let Some((status, vendor_err)) = wc.error() {
974                            return Err(anyhow::anyhow!(
975                                "Failed to poll {} CQ for wr_id={}: {} [status={:?}, vendor_err={}, qp_num={}, byte_len={}]",
976                                cq_type,
977                                expected_wr_id,
978                                error_msg,
979                                status,
980                                vendor_err,
981                                wc.qp_num,
982                                wc.len(),
983                            ));
984                        } else {
985                            return Err(anyhow::anyhow!(
986                                "Failed to poll {} CQ for wr_id={}: {} [qp_num={}, byte_len={}]",
987                                cq_type,
988                                expected_wr_id,
989                                error_msg,
990                                wc.qp_num,
991                                wc.len(),
992                            ));
993                        }
994                    }
995                    _ => {
996                        let error_msg =
997                            std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
998                                .to_str()
999                                .unwrap_or("Unknown error");
1000                        return Err(anyhow::anyhow!(
1001                            "Failed to poll {} CQ for wr_id={}: {}",
1002                            cq_type,
1003                            expected_wr_id,
1004                            error_msg
1005                        ));
1006                    }
1007                }
1008            }
1009
1010            Ok(results)
1011        }
1012    }
1013}
1014
1015#[cfg(test)]
1016mod tests {
1017    use super::*;
1018    use crate::backend::ibverbs::domain::IbvDomain;
1019    use crate::backend::ibverbs::primitives::IbvConfig;
1020    use crate::backend::ibverbs::primitives::get_all_devices;
1021
1022    #[test]
1023    fn test_create_connection() {
1024        if get_all_devices().is_empty() {
1025            println!("Skipping test: RDMA devices not available");
1026            return;
1027        }
1028
1029        let config = IbvConfig {
1030            use_gpu_direct: false,
1031            ..Default::default()
1032        };
1033        let domain = IbvDomain::new(config.device.clone());
1034        assert!(domain.is_ok());
1035
1036        let domain = domain.unwrap();
1037        let queue_pair = IbvQueuePair::new(domain.context, domain.pd, config.clone());
1038        assert!(queue_pair.is_ok());
1039    }
1040
1041    #[test]
1042    fn test_loopback_connection() {
1043        if get_all_devices().is_empty() {
1044            println!("Skipping test: RDMA devices not available");
1045            return;
1046        }
1047
1048        let server_config = IbvConfig {
1049            use_gpu_direct: false,
1050            ..Default::default()
1051        };
1052        let client_config = IbvConfig {
1053            use_gpu_direct: false,
1054            ..Default::default()
1055        };
1056
1057        let server_domain = IbvDomain::new(server_config.device.clone()).unwrap();
1058        let client_domain = IbvDomain::new(client_config.device.clone()).unwrap();
1059
1060        let mut server_qp = IbvQueuePair::new(
1061            server_domain.context,
1062            server_domain.pd,
1063            server_config.clone(),
1064        )
1065        .unwrap();
1066        let mut client_qp = IbvQueuePair::new(
1067            client_domain.context,
1068            client_domain.pd,
1069            client_config.clone(),
1070        )
1071        .unwrap();
1072
1073        let server_connection_info = server_qp.get_qp_info().unwrap();
1074        let client_connection_info = client_qp.get_qp_info().unwrap();
1075
1076        assert!(server_qp.connect(&client_connection_info).is_ok());
1077        assert!(client_qp.connect(&server_connection_info).is_ok());
1078    }
1079}