monarch_rdma/
rdma_components.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # RDMA Components
10//!
11//! This module provides the core RDMA building blocks for establishing and managing RDMA connections.
12//!
13//! ## Core Components
14//!
15//! * `RdmaDomain` - Manages RDMA resources including context, protection domain, and memory region
16//! * `RdmaQueuePair` - Handles communication between endpoints via queue pairs and completion queues
17//!
18//! ## RDMA Overview
19//!
20//! Remote Direct Memory Access (RDMA) allows direct memory access from the memory of one computer
21//! into the memory of another without involving either computer's operating system. This permits
22//! high-throughput, low-latency networking with minimal CPU overhead.
23//!
24//! ## Connection Architecture
25//!
26//! The module manages the following ibverbs primitives:
27//!
28//! 1. **Queue Pairs (QP)**: Each connection has a send queue and a receive queue
29//! 2. **Completion Queues (CQ)**: Events are reported when operations complete
30//! 3. **Memory Regions (MR)**: Memory must be registered with the RDMA device before use
31//! 4. **Protection Domains (PD)**: Provide isolation between different connections
32//!
33//! ## Connection Lifecycle
34//!
35//! 1. Create an `RdmaDomain` with `new()`
36//! 2. Create an `RdmaQueuePair` from the domain
37//! 3. Exchange connection info with remote peer (application must handle this)
38//! 4. Connect to remote endpoint with `connect()`
39//! 5. Perform RDMA operations (read/write)
40//! 6. Poll for completions
41//! 7. Resources are cleaned up when dropped
42
43/// Maximum size for a single RDMA operation in bytes (1 GiB)
44const MAX_RDMA_MSG_SIZE: usize = 1024 * 1024 * 1024;
45
46use std::ffi::CStr;
47use std::fs;
48use std::io::Error;
49use std::result::Result;
50use std::thread::sleep;
51use std::time::Duration;
52
53use hyperactor::ActorRef;
54use hyperactor::Named;
55use hyperactor::clock::Clock;
56use hyperactor::clock::RealClock;
57use hyperactor::context;
58use serde::Deserialize;
59use serde::Serialize;
60
61use crate::RdmaDevice;
62use crate::RdmaManagerActor;
63use crate::RdmaManagerMessageClient;
64use crate::ibverbs_primitives::Gid;
65use crate::ibverbs_primitives::IbvWc;
66use crate::ibverbs_primitives::IbverbsConfig;
67use crate::ibverbs_primitives::RdmaOperation;
68use crate::ibverbs_primitives::RdmaQpInfo;
69use crate::ibverbs_primitives::resolve_qp_type;
70
71#[derive(Debug, Named, Clone, Serialize, Deserialize)]
72pub struct DoorBell {
73    pub src_ptr: usize,
74    pub dst_ptr: usize,
75    pub size: usize,
76}
77
78impl DoorBell {
79    /// Rings the doorbell to trigger the execution of previously enqueued operations.
80    ///
81    /// This method uses unsafe code to directly interact with the RDMA device,
82    /// sending a signal from the source pointer to the destination pointer.
83    ///
84    /// # Returns
85    /// * `Ok(())` if the operation is successful.
86    /// * `Err(anyhow::Error)` if an error occurs during the operation.
87    pub fn ring(&self) -> Result<(), anyhow::Error> {
88        unsafe {
89            let src_ptr = self.src_ptr as *mut std::ffi::c_void;
90            let dst_ptr = self.dst_ptr as *mut std::ffi::c_void;
91            rdmaxcel_sys::db_ring(dst_ptr, src_ptr);
92            Ok(())
93        }
94    }
95}
96
97#[derive(Debug, Serialize, Deserialize, Named, Clone)]
98pub struct RdmaBuffer {
99    pub owner: ActorRef<RdmaManagerActor>,
100    pub mr_id: usize,
101    pub lkey: u32,
102    pub rkey: u32,
103    pub addr: usize,
104    pub size: usize,
105    pub device_name: String,
106}
107
108impl RdmaBuffer {
109    /// Read from the RdmaBuffer into the provided memory.
110    ///
111    /// This method transfers data from the buffer into the local memory region provided over RDMA.
112    /// This involves calling a `Put` operation on the RdmaBuffer's actor side.
113    ///
114    /// # Arguments
115    /// * `client` - The actor who is reading.
116    /// * `remote` - RdmaBuffer representing the remote memory region
117    /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
118    ///
119    /// # Returns
120    /// `Ok(bool)` indicating if the operation completed successfully.
121    pub async fn read_into(
122        &self,
123        client: &impl context::Actor,
124        remote: RdmaBuffer,
125        timeout: u64,
126    ) -> Result<bool, anyhow::Error> {
127        tracing::debug!(
128            "[buffer] reading from {:?} into remote ({:?}) at {:?}",
129            self,
130            remote.owner.actor_id(),
131            remote,
132        );
133        let remote_owner = remote.owner.clone();
134
135        let local_device = self.device_name.clone();
136        let remote_device = remote.device_name.clone();
137        let mut qp = self
138            .owner
139            .request_queue_pair(
140                client,
141                remote_owner.clone(),
142                local_device.clone(),
143                remote_device.clone(),
144            )
145            .await?;
146
147        let wr_id = qp.put(self.clone(), remote)?;
148        let result = self
149            .wait_for_completion(&mut qp, PollTarget::Send, &wr_id, timeout)
150            .await;
151
152        // Release the queue pair back to the actor
153        self.owner
154            .release_queue_pair(client, remote_owner, local_device, remote_device, qp)
155            .await?;
156
157        result
158    }
159
160    /// Write from the provided memory into the RdmaBuffer.
161    ///
162    /// This method performs an RDMA write operation, transferring data from the caller's
163    /// memory region to this buffer.
164    /// This involves calling a `Fetch` operation on the RdmaBuffer's actor side.
165    ///
166    /// # Arguments
167    /// * `client` - The actor who is writing.
168    /// * `remote` - RdmaBuffer representing the remote memory region
169    /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
170    ///
171    /// # Returns
172    /// `Ok(bool)` indicating if the operation completed successfully.
173    pub async fn write_from(
174        &self,
175        client: &impl context::Actor,
176        remote: RdmaBuffer,
177        timeout: u64,
178    ) -> Result<bool, anyhow::Error> {
179        tracing::debug!(
180            "[buffer] writing into {:?} from remote ({:?}) at {:?}",
181            self,
182            remote.owner.actor_id(),
183            remote,
184        );
185        let remote_owner = remote.owner.clone(); // Clone before the move!
186
187        // Extract device name from buffer, fallback to a default if not present
188        let local_device = self.device_name.clone();
189        let remote_device = remote.device_name.clone();
190
191        let mut qp = self
192            .owner
193            .request_queue_pair(
194                client,
195                remote_owner.clone(),
196                local_device.clone(),
197                remote_device.clone(),
198            )
199            .await?;
200        let wr_id = qp.get(self.clone(), remote)?;
201        let result = self
202            .wait_for_completion(&mut qp, PollTarget::Send, &wr_id, timeout)
203            .await;
204
205        // Release the queue pair back to the actor
206        self.owner
207            .release_queue_pair(client, remote_owner, local_device, remote_device, qp)
208            .await?;
209
210        result
211    }
212    /// Waits for the completion of RDMA operations.
213    ///
214    /// This method polls the completion queue until all specified work requests complete
215    /// or until the timeout is reached.
216    ///
217    /// # Arguments
218    /// * `qp` - The RDMA Queue Pair to poll for completion
219    /// * `poll_target` - Which CQ to poll (Send or Recv)
220    /// * `expected_wr_ids` - The work request IDs to wait for
221    /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
222    ///
223    /// # Returns
224    /// `Ok(true)` if all operations complete successfully within the timeout,
225    /// or an error if the timeout is reached
226    async fn wait_for_completion(
227        &self,
228        qp: &mut RdmaQueuePair,
229        poll_target: PollTarget,
230        expected_wr_ids: &[u64],
231        timeout: u64,
232    ) -> Result<bool, anyhow::Error> {
233        let timeout = Duration::from_secs(timeout);
234        let start_time = std::time::Instant::now();
235
236        let mut remaining: std::collections::HashSet<u64> =
237            expected_wr_ids.iter().copied().collect();
238
239        while start_time.elapsed() < timeout {
240            if remaining.is_empty() {
241                return Ok(true);
242            }
243
244            let wr_ids_to_poll: Vec<u64> = remaining.iter().copied().collect();
245            match qp.poll_completion(poll_target, &wr_ids_to_poll) {
246                Ok(completions) => {
247                    for (wr_id, _wc) in completions {
248                        remaining.remove(&wr_id);
249                    }
250                    if remaining.is_empty() {
251                        return Ok(true);
252                    }
253                    RealClock.sleep(Duration::from_millis(1)).await;
254                }
255                Err(e) => {
256                    return Err(anyhow::anyhow!(
257                        "RDMA polling completion failed: {} [lkey={}, rkey={}, addr=0x{:x}, size={}]",
258                        e,
259                        self.lkey,
260                        self.rkey,
261                        self.addr,
262                        self.size
263                    ));
264                }
265            }
266        }
267        tracing::error!(
268            "timed out while waiting on request completion for wr_ids={:?}",
269            remaining
270        );
271        Err(anyhow::anyhow!(
272            "[buffer({:?})] rdma operation did not complete in time (expected wr_ids={:?})",
273            self,
274            expected_wr_ids
275        ))
276    }
277
278    /// Drop the buffer and release remote handles.
279    ///
280    /// This method calls the owning RdmaManagerActor to release the buffer and clean up
281    /// associated memory regions. This is typically called when the buffer is no longer
282    /// needed and resources should be freed.
283    ///
284    /// # Arguments
285    /// * `client` - Mailbox used for communication
286    ///
287    /// # Returns
288    /// `Ok(())` if the operation completed successfully.
289    pub async fn drop_buffer(&self, client: &impl context::Actor) -> Result<(), anyhow::Error> {
290        tracing::debug!("[buffer] dropping buffer {:?}", self);
291        self.owner.release_buffer(client, self.clone()).await?;
292        Ok(())
293    }
294}
295
296/// Represents a domain for RDMA operations, encapsulating the necessary resources
297/// for establishing and managing RDMA connections.
298///
299/// An `RdmaDomain` manages the context, protection domain (PD), and memory region (MR)
300/// required for RDMA operations. It provides the foundation for creating queue pairs
301/// and establishing connections between RDMA devices.
302///
303/// # Fields
304///
305/// * `context`: A pointer to the RDMA device context, representing the connection to the RDMA device.
306/// * `pd`: A pointer to the protection domain, which provides isolation between different connections.
307#[derive(Clone)]
308pub struct RdmaDomain {
309    pub context: *mut rdmaxcel_sys::ibv_context,
310    pub pd: *mut rdmaxcel_sys::ibv_pd,
311}
312
313impl std::fmt::Debug for RdmaDomain {
314    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315        f.debug_struct("RdmaDomain")
316            .field("context", &format!("{:p}", self.context))
317            .field("pd", &format!("{:p}", self.pd))
318            .finish()
319    }
320}
321
322// SAFETY:
323// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls.
324// RdmaDomain is `Send` because the raw pointers to ibverbs structs can be
325// accessed from any thread, and it is safe to drop `RdmaDomain` (and run the
326// ibverbs destructors) from any thread.
327unsafe impl Send for RdmaDomain {}
328
329// SAFETY:
330// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls.
331// RdmaDomain is `Sync` because the underlying ibverbs APIs are thread-safe.
332unsafe impl Sync for RdmaDomain {}
333
334impl Drop for RdmaDomain {
335    fn drop(&mut self) {
336        unsafe {
337            rdmaxcel_sys::ibv_dealloc_pd(self.pd);
338        }
339    }
340}
341
342impl RdmaDomain {
343    /// Creates a new RdmaDomain.
344    ///
345    /// This function initializes the RDMA device context, creates a protection domain,
346    /// and registers a memory region with appropriate access permissions.
347    ///
348    /// SAFETY:
349    /// Our memory region (MR) registration uses implicit ODP for RDMA access, which maps large virtual
350    /// address ranges without explicit pinning. This is convenient, but it broadens the memory footprint
351    /// exposed to the NIC and introduces a security liability.
352    ///
353    /// We currently assume a trusted, single-environment and are not enforcing finer-grained memory isolation
354    /// at this layer. We plan to investigate mitigations - such as memory windows or tighter registration
355    /// boundaries in future follow-ups.
356    ///
357    /// # Arguments
358    ///
359    /// * `config` - Configuration settings for the RDMA operations
360    ///
361    /// # Errors
362    ///
363    /// This function may return errors if:
364    /// * No RDMA devices are found
365    /// * The specified device cannot be found
366    /// * Device context creation fails
367    /// * Protection domain allocation fails
368    /// * Memory region registration fails
369    pub fn new(device: RdmaDevice) -> Result<Self, anyhow::Error> {
370        tracing::debug!("creating RdmaDomain for device {}", device.name());
371        // SAFETY:
372        // This code uses unsafe rdmaxcel_sys calls to interact with the RDMA device, but is safe because:
373        // - All pointers are properly initialized and checked for null before use
374        // - Memory registration follows the ibverbs API contract with proper access flags
375        // - Resources are properly cleaned up in error cases to prevent leaks
376        // - The operations follow the documented RDMA protocol for device initialization
377        unsafe {
378            // Get the device based on the provided RdmaDevice
379            let device_name = device.name();
380            let mut num_devices = 0i32;
381            let devices = rdmaxcel_sys::ibv_get_device_list(&mut num_devices as *mut _);
382
383            if devices.is_null() || num_devices == 0 {
384                return Err(anyhow::anyhow!("no RDMA devices found"));
385            }
386
387            // Find the device with the matching name
388            let mut device_ptr = std::ptr::null_mut();
389            for i in 0..num_devices {
390                let dev = *devices.offset(i as isize);
391                let dev_name =
392                    CStr::from_ptr(rdmaxcel_sys::ibv_get_device_name(dev)).to_string_lossy();
393
394                if dev_name == *device_name {
395                    device_ptr = dev;
396                    break;
397                }
398            }
399
400            // If we didn't find the device, return an error
401            if device_ptr.is_null() {
402                rdmaxcel_sys::ibv_free_device_list(devices);
403                return Err(anyhow::anyhow!("device '{}' not found", device_name));
404            }
405            tracing::info!("using RDMA device: {}", device_name);
406
407            // Open device
408            let context = rdmaxcel_sys::ibv_open_device(device_ptr);
409            if context.is_null() {
410                rdmaxcel_sys::ibv_free_device_list(devices);
411                let os_error = Error::last_os_error();
412                return Err(anyhow::anyhow!("failed to create context: {}", os_error));
413            }
414
415            // Create protection domain
416            let pd = rdmaxcel_sys::ibv_alloc_pd(context);
417            if pd.is_null() {
418                rdmaxcel_sys::ibv_close_device(context);
419                rdmaxcel_sys::ibv_free_device_list(devices);
420                let os_error = Error::last_os_error();
421                return Err(anyhow::anyhow!(
422                    "failed to create protection domain (PD): {}",
423                    os_error
424                ));
425            }
426
427            // Avoids memory leaks
428            rdmaxcel_sys::ibv_free_device_list(devices);
429
430            let domain = RdmaDomain { context, pd };
431
432            Ok(domain)
433        }
434    }
435}
436/// Enum to specify which completion queue to poll
437#[derive(Debug, Clone, Copy, PartialEq)]
438pub enum PollTarget {
439    Send,
440    Recv,
441}
442
443/// Represents an RDMA Queue Pair (QP) that enables communication between two endpoints.
444///
445/// An `RdmaQueuePair` encapsulates the send and receive queues, completion queue,
446/// and other resources needed for RDMA communication. It provides methods for
447/// establishing connections and performing RDMA operations like read and write.
448///
449/// # Fields
450///
451/// * `send_cq` - Send Completion Queue pointer for tracking send operation completions
452/// * `recv_cq` - Receive Completion Queue pointer for tracking receive operation completions
453/// * `qp` - Queue Pair pointer that manages send and receive operations
454/// * `dv_qp` - Pointer to the mlx5 device-specific queue pair structure
455/// * `dv_send_cq` - Pointer to the mlx5 device-specific send completion queue structure
456/// * `dv_recv_cq` - Pointer to the mlx5 device-specific receive completion queue structure
457/// * `context` - RDMA device context pointer
458/// * `config` - Configuration settings for the queue pair
459///
460/// # Connection Lifecycle
461///
462/// 1. Create with `new()` from an `RdmaDomain`
463/// 2. Get connection info with `get_qp_info()`
464/// 3. Exchange connection info with remote peer (application must handle this)
465/// 4. Connect to remote endpoint with `connect()`
466/// 5. Perform RDMA operations with `put()` or `get()`
467/// 6. Poll for completions with `poll_send_completion(wr_id)` or `poll_recv_completion(wr_id)`
468///
469/// # Notes
470/// - The `qp` field stores a pointer to `rdmaxcel_qp_t` (not `ibv_qp`)
471/// - `rdmaxcel_qp_t` contains atomic counters and completion caches internally
472/// - This makes RdmaQueuePair trivially Clone and Serialize
473/// - Multiple clones share the same underlying rdmaxcel_qp_t via the pointer
474#[derive(Debug, Serialize, Deserialize, Named, Clone)]
475pub struct RdmaQueuePair {
476    pub send_cq: usize,    // *mut rdmaxcel_sys::ibv_cq,
477    pub recv_cq: usize,    // *mut rdmaxcel_sys::ibv_cq,
478    pub qp: usize,         // *mut rdmaxcel_sys::rdmaxcel_qp_t
479    pub dv_qp: usize,      // *mut rdmaxcel_sys::mlx5dv_qp,
480    pub dv_send_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
481    pub dv_recv_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
482    context: usize,        // *mut rdmaxcel_sys::ibv_context,
483    config: IbverbsConfig,
484}
485
486impl RdmaQueuePair {
487    /// Applies hardware initialization delay if this is the first operation since RTS.
488    ///
489    /// This ensures the hardware has sufficient time to settle after reaching
490    /// Ready-to-Send state before the first actual operation.
491    fn apply_first_op_delay(&self, wr_id: u64) {
492        unsafe {
493            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
494            if wr_id == 0 {
495                let rts_timestamp = rdmaxcel_sys::rdmaxcel_qp_load_rts_timestamp(qp);
496                assert!(
497                    rts_timestamp != u64::MAX,
498                    "First operation attempted before queue pair reached RTS state! Call connect() first."
499                );
500                let current_nanos = RealClock
501                    .system_time_now()
502                    .duration_since(std::time::UNIX_EPOCH)
503                    .unwrap()
504                    .as_nanos() as u64;
505                let elapsed_nanos = current_nanos - rts_timestamp;
506                let elapsed = Duration::from_nanos(elapsed_nanos);
507                let init_delay = Duration::from_millis(self.config.hw_init_delay_ms);
508                if elapsed < init_delay {
509                    let remaining_delay = init_delay - elapsed;
510                    sleep(remaining_delay);
511                }
512            }
513        }
514    }
515
516    /// Creates a new RdmaQueuePair from a given RdmaDomain.
517    ///
518    /// This function initializes a new Queue Pair (QP) and associated Completion Queue (CQ)
519    /// using the resources from the provided RdmaDomain. The QP is created in the RESET state
520    /// and must be transitioned to other states via the `connect()` method before use.
521    ///
522    /// # Arguments
523    ///
524    /// * `domain` - Reference to an RdmaDomain that provides the context, protection domain,
525    ///   and memory region for this queue pair
526    ///
527    /// # Returns
528    ///
529    /// * `Result<Self>` - A new RdmaQueuePair instance or an error if creation fails
530    ///
531    /// # Errors
532    ///
533    /// This function may return errors if:
534    /// * Completion queue (CQ) creation fails
535    /// * Queue pair (QP) creation fails
536    pub fn new(
537        context: *mut rdmaxcel_sys::ibv_context,
538        pd: *mut rdmaxcel_sys::ibv_pd,
539        config: IbverbsConfig,
540    ) -> Result<Self, anyhow::Error> {
541        tracing::debug!("creating an RdmaQueuePair from config {}", config);
542        unsafe {
543            // Resolve Auto to a concrete QP type based on device capabilities
544            let resolved_qp_type = resolve_qp_type(config.qp_type);
545            let qp = rdmaxcel_sys::rdmaxcel_qp_create(
546                context,
547                pd,
548                config.cq_entries,
549                config.max_send_wr.try_into().unwrap(),
550                config.max_recv_wr.try_into().unwrap(),
551                config.max_send_sge.try_into().unwrap(),
552                config.max_recv_sge.try_into().unwrap(),
553                resolved_qp_type,
554            );
555
556            if qp.is_null() {
557                let os_error = Error::last_os_error();
558                return Err(anyhow::anyhow!(
559                    "failed to create queue pair (QP): {}",
560                    os_error
561                ));
562            }
563
564            let send_cq = (*(*qp).ibv_qp).send_cq;
565            let recv_cq = (*(*qp).ibv_qp).recv_cq;
566
567            // mlx5dv provider APIs
568            let dv_qp = rdmaxcel_sys::create_mlx5dv_qp((*qp).ibv_qp);
569            let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq((*qp).ibv_qp);
570            let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq((*qp).ibv_qp);
571
572            if dv_qp.is_null() || dv_send_cq.is_null() || dv_recv_cq.is_null() {
573                rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
574                rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
575                rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
576                return Err(anyhow::anyhow!(
577                    "failed to init mlx5dv_qp or completion queues"
578                ));
579            }
580
581            // GPU Direct RDMA specific registrations
582            if config.use_gpu_direct {
583                let ret = rdmaxcel_sys::register_cuda_memory(dv_qp, dv_recv_cq, dv_send_cq);
584                if ret != 0 {
585                    rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
586                    rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
587                    rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
588                    return Err(anyhow::anyhow!(
589                        "failed to register GPU Direct RDMA memory: {:?}",
590                        ret
591                    ));
592                }
593            }
594            Ok(RdmaQueuePair {
595                send_cq: send_cq as usize,
596                recv_cq: recv_cq as usize,
597                qp: qp as usize,
598                dv_qp: qp as usize,
599                dv_send_cq: dv_send_cq as usize,
600                dv_recv_cq: dv_recv_cq as usize,
601                context: context as usize,
602                config,
603            })
604        }
605    }
606
607    /// Returns the information required for a remote peer to connect to this queue pair.
608    ///
609    /// This method retrieves the local queue pair attributes and port information needed by
610    /// a remote peer to establish an RDMA connection. The returned `RdmaQpInfo` contains
611    /// the queue pair number, LID, GID, and other necessary connection parameters.
612    ///
613    /// # Returns
614    ///
615    /// * `Result<RdmaQpInfo>` - Connection information for the remote peer or an error
616    ///
617    /// # Errors
618    ///
619    /// This function may return errors if:
620    /// * Port attribute query fails
621    /// * GID query fails
622    pub fn get_qp_info(&mut self) -> Result<RdmaQpInfo, anyhow::Error> {
623        // SAFETY:
624        // This code uses unsafe rdmaxcel_sys calls to query RDMA device information, but is safe because:
625        // - All pointers are properly initialized before use
626        // - Port and GID queries follow the documented ibverbs API contract
627        // - Error handling properly checks return codes from ibverbs functions
628        // - The memory address provided is only stored, not dereferenced in this function
629        unsafe {
630            let context = self.context as *mut rdmaxcel_sys::ibv_context;
631            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
632            let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
633            let errno = rdmaxcel_sys::ibv_query_port(
634                context,
635                self.config.port_num,
636                &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
637            );
638            if errno != 0 {
639                let os_error = Error::last_os_error();
640                return Err(anyhow::anyhow!(
641                    "Failed to query port attributes: {}",
642                    os_error
643                ));
644            }
645
646            let mut gid = Gid::default();
647            let ret = rdmaxcel_sys::ibv_query_gid(
648                context,
649                self.config.port_num,
650                i32::from(self.config.gid_index),
651                gid.as_mut(),
652            );
653            if ret != 0 {
654                return Err(anyhow::anyhow!("Failed to query GID"));
655            }
656
657            Ok(RdmaQpInfo {
658                qp_num: (*(*qp).ibv_qp).qp_num,
659                lid: port_attr.lid,
660                gid: Some(gid),
661                psn: self.config.psn,
662            })
663        }
664    }
665
666    pub fn state(&mut self) -> Result<u32, anyhow::Error> {
667        // SAFETY: This block interacts with the RDMA device through rdmaxcel_sys calls.
668        unsafe {
669            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
670            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
671                ..Default::default()
672            };
673            let mut qp_init_attr = rdmaxcel_sys::ibv_qp_init_attr {
674                ..Default::default()
675            };
676            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE;
677            let errno = rdmaxcel_sys::ibv_query_qp(
678                (*qp).ibv_qp,
679                &mut qp_attr,
680                mask.0 as i32,
681                &mut qp_init_attr,
682            );
683            if errno != 0 {
684                let os_error = Error::last_os_error();
685                return Err(anyhow::anyhow!("failed to query QP state: {}", os_error));
686            }
687            Ok(qp_attr.qp_state)
688        }
689    }
690    /// Connect to a remote Rdma connection point.
691    ///
692    /// This performs the necessary QP state transitions (INIT->RTR->RTS) to establish a connection.
693    ///
694    /// # Arguments
695    ///
696    /// * `connection_info` - The remote connection info to connect to
697    pub fn connect(&mut self, connection_info: &RdmaQpInfo) -> Result<(), anyhow::Error> {
698        // SAFETY:
699        // This unsafe block is necessary because we're interacting with the RDMA device through rdmaxcel_sys calls.
700        // The operations are safe because:
701        // 1. We're following the documented ibverbs API contract
702        // 2. All pointers used are properly initialized and owned by this struct
703        // 3. The QP state transitions (INIT->RTR->RTS) follow the required RDMA connection protocol
704        // 4. Memory access is properly bounded by the registered memory regions
705        unsafe {
706            // Transition to INIT
707            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
708
709            let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
710                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
711                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
712                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
713
714            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
715                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_INIT,
716                qp_access_flags: qp_access_flags.0,
717                pkey_index: self.config.pkey_index,
718                port_num: self.config.port_num,
719                ..Default::default()
720            };
721
722            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
723                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX
724                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT
725                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS;
726
727            let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
728            if errno != 0 {
729                let os_error = Error::last_os_error();
730                return Err(anyhow::anyhow!(
731                    "failed to transition QP to INIT: {}",
732                    os_error
733                ));
734            }
735
736            // Transition to RTR (Ready to Receive)
737            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
738                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTR,
739                path_mtu: self.config.path_mtu,
740                dest_qp_num: connection_info.qp_num,
741                rq_psn: connection_info.psn,
742                max_dest_rd_atomic: self.config.max_dest_rd_atomic,
743                min_rnr_timer: self.config.min_rnr_timer,
744                ah_attr: rdmaxcel_sys::ibv_ah_attr {
745                    dlid: connection_info.lid,
746                    sl: 0,
747                    src_path_bits: 0,
748                    port_num: self.config.port_num,
749                    grh: Default::default(),
750                    ..Default::default()
751                },
752                ..Default::default()
753            };
754
755            // If the remote connection info contains a Gid, the routing will be global.
756            // Otherwise, it will be local, i.e. using LID.
757            if let Some(gid) = connection_info.gid {
758                qp_attr.ah_attr.is_global = 1;
759                qp_attr.ah_attr.grh.dgid = gid.into();
760                qp_attr.ah_attr.grh.hop_limit = 0xff;
761                qp_attr.ah_attr.grh.sgid_index = self.config.gid_index;
762            } else {
763                // Use LID-based routing, e.g. for Infiniband/RoCEv1
764                qp_attr.ah_attr.is_global = 0;
765            }
766
767            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
768                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_AV
769                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PATH_MTU
770                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_DEST_QPN
771                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RQ_PSN
772                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC
773                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER;
774
775            let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
776            if errno != 0 {
777                let os_error = Error::last_os_error();
778                return Err(anyhow::anyhow!(
779                    "failed to transition QP to RTR: {}",
780                    os_error
781                ));
782            }
783
784            // Transition to RTS (Ready to Send)
785            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
786                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS,
787                sq_psn: self.config.psn,
788                max_rd_atomic: self.config.max_rd_atomic,
789                retry_cnt: self.config.retry_cnt,
790                rnr_retry: self.config.rnr_retry,
791                timeout: self.config.qp_timeout,
792                ..Default::default()
793            };
794
795            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
796                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_TIMEOUT
797                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RETRY_CNT
798                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_SQ_PSN
799                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY
800                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC;
801
802            let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
803            if errno != 0 {
804                let os_error = Error::last_os_error();
805                return Err(anyhow::anyhow!(
806                    "failed to transition QP to RTS: {}",
807                    os_error
808                ));
809            }
810            tracing::debug!(
811                "connection sequence has successfully completed (qp: {:?})",
812                qp
813            );
814
815            // Record RTS time now that the queue pair is ready to send
816            let rts_timestamp_nanos = RealClock
817                .system_time_now()
818                .duration_since(std::time::UNIX_EPOCH)
819                .unwrap()
820                .as_nanos() as u64;
821            rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos);
822
823            Ok(())
824        }
825    }
826
827    pub fn recv(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<u64, anyhow::Error> {
828        unsafe {
829            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
830            let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(qp);
831            self.post_op(
832                0,
833                lhandle.lkey,
834                0,
835                idx,
836                true,
837                RdmaOperation::Recv,
838                0,
839                rhandle.rkey,
840            )
841            .unwrap();
842            rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(qp);
843            Ok(idx)
844        }
845    }
846
847    pub fn put_with_recv(
848        &mut self,
849        lhandle: RdmaBuffer,
850        rhandle: RdmaBuffer,
851    ) -> Result<Vec<u64>, anyhow::Error> {
852        unsafe {
853            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
854            let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(qp);
855            self.post_op(
856                lhandle.addr,
857                lhandle.lkey,
858                lhandle.size,
859                idx,
860                true,
861                RdmaOperation::WriteWithImm,
862                rhandle.addr,
863                rhandle.rkey,
864            )
865            .unwrap();
866            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(qp);
867            Ok(vec![idx])
868        }
869    }
870
871    pub fn put(
872        &mut self,
873        lhandle: RdmaBuffer,
874        rhandle: RdmaBuffer,
875    ) -> Result<Vec<u64>, anyhow::Error> {
876        let total_size = lhandle.size;
877        if rhandle.size < total_size {
878            return Err(anyhow::anyhow!(
879                "Remote buffer size ({}) is smaller than local buffer size ({})",
880                rhandle.size,
881                total_size
882            ));
883        }
884
885        let mut remaining = total_size;
886        let mut offset = 0;
887        let mut wr_ids = Vec::new();
888        while remaining > 0 {
889            let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
890            let idx = unsafe {
891                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
892                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
893                )
894            };
895            wr_ids.push(idx);
896            self.post_op(
897                lhandle.addr + offset,
898                lhandle.lkey,
899                chunk_size,
900                idx,
901                true,
902                RdmaOperation::Write,
903                rhandle.addr + offset,
904                rhandle.rkey,
905            )?;
906            unsafe {
907                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
908                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
909                );
910            }
911
912            remaining -= chunk_size;
913            offset += chunk_size;
914        }
915
916        Ok(wr_ids)
917    }
918
919    /// Get a doorbell for the queue pair.
920    ///
921    /// This method returns a doorbell that can be used to trigger the execution of
922    /// previously enqueued operations.
923    ///
924    /// # Returns
925    ///
926    /// * `Result<DoorBell, anyhow::Error>` - A doorbell for the queue pair
927    pub fn ring_doorbell(&mut self) -> Result<(), anyhow::Error> {
928        unsafe {
929            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
930            let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
931            let base_ptr = (*dv_qp).sq.buf as *mut u8;
932            let wqe_cnt = (*dv_qp).sq.wqe_cnt;
933            let stride = (*dv_qp).sq.stride;
934            let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(qp);
935            let mut send_db_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp);
936            if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) {
937                return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
938            }
939            self.apply_first_op_delay(send_db_idx);
940            while send_db_idx < send_wqe_idx {
941                let offset = (send_db_idx % wqe_cnt as u64) * stride as u64;
942                let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
943                rdmaxcel_sys::db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
944                send_db_idx += 1;
945                rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx(qp, send_db_idx);
946            }
947            Ok(())
948        }
949    }
950
951    /// Enqueues a put operation without ringing the doorbell.
952    ///
953    /// This method prepares a put operation but does not execute it.
954    /// Use `get_doorbell().ring()` to execute the operation.
955    ///
956    /// # Arguments
957    ///
958    /// * `lhandle` - Local buffer handle
959    /// * `rhandle` - Remote buffer handle
960    ///
961    /// # Returns
962    ///
963    /// * `Result<Vec<u64>, anyhow::Error>` - The work request IDs or error
964    pub fn enqueue_put(
965        &mut self,
966        lhandle: RdmaBuffer,
967        rhandle: RdmaBuffer,
968    ) -> Result<Vec<u64>, anyhow::Error> {
969        let idx = unsafe {
970            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
971                self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
972            )
973        };
974
975        self.send_wqe(
976            lhandle.addr,
977            lhandle.lkey,
978            lhandle.size,
979            idx,
980            true,
981            RdmaOperation::Write,
982            rhandle.addr,
983            rhandle.rkey,
984        )?;
985        Ok(vec![idx])
986    }
987
988    /// Enqueues a put with receive operation without ringing the doorbell.
989    ///
990    /// This method prepares a put with receive operation but does not execute it.
991    /// Use `get_doorbell().ring()` to execute the operation.
992    ///
993    /// # Arguments
994    ///
995    /// * `lhandle` - Local buffer handle
996    /// * `rhandle` - Remote buffer handle
997    ///
998    /// # Returns
999    ///
1000    /// * `Result<Vec<u64>, anyhow::Error>` - The work request IDs or error
1001    pub fn enqueue_put_with_recv(
1002        &mut self,
1003        lhandle: RdmaBuffer,
1004        rhandle: RdmaBuffer,
1005    ) -> Result<Vec<u64>, anyhow::Error> {
1006        let idx = unsafe {
1007            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
1008                self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1009            )
1010        };
1011
1012        self.send_wqe(
1013            lhandle.addr,
1014            lhandle.lkey,
1015            lhandle.size,
1016            idx,
1017            true,
1018            RdmaOperation::WriteWithImm,
1019            rhandle.addr,
1020            rhandle.rkey,
1021        )?;
1022        Ok(vec![idx])
1023    }
1024
1025    /// Enqueues a get operation without ringing the doorbell.
1026    ///
1027    /// This method prepares a get operation but does not execute it.
1028    /// Use `get_doorbell().ring()` to execute the operation.
1029    ///
1030    /// # Arguments
1031    ///
1032    /// * `lhandle` - Local buffer handle
1033    /// * `rhandle` - Remote buffer handle
1034    ///
1035    /// # Returns
1036    ///
1037    /// * `Result<Vec<u64>, anyhow::Error>` - The work request IDs or error
1038    pub fn enqueue_get(
1039        &mut self,
1040        lhandle: RdmaBuffer,
1041        rhandle: RdmaBuffer,
1042    ) -> Result<Vec<u64>, anyhow::Error> {
1043        let idx = unsafe {
1044            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
1045                self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1046            )
1047        };
1048
1049        self.send_wqe(
1050            lhandle.addr,
1051            lhandle.lkey,
1052            lhandle.size,
1053            idx,
1054            true,
1055            RdmaOperation::Read,
1056            rhandle.addr,
1057            rhandle.rkey,
1058        )?;
1059        Ok(vec![idx])
1060    }
1061
1062    pub fn get(
1063        &mut self,
1064        lhandle: RdmaBuffer,
1065        rhandle: RdmaBuffer,
1066    ) -> Result<Vec<u64>, anyhow::Error> {
1067        let total_size = lhandle.size;
1068        if rhandle.size < total_size {
1069            return Err(anyhow::anyhow!(
1070                "Remote buffer size ({}) is smaller than local buffer size ({})",
1071                rhandle.size,
1072                total_size
1073            ));
1074        }
1075
1076        let mut remaining = total_size;
1077        let mut offset = 0;
1078        let mut wr_ids = Vec::new();
1079
1080        while remaining > 0 {
1081            let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
1082            let idx = unsafe {
1083                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
1084                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1085                )
1086            };
1087            wr_ids.push(idx);
1088
1089            self.post_op(
1090                lhandle.addr + offset,
1091                lhandle.lkey,
1092                chunk_size,
1093                idx,
1094                true,
1095                RdmaOperation::Read,
1096                rhandle.addr + offset,
1097                rhandle.rkey,
1098            )?;
1099            unsafe {
1100                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
1101                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1102                );
1103            }
1104
1105            remaining -= chunk_size;
1106            offset += chunk_size;
1107        }
1108
1109        Ok(wr_ids)
1110    }
1111
1112    /// Posts a request to the queue pair.
1113    ///
1114    /// # Arguments
1115    ///
1116    /// * `local_addr` - The local address containing data to send
1117    /// * `length` - Length of the data to send
1118    /// * `wr_id` - Work request ID for completion identification
1119    /// * `signaled` - Whether to generate a completion event
1120    /// * `op_type` - Optional operation type
1121    /// * `raddr` - the remote address, representing the memory location on the remote peer
1122    /// * `rkey` - the remote key, representing the key required to access the remote memory region
1123    fn post_op(
1124        &mut self,
1125        laddr: usize,
1126        lkey: u32,
1127        length: usize,
1128        wr_id: u64,
1129        signaled: bool,
1130        op_type: RdmaOperation,
1131        raddr: usize,
1132        rkey: u32,
1133    ) -> Result<(), anyhow::Error> {
1134        // SAFETY:
1135        // This code uses unsafe rdmaxcel_sys calls to post work requests to the RDMA device, but is safe because:
1136        // - All pointers (send_sge, send_wr) are properly initialized on the stack before use
1137        // - The memory address in `local_addr` is not dereferenced, only passed to the device
1138        // - The remote connection info is verified to exist before accessing
1139        // - The ibverbs post_send operation follows the documented API contract
1140        // - Error codes from the device are properly checked and propagated
1141        unsafe {
1142            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
1143            let context = self.context as *mut rdmaxcel_sys::ibv_context;
1144            let ops = &mut (*context).ops;
1145            let errno;
1146            if op_type == RdmaOperation::Recv {
1147                let mut sge = rdmaxcel_sys::ibv_sge {
1148                    addr: laddr as u64,
1149                    length: length as u32,
1150                    lkey,
1151                };
1152                let mut wr = rdmaxcel_sys::ibv_recv_wr {
1153                    wr_id: wr_id.into(),
1154                    sg_list: &mut sge as *mut _,
1155                    num_sge: 1,
1156                    ..Default::default()
1157                };
1158                let mut bad_wr: *mut rdmaxcel_sys::ibv_recv_wr = std::ptr::null_mut();
1159                errno =
1160                    ops.post_recv.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
1161            } else if op_type == RdmaOperation::Write
1162                || op_type == RdmaOperation::Read
1163                || op_type == RdmaOperation::WriteWithImm
1164            {
1165                // Apply hardware initialization delay if this is the first operation
1166                self.apply_first_op_delay(wr_id);
1167                let send_flags = if signaled {
1168                    rdmaxcel_sys::ibv_send_flags::IBV_SEND_SIGNALED.0
1169                } else {
1170                    0
1171                };
1172                let mut sge = rdmaxcel_sys::ibv_sge {
1173                    addr: laddr as u64,
1174                    length: length as u32,
1175                    lkey,
1176                };
1177                let mut wr = rdmaxcel_sys::ibv_send_wr {
1178                    wr_id: wr_id.into(),
1179                    next: std::ptr::null_mut(),
1180                    sg_list: &mut sge as *mut _,
1181                    num_sge: 1,
1182                    opcode: op_type.into(),
1183                    send_flags,
1184                    wr: Default::default(),
1185                    qp_type: Default::default(),
1186                    __bindgen_anon_1: Default::default(),
1187                    __bindgen_anon_2: Default::default(),
1188                };
1189
1190                wr.wr.rdma.remote_addr = raddr as u64;
1191                wr.wr.rdma.rkey = rkey;
1192                let mut bad_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut();
1193
1194                errno =
1195                    ops.post_send.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
1196            } else {
1197                panic!("Not Implemented");
1198            }
1199
1200            if errno != 0 {
1201                let os_error = Error::last_os_error();
1202                return Err(anyhow::anyhow!("Failed to post send request: {}", os_error));
1203            }
1204            tracing::debug!(
1205                "completed sending {:?} request (lkey: {}, addr: 0x{:x}, length {}) to (raddr 0x{:x}, rkey {})",
1206                op_type,
1207                lkey,
1208                laddr,
1209                length,
1210                raddr,
1211                rkey,
1212            );
1213
1214            Ok(())
1215        }
1216    }
1217
1218    fn send_wqe(
1219        &mut self,
1220        laddr: usize,
1221        lkey: u32,
1222        length: usize,
1223        wr_id: u64,
1224        signaled: bool,
1225        op_type: RdmaOperation,
1226        raddr: usize,
1227        rkey: u32,
1228    ) -> Result<DoorBell, anyhow::Error> {
1229        unsafe {
1230            let op_type_val = match op_type {
1231                RdmaOperation::Write => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE,
1232                RdmaOperation::WriteWithImm => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE_IMM,
1233                RdmaOperation::Read => rdmaxcel_sys::MLX5_OPCODE_RDMA_READ,
1234                RdmaOperation::Recv => 0,
1235            };
1236
1237            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
1238            let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
1239            let _dv_cq = if op_type == RdmaOperation::Recv {
1240                self.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq
1241            } else {
1242                self.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq
1243            };
1244
1245            // Create the WQE parameters struct
1246
1247            let buf = if op_type == RdmaOperation::Recv {
1248                (*dv_qp).rq.buf as *mut u8
1249            } else {
1250                (*dv_qp).sq.buf as *mut u8
1251            };
1252
1253            let params = rdmaxcel_sys::wqe_params_t {
1254                laddr,
1255                lkey,
1256                length,
1257                wr_id: wr_id.into(),
1258                signaled,
1259                op_type: op_type_val,
1260                raddr,
1261                rkey,
1262                qp_num: (*(*qp).ibv_qp).qp_num,
1263                buf,
1264                dbrec: (*dv_qp).dbrec,
1265                wqe_cnt: (*dv_qp).sq.wqe_cnt,
1266            };
1267
1268            // Call the C function to post the WQE
1269            if op_type == RdmaOperation::Recv {
1270                rdmaxcel_sys::recv_wqe(params);
1271                std::ptr::write_volatile((*dv_qp).dbrec, 1_u32.to_be());
1272            } else {
1273                rdmaxcel_sys::send_wqe(params);
1274            };
1275
1276            // Create and return a DoorBell struct
1277            Ok(DoorBell {
1278                dst_ptr: (*dv_qp).bf.reg as usize,
1279                src_ptr: (*dv_qp).sq.buf as usize,
1280                size: 8,
1281            })
1282        }
1283    }
1284
1285    /// Poll for work completions by wr_ids.
1286    ///
1287    /// # Arguments
1288    ///
1289    /// * `target` - Which completion queue to poll (Send, Receive)
1290    /// * `expected_wr_ids` - Slice of work request IDs to wait for
1291    ///
1292    /// # Returns
1293    ///
1294    /// * `Ok(Vec<(u64, IbvWc)>)` - Vector of (wr_id, completion) pairs that were found
1295    /// * `Err(e)` - An error occurred
1296    pub fn poll_completion(
1297        &mut self,
1298        target: PollTarget,
1299        expected_wr_ids: &[u64],
1300    ) -> Result<Vec<(u64, IbvWc)>, anyhow::Error> {
1301        if expected_wr_ids.is_empty() {
1302            return Ok(Vec::new());
1303        }
1304
1305        unsafe {
1306            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
1307            let qp_num = (*(*qp).ibv_qp).qp_num;
1308
1309            let (cq, cache, cq_type) = match target {
1310                PollTarget::Send => (
1311                    self.send_cq as *mut rdmaxcel_sys::ibv_cq,
1312                    rdmaxcel_sys::rdmaxcel_qp_get_send_cache(qp),
1313                    "send",
1314                ),
1315                PollTarget::Recv => (
1316                    self.recv_cq as *mut rdmaxcel_sys::ibv_cq,
1317                    rdmaxcel_sys::rdmaxcel_qp_get_recv_cache(qp),
1318                    "recv",
1319                ),
1320            };
1321
1322            let mut results = Vec::new();
1323
1324            // Single-shot poll: check each wr_id once and return what we find
1325            for &expected_wr_id in expected_wr_ids {
1326                let mut poll_ctx = rdmaxcel_sys::poll_context_t {
1327                    expected_wr_id,
1328                    expected_qp_num: qp_num,
1329                    cache,
1330                    cq,
1331                };
1332
1333                let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
1334                let ret = rdmaxcel_sys::poll_cq_with_cache(&mut poll_ctx, &mut wc);
1335
1336                match ret {
1337                    1 => {
1338                        // Found completion
1339                        if !wc.is_valid() {
1340                            if let Some((status, vendor_err)) = wc.error() {
1341                                return Err(anyhow::anyhow!(
1342                                    "{} completion failed for wr_id={}: status={:?}, vendor_err={}",
1343                                    cq_type,
1344                                    expected_wr_id,
1345                                    status,
1346                                    vendor_err,
1347                                ));
1348                            }
1349                        }
1350                        results.push((expected_wr_id, IbvWc::from(wc)));
1351                    }
1352                    0 => {
1353                        // Not found yet - this is fine for single-shot poll
1354                    }
1355                    -17 => {
1356                        // RDMAXCEL_COMPLETION_FAILED: Completion found but failed - wc contains the error details
1357                        let error_msg =
1358                            std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1359                                .to_str()
1360                                .unwrap_or("Unknown error");
1361                        if let Some((status, vendor_err)) = wc.error() {
1362                            return Err(anyhow::anyhow!(
1363                                "Failed to poll {} CQ for wr_id={}: {} [status={:?}, vendor_err={}, qp_num={}, byte_len={}]",
1364                                cq_type,
1365                                expected_wr_id,
1366                                error_msg,
1367                                status,
1368                                vendor_err,
1369                                wc.qp_num,
1370                                wc.len(),
1371                            ));
1372                        } else {
1373                            return Err(anyhow::anyhow!(
1374                                "Failed to poll {} CQ for wr_id={}: {} [qp_num={}, byte_len={}]",
1375                                cq_type,
1376                                expected_wr_id,
1377                                error_msg,
1378                                wc.qp_num,
1379                                wc.len(),
1380                            ));
1381                        }
1382                    }
1383                    _ => {
1384                        // Other errors
1385                        let error_msg =
1386                            std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1387                                .to_str()
1388                                .unwrap_or("Unknown error");
1389                        return Err(anyhow::anyhow!(
1390                            "Failed to poll {} CQ for wr_id={}: {}",
1391                            cq_type,
1392                            expected_wr_id,
1393                            error_msg
1394                        ));
1395                    }
1396                }
1397            }
1398
1399            Ok(results)
1400        }
1401    }
1402}
1403
1404/// Utility to validate execution context.
1405///
1406/// Remote Execution environments do not always have access to the nvidia_peermem module
1407/// and/or set the PeerMappingOverride parameter due to security. This function can be
1408/// used to validate that the execution context when running operations that need this
1409/// functionality (ie. cudaHostRegisterIoMemory).
1410///
1411/// # Returns
1412///
1413/// * `Ok(())` if the execution context is valid
1414/// * `Err(anyhow::Error)` if the execution context is invalid
1415pub async fn validate_execution_context() -> Result<(), anyhow::Error> {
1416    // Check for nvidia peermem
1417    match fs::read_to_string("/proc/modules") {
1418        Ok(contents) => {
1419            if !contents.contains("nvidia_peermem") {
1420                return Err(anyhow::anyhow!(
1421                    "nvidia_peermem module not found in /proc/modules"
1422                ));
1423            }
1424        }
1425        Err(e) => {
1426            return Err(anyhow::anyhow!(e));
1427        }
1428    }
1429
1430    // Test file access to nvidia params
1431    match fs::read_to_string("/proc/driver/nvidia/params") {
1432        Ok(contents) => {
1433            if !contents.contains("PeerMappingOverride=1") {
1434                return Err(anyhow::anyhow!(
1435                    "PeerMappingOverride=1 not found in /proc/driver/nvidia/params"
1436                ));
1437            }
1438        }
1439        Err(e) => {
1440            return Err(anyhow::anyhow!(e));
1441        }
1442    }
1443    Ok(())
1444}
1445
1446/// Get all segments that have been registered with MRs
1447///
1448/// # Returns
1449/// * `Vec<SegmentInfo>` - Vector containing all registered segment information
1450pub fn get_registered_cuda_segments() -> Vec<rdmaxcel_sys::rdma_segment_info_t> {
1451    unsafe {
1452        let segment_count = rdmaxcel_sys::rdma_get_active_segment_count();
1453        if segment_count <= 0 {
1454            return Vec::new();
1455        }
1456
1457        let mut segments = vec![
1458            std::mem::MaybeUninit::<rdmaxcel_sys::rdma_segment_info_t>::zeroed()
1459                .assume_init();
1460            segment_count as usize
1461        ];
1462        let actual_count =
1463            rdmaxcel_sys::rdma_get_all_segment_info(segments.as_mut_ptr(), segment_count);
1464
1465        if actual_count > 0 {
1466            segments.truncate(actual_count as usize);
1467            segments
1468        } else {
1469            Vec::new()
1470        }
1471    }
1472}
1473
1474/// Check if PyTorch CUDA caching allocator has expandable segments enabled.
1475///
1476/// This function calls the C++ implementation that directly accesses the
1477/// PyTorch C10 CUDA allocator configuration to check if expandable segments
1478/// are enabled, which is required for RDMA operations with CUDA tensors.
1479///
1480/// # Returns
1481///
1482/// `true` if both CUDA caching allocator is enabled AND expandable segments are enabled,
1483/// `false` otherwise.
1484pub fn pt_cuda_allocator_compatibility() -> bool {
1485    // SAFETY: We are calling a C++ function from rdmaxcel that accesses PyTorch C10 APIs.
1486    unsafe { rdmaxcel_sys::pt_cuda_allocator_compatibility() }
1487}
1488
1489#[cfg(test)]
1490mod tests {
1491    use super::*;
1492
1493    #[test]
1494    fn test_create_connection() {
1495        // Skip test if RDMA devices are not available
1496        if crate::ibverbs_primitives::get_all_devices().is_empty() {
1497            println!("Skipping test: RDMA devices not available");
1498            return;
1499        }
1500
1501        let config = IbverbsConfig {
1502            use_gpu_direct: false,
1503            ..Default::default()
1504        };
1505        let domain = RdmaDomain::new(config.device.clone());
1506        assert!(domain.is_ok());
1507
1508        let domain = domain.unwrap();
1509        let queue_pair = RdmaQueuePair::new(domain.context, domain.pd, config.clone());
1510        assert!(queue_pair.is_ok());
1511    }
1512
1513    #[test]
1514    fn test_loopback_connection() {
1515        // Skip test if RDMA devices are not available
1516        if crate::ibverbs_primitives::get_all_devices().is_empty() {
1517            println!("Skipping test: RDMA devices not available");
1518            return;
1519        }
1520
1521        let server_config = IbverbsConfig {
1522            use_gpu_direct: false,
1523            ..Default::default()
1524        };
1525        let client_config = IbverbsConfig {
1526            use_gpu_direct: false,
1527            ..Default::default()
1528        };
1529
1530        let server_domain = RdmaDomain::new(server_config.device.clone()).unwrap();
1531        let client_domain = RdmaDomain::new(client_config.device.clone()).unwrap();
1532
1533        let mut server_qp = RdmaQueuePair::new(
1534            server_domain.context,
1535            server_domain.pd,
1536            server_config.clone(),
1537        )
1538        .unwrap();
1539        let mut client_qp = RdmaQueuePair::new(
1540            client_domain.context,
1541            client_domain.pd,
1542            client_config.clone(),
1543        )
1544        .unwrap();
1545
1546        let server_connection_info = server_qp.get_qp_info().unwrap();
1547        let client_connection_info = client_qp.get_qp_info().unwrap();
1548
1549        assert!(server_qp.connect(&client_connection_info).is_ok());
1550        assert!(client_qp.connect(&server_connection_info).is_ok());
1551    }
1552}