monarch_rdma/
rdma_components.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # RDMA Components
10//!
11//! This module provides the core RDMA building blocks for establishing and managing RDMA connections.
12//!
13//! ## Core Components
14//!
15//! * `RdmaDomain` - Manages RDMA resources including context, protection domain, and memory region
16//! * `RdmaQueuePair` - Handles communication between endpoints via queue pairs and completion queues
17//!
18//! ## RDMA Overview
19//!
20//! Remote Direct Memory Access (RDMA) allows direct memory access from the memory of one computer
21//! into the memory of another without involving either computer's operating system. This permits
22//! high-throughput, low-latency networking with minimal CPU overhead.
23//!
24//! ## Connection Architecture
25//!
26//! The module manages the following ibverbs primitives:
27//!
28//! 1. **Queue Pairs (QP)**: Each connection has a send queue and a receive queue
29//! 2. **Completion Queues (CQ)**: Events are reported when operations complete
30//! 3. **Memory Regions (MR)**: Memory must be registered with the RDMA device before use
31//! 4. **Protection Domains (PD)**: Provide isolation between different connections
32//!
33//! ## Connection Lifecycle
34//!
35//! 1. Create an `RdmaDomain` with `new()`
36//! 2. Create an `RdmaQueuePair` from the domain
37//! 3. Exchange connection info with remote peer (application must handle this)
38//! 4. Connect to remote endpoint with `connect()`
39//! 5. Perform RDMA operations (read/write)
40//! 6. Poll for completions
41//! 7. Resources are cleaned up when dropped
42
43use std::collections::HashMap;
44use std::ffi::CStr;
45use std::fs;
46use std::io::Error;
47use std::result::Result;
48use std::time::Duration;
49
50use hyperactor::ActorRef;
51use hyperactor::Mailbox;
52use hyperactor::Named;
53use hyperactor::clock::Clock;
54use hyperactor::clock::RealClock;
55use serde::Deserialize;
56use serde::Serialize;
57
58use crate::RdmaDevice;
59use crate::RdmaManagerActor;
60use crate::RdmaManagerMessageClient;
61use crate::ibverbs_primitives::Gid;
62use crate::ibverbs_primitives::IbvWc;
63use crate::ibverbs_primitives::IbverbsConfig;
64use crate::ibverbs_primitives::RdmaMemoryRegionView;
65use crate::ibverbs_primitives::RdmaOperation;
66use crate::ibverbs_primitives::RdmaQpInfo;
67
68#[derive(Debug, Named, Clone, Serialize, Deserialize)]
69pub struct DoorBell {
70    pub src_ptr: usize,
71    pub dst_ptr: usize,
72    pub size: usize,
73}
74
75impl DoorBell {
76    /// Rings the doorbell to trigger the execution of previously enqueued operations.
77    ///
78    /// This method uses unsafe code to directly interact with the RDMA device,
79    /// sending a signal from the source pointer to the destination pointer.
80    ///
81    /// # Returns
82    /// * `Ok(())` if the operation is successful.
83    /// * `Err(anyhow::Error)` if an error occurs during the operation.
84    pub fn ring(&self) -> Result<(), anyhow::Error> {
85        unsafe {
86            let src_ptr = self.src_ptr as *mut std::ffi::c_void;
87            let dst_ptr = self.dst_ptr as *mut std::ffi::c_void;
88            rdmaxcel_sys::db_ring(dst_ptr, src_ptr);
89            Ok(())
90        }
91    }
92}
93
94#[derive(Debug, Serialize, Deserialize, Named, Clone)]
95pub struct RdmaBuffer {
96    pub owner: ActorRef<RdmaManagerActor>,
97    pub mr_id: u32,
98    pub lkey: u32,
99    pub rkey: u32,
100    pub addr: usize,
101    pub size: usize,
102}
103
104impl RdmaBuffer {
105    /// Read from the RdmaBuffer into the provided memory.
106    ///
107    /// This method transfers data from the buffer into the local memory region provided over RDMA.
108    /// This involves calling a `Put` operation on the RdmaBuffer's actor side.
109    ///
110    /// # Arguments
111    /// * `client` - Mailbox used for communication
112    /// * `remote` - RdmaBuffer representing the remote memory region
113    /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
114    ///
115    /// # Returns
116    /// `Ok(bool)` indicating if the operation completed successfully.
117    pub async fn read_into(
118        &self,
119        client: &Mailbox,
120        remote: RdmaBuffer,
121        timeout: u64,
122    ) -> Result<bool, anyhow::Error> {
123        tracing::debug!(
124            "[buffer] reading from {:?} into remote ({:?}) at {:?}",
125            self,
126            remote.owner.actor_id(),
127            remote,
128        );
129        let mut qp = self
130            .owner
131            .request_queue_pair(client, remote.owner.clone())
132            .await?;
133
134        qp.put(self.clone(), remote)?;
135        self.wait_for_completion(&mut qp, PollTarget::Send, timeout)
136            .await
137    }
138
139    /// Write from the provided memory into the RdmaBuffer.
140    ///
141    /// This method performs an RDMA write operation, transferring data from the caller's
142    /// memory region to this buffer.
143    /// This involves calling a `Fetch` operation on the RdmaBuffer's actor side.
144    ///
145    /// # Arguments
146    /// * `client` - Mailbox used for communication
147    /// * `remote` - RdmaBuffer representing the remote memory region
148    /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
149    ///
150    /// # Returns
151    /// `Ok(bool)` indicating if the operation completed successfully.
152    pub async fn write_from(
153        &self,
154        client: &Mailbox,
155        remote: RdmaBuffer,
156        timeout: u64,
157    ) -> Result<bool, anyhow::Error> {
158        tracing::debug!(
159            "[buffer] writing into {:?} from remote ({:?}) at {:?}",
160            self,
161            remote.owner.actor_id(),
162            remote,
163        );
164        let mut qp = self
165            .owner
166            .request_queue_pair(client, remote.owner.clone())
167            .await?;
168        qp.get(self.clone(), remote)?;
169        self.wait_for_completion(&mut qp, PollTarget::Send, timeout)
170            .await
171    }
172    /// Waits for the completion of an RDMA operation.
173    ///
174    /// This method polls the completion queue until the specified work request completes
175    /// or until the timeout is reached.
176    ///
177    /// # Arguments
178    /// * `qp` - The RDMA Queue Pair to poll for completion
179    /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
180    ///
181    /// # Returns
182    /// `Ok(true)` if the operation completes successfully within the timeout,
183    /// or an error if the timeout is reached
184    async fn wait_for_completion(
185        &self,
186        qp: &mut RdmaQueuePair,
187        poll_target: PollTarget,
188        timeout: u64,
189    ) -> Result<bool, anyhow::Error> {
190        let timeout = Duration::from_secs(timeout);
191        let start_time = std::time::Instant::now();
192
193        while start_time.elapsed() < timeout {
194            match qp.poll_completion_target(poll_target) {
195                Ok(Some(_wc)) => {
196                    tracing::debug!("work completed");
197                    return Ok(true);
198                }
199                Ok(None) => {
200                    RealClock.sleep(Duration::from_millis(1)).await;
201                }
202                Err(e) => {
203                    tracing::error!("polling completion failed: {}", e);
204                    return Err(anyhow::anyhow!(e));
205                }
206            }
207        }
208        tracing::error!("timed out while waiting on request completion");
209        Err(anyhow::anyhow!(
210            "[buffer({:?})] rdma operation did not complete in time",
211            self
212        ))
213    }
214}
215
216/// Represents a domain for RDMA operations, encapsulating the necessary resources
217/// for establishing and managing RDMA connections.
218///
219/// An `RdmaDomain` manages the context, protection domain (PD), and memory region (MR)
220/// required for RDMA operations. It provides the foundation for creating queue pairs
221/// and establishing connections between RDMA devices.
222///
223/// # Fields
224///
225/// * `context`: A pointer to the RDMA device context, representing the connection to the RDMA device.
226/// * `pd`: A pointer to the protection domain, which provides isolation between different connections.
227/// * `mr_map`: A map of memory region IDs to pointers, representing registered memory regions.
228/// * `counter`: A counter for generating unique memory region IDs.
229pub struct RdmaDomain {
230    pub context: *mut rdmaxcel_sys::ibv_context,
231    pub pd: *mut rdmaxcel_sys::ibv_pd,
232    mr_map: HashMap<u32, *mut rdmaxcel_sys::ibv_mr>,
233    counter: u32,
234}
235
236impl std::fmt::Debug for RdmaDomain {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        f.debug_struct("RdmaDomain")
239            .field("context", &format!("{:p}", self.context))
240            .field("pd", &format!("{:p}", self.pd))
241            .field("mr", &format!("{:?}", self.mr_map))
242            .field("counter", &self.counter)
243            .finish()
244    }
245}
246
247// SAFETY:
248// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls.
249// RdmaDomain is `Send` because the raw pointers to ibverbs structs can be
250// accessed from any thread, and it is safe to drop `RdmaDomain` (and run the
251// ibverbs destructors) from any thread.
252unsafe impl Send for RdmaDomain {}
253
254// SAFETY:
255// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls.
256// RdmaDomain is `Sync` because the underlying ibverbs APIs are thread-safe.
257unsafe impl Sync for RdmaDomain {}
258
259impl Drop for RdmaDomain {
260    fn drop(&mut self) {
261        unsafe {
262            rdmaxcel_sys::ibv_dealloc_pd(self.pd);
263        }
264    }
265}
266
267impl RdmaDomain {
268    /// Creates a new RdmaDomain.
269    ///
270    /// This function initializes the RDMA device context, creates a protection domain,
271    /// and registers a memory region with appropriate access permissions.
272    ///
273    /// SAFETY:
274    /// Our memory region (MR) registration uses implicit ODP for RDMA access, which maps large virtual
275    /// address ranges without explicit pinning. This is convenient, but it broadens the memory footprint
276    /// exposed to the NIC and introduces a security liability.
277    ///
278    /// We currently assume a trusted, single-environment and are not enforcing finer-grained memory isolation
279    /// at this layer. We plan to investigate mitigations - such as memory windows or tighter registration
280    /// boundaries in future follow-ups.
281    ///
282    /// # Arguments
283    ///
284    /// * `config` - Configuration settings for the RDMA operations
285    ///
286    /// # Errors
287    ///
288    /// This function may return errors if:
289    /// * No RDMA devices are found
290    /// * The specified device cannot be found
291    /// * Device context creation fails
292    /// * Protection domain allocation fails
293    /// * Memory region registration fails
294    pub fn new(device: RdmaDevice) -> Result<Self, anyhow::Error> {
295        tracing::debug!("creating RdmaDomain for device {}", device.name());
296        // SAFETY:
297        // This code uses unsafe rdmaxcel_sys calls to interact with the RDMA device, but is safe because:
298        // - All pointers are properly initialized and checked for null before use
299        // - Memory registration follows the ibverbs API contract with proper access flags
300        // - Resources are properly cleaned up in error cases to prevent leaks
301        // - The operations follow the documented RDMA protocol for device initialization
302        unsafe {
303            // Get the device based on the provided RdmaDevice
304            let device_name = device.name();
305            let mut num_devices = 0i32;
306            let devices = rdmaxcel_sys::ibv_get_device_list(&mut num_devices as *mut _);
307
308            if devices.is_null() || num_devices == 0 {
309                return Err(anyhow::anyhow!("no RDMA devices found"));
310            }
311
312            // Find the device with the matching name
313            let mut device_ptr = std::ptr::null_mut();
314            for i in 0..num_devices {
315                let dev = *devices.offset(i as isize);
316                let dev_name =
317                    CStr::from_ptr(rdmaxcel_sys::ibv_get_device_name(dev)).to_string_lossy();
318
319                if dev_name == *device_name {
320                    device_ptr = dev;
321                    break;
322                }
323            }
324
325            // If we didn't find the device, return an error
326            if device_ptr.is_null() {
327                rdmaxcel_sys::ibv_free_device_list(devices);
328                return Err(anyhow::anyhow!("device '{}' not found", device_name));
329            }
330            tracing::info!("using RDMA device: {}", device_name);
331
332            // Open device
333            let context = rdmaxcel_sys::ibv_open_device(device_ptr);
334            if context.is_null() {
335                rdmaxcel_sys::ibv_free_device_list(devices);
336                let os_error = Error::last_os_error();
337                return Err(anyhow::anyhow!("failed to create context: {}", os_error));
338            }
339
340            // Create protection domain
341            let pd = rdmaxcel_sys::ibv_alloc_pd(context);
342            if pd.is_null() {
343                rdmaxcel_sys::ibv_close_device(context);
344                rdmaxcel_sys::ibv_free_device_list(devices);
345                let os_error = Error::last_os_error();
346                return Err(anyhow::anyhow!(
347                    "failed to create protection domain (PD): {}",
348                    os_error
349                ));
350            }
351
352            // Avoids memory leaks
353            rdmaxcel_sys::ibv_free_device_list(devices);
354
355            Ok(RdmaDomain {
356                context,
357                pd,
358                mr_map: HashMap::new(),
359                counter: 0,
360            })
361        }
362    }
363
364    fn register_mr(
365        &mut self,
366        addr: usize,
367        size: usize,
368    ) -> Result<RdmaMemoryRegionView, anyhow::Error> {
369        unsafe {
370            let mut mem_type: i32 = 0;
371            let ptr = addr as cuda_sys::CUdeviceptr;
372            let err = cuda_sys::cuPointerGetAttribute(
373                &mut mem_type as *mut _ as *mut std::ffi::c_void,
374                cuda_sys::CUpointer_attribute_enum::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
375                ptr,
376            );
377            let is_cuda = err == cuda_sys::CUresult::CUDA_SUCCESS;
378
379            let access = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
380                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
381                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
382                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
383
384            let mr;
385            if is_cuda {
386                let mut fd: i32 = -1;
387                cuda_sys::cuMemGetHandleForAddressRange(
388                    &mut fd as *mut i32 as *mut std::ffi::c_void,
389                    addr as cuda_sys::CUdeviceptr,
390                    size,
391                    cuda_sys::CUmemRangeHandleType::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD,
392                    0,
393                );
394                mr = rdmaxcel_sys::ibv_reg_dmabuf_mr(self.pd, 0, size, 0, fd, access.0 as i32);
395            } else {
396                mr = rdmaxcel_sys::ibv_reg_mr(
397                    self.pd,
398                    addr as *mut std::ffi::c_void,
399                    size,
400                    access.0 as i32,
401                );
402            }
403
404            if mr.is_null() {
405                return Err(anyhow::anyhow!("failed to register memory region (MR)"));
406            }
407            let id = self.counter;
408            self.mr_map.insert(id, mr);
409            self.counter += 1;
410
411            Ok(RdmaMemoryRegionView {
412                id,
413                addr: (*mr).addr as usize,
414                size: (*mr).length,
415                lkey: (*mr).lkey,
416                rkey: (*mr).rkey,
417            })
418        }
419    }
420
421    fn deregister_mr(&mut self, id: u32) -> Result<(), anyhow::Error> {
422        let mr = self.mr_map.remove(&id);
423        if mr.is_some() {
424            unsafe {
425                rdmaxcel_sys::ibv_dereg_mr(mr.expect("mr is required"));
426            }
427        }
428        Ok(())
429    }
430
431    pub fn register_buffer(
432        &mut self,
433        addr: usize,
434        size: usize,
435    ) -> Result<RdmaMemoryRegionView, anyhow::Error> {
436        let region_view = self.register_mr(addr, size)?;
437        Ok(region_view)
438    }
439
440    // Removes a specific address from memory region.   Currently we only support single address,
441    // but in future we can expand/contract effective memory region.
442    pub fn deregister_buffer(&mut self, buffer: RdmaBuffer) -> Result<(), anyhow::Error> {
443        self.deregister_mr(buffer.mr_id)?;
444        Ok(())
445    }
446}
447/// Enum to specify which completion queue to poll
448#[derive(Debug, Clone, Copy, PartialEq)]
449pub enum PollTarget {
450    Send,
451    Recv,
452}
453
454/// Represents an RDMA Queue Pair (QP) that enables communication between two endpoints.
455///
456/// An `RdmaQueuePair` encapsulates the send and receive queues, completion queue,
457/// and other resources needed for RDMA communication. It provides methods for
458/// establishing connections and performing RDMA operations like read and write.
459///
460/// # Fields
461///
462/// * `send_cq` - Send Completion Queue pointer for tracking send operation completions
463/// * `recv_cq` - Receive Completion Queue pointer for tracking receive operation completions
464/// * `qp` - Queue Pair pointer that manages send and receive operations
465/// * `dv_qp` - Pointer to the mlx5 device-specific queue pair structure
466/// * `dv_send_cq` - Pointer to the mlx5 device-specific send completion queue structure
467/// * `dv_recv_cq` - Pointer to the mlx5 device-specific receive completion queue structure
468/// * `context` - RDMA device context pointer
469/// * `config` - Configuration settings for the queue pair
470///
471/// # Connection Lifecycle
472///
473/// 1. Create with `new()` from an `RdmaDomain`
474/// 2. Get connection info with `get_qp_info()`
475/// 3. Exchange connection info with remote peer (application must handle this)
476/// 4. Connect to remote endpoint with `connect()`
477/// 5. Perform RDMA operations with `put()` or `get()`
478/// 6. Poll for completions with `poll_send_completion()` or `poll_recv_completion()`
479
480#[derive(Debug, Serialize, Deserialize, Named, Clone)]
481pub struct RdmaQueuePair {
482    pub send_cq: usize,    // *mut rdmaxcel_sys::ibv_cq,
483    pub recv_cq: usize,    // *mut rdmaxcel_sys::ibv_cq,
484    pub qp: usize,         // *mut rdmaxcel_sys::ibv_qp,
485    pub dv_qp: usize,      // *mut rdmaxcel_sys::mlx5dv_qp,
486    pub dv_send_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
487    pub dv_recv_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
488    context: usize,        // *mut rdmaxcel_sys::ibv_context,
489    config: IbverbsConfig,
490    pub send_wqe_idx: u32,
491    pub send_db_idx: u32,
492    pub send_cq_idx: u32,
493    pub recv_wqe_idx: u32,
494    pub recv_db_idx: u32,
495    pub recv_cq_idx: u32,
496}
497
498impl RdmaQueuePair {
499    /// Creates a new RdmaQueuePair from a given RdmaDomain.
500    ///
501    /// This function initializes a new Queue Pair (QP) and associated Completion Queue (CQ)
502    /// using the resources from the provided RdmaDomain. The QP is created in the RESET state
503    /// and must be transitioned to other states via the `connect()` method before use.
504    ///
505    /// # Arguments
506    ///
507    /// * `domain` - Reference to an RdmaDomain that provides the context, protection domain,
508    ///   and memory region for this queue pair
509    ///
510    /// # Returns
511    ///
512    /// * `Result<Self>` - A new RdmaQueuePair instance or an error if creation fails
513    ///
514    /// # Errors
515    ///
516    /// This function may return errors if:
517    /// * Completion queue (CQ) creation fails
518    /// * Queue pair (QP) creation fails
519    pub fn new(
520        context: *mut rdmaxcel_sys::ibv_context,
521        pd: *mut rdmaxcel_sys::ibv_pd,
522        config: IbverbsConfig,
523    ) -> Result<Self, anyhow::Error> {
524        tracing::debug!("creating an RdmaQueuePair from config {}", config);
525        unsafe {
526            // standard ibverbs QP
527            let qp = rdmaxcel_sys::create_qp(
528                context,
529                pd,
530                config.cq_entries,
531                config.max_send_wr.try_into().unwrap(),
532                config.max_recv_wr.try_into().unwrap(),
533                config.max_send_sge.try_into().unwrap(),
534                config.max_recv_sge.try_into().unwrap(),
535            );
536
537            if qp.is_null() {
538                let os_error = Error::last_os_error();
539                return Err(anyhow::anyhow!(
540                    "failed to create queue pair (QP): {}",
541                    os_error
542                ));
543            }
544
545            let send_cq = (*qp).send_cq;
546            let recv_cq = (*qp).recv_cq;
547
548            // mlx5dv provider APIs
549            let dv_qp = rdmaxcel_sys::create_mlx5dv_qp(qp);
550            let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq(qp);
551            let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq(qp);
552
553            if dv_qp.is_null() || dv_send_cq.is_null() || dv_recv_cq.is_null() {
554                rdmaxcel_sys::ibv_destroy_cq((*qp).recv_cq);
555                rdmaxcel_sys::ibv_destroy_cq((*qp).send_cq);
556                rdmaxcel_sys::ibv_destroy_qp(qp);
557                return Err(anyhow::anyhow!(
558                    "failed to init mlx5dv_qp or completion queues"
559                ));
560            }
561
562            // GPU Direct RDMA specific registrations
563            if config.use_gpu_direct {
564                let ret = rdmaxcel_sys::register_cuda_memory(dv_qp, dv_recv_cq, dv_send_cq);
565                if ret != 0 {
566                    rdmaxcel_sys::ibv_destroy_cq((*qp).recv_cq);
567                    rdmaxcel_sys::ibv_destroy_cq((*qp).send_cq);
568                    rdmaxcel_sys::ibv_destroy_qp(qp);
569                    return Err(anyhow::anyhow!(
570                        "failed to register GPU Direct RDMA memory: {:?}",
571                        ret
572                    ));
573                }
574            }
575
576            Ok(RdmaQueuePair {
577                send_cq: send_cq as usize,
578                recv_cq: recv_cq as usize,
579                qp: qp as usize,
580                dv_qp: dv_qp as usize,
581                dv_send_cq: dv_send_cq as usize,
582                dv_recv_cq: dv_recv_cq as usize,
583                context: context as usize,
584                config,
585                recv_db_idx: 0,
586                recv_wqe_idx: 0,
587                recv_cq_idx: 0,
588                send_db_idx: 0,
589                send_wqe_idx: 0,
590                send_cq_idx: 0,
591            })
592        }
593    }
594
595    /// Returns the information required for a remote peer to connect to this queue pair.
596    ///
597    /// This method retrieves the local queue pair attributes and port information needed by
598    /// a remote peer to establish an RDMA connection. The returned `RdmaQpInfo` contains
599    /// the queue pair number, LID, GID, and other necessary connection parameters.
600    ///
601    /// # Returns
602    ///
603    /// * `Result<RdmaQpInfo>` - Connection information for the remote peer or an error
604    ///
605    /// # Errors
606    ///
607    /// This function may return errors if:
608    /// * Port attribute query fails
609    /// * GID query fails
610    pub fn get_qp_info(&mut self) -> Result<RdmaQpInfo, anyhow::Error> {
611        // SAFETY:
612        // This code uses unsafe rdmaxcel_sys calls to query RDMA device information, but is safe because:
613        // - All pointers are properly initialized before use
614        // - Port and GID queries follow the documented ibverbs API contract
615        // - Error handling properly checks return codes from ibverbs functions
616        // - The memory address provided is only stored, not dereferenced in this function
617        unsafe {
618            let context = self.context as *mut rdmaxcel_sys::ibv_context;
619            let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
620            let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
621            let errno = rdmaxcel_sys::ibv_query_port(
622                context,
623                self.config.port_num,
624                &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
625            );
626            if errno != 0 {
627                let os_error = Error::last_os_error();
628                return Err(anyhow::anyhow!(
629                    "Failed to query port attributes: {}",
630                    os_error
631                ));
632            }
633
634            let mut gid = Gid::default();
635            let ret = rdmaxcel_sys::ibv_query_gid(
636                context,
637                self.config.port_num,
638                i32::from(self.config.gid_index),
639                gid.as_mut(),
640            );
641            if ret != 0 {
642                return Err(anyhow::anyhow!("Failed to query GID"));
643            }
644
645            Ok(RdmaQpInfo {
646                qp_num: (*qp).qp_num,
647                lid: port_attr.lid,
648                gid: Some(gid),
649                psn: self.config.psn,
650            })
651        }
652    }
653
654    pub fn state(&mut self) -> Result<u32, anyhow::Error> {
655        // SAFETY: This block interacts with the RDMA device through rdmaxcel_sys calls.
656        unsafe {
657            let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
658            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
659                ..Default::default()
660            };
661            let mut qp_init_attr = rdmaxcel_sys::ibv_qp_init_attr {
662                ..Default::default()
663            };
664            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE;
665            let errno =
666                rdmaxcel_sys::ibv_query_qp(qp, &mut qp_attr, mask.0 as i32, &mut qp_init_attr);
667            if errno != 0 {
668                let os_error = Error::last_os_error();
669                return Err(anyhow::anyhow!("failed to query QP state: {}", os_error));
670            }
671            Ok(qp_attr.qp_state)
672        }
673    }
674    /// Connect to a remote Rdma connection point.
675    ///
676    /// This performs the necessary QP state transitions (INIT->RTR->RTS) to establish a connection.
677    ///
678    /// # Arguments
679    ///
680    /// * `connection_info` - The remote connection info to connect to
681    pub fn connect(&mut self, connection_info: &RdmaQpInfo) -> Result<(), anyhow::Error> {
682        // SAFETY:
683        // This unsafe block is necessary because we're interacting with the RDMA device through rdmaxcel_sys calls.
684        // The operations are safe because:
685        // 1. We're following the documented ibverbs API contract
686        // 2. All pointers used are properly initialized and owned by this struct
687        // 3. The QP state transitions (INIT->RTR->RTS) follow the required RDMA connection protocol
688        // 4. Memory access is properly bounded by the registered memory regions
689        unsafe {
690            // Transition to INIT
691            let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
692
693            let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
694                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
695                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
696                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
697
698            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
699                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_INIT,
700                qp_access_flags: qp_access_flags.0,
701                pkey_index: self.config.pkey_index,
702                port_num: self.config.port_num,
703                ..Default::default()
704            };
705
706            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
707                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX
708                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT
709                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS;
710
711            let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32);
712            if errno != 0 {
713                let os_error = Error::last_os_error();
714                return Err(anyhow::anyhow!(
715                    "failed to transition QP to INIT: {}",
716                    os_error
717                ));
718            }
719
720            // Transition to RTR (Ready to Receive)
721            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
722                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTR,
723                path_mtu: self.config.path_mtu,
724                dest_qp_num: connection_info.qp_num,
725                rq_psn: connection_info.psn,
726                max_dest_rd_atomic: self.config.max_dest_rd_atomic,
727                min_rnr_timer: self.config.min_rnr_timer,
728                ah_attr: rdmaxcel_sys::ibv_ah_attr {
729                    dlid: connection_info.lid,
730                    sl: 0,
731                    src_path_bits: 0,
732                    port_num: self.config.port_num,
733                    grh: Default::default(),
734                    ..Default::default()
735                },
736                ..Default::default()
737            };
738
739            // If the remote connection info contains a Gid, the routing will be global.
740            // Otherwise, it will be local, i.e. using LID.
741            if let Some(gid) = connection_info.gid {
742                qp_attr.ah_attr.is_global = 1;
743                qp_attr.ah_attr.grh.dgid = gid.into();
744                qp_attr.ah_attr.grh.hop_limit = 0xff;
745                qp_attr.ah_attr.grh.sgid_index = self.config.gid_index;
746            } else {
747                // Use LID-based routing, e.g. for Infiniband/RoCEv1
748                qp_attr.ah_attr.is_global = 0;
749            }
750
751            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
752                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_AV
753                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PATH_MTU
754                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_DEST_QPN
755                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RQ_PSN
756                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC
757                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER;
758
759            let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32);
760            if errno != 0 {
761                let os_error = Error::last_os_error();
762                return Err(anyhow::anyhow!(
763                    "failed to transition QP to RTR: {}",
764                    os_error
765                ));
766            }
767
768            // Transition to RTS (Ready to Send)
769            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
770                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS,
771                sq_psn: self.config.psn,
772                max_rd_atomic: self.config.max_rd_atomic,
773                retry_cnt: self.config.retry_cnt,
774                rnr_retry: self.config.rnr_retry,
775                timeout: self.config.qp_timeout,
776                ..Default::default()
777            };
778
779            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
780                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_TIMEOUT
781                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RETRY_CNT
782                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_SQ_PSN
783                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY
784                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC;
785
786            let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32);
787            if errno != 0 {
788                let os_error = Error::last_os_error();
789                return Err(anyhow::anyhow!(
790                    "failed to transition QP to RTS: {}",
791                    os_error
792                ));
793            }
794            tracing::debug!(
795                "connection sequence has successfully completed (qp: {:?})",
796                qp
797            );
798
799            Ok(())
800        }
801    }
802
803    pub fn recv(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> {
804        let idx = self.recv_wqe_idx;
805        self.recv_wqe_idx += 1;
806        self.send_wqe(
807            0,
808            lhandle.lkey,
809            0,
810            idx,
811            true,
812            RdmaOperation::Recv,
813            0,
814            rhandle.rkey,
815        )
816        .unwrap();
817        Ok(())
818    }
819
820    pub fn put_with_recv(
821        &mut self,
822        lhandle: RdmaBuffer,
823        rhandle: RdmaBuffer,
824    ) -> Result<(), anyhow::Error> {
825        let idx = self.send_wqe_idx;
826        self.send_wqe_idx += 1;
827        self.post_op(
828            lhandle.addr,
829            lhandle.lkey,
830            lhandle.size,
831            idx,
832            true,
833            RdmaOperation::WriteWithImm,
834            rhandle.addr,
835            rhandle.rkey,
836        )
837        .unwrap();
838        self.send_db_idx += 1;
839        Ok(())
840    }
841
842    pub fn put(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> {
843        let idx = self.send_wqe_idx;
844        self.send_wqe_idx += 1;
845        self.post_op(
846            lhandle.addr,
847            lhandle.lkey,
848            lhandle.size,
849            idx,
850            true,
851            RdmaOperation::Write,
852            rhandle.addr,
853            rhandle.rkey,
854        )
855        .unwrap();
856        self.send_db_idx += 1;
857        Ok(())
858    }
859
860    /// Get a doorbell for the queue pair.
861    ///
862    /// This method returns a doorbell that can be used to trigger the execution of
863    /// previously enqueued operations.
864    ///
865    /// # Returns
866    ///
867    /// * `Result<DoorBell, anyhow::Error>` - A doorbell for the queue pair
868    pub fn ring_doorbell(&mut self) -> Result<(), anyhow::Error> {
869        unsafe {
870            let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
871            let base_ptr = (*dv_qp).sq.buf as *mut u8;
872            let wqe_cnt = (*dv_qp).sq.wqe_cnt;
873            let stride = (*dv_qp).sq.stride;
874            if wqe_cnt < (self.send_wqe_idx - self.send_db_idx) {
875                return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
876            }
877            while self.send_db_idx < self.send_wqe_idx {
878                let offset = (self.send_db_idx % wqe_cnt) * stride;
879                let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
880                rdmaxcel_sys::db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
881                self.send_db_idx += 1;
882            }
883            Ok(())
884        }
885    }
886
887    /// Enqueues a put operation without ringing the doorbell.
888    ///
889    /// This method prepares a put operation but does not execute it.
890    /// Use `get_doorbell().ring()` to execute the operation.
891    ///
892    /// # Arguments
893    ///
894    /// * `lhandle` - Local buffer handle
895    /// * `rhandle` - Remote buffer handle
896    ///
897    /// # Returns
898    ///
899    /// * `Result<(), anyhow::Error>` - Success or error
900    pub fn enqueue_put(
901        &mut self,
902        lhandle: RdmaBuffer,
903        rhandle: RdmaBuffer,
904    ) -> Result<(), anyhow::Error> {
905        let idx = self.send_wqe_idx;
906        self.send_wqe_idx += 1;
907        self.send_wqe(
908            lhandle.addr,
909            lhandle.lkey,
910            lhandle.size,
911            idx,
912            true,
913            RdmaOperation::Write,
914            rhandle.addr,
915            rhandle.rkey,
916        )?;
917        Ok(())
918    }
919
920    /// Enqueues a put with receive operation without ringing the doorbell.
921    ///
922    /// This method prepares a put with receive operation but does not execute it.
923    /// Use `get_doorbell().ring()` to execute the operation.
924    ///
925    /// # Arguments
926    ///
927    /// * `lhandle` - Local buffer handle
928    /// * `rhandle` - Remote buffer handle
929    ///
930    /// # Returns
931    ///
932    /// * `Result<(), anyhow::Error>` - Success or error
933    pub fn enqueue_put_with_recv(
934        &mut self,
935        lhandle: RdmaBuffer,
936        rhandle: RdmaBuffer,
937    ) -> Result<(), anyhow::Error> {
938        let idx = self.send_wqe_idx;
939        self.send_wqe_idx += 1;
940        self.send_wqe(
941            lhandle.addr,
942            lhandle.lkey,
943            lhandle.size,
944            idx,
945            true,
946            RdmaOperation::WriteWithImm,
947            rhandle.addr,
948            rhandle.rkey,
949        )?;
950        Ok(())
951    }
952
953    /// Enqueues a get operation without ringing the doorbell.
954    ///
955    /// This method prepares a get operation but does not execute it.
956    /// Use `get_doorbell().ring()` to execute the operation.
957    ///
958    /// # Arguments
959    ///
960    /// * `lhandle` - Local buffer handle
961    /// * `rhandle` - Remote buffer handle
962    ///
963    /// # Returns
964    ///
965    /// * `Result<(), anyhow::Error>` - Success or error
966    pub fn enqueue_get(
967        &mut self,
968        lhandle: RdmaBuffer,
969        rhandle: RdmaBuffer,
970    ) -> Result<(), anyhow::Error> {
971        let idx = self.send_wqe_idx;
972        self.send_wqe_idx += 1;
973        self.send_wqe(
974            lhandle.addr,
975            lhandle.lkey,
976            lhandle.size,
977            idx,
978            true,
979            RdmaOperation::Read,
980            rhandle.addr,
981            rhandle.rkey,
982        )?;
983        Ok(())
984    }
985
986    pub fn get(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> {
987        let idx = self.send_wqe_idx;
988        self.send_wqe_idx += 1;
989        self.post_op(
990            lhandle.addr,
991            lhandle.lkey,
992            lhandle.size,
993            idx,
994            true,
995            RdmaOperation::Read,
996            rhandle.addr,
997            rhandle.rkey,
998        )
999        .unwrap();
1000        self.send_db_idx += 1;
1001        Ok(())
1002    }
1003
1004    /// Posts a request to the queue pair.
1005    ///
1006    /// # Arguments
1007    ///
1008    /// * `local_addr` - The local address containing data to send
1009    /// * `length` - Length of the data to send
1010    /// * `wr_id` - Work request ID for completion identification
1011    /// * `signaled` - Whether to generate a completion event
1012    /// * `op_type` - Optional operation type
1013    /// * `raddr` - the remote address, representing the memory location on the remote peer
1014    /// * `rkey` - the remote key, representing the key required to access the remote memory region
1015    fn post_op(
1016        &mut self,
1017        laddr: usize,
1018        lkey: u32,
1019        length: usize,
1020        wr_id: u32,
1021        signaled: bool,
1022        op_type: RdmaOperation,
1023        raddr: usize,
1024        rkey: u32,
1025    ) -> Result<(), anyhow::Error> {
1026        // SAFETY:
1027        // This code uses unsafe rdmaxcel_sys calls to post work requests to the RDMA device, but is safe because:
1028        // - All pointers (send_sge, send_wr) are properly initialized on the stack before use
1029        // - The memory address in `local_addr` is not dereferenced, only passed to the device
1030        // - The remote connection info is verified to exist before accessing
1031        // - The ibverbs post_send operation follows the documented API contract
1032        // - Error codes from the device are properly checked and propagated
1033        unsafe {
1034            let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
1035            let context = self.context as *mut rdmaxcel_sys::ibv_context;
1036            let ops = &mut (*context).ops;
1037            let errno;
1038            if op_type == RdmaOperation::Recv {
1039                let mut sge = rdmaxcel_sys::ibv_sge {
1040                    addr: laddr as u64,
1041                    length: length as u32,
1042                    lkey,
1043                };
1044                let mut wr = rdmaxcel_sys::ibv_recv_wr {
1045                    wr_id: wr_id.try_into().unwrap(),
1046                    sg_list: &mut sge as *mut _,
1047                    num_sge: 1,
1048                    ..Default::default()
1049                };
1050                let mut bad_wr: *mut rdmaxcel_sys::ibv_recv_wr = std::ptr::null_mut();
1051                errno = ops.post_recv.as_mut().unwrap()(qp, &mut wr as *mut _, &mut bad_wr);
1052            } else if op_type == RdmaOperation::Write
1053                || op_type == RdmaOperation::Read
1054                || op_type == RdmaOperation::WriteWithImm
1055            {
1056                let send_flags = if signaled {
1057                    rdmaxcel_sys::ibv_send_flags::IBV_SEND_SIGNALED.0
1058                } else {
1059                    0
1060                };
1061                let mut sge = rdmaxcel_sys::ibv_sge {
1062                    addr: laddr as u64,
1063                    length: length as u32,
1064                    lkey,
1065                };
1066                let mut wr = rdmaxcel_sys::ibv_send_wr {
1067                    wr_id: wr_id.try_into().unwrap(),
1068                    next: std::ptr::null_mut(),
1069                    sg_list: &mut sge as *mut _,
1070                    num_sge: 1,
1071                    opcode: op_type.into(),
1072                    send_flags,
1073                    wr: Default::default(),
1074                    qp_type: Default::default(),
1075                    __bindgen_anon_1: Default::default(),
1076                    __bindgen_anon_2: Default::default(),
1077                };
1078
1079                wr.wr.rdma.remote_addr = raddr as u64;
1080                wr.wr.rdma.rkey = rkey;
1081
1082                let mut bad_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut();
1083
1084                errno = ops.post_send.as_mut().unwrap()(qp, &mut wr as *mut _, &mut bad_wr);
1085            } else {
1086                panic!("Not Implemented");
1087            }
1088
1089            if errno != 0 {
1090                let os_error = Error::last_os_error();
1091                return Err(anyhow::anyhow!("Failed to post send request: {}", os_error));
1092            }
1093            tracing::debug!(
1094                "completed sending {:?} request (lkey: {}, addr: 0x{:x}, length {}) to (raddr 0x{:x}, rkey {})",
1095                op_type,
1096                lkey,
1097                laddr,
1098                length,
1099                raddr,
1100                rkey,
1101            );
1102
1103            Ok(())
1104        }
1105    }
1106
1107    fn send_wqe(
1108        &mut self,
1109        laddr: usize,
1110        lkey: u32,
1111        length: usize,
1112        wr_id: u32,
1113        signaled: bool,
1114        op_type: RdmaOperation,
1115        raddr: usize,
1116        rkey: u32,
1117    ) -> Result<DoorBell, anyhow::Error> {
1118        unsafe {
1119            let op_type_val = match op_type {
1120                RdmaOperation::Write => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE,
1121                RdmaOperation::WriteWithImm => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE_IMM,
1122                RdmaOperation::Read => rdmaxcel_sys::MLX5_OPCODE_RDMA_READ,
1123                RdmaOperation::Recv => 0,
1124            };
1125
1126            let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
1127            let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
1128            let _dv_cq = if op_type == RdmaOperation::Recv {
1129                self.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq
1130            } else {
1131                self.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq
1132            };
1133
1134            // Create the WQE parameters struct
1135
1136            let buf = if op_type == RdmaOperation::Recv {
1137                (*dv_qp).rq.buf as *mut u8
1138            } else {
1139                (*dv_qp).sq.buf as *mut u8
1140            };
1141
1142            let params = rdmaxcel_sys::wqe_params_t {
1143                laddr,
1144                lkey,
1145                length,
1146                wr_id: wr_id.try_into().unwrap(),
1147                signaled,
1148                op_type: op_type_val,
1149                raddr,
1150                rkey,
1151                qp_num: (*qp).qp_num,
1152                buf,
1153                dbrec: (*dv_qp).dbrec,
1154                wqe_cnt: (*dv_qp).sq.wqe_cnt,
1155            };
1156
1157            // Call the C function to post the WQE
1158            if op_type == RdmaOperation::Recv {
1159                rdmaxcel_sys::recv_wqe(params);
1160                std::ptr::write_volatile((*dv_qp).dbrec, 1_u32.to_be());
1161            } else {
1162                rdmaxcel_sys::send_wqe(params);
1163            };
1164
1165            // Create and return a DoorBell struct
1166            Ok(DoorBell {
1167                dst_ptr: (*dv_qp).bf.reg as usize,
1168                src_ptr: (*dv_qp).sq.buf as usize,
1169                size: 8,
1170            })
1171        }
1172    }
1173
1174    /// Poll for completions on the specified completion queue(s)
1175    ///
1176    /// # Arguments
1177    ///
1178    /// * `target` - Which completion queue(s) to poll (Send, Receive, or Both)
1179    ///
1180    /// # Returns
1181    ///
1182    /// * `Ok(Some(wc))` - A completion was found
1183    /// * `Ok(None)` - No completion was found
1184    /// * `Err(e)` - An error occurred
1185    pub fn poll_completion_target(
1186        &mut self,
1187        target: PollTarget,
1188    ) -> Result<Option<IbvWc>, anyhow::Error> {
1189        unsafe {
1190            let context = self.context as *mut rdmaxcel_sys::ibv_context;
1191            let _outstanding_wqe =
1192                self.send_db_idx + self.recv_db_idx - self.send_cq_idx - self.recv_cq_idx;
1193
1194            // Check for send completions if requested
1195            if (target == PollTarget::Send) && self.send_db_idx > self.send_cq_idx {
1196                let send_cq = self.send_cq as *mut rdmaxcel_sys::ibv_cq;
1197                let ops = &mut (*context).ops;
1198                let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
1199                let ret = ops.poll_cq.as_mut().unwrap()(send_cq, 1, &mut wc);
1200
1201                if ret < 0 {
1202                    return Err(anyhow::anyhow!(
1203                        "Failed to poll send CQ: {}",
1204                        Error::last_os_error()
1205                    ));
1206                }
1207
1208                if ret > 0 {
1209                    if !wc.is_valid() {
1210                        if let Some((status, vendor_err)) = wc.error() {
1211                            return Err(anyhow::anyhow!(
1212                                "Send work completion failed with status: {:?}, vendor error: {}",
1213                                status,
1214                                vendor_err
1215                            ));
1216                        }
1217                    }
1218
1219                    // This should be a send completion
1220                    self.send_cq_idx += 1;
1221
1222                    return Ok(Some(IbvWc::from(wc)));
1223                }
1224            }
1225
1226            // Check for receive completions if requested
1227            if (target == PollTarget::Recv) && self.recv_db_idx > self.recv_cq_idx {
1228                let recv_cq = self.recv_cq as *mut rdmaxcel_sys::ibv_cq;
1229                let ops = &mut (*context).ops;
1230                let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
1231                let ret = ops.poll_cq.as_mut().unwrap()(recv_cq, 1, &mut wc);
1232
1233                if ret < 0 {
1234                    return Err(anyhow::anyhow!(
1235                        "Failed to poll receive CQ: {}",
1236                        Error::last_os_error()
1237                    ));
1238                }
1239
1240                if ret > 0 {
1241                    if !wc.is_valid() {
1242                        if let Some((status, vendor_err)) = wc.error() {
1243                            return Err(anyhow::anyhow!(
1244                                "Receive work completion failed with status: {:?}, vendor error: {}",
1245                                status,
1246                                vendor_err
1247                            ));
1248                        }
1249                    }
1250
1251                    // This should be a receive completion
1252                    self.recv_cq_idx += 1;
1253
1254                    return Ok(Some(IbvWc::from(wc)));
1255                }
1256            }
1257
1258            // No completion found
1259            Ok(None)
1260        }
1261    }
1262
1263    pub fn poll_send_completion(&mut self) -> Result<Option<IbvWc>, anyhow::Error> {
1264        self.poll_completion_target(PollTarget::Send)
1265    }
1266
1267    pub fn poll_recv_completion(&mut self) -> Result<Option<IbvWc>, anyhow::Error> {
1268        self.poll_completion_target(PollTarget::Recv)
1269    }
1270}
1271
1272/// Utility to validate execution context.
1273///
1274/// Remote Execution environments do not always have access to the nvidia_peermem module
1275/// and/or set the PeerMappingOverride parameter due to security. This function can be
1276/// used to validate that the execution context when running operations that need this
1277/// functionality (ie. cudaHostRegisterIoMemory).
1278///
1279/// # Returns
1280///
1281/// * `Ok(())` if the execution context is valid
1282/// * `Err(anyhow::Error)` if the execution context is invalid
1283pub async fn validate_execution_context() -> Result<(), anyhow::Error> {
1284    // Check for nvidia peermem
1285    match fs::read_to_string("/proc/modules") {
1286        Ok(contents) => {
1287            if !contents.contains("nvidia_peermem") {
1288                return Err(anyhow::anyhow!(
1289                    "nvidia_peermem module not found in /proc/modules"
1290                ));
1291            }
1292        }
1293        Err(e) => {
1294            return Err(anyhow::anyhow!(e));
1295        }
1296    }
1297
1298    // Test file access to nvidia params
1299    match fs::read_to_string("/proc/driver/nvidia/params") {
1300        Ok(contents) => {
1301            if !contents.contains("PeerMappingOverride=1") {
1302                return Err(anyhow::anyhow!(
1303                    "PeerMappingOverride=1 not found in /proc/driver/nvidia/params"
1304                ));
1305            }
1306        }
1307        Err(e) => {
1308            return Err(anyhow::anyhow!(e));
1309        }
1310    }
1311    Ok(())
1312}
1313
1314#[cfg(test)]
1315mod tests {
1316    use super::*;
1317
1318    #[test]
1319    fn test_create_connection() {
1320        // Skip test if RDMA devices are not available
1321        if crate::ibverbs_primitives::get_all_devices().len() < 1 {
1322            println!("Skipping test: RDMA devices not available");
1323            return;
1324        }
1325
1326        let config = IbverbsConfig {
1327            use_gpu_direct: false,
1328            ..Default::default()
1329        };
1330        let domain = RdmaDomain::new(config.device.clone());
1331        assert!(domain.is_ok());
1332
1333        let domain = domain.unwrap();
1334        let queue_pair = RdmaQueuePair::new(domain.context, domain.pd, config.clone());
1335        assert!(queue_pair.is_ok());
1336    }
1337
1338    #[test]
1339    fn test_loopback_connection() {
1340        // Skip test if RDMA devices are not available
1341        if crate::ibverbs_primitives::get_all_devices().len() < 1 {
1342            println!("Skipping test: RDMA devices not available");
1343            return;
1344        }
1345
1346        let server_config = IbverbsConfig {
1347            use_gpu_direct: false,
1348            ..Default::default()
1349        };
1350        let client_config = IbverbsConfig {
1351            use_gpu_direct: false,
1352            ..Default::default()
1353        };
1354
1355        let server_domain = RdmaDomain::new(server_config.device.clone()).unwrap();
1356        let client_domain = RdmaDomain::new(client_config.device.clone()).unwrap();
1357
1358        let mut server_qp = RdmaQueuePair::new(
1359            server_domain.context,
1360            server_domain.pd,
1361            server_config.clone(),
1362        )
1363        .unwrap();
1364        let mut client_qp = RdmaQueuePair::new(
1365            client_domain.context,
1366            client_domain.pd,
1367            client_config.clone(),
1368        )
1369        .unwrap();
1370
1371        let server_connection_info = server_qp.get_qp_info().unwrap();
1372        let client_connection_info = client_qp.get_qp_info().unwrap();
1373
1374        assert!(server_qp.connect(&client_connection_info).is_ok());
1375        assert!(client_qp.connect(&server_connection_info).is_ok());
1376    }
1377}