monarch_rdma/
rdma_components.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # RDMA Components
10//!
11//! This module provides the core RDMA building blocks for establishing and managing RDMA connections.
12//!
13//! ## Core Components
14//!
15//! * `RdmaDomain` - Manages RDMA resources including context, protection domain, and memory region
16//! * `RdmaQueuePair` - Handles communication between endpoints via queue pairs and completion queues
17//!
18//! ## RDMA Overview
19//!
20//! Remote Direct Memory Access (RDMA) allows direct memory access from the memory of one computer
21//! into the memory of another without involving either computer's operating system. This permits
22//! high-throughput, low-latency networking with minimal CPU overhead.
23//!
24//! ## Connection Architecture
25//!
26//! The module manages the following ibverbs primitives:
27//!
28//! 1. **Queue Pairs (QP)**: Each connection has a send queue and a receive queue
29//! 2. **Completion Queues (CQ)**: Events are reported when operations complete
30//! 3. **Memory Regions (MR)**: Memory must be registered with the RDMA device before use
31//! 4. **Protection Domains (PD)**: Provide isolation between different connections
32//!
33//! ## Connection Lifecycle
34//!
35//! 1. Create an `RdmaDomain` with `new()`
36//! 2. Create an `RdmaQueuePair` from the domain
37//! 3. Exchange connection info with remote peer (application must handle this)
38//! 4. Connect to remote endpoint with `connect()`
39//! 5. Perform RDMA operations (read/write)
40//! 6. Poll for completions
41//! 7. Resources are cleaned up when dropped
42
43/// Maximum size for a single RDMA operation in bytes (1 GiB)
44const MAX_RDMA_MSG_SIZE: usize = 1024 * 1024 * 1024;
45
46use std::ffi::CStr;
47use std::fs;
48use std::io::Error;
49use std::result::Result;
50use std::thread::sleep;
51use std::time::Duration;
52
53use hyperactor::ActorRef;
54use hyperactor::clock::Clock;
55use hyperactor::clock::RealClock;
56use hyperactor::context;
57use serde::Deserialize;
58use serde::Serialize;
59use typeuri::Named;
60
61use crate::RdmaDevice;
62use crate::RdmaManagerActor;
63use crate::RdmaManagerMessageClient;
64use crate::ibverbs_primitives::Gid;
65use crate::ibverbs_primitives::IbvWc;
66use crate::ibverbs_primitives::IbverbsConfig;
67use crate::ibverbs_primitives::RdmaOperation;
68use crate::ibverbs_primitives::RdmaQpInfo;
69use crate::ibverbs_primitives::resolve_qp_type;
70
71#[derive(Debug, Named, Clone, Serialize, Deserialize)]
72pub struct DoorBell {
73    pub src_ptr: usize,
74    pub dst_ptr: usize,
75    pub size: usize,
76}
77wirevalue::register_type!(DoorBell);
78
79impl DoorBell {
80    /// Rings the doorbell to trigger the execution of previously enqueued operations.
81    ///
82    /// This method uses unsafe code to directly interact with the RDMA device,
83    /// sending a signal from the source pointer to the destination pointer.
84    ///
85    /// # Returns
86    /// * `Ok(())` if the operation is successful.
87    /// * `Err(anyhow::Error)` if an error occurs during the operation.
88    pub fn ring(&self) -> Result<(), anyhow::Error> {
89        unsafe {
90            let src_ptr = self.src_ptr as *mut std::ffi::c_void;
91            let dst_ptr = self.dst_ptr as *mut std::ffi::c_void;
92            rdmaxcel_sys::db_ring(dst_ptr, src_ptr);
93            Ok(())
94        }
95    }
96}
97
98#[derive(Debug, Serialize, Deserialize, Named, Clone)]
99pub struct RdmaBuffer {
100    pub owner: ActorRef<RdmaManagerActor>,
101    pub mr_id: usize,
102    pub lkey: u32,
103    pub rkey: u32,
104    pub addr: usize,
105    pub size: usize,
106    pub device_name: String,
107}
108wirevalue::register_type!(RdmaBuffer);
109
110impl RdmaBuffer {
111    /// Read from the RdmaBuffer into the provided memory.
112    ///
113    /// This method transfers data from the buffer into the local memory region provided over RDMA.
114    /// This involves calling a `Put` operation on the RdmaBuffer's actor side.
115    ///
116    /// # Arguments
117    /// * `client` - The actor who is reading.
118    /// * `remote` - RdmaBuffer representing the remote memory region
119    /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
120    ///
121    /// # Returns
122    /// `Ok(bool)` indicating if the operation completed successfully.
123    pub async fn read_into(
124        &self,
125        client: &impl context::Actor,
126        remote: RdmaBuffer,
127        timeout: u64,
128    ) -> Result<bool, anyhow::Error> {
129        tracing::debug!(
130            "[buffer] reading from {:?} into remote ({:?}) at {:?}",
131            self,
132            remote.owner.actor_id(),
133            remote,
134        );
135        let remote_owner = remote.owner.clone();
136
137        let local_device = self.device_name.clone();
138        let remote_device = remote.device_name.clone();
139        let mut qp = self
140            .owner
141            .request_queue_pair(
142                client,
143                remote_owner.clone(),
144                local_device.clone(),
145                remote_device.clone(),
146            )
147            .await?;
148
149        let wr_id = qp.put(self.clone(), remote)?;
150        let result = self
151            .wait_for_completion(&mut qp, PollTarget::Send, &wr_id, timeout)
152            .await;
153
154        // Release the queue pair back to the actor
155        self.owner
156            .release_queue_pair(client, remote_owner, local_device, remote_device, qp)
157            .await?;
158
159        result
160    }
161
162    /// Write from the provided memory into the RdmaBuffer.
163    ///
164    /// This method performs an RDMA write operation, transferring data from the caller's
165    /// memory region to this buffer.
166    /// This involves calling a `Fetch` operation on the RdmaBuffer's actor side.
167    ///
168    /// # Arguments
169    /// * `client` - The actor who is writing.
170    /// * `remote` - RdmaBuffer representing the remote memory region
171    /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
172    ///
173    /// # Returns
174    /// `Ok(bool)` indicating if the operation completed successfully.
175    pub async fn write_from(
176        &self,
177        client: &impl context::Actor,
178        remote: RdmaBuffer,
179        timeout: u64,
180    ) -> Result<bool, anyhow::Error> {
181        tracing::debug!(
182            "[buffer] writing into {:?} from remote ({:?}) at {:?}",
183            self,
184            remote.owner.actor_id(),
185            remote,
186        );
187        let remote_owner = remote.owner.clone(); // Clone before the move!
188
189        // Extract device name from buffer, fallback to a default if not present
190        let local_device = self.device_name.clone();
191        let remote_device = remote.device_name.clone();
192
193        let mut qp = self
194            .owner
195            .request_queue_pair(
196                client,
197                remote_owner.clone(),
198                local_device.clone(),
199                remote_device.clone(),
200            )
201            .await?;
202        let wr_id = qp.get(self.clone(), remote)?;
203        let result = self
204            .wait_for_completion(&mut qp, PollTarget::Send, &wr_id, timeout)
205            .await;
206
207        // Release the queue pair back to the actor
208        self.owner
209            .release_queue_pair(client, remote_owner, local_device, remote_device, qp)
210            .await?;
211
212        result
213    }
214    /// Waits for the completion of RDMA operations.
215    ///
216    /// This method polls the completion queue until all specified work requests complete
217    /// or until the timeout is reached.
218    ///
219    /// # Arguments
220    /// * `qp` - The RDMA Queue Pair to poll for completion
221    /// * `poll_target` - Which CQ to poll (Send or Recv)
222    /// * `expected_wr_ids` - The work request IDs to wait for
223    /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
224    ///
225    /// # Returns
226    /// `Ok(true)` if all operations complete successfully within the timeout,
227    /// or an error if the timeout is reached
228    async fn wait_for_completion(
229        &self,
230        qp: &mut RdmaQueuePair,
231        poll_target: PollTarget,
232        expected_wr_ids: &[u64],
233        timeout: u64,
234    ) -> Result<bool, anyhow::Error> {
235        let timeout = Duration::from_secs(timeout);
236        let start_time = std::time::Instant::now();
237
238        let mut remaining: std::collections::HashSet<u64> =
239            expected_wr_ids.iter().copied().collect();
240
241        while start_time.elapsed() < timeout {
242            if remaining.is_empty() {
243                return Ok(true);
244            }
245
246            let wr_ids_to_poll: Vec<u64> = remaining.iter().copied().collect();
247            match qp.poll_completion(poll_target, &wr_ids_to_poll) {
248                Ok(completions) => {
249                    for (wr_id, _wc) in completions {
250                        remaining.remove(&wr_id);
251                    }
252                    if remaining.is_empty() {
253                        return Ok(true);
254                    }
255                    RealClock.sleep(Duration::from_millis(1)).await;
256                }
257                Err(e) => {
258                    return Err(anyhow::anyhow!(
259                        "RDMA polling completion failed: {} [lkey={}, rkey={}, addr=0x{:x}, size={}]",
260                        e,
261                        self.lkey,
262                        self.rkey,
263                        self.addr,
264                        self.size
265                    ));
266                }
267            }
268        }
269        tracing::error!(
270            "timed out while waiting on request completion for wr_ids={:?}",
271            remaining
272        );
273        Err(anyhow::anyhow!(
274            "[buffer({:?})] rdma operation did not complete in time (expected wr_ids={:?})",
275            self,
276            expected_wr_ids
277        ))
278    }
279
280    /// Drop the buffer and release remote handles.
281    ///
282    /// This method calls the owning RdmaManagerActor to release the buffer and clean up
283    /// associated memory regions. This is typically called when the buffer is no longer
284    /// needed and resources should be freed.
285    ///
286    /// # Arguments
287    /// * `client` - Mailbox used for communication
288    ///
289    /// # Returns
290    /// `Ok(())` if the operation completed successfully.
291    pub async fn drop_buffer(&self, client: &impl context::Actor) -> Result<(), anyhow::Error> {
292        tracing::debug!("[buffer] dropping buffer {:?}", self);
293        self.owner.release_buffer(client, self.clone()).await?;
294        Ok(())
295    }
296}
297
298/// Represents a domain for RDMA operations, encapsulating the necessary resources
299/// for establishing and managing RDMA connections.
300///
301/// An `RdmaDomain` manages the context, protection domain (PD), and memory region (MR)
302/// required for RDMA operations. It provides the foundation for creating queue pairs
303/// and establishing connections between RDMA devices.
304///
305/// # Fields
306///
307/// * `context`: A pointer to the RDMA device context, representing the connection to the RDMA device.
308/// * `pd`: A pointer to the protection domain, which provides isolation between different connections.
309#[derive(Clone)]
310pub struct RdmaDomain {
311    pub context: *mut rdmaxcel_sys::ibv_context,
312    pub pd: *mut rdmaxcel_sys::ibv_pd,
313}
314
315impl std::fmt::Debug for RdmaDomain {
316    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
317        f.debug_struct("RdmaDomain")
318            .field("context", &format!("{:p}", self.context))
319            .field("pd", &format!("{:p}", self.pd))
320            .finish()
321    }
322}
323
324// SAFETY:
325// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls.
326// RdmaDomain is `Send` because the raw pointers to ibverbs structs can be
327// accessed from any thread, and it is safe to drop `RdmaDomain` (and run the
328// ibverbs destructors) from any thread.
329unsafe impl Send for RdmaDomain {}
330
331// SAFETY:
332// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls.
333// RdmaDomain is `Sync` because the underlying ibverbs APIs are thread-safe.
334unsafe impl Sync for RdmaDomain {}
335
336impl Drop for RdmaDomain {
337    fn drop(&mut self) {
338        unsafe {
339            rdmaxcel_sys::ibv_dealloc_pd(self.pd);
340        }
341    }
342}
343
344impl RdmaDomain {
345    /// Creates a new RdmaDomain.
346    ///
347    /// This function initializes the RDMA device context, creates a protection domain,
348    /// and registers a memory region with appropriate access permissions.
349    ///
350    /// SAFETY:
351    /// Our memory region (MR) registration uses implicit ODP for RDMA access, which maps large virtual
352    /// address ranges without explicit pinning. This is convenient, but it broadens the memory footprint
353    /// exposed to the NIC and introduces a security liability.
354    ///
355    /// We currently assume a trusted, single-environment and are not enforcing finer-grained memory isolation
356    /// at this layer. We plan to investigate mitigations - such as memory windows or tighter registration
357    /// boundaries in future follow-ups.
358    ///
359    /// # Arguments
360    ///
361    /// * `config` - Configuration settings for the RDMA operations
362    ///
363    /// # Errors
364    ///
365    /// This function may return errors if:
366    /// * No RDMA devices are found
367    /// * The specified device cannot be found
368    /// * Device context creation fails
369    /// * Protection domain allocation fails
370    /// * Memory region registration fails
371    pub fn new(device: RdmaDevice) -> Result<Self, anyhow::Error> {
372        tracing::debug!("creating RdmaDomain for device {}", device.name());
373        // SAFETY:
374        // This code uses unsafe rdmaxcel_sys calls to interact with the RDMA device, but is safe because:
375        // - All pointers are properly initialized and checked for null before use
376        // - Memory registration follows the ibverbs API contract with proper access flags
377        // - Resources are properly cleaned up in error cases to prevent leaks
378        // - The operations follow the documented RDMA protocol for device initialization
379        unsafe {
380            // Get the device based on the provided RdmaDevice
381            let device_name = device.name();
382            let mut num_devices = 0i32;
383            let devices = rdmaxcel_sys::ibv_get_device_list(&mut num_devices as *mut _);
384
385            if devices.is_null() || num_devices == 0 {
386                return Err(anyhow::anyhow!("no RDMA devices found"));
387            }
388
389            // Find the device with the matching name
390            let mut device_ptr = std::ptr::null_mut();
391            for i in 0..num_devices {
392                let dev = *devices.offset(i as isize);
393                let dev_name =
394                    CStr::from_ptr(rdmaxcel_sys::ibv_get_device_name(dev)).to_string_lossy();
395
396                if dev_name == *device_name {
397                    device_ptr = dev;
398                    break;
399                }
400            }
401
402            // If we didn't find the device, return an error
403            if device_ptr.is_null() {
404                rdmaxcel_sys::ibv_free_device_list(devices);
405                return Err(anyhow::anyhow!("device '{}' not found", device_name));
406            }
407            tracing::info!("using RDMA device: {}", device_name);
408
409            // Open device
410            let context = rdmaxcel_sys::ibv_open_device(device_ptr);
411            if context.is_null() {
412                rdmaxcel_sys::ibv_free_device_list(devices);
413                let os_error = Error::last_os_error();
414                return Err(anyhow::anyhow!("failed to create context: {}", os_error));
415            }
416
417            // Create protection domain
418            let pd = rdmaxcel_sys::ibv_alloc_pd(context);
419            if pd.is_null() {
420                rdmaxcel_sys::ibv_close_device(context);
421                rdmaxcel_sys::ibv_free_device_list(devices);
422                let os_error = Error::last_os_error();
423                return Err(anyhow::anyhow!(
424                    "failed to create protection domain (PD): {}",
425                    os_error
426                ));
427            }
428
429            // Avoids memory leaks
430            rdmaxcel_sys::ibv_free_device_list(devices);
431
432            let domain = RdmaDomain { context, pd };
433
434            Ok(domain)
435        }
436    }
437}
438/// Enum to specify which completion queue to poll
439#[derive(Debug, Clone, Copy, PartialEq)]
440pub enum PollTarget {
441    Send,
442    Recv,
443}
444
445/// Represents an RDMA Queue Pair (QP) that enables communication between two endpoints.
446///
447/// An `RdmaQueuePair` encapsulates the send and receive queues, completion queue,
448/// and other resources needed for RDMA communication. It provides methods for
449/// establishing connections and performing RDMA operations like read and write.
450///
451/// # Fields
452///
453/// * `send_cq` - Send Completion Queue pointer for tracking send operation completions
454/// * `recv_cq` - Receive Completion Queue pointer for tracking receive operation completions
455/// * `qp` - Queue Pair pointer that manages send and receive operations
456/// * `dv_qp` - Pointer to the mlx5 device-specific queue pair structure
457/// * `dv_send_cq` - Pointer to the mlx5 device-specific send completion queue structure
458/// * `dv_recv_cq` - Pointer to the mlx5 device-specific receive completion queue structure
459/// * `context` - RDMA device context pointer
460/// * `config` - Configuration settings for the queue pair
461///
462/// # Connection Lifecycle
463///
464/// 1. Create with `new()` from an `RdmaDomain`
465/// 2. Get connection info with `get_qp_info()`
466/// 3. Exchange connection info with remote peer (application must handle this)
467/// 4. Connect to remote endpoint with `connect()`
468/// 5. Perform RDMA operations with `put()` or `get()`
469/// 6. Poll for completions with `poll_send_completion(wr_id)` or `poll_recv_completion(wr_id)`
470///
471/// # Notes
472/// - The `qp` field stores a pointer to `rdmaxcel_qp_t` (not `ibv_qp`)
473/// - `rdmaxcel_qp_t` contains atomic counters and completion caches internally
474/// - This makes RdmaQueuePair trivially Clone and Serialize
475/// - Multiple clones share the same underlying rdmaxcel_qp_t via the pointer
476#[derive(Debug, Serialize, Deserialize, Named, Clone)]
477pub struct RdmaQueuePair {
478    pub send_cq: usize,    // *mut rdmaxcel_sys::ibv_cq,
479    pub recv_cq: usize,    // *mut rdmaxcel_sys::ibv_cq,
480    pub qp: usize,         // *mut rdmaxcel_sys::rdmaxcel_qp_t
481    pub dv_qp: usize,      // *mut rdmaxcel_sys::mlx5dv_qp,
482    pub dv_send_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
483    pub dv_recv_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
484    context: usize,        // *mut rdmaxcel_sys::ibv_context,
485    config: IbverbsConfig,
486}
487wirevalue::register_type!(RdmaQueuePair);
488
489impl RdmaQueuePair {
490    /// Applies hardware initialization delay if this is the first operation since RTS.
491    ///
492    /// This ensures the hardware has sufficient time to settle after reaching
493    /// Ready-to-Send state before the first actual operation.
494    fn apply_first_op_delay(&self, wr_id: u64) {
495        unsafe {
496            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
497            if wr_id == 0 {
498                let rts_timestamp = rdmaxcel_sys::rdmaxcel_qp_load_rts_timestamp(qp);
499                assert!(
500                    rts_timestamp != u64::MAX,
501                    "First operation attempted before queue pair reached RTS state! Call connect() first."
502                );
503                let current_nanos = RealClock
504                    .system_time_now()
505                    .duration_since(std::time::UNIX_EPOCH)
506                    .unwrap()
507                    .as_nanos() as u64;
508                let elapsed_nanos = current_nanos - rts_timestamp;
509                let elapsed = Duration::from_nanos(elapsed_nanos);
510                let init_delay = Duration::from_millis(self.config.hw_init_delay_ms);
511                if elapsed < init_delay {
512                    let remaining_delay = init_delay - elapsed;
513                    sleep(remaining_delay);
514                }
515            }
516        }
517    }
518
519    /// Creates a new RdmaQueuePair from a given RdmaDomain.
520    ///
521    /// This function initializes a new Queue Pair (QP) and associated Completion Queue (CQ)
522    /// using the resources from the provided RdmaDomain. The QP is created in the RESET state
523    /// and must be transitioned to other states via the `connect()` method before use.
524    ///
525    /// # Arguments
526    ///
527    /// * `domain` - Reference to an RdmaDomain that provides the context, protection domain,
528    ///   and memory region for this queue pair
529    ///
530    /// # Returns
531    ///
532    /// * `Result<Self>` - A new RdmaQueuePair instance or an error if creation fails
533    ///
534    /// # Errors
535    ///
536    /// This function may return errors if:
537    /// * Completion queue (CQ) creation fails
538    /// * Queue pair (QP) creation fails
539    pub fn new(
540        context: *mut rdmaxcel_sys::ibv_context,
541        pd: *mut rdmaxcel_sys::ibv_pd,
542        config: IbverbsConfig,
543    ) -> Result<Self, anyhow::Error> {
544        tracing::debug!("creating an RdmaQueuePair from config {}", config);
545        unsafe {
546            // Resolve Auto to a concrete QP type based on device capabilities
547            let resolved_qp_type = resolve_qp_type(config.qp_type);
548            let qp = rdmaxcel_sys::rdmaxcel_qp_create(
549                context,
550                pd,
551                config.cq_entries,
552                config.max_send_wr.try_into().unwrap(),
553                config.max_recv_wr.try_into().unwrap(),
554                config.max_send_sge.try_into().unwrap(),
555                config.max_recv_sge.try_into().unwrap(),
556                resolved_qp_type,
557            );
558
559            if qp.is_null() {
560                let os_error = Error::last_os_error();
561                return Err(anyhow::anyhow!(
562                    "failed to create queue pair (QP): {}",
563                    os_error
564                ));
565            }
566
567            let send_cq = (*(*qp).ibv_qp).send_cq;
568            let recv_cq = (*(*qp).ibv_qp).recv_cq;
569
570            // mlx5dv provider APIs
571            let dv_qp = rdmaxcel_sys::create_mlx5dv_qp((*qp).ibv_qp);
572            let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq((*qp).ibv_qp);
573            let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq((*qp).ibv_qp);
574
575            if dv_qp.is_null() || dv_send_cq.is_null() || dv_recv_cq.is_null() {
576                rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
577                rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
578                rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
579                return Err(anyhow::anyhow!(
580                    "failed to init mlx5dv_qp or completion queues"
581                ));
582            }
583
584            // GPU Direct RDMA specific registrations
585            if config.use_gpu_direct {
586                let ret = rdmaxcel_sys::register_cuda_memory(dv_qp, dv_recv_cq, dv_send_cq);
587                if ret != 0 {
588                    rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
589                    rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
590                    rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
591                    return Err(anyhow::anyhow!(
592                        "failed to register GPU Direct RDMA memory: {:?}",
593                        ret
594                    ));
595                }
596            }
597            Ok(RdmaQueuePair {
598                send_cq: send_cq as usize,
599                recv_cq: recv_cq as usize,
600                qp: qp as usize,
601                dv_qp: qp as usize,
602                dv_send_cq: dv_send_cq as usize,
603                dv_recv_cq: dv_recv_cq as usize,
604                context: context as usize,
605                config,
606            })
607        }
608    }
609
610    /// Returns the information required for a remote peer to connect to this queue pair.
611    ///
612    /// This method retrieves the local queue pair attributes and port information needed by
613    /// a remote peer to establish an RDMA connection. The returned `RdmaQpInfo` contains
614    /// the queue pair number, LID, GID, and other necessary connection parameters.
615    ///
616    /// # Returns
617    ///
618    /// * `Result<RdmaQpInfo>` - Connection information for the remote peer or an error
619    ///
620    /// # Errors
621    ///
622    /// This function may return errors if:
623    /// * Port attribute query fails
624    /// * GID query fails
625    pub fn get_qp_info(&mut self) -> Result<RdmaQpInfo, anyhow::Error> {
626        // SAFETY:
627        // This code uses unsafe rdmaxcel_sys calls to query RDMA device information, but is safe because:
628        // - All pointers are properly initialized before use
629        // - Port and GID queries follow the documented ibverbs API contract
630        // - Error handling properly checks return codes from ibverbs functions
631        // - The memory address provided is only stored, not dereferenced in this function
632        unsafe {
633            let context = self.context as *mut rdmaxcel_sys::ibv_context;
634            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
635            let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
636            let errno = rdmaxcel_sys::ibv_query_port(
637                context,
638                self.config.port_num,
639                &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
640            );
641            if errno != 0 {
642                let os_error = Error::last_os_error();
643                return Err(anyhow::anyhow!(
644                    "Failed to query port attributes: {}",
645                    os_error
646                ));
647            }
648
649            let mut gid = Gid::default();
650            let ret = rdmaxcel_sys::ibv_query_gid(
651                context,
652                self.config.port_num,
653                i32::from(self.config.gid_index),
654                gid.as_mut(),
655            );
656            if ret != 0 {
657                return Err(anyhow::anyhow!("Failed to query GID"));
658            }
659
660            Ok(RdmaQpInfo {
661                qp_num: (*(*qp).ibv_qp).qp_num,
662                lid: port_attr.lid,
663                gid: Some(gid),
664                psn: self.config.psn,
665            })
666        }
667    }
668
669    pub fn state(&mut self) -> Result<u32, anyhow::Error> {
670        // SAFETY: This block interacts with the RDMA device through rdmaxcel_sys calls.
671        unsafe {
672            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
673            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
674                ..Default::default()
675            };
676            let mut qp_init_attr = rdmaxcel_sys::ibv_qp_init_attr {
677                ..Default::default()
678            };
679            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE;
680            let errno = rdmaxcel_sys::ibv_query_qp(
681                (*qp).ibv_qp,
682                &mut qp_attr,
683                mask.0 as i32,
684                &mut qp_init_attr,
685            );
686            if errno != 0 {
687                let os_error = Error::last_os_error();
688                return Err(anyhow::anyhow!("failed to query QP state: {}", os_error));
689            }
690            Ok(qp_attr.qp_state)
691        }
692    }
693    /// Connect to a remote Rdma connection point.
694    ///
695    /// This performs the necessary QP state transitions (INIT->RTR->RTS) to establish a connection.
696    ///
697    /// # Arguments
698    ///
699    /// * `connection_info` - The remote connection info to connect to
700    pub fn connect(&mut self, connection_info: &RdmaQpInfo) -> Result<(), anyhow::Error> {
701        // SAFETY:
702        // This unsafe block is necessary because we're interacting with the RDMA device through rdmaxcel_sys calls.
703        // The operations are safe because:
704        // 1. We're following the documented ibverbs API contract
705        // 2. All pointers used are properly initialized and owned by this struct
706        // 3. The QP state transitions (INIT->RTR->RTS) follow the required RDMA connection protocol
707        // 4. Memory access is properly bounded by the registered memory regions
708        unsafe {
709            // Transition to INIT
710            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
711
712            let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
713                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
714                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
715                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
716
717            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
718                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_INIT,
719                qp_access_flags: qp_access_flags.0,
720                pkey_index: self.config.pkey_index,
721                port_num: self.config.port_num,
722                ..Default::default()
723            };
724
725            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
726                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX
727                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT
728                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS;
729
730            let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
731            if errno != 0 {
732                let os_error = Error::last_os_error();
733                return Err(anyhow::anyhow!(
734                    "failed to transition QP to INIT: {}",
735                    os_error
736                ));
737            }
738
739            // Transition to RTR (Ready to Receive)
740            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
741                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTR,
742                path_mtu: self.config.path_mtu,
743                dest_qp_num: connection_info.qp_num,
744                rq_psn: connection_info.psn,
745                max_dest_rd_atomic: self.config.max_dest_rd_atomic,
746                min_rnr_timer: self.config.min_rnr_timer,
747                ah_attr: rdmaxcel_sys::ibv_ah_attr {
748                    dlid: connection_info.lid,
749                    sl: 0,
750                    src_path_bits: 0,
751                    port_num: self.config.port_num,
752                    grh: Default::default(),
753                    ..Default::default()
754                },
755                ..Default::default()
756            };
757
758            // If the remote connection info contains a Gid, the routing will be global.
759            // Otherwise, it will be local, i.e. using LID.
760            if let Some(gid) = connection_info.gid {
761                qp_attr.ah_attr.is_global = 1;
762                qp_attr.ah_attr.grh.dgid = gid.into();
763                qp_attr.ah_attr.grh.hop_limit = 0xff;
764                qp_attr.ah_attr.grh.sgid_index = self.config.gid_index;
765            } else {
766                // Use LID-based routing, e.g. for Infiniband/RoCEv1
767                qp_attr.ah_attr.is_global = 0;
768            }
769
770            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
771                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_AV
772                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PATH_MTU
773                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_DEST_QPN
774                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RQ_PSN
775                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC
776                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER;
777
778            let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
779            if errno != 0 {
780                let os_error = Error::last_os_error();
781                return Err(anyhow::anyhow!(
782                    "failed to transition QP to RTR: {}",
783                    os_error
784                ));
785            }
786
787            // Transition to RTS (Ready to Send)
788            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
789                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS,
790                sq_psn: self.config.psn,
791                max_rd_atomic: self.config.max_rd_atomic,
792                retry_cnt: self.config.retry_cnt,
793                rnr_retry: self.config.rnr_retry,
794                timeout: self.config.qp_timeout,
795                ..Default::default()
796            };
797
798            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
799                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_TIMEOUT
800                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RETRY_CNT
801                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_SQ_PSN
802                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY
803                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC;
804
805            let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
806            if errno != 0 {
807                let os_error = Error::last_os_error();
808                return Err(anyhow::anyhow!(
809                    "failed to transition QP to RTS: {}",
810                    os_error
811                ));
812            }
813            tracing::debug!(
814                "connection sequence has successfully completed (qp: {:?})",
815                qp
816            );
817
818            // Record RTS time now that the queue pair is ready to send
819            let rts_timestamp_nanos = RealClock
820                .system_time_now()
821                .duration_since(std::time::UNIX_EPOCH)
822                .unwrap()
823                .as_nanos() as u64;
824            rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos);
825
826            Ok(())
827        }
828    }
829
830    pub fn recv(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<u64, anyhow::Error> {
831        unsafe {
832            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
833            let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(qp);
834            self.post_op(
835                0,
836                lhandle.lkey,
837                0,
838                idx,
839                true,
840                RdmaOperation::Recv,
841                0,
842                rhandle.rkey,
843            )
844            .unwrap();
845            rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(qp);
846            Ok(idx)
847        }
848    }
849
850    pub fn put_with_recv(
851        &mut self,
852        lhandle: RdmaBuffer,
853        rhandle: RdmaBuffer,
854    ) -> Result<Vec<u64>, anyhow::Error> {
855        unsafe {
856            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
857            let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(qp);
858            self.post_op(
859                lhandle.addr,
860                lhandle.lkey,
861                lhandle.size,
862                idx,
863                true,
864                RdmaOperation::WriteWithImm,
865                rhandle.addr,
866                rhandle.rkey,
867            )
868            .unwrap();
869            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(qp);
870            Ok(vec![idx])
871        }
872    }
873
874    pub fn put(
875        &mut self,
876        lhandle: RdmaBuffer,
877        rhandle: RdmaBuffer,
878    ) -> Result<Vec<u64>, anyhow::Error> {
879        let total_size = lhandle.size;
880        if rhandle.size < total_size {
881            return Err(anyhow::anyhow!(
882                "Remote buffer size ({}) is smaller than local buffer size ({})",
883                rhandle.size,
884                total_size
885            ));
886        }
887
888        let mut remaining = total_size;
889        let mut offset = 0;
890        let mut wr_ids = Vec::new();
891        while remaining > 0 {
892            let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
893            let idx = unsafe {
894                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
895                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
896                )
897            };
898            wr_ids.push(idx);
899            self.post_op(
900                lhandle.addr + offset,
901                lhandle.lkey,
902                chunk_size,
903                idx,
904                true,
905                RdmaOperation::Write,
906                rhandle.addr + offset,
907                rhandle.rkey,
908            )?;
909            unsafe {
910                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
911                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
912                );
913            }
914
915            remaining -= chunk_size;
916            offset += chunk_size;
917        }
918
919        Ok(wr_ids)
920    }
921
922    /// Get a doorbell for the queue pair.
923    ///
924    /// This method returns a doorbell that can be used to trigger the execution of
925    /// previously enqueued operations.
926    ///
927    /// # Returns
928    ///
929    /// * `Result<DoorBell, anyhow::Error>` - A doorbell for the queue pair
930    pub fn ring_doorbell(&mut self) -> Result<(), anyhow::Error> {
931        unsafe {
932            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
933            let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
934            let base_ptr = (*dv_qp).sq.buf as *mut u8;
935            let wqe_cnt = (*dv_qp).sq.wqe_cnt;
936            let stride = (*dv_qp).sq.stride;
937            let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(qp);
938            let mut send_db_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp);
939            if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) {
940                return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
941            }
942            self.apply_first_op_delay(send_db_idx);
943            while send_db_idx < send_wqe_idx {
944                let offset = (send_db_idx % wqe_cnt as u64) * stride as u64;
945                let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
946                rdmaxcel_sys::db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
947                send_db_idx += 1;
948                rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx(qp, send_db_idx);
949            }
950            Ok(())
951        }
952    }
953
954    /// Enqueues a put operation without ringing the doorbell.
955    ///
956    /// This method prepares a put operation but does not execute it.
957    /// Use `get_doorbell().ring()` to execute the operation.
958    ///
959    /// # Arguments
960    ///
961    /// * `lhandle` - Local buffer handle
962    /// * `rhandle` - Remote buffer handle
963    ///
964    /// # Returns
965    ///
966    /// * `Result<Vec<u64>, anyhow::Error>` - The work request IDs or error
967    pub fn enqueue_put(
968        &mut self,
969        lhandle: RdmaBuffer,
970        rhandle: RdmaBuffer,
971    ) -> Result<Vec<u64>, anyhow::Error> {
972        let idx = unsafe {
973            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
974                self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
975            )
976        };
977
978        self.send_wqe(
979            lhandle.addr,
980            lhandle.lkey,
981            lhandle.size,
982            idx,
983            true,
984            RdmaOperation::Write,
985            rhandle.addr,
986            rhandle.rkey,
987        )?;
988        Ok(vec![idx])
989    }
990
991    /// Enqueues a put with receive operation without ringing the doorbell.
992    ///
993    /// This method prepares a put with receive operation but does not execute it.
994    /// Use `get_doorbell().ring()` to execute the operation.
995    ///
996    /// # Arguments
997    ///
998    /// * `lhandle` - Local buffer handle
999    /// * `rhandle` - Remote buffer handle
1000    ///
1001    /// # Returns
1002    ///
1003    /// * `Result<Vec<u64>, anyhow::Error>` - The work request IDs or error
1004    pub fn enqueue_put_with_recv(
1005        &mut self,
1006        lhandle: RdmaBuffer,
1007        rhandle: RdmaBuffer,
1008    ) -> Result<Vec<u64>, anyhow::Error> {
1009        let idx = unsafe {
1010            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
1011                self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1012            )
1013        };
1014
1015        self.send_wqe(
1016            lhandle.addr,
1017            lhandle.lkey,
1018            lhandle.size,
1019            idx,
1020            true,
1021            RdmaOperation::WriteWithImm,
1022            rhandle.addr,
1023            rhandle.rkey,
1024        )?;
1025        Ok(vec![idx])
1026    }
1027
1028    /// Enqueues a get operation without ringing the doorbell.
1029    ///
1030    /// This method prepares a get operation but does not execute it.
1031    /// Use `get_doorbell().ring()` to execute the operation.
1032    ///
1033    /// # Arguments
1034    ///
1035    /// * `lhandle` - Local buffer handle
1036    /// * `rhandle` - Remote buffer handle
1037    ///
1038    /// # Returns
1039    ///
1040    /// * `Result<Vec<u64>, anyhow::Error>` - The work request IDs or error
1041    pub fn enqueue_get(
1042        &mut self,
1043        lhandle: RdmaBuffer,
1044        rhandle: RdmaBuffer,
1045    ) -> Result<Vec<u64>, anyhow::Error> {
1046        let idx = unsafe {
1047            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
1048                self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1049            )
1050        };
1051
1052        self.send_wqe(
1053            lhandle.addr,
1054            lhandle.lkey,
1055            lhandle.size,
1056            idx,
1057            true,
1058            RdmaOperation::Read,
1059            rhandle.addr,
1060            rhandle.rkey,
1061        )?;
1062        Ok(vec![idx])
1063    }
1064
1065    pub fn get(
1066        &mut self,
1067        lhandle: RdmaBuffer,
1068        rhandle: RdmaBuffer,
1069    ) -> Result<Vec<u64>, anyhow::Error> {
1070        let total_size = lhandle.size;
1071        if rhandle.size < total_size {
1072            return Err(anyhow::anyhow!(
1073                "Remote buffer size ({}) is smaller than local buffer size ({})",
1074                rhandle.size,
1075                total_size
1076            ));
1077        }
1078
1079        let mut remaining = total_size;
1080        let mut offset = 0;
1081        let mut wr_ids = Vec::new();
1082
1083        while remaining > 0 {
1084            let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
1085            let idx = unsafe {
1086                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
1087                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1088                )
1089            };
1090            wr_ids.push(idx);
1091
1092            self.post_op(
1093                lhandle.addr + offset,
1094                lhandle.lkey,
1095                chunk_size,
1096                idx,
1097                true,
1098                RdmaOperation::Read,
1099                rhandle.addr + offset,
1100                rhandle.rkey,
1101            )?;
1102            unsafe {
1103                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
1104                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1105                );
1106            }
1107
1108            remaining -= chunk_size;
1109            offset += chunk_size;
1110        }
1111
1112        Ok(wr_ids)
1113    }
1114
1115    /// Posts a request to the queue pair.
1116    ///
1117    /// # Arguments
1118    ///
1119    /// * `local_addr` - The local address containing data to send
1120    /// * `length` - Length of the data to send
1121    /// * `wr_id` - Work request ID for completion identification
1122    /// * `signaled` - Whether to generate a completion event
1123    /// * `op_type` - Optional operation type
1124    /// * `raddr` - the remote address, representing the memory location on the remote peer
1125    /// * `rkey` - the remote key, representing the key required to access the remote memory region
1126    fn post_op(
1127        &mut self,
1128        laddr: usize,
1129        lkey: u32,
1130        length: usize,
1131        wr_id: u64,
1132        signaled: bool,
1133        op_type: RdmaOperation,
1134        raddr: usize,
1135        rkey: u32,
1136    ) -> Result<(), anyhow::Error> {
1137        // SAFETY:
1138        // This code uses unsafe rdmaxcel_sys calls to post work requests to the RDMA device, but is safe because:
1139        // - All pointers (send_sge, send_wr) are properly initialized on the stack before use
1140        // - The memory address in `local_addr` is not dereferenced, only passed to the device
1141        // - The remote connection info is verified to exist before accessing
1142        // - The ibverbs post_send operation follows the documented API contract
1143        // - Error codes from the device are properly checked and propagated
1144        unsafe {
1145            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
1146            let context = self.context as *mut rdmaxcel_sys::ibv_context;
1147            let ops = &mut (*context).ops;
1148            let errno;
1149            if op_type == RdmaOperation::Recv {
1150                let mut sge = rdmaxcel_sys::ibv_sge {
1151                    addr: laddr as u64,
1152                    length: length as u32,
1153                    lkey,
1154                };
1155                let mut wr = rdmaxcel_sys::ibv_recv_wr {
1156                    wr_id,
1157                    sg_list: &mut sge as *mut _,
1158                    num_sge: 1,
1159                    ..Default::default()
1160                };
1161                let mut bad_wr: *mut rdmaxcel_sys::ibv_recv_wr = std::ptr::null_mut();
1162                errno =
1163                    ops.post_recv.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
1164            } else if op_type == RdmaOperation::Write
1165                || op_type == RdmaOperation::Read
1166                || op_type == RdmaOperation::WriteWithImm
1167            {
1168                // Apply hardware initialization delay if this is the first operation
1169                self.apply_first_op_delay(wr_id);
1170                let send_flags = if signaled {
1171                    rdmaxcel_sys::ibv_send_flags::IBV_SEND_SIGNALED.0
1172                } else {
1173                    0
1174                };
1175                let mut sge = rdmaxcel_sys::ibv_sge {
1176                    addr: laddr as u64,
1177                    length: length as u32,
1178                    lkey,
1179                };
1180                let mut wr = rdmaxcel_sys::ibv_send_wr {
1181                    wr_id,
1182                    next: std::ptr::null_mut(),
1183                    sg_list: &mut sge as *mut _,
1184                    num_sge: 1,
1185                    opcode: op_type.into(),
1186                    send_flags,
1187                    wr: Default::default(),
1188                    qp_type: Default::default(),
1189                    __bindgen_anon_1: Default::default(),
1190                    __bindgen_anon_2: Default::default(),
1191                };
1192
1193                wr.wr.rdma.remote_addr = raddr as u64;
1194                wr.wr.rdma.rkey = rkey;
1195                let mut bad_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut();
1196
1197                errno =
1198                    ops.post_send.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
1199            } else {
1200                panic!("Not Implemented");
1201            }
1202
1203            if errno != 0 {
1204                let os_error = Error::last_os_error();
1205                return Err(anyhow::anyhow!("Failed to post send request: {}", os_error));
1206            }
1207            tracing::debug!(
1208                "completed sending {:?} request (lkey: {}, addr: 0x{:x}, length {}) to (raddr 0x{:x}, rkey {})",
1209                op_type,
1210                lkey,
1211                laddr,
1212                length,
1213                raddr,
1214                rkey,
1215            );
1216
1217            Ok(())
1218        }
1219    }
1220
1221    fn send_wqe(
1222        &mut self,
1223        laddr: usize,
1224        lkey: u32,
1225        length: usize,
1226        wr_id: u64,
1227        signaled: bool,
1228        op_type: RdmaOperation,
1229        raddr: usize,
1230        rkey: u32,
1231    ) -> Result<DoorBell, anyhow::Error> {
1232        unsafe {
1233            let op_type_val = match op_type {
1234                RdmaOperation::Write => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE,
1235                RdmaOperation::WriteWithImm => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE_IMM,
1236                RdmaOperation::Read => rdmaxcel_sys::MLX5_OPCODE_RDMA_READ,
1237                RdmaOperation::Recv => 0,
1238            };
1239
1240            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
1241            let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
1242            let _dv_cq = if op_type == RdmaOperation::Recv {
1243                self.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq
1244            } else {
1245                self.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq
1246            };
1247
1248            // Create the WQE parameters struct
1249
1250            let buf = if op_type == RdmaOperation::Recv {
1251                (*dv_qp).rq.buf as *mut u8
1252            } else {
1253                (*dv_qp).sq.buf as *mut u8
1254            };
1255
1256            let params = rdmaxcel_sys::wqe_params_t {
1257                laddr,
1258                lkey,
1259                length,
1260                wr_id,
1261                signaled,
1262                op_type: op_type_val,
1263                raddr,
1264                rkey,
1265                qp_num: (*(*qp).ibv_qp).qp_num,
1266                buf,
1267                dbrec: (*dv_qp).dbrec,
1268                wqe_cnt: (*dv_qp).sq.wqe_cnt,
1269            };
1270
1271            // Call the C function to post the WQE
1272            if op_type == RdmaOperation::Recv {
1273                rdmaxcel_sys::recv_wqe(params);
1274                std::ptr::write_volatile((*dv_qp).dbrec, 1_u32.to_be());
1275            } else {
1276                rdmaxcel_sys::send_wqe(params);
1277            };
1278
1279            // Create and return a DoorBell struct
1280            Ok(DoorBell {
1281                dst_ptr: (*dv_qp).bf.reg as usize,
1282                src_ptr: (*dv_qp).sq.buf as usize,
1283                size: 8,
1284            })
1285        }
1286    }
1287
1288    /// Poll for work completions by wr_ids.
1289    ///
1290    /// # Arguments
1291    ///
1292    /// * `target` - Which completion queue to poll (Send, Receive)
1293    /// * `expected_wr_ids` - Slice of work request IDs to wait for
1294    ///
1295    /// # Returns
1296    ///
1297    /// * `Ok(Vec<(u64, IbvWc)>)` - Vector of (wr_id, completion) pairs that were found
1298    /// * `Err(e)` - An error occurred
1299    pub fn poll_completion(
1300        &mut self,
1301        target: PollTarget,
1302        expected_wr_ids: &[u64],
1303    ) -> Result<Vec<(u64, IbvWc)>, anyhow::Error> {
1304        if expected_wr_ids.is_empty() {
1305            return Ok(Vec::new());
1306        }
1307
1308        unsafe {
1309            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
1310            let qp_num = (*(*qp).ibv_qp).qp_num;
1311
1312            let (cq, cache, cq_type) = match target {
1313                PollTarget::Send => (
1314                    self.send_cq as *mut rdmaxcel_sys::ibv_cq,
1315                    rdmaxcel_sys::rdmaxcel_qp_get_send_cache(qp),
1316                    "send",
1317                ),
1318                PollTarget::Recv => (
1319                    self.recv_cq as *mut rdmaxcel_sys::ibv_cq,
1320                    rdmaxcel_sys::rdmaxcel_qp_get_recv_cache(qp),
1321                    "recv",
1322                ),
1323            };
1324
1325            let mut results = Vec::new();
1326
1327            // Single-shot poll: check each wr_id once and return what we find
1328            for &expected_wr_id in expected_wr_ids {
1329                let mut poll_ctx = rdmaxcel_sys::poll_context_t {
1330                    expected_wr_id,
1331                    expected_qp_num: qp_num,
1332                    cache,
1333                    cq,
1334                };
1335
1336                let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
1337                let ret = rdmaxcel_sys::poll_cq_with_cache(&mut poll_ctx, &mut wc);
1338
1339                match ret {
1340                    1 => {
1341                        // Found completion
1342                        if !wc.is_valid() {
1343                            if let Some((status, vendor_err)) = wc.error() {
1344                                return Err(anyhow::anyhow!(
1345                                    "{} completion failed for wr_id={}: status={:?}, vendor_err={}",
1346                                    cq_type,
1347                                    expected_wr_id,
1348                                    status,
1349                                    vendor_err,
1350                                ));
1351                            }
1352                        }
1353                        results.push((expected_wr_id, IbvWc::from(wc)));
1354                    }
1355                    0 => {
1356                        // Not found yet - this is fine for single-shot poll
1357                    }
1358                    -17 => {
1359                        // RDMAXCEL_COMPLETION_FAILED: Completion found but failed - wc contains the error details
1360                        let error_msg =
1361                            std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1362                                .to_str()
1363                                .unwrap_or("Unknown error");
1364                        if let Some((status, vendor_err)) = wc.error() {
1365                            return Err(anyhow::anyhow!(
1366                                "Failed to poll {} CQ for wr_id={}: {} [status={:?}, vendor_err={}, qp_num={}, byte_len={}]",
1367                                cq_type,
1368                                expected_wr_id,
1369                                error_msg,
1370                                status,
1371                                vendor_err,
1372                                wc.qp_num,
1373                                wc.len(),
1374                            ));
1375                        } else {
1376                            return Err(anyhow::anyhow!(
1377                                "Failed to poll {} CQ for wr_id={}: {} [qp_num={}, byte_len={}]",
1378                                cq_type,
1379                                expected_wr_id,
1380                                error_msg,
1381                                wc.qp_num,
1382                                wc.len(),
1383                            ));
1384                        }
1385                    }
1386                    _ => {
1387                        // Other errors
1388                        let error_msg =
1389                            std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1390                                .to_str()
1391                                .unwrap_or("Unknown error");
1392                        return Err(anyhow::anyhow!(
1393                            "Failed to poll {} CQ for wr_id={}: {}",
1394                            cq_type,
1395                            expected_wr_id,
1396                            error_msg
1397                        ));
1398                    }
1399                }
1400            }
1401
1402            Ok(results)
1403        }
1404    }
1405}
1406
1407/// Utility to validate execution context.
1408///
1409/// Remote Execution environments do not always have access to the nvidia_peermem module
1410/// and/or set the PeerMappingOverride parameter due to security. This function can be
1411/// used to validate that the execution context when running operations that need this
1412/// functionality (ie. cudaHostRegisterIoMemory).
1413///
1414/// # Returns
1415///
1416/// * `Ok(())` if the execution context is valid
1417/// * `Err(anyhow::Error)` if the execution context is invalid
1418pub async fn validate_execution_context() -> Result<(), anyhow::Error> {
1419    // Check for nvidia peermem
1420    match fs::read_to_string("/proc/modules") {
1421        Ok(contents) => {
1422            if !contents.contains("nvidia_peermem") {
1423                return Err(anyhow::anyhow!(
1424                    "nvidia_peermem module not found in /proc/modules"
1425                ));
1426            }
1427        }
1428        Err(e) => {
1429            return Err(anyhow::anyhow!(e));
1430        }
1431    }
1432
1433    // Test file access to nvidia params
1434    match fs::read_to_string("/proc/driver/nvidia/params") {
1435        Ok(contents) => {
1436            if !contents.contains("PeerMappingOverride=1") {
1437                return Err(anyhow::anyhow!(
1438                    "PeerMappingOverride=1 not found in /proc/driver/nvidia/params"
1439                ));
1440            }
1441        }
1442        Err(e) => {
1443            return Err(anyhow::anyhow!(e));
1444        }
1445    }
1446    Ok(())
1447}
1448
1449/// Get all segments that have been registered with MRs
1450///
1451/// # Returns
1452/// * `Vec<SegmentInfo>` - Vector containing all registered segment information
1453pub fn get_registered_cuda_segments() -> Vec<rdmaxcel_sys::rdma_segment_info_t> {
1454    unsafe {
1455        let segment_count = rdmaxcel_sys::rdma_get_active_segment_count();
1456        if segment_count <= 0 {
1457            return Vec::new();
1458        }
1459
1460        let mut segments = vec![
1461            std::mem::MaybeUninit::<rdmaxcel_sys::rdma_segment_info_t>::zeroed()
1462                .assume_init();
1463            segment_count as usize
1464        ];
1465        let actual_count =
1466            rdmaxcel_sys::rdma_get_all_segment_info(segments.as_mut_ptr(), segment_count);
1467
1468        if actual_count > 0 {
1469            segments.truncate(actual_count as usize);
1470            segments
1471        } else {
1472            Vec::new()
1473        }
1474    }
1475}
1476
1477/// Segment scanner callback type alias for convenience.
1478pub type SegmentScannerFn = rdmaxcel_sys::RdmaxcelSegmentScannerFn;
1479
1480/// Register a segment scanner callback.
1481///
1482/// The scanner callback is called during RDMA segment registration to discover
1483/// CUDA memory segments. The callback should fill the provided buffer with
1484/// segment information and return the total count of segments found.
1485///
1486/// If the returned count exceeds the buffer size, the caller will allocate
1487/// a larger buffer and retry.
1488///
1489/// Pass `None` to unregister the scanner.
1490///
1491/// # Safety
1492///
1493/// The provided callback function must be safe to call from C code and must
1494/// properly handle the segment buffer.
1495pub fn register_segment_scanner(scanner: SegmentScannerFn) {
1496    // SAFETY: We are registering a callback function pointer with rdmaxcel.
1497    unsafe { rdmaxcel_sys::rdmaxcel_register_segment_scanner(scanner) }
1498}
1499
1500#[cfg(test)]
1501mod tests {
1502    use super::*;
1503
1504    #[test]
1505    fn test_create_connection() {
1506        // Skip test if RDMA devices are not available
1507        if crate::ibverbs_primitives::get_all_devices().is_empty() {
1508            println!("Skipping test: RDMA devices not available");
1509            return;
1510        }
1511
1512        let config = IbverbsConfig {
1513            use_gpu_direct: false,
1514            ..Default::default()
1515        };
1516        let domain = RdmaDomain::new(config.device.clone());
1517        assert!(domain.is_ok());
1518
1519        let domain = domain.unwrap();
1520        let queue_pair = RdmaQueuePair::new(domain.context, domain.pd, config.clone());
1521        assert!(queue_pair.is_ok());
1522    }
1523
1524    #[test]
1525    fn test_loopback_connection() {
1526        // Skip test if RDMA devices are not available
1527        if crate::ibverbs_primitives::get_all_devices().is_empty() {
1528            println!("Skipping test: RDMA devices not available");
1529            return;
1530        }
1531
1532        let server_config = IbverbsConfig {
1533            use_gpu_direct: false,
1534            ..Default::default()
1535        };
1536        let client_config = IbverbsConfig {
1537            use_gpu_direct: false,
1538            ..Default::default()
1539        };
1540
1541        let server_domain = RdmaDomain::new(server_config.device.clone()).unwrap();
1542        let client_domain = RdmaDomain::new(client_config.device.clone()).unwrap();
1543
1544        let mut server_qp = RdmaQueuePair::new(
1545            server_domain.context,
1546            server_domain.pd,
1547            server_config.clone(),
1548        )
1549        .unwrap();
1550        let mut client_qp = RdmaQueuePair::new(
1551            client_domain.context,
1552            client_domain.pd,
1553            client_config.clone(),
1554        )
1555        .unwrap();
1556
1557        let server_connection_info = server_qp.get_qp_info().unwrap();
1558        let client_connection_info = client_qp.get_qp_info().unwrap();
1559
1560        assert!(server_qp.connect(&client_connection_info).is_ok());
1561        assert!(client_qp.connect(&server_connection_info).is_ok());
1562    }
1563}