monarch_rdma/
rdma_components.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # RDMA Components
10//!
11//! This module provides the core RDMA building blocks for establishing and managing RDMA connections.
12//!
13//! ## Core Components
14//!
15//! * `RdmaDomain` - Manages RDMA resources including context, protection domain, and memory region
16//! * `RdmaQueuePair` - Handles communication between endpoints via queue pairs and completion queues
17//!
18//! ## RDMA Overview
19//!
20//! Remote Direct Memory Access (RDMA) allows direct memory access from the memory of one computer
21//! into the memory of another without involving either computer's operating system. This permits
22//! high-throughput, low-latency networking with minimal CPU overhead.
23//!
24//! ## Connection Architecture
25//!
26//! The module manages the following ibverbs primitives:
27//!
28//! 1. **Queue Pairs (QP)**: Each connection has a send queue and a receive queue
29//! 2. **Completion Queues (CQ)**: Events are reported when operations complete
30//! 3. **Memory Regions (MR)**: Memory must be registered with the RDMA device before use
31//! 4. **Protection Domains (PD)**: Provide isolation between different connections
32//!
33//! ## Connection Lifecycle
34//!
35//! 1. Create an `RdmaDomain` with `new()`
36//! 2. Create an `RdmaQueuePair` from the domain
37//! 3. Exchange connection info with remote peer (application must handle this)
38//! 4. Connect to remote endpoint with `connect()`
39//! 5. Perform RDMA operations (read/write)
40//! 6. Poll for completions
41//! 7. Resources are cleaned up when dropped
42
43/// Maximum size for a single RDMA operation in bytes (1 GiB)
44const MAX_RDMA_MSG_SIZE: usize = 1024 * 1024 * 1024;
45
46use std::ffi::CStr;
47use std::fs;
48use std::io::Error;
49use std::result::Result;
50use std::thread::sleep;
51use std::time::Duration;
52
53use hyperactor::ActorRef;
54use hyperactor::Named;
55use hyperactor::clock::Clock;
56use hyperactor::clock::RealClock;
57use hyperactor::context;
58use serde::Deserialize;
59use serde::Serialize;
60
61use crate::RdmaDevice;
62use crate::RdmaManagerActor;
63use crate::RdmaManagerMessageClient;
64use crate::ibverbs_primitives::Gid;
65use crate::ibverbs_primitives::IbvWc;
66use crate::ibverbs_primitives::IbverbsConfig;
67use crate::ibverbs_primitives::RdmaOperation;
68use crate::ibverbs_primitives::RdmaQpInfo;
69
70#[derive(Debug, Named, Clone, Serialize, Deserialize)]
71pub struct DoorBell {
72    pub src_ptr: usize,
73    pub dst_ptr: usize,
74    pub size: usize,
75}
76
77impl DoorBell {
78    /// Rings the doorbell to trigger the execution of previously enqueued operations.
79    ///
80    /// This method uses unsafe code to directly interact with the RDMA device,
81    /// sending a signal from the source pointer to the destination pointer.
82    ///
83    /// # Returns
84    /// * `Ok(())` if the operation is successful.
85    /// * `Err(anyhow::Error)` if an error occurs during the operation.
86    pub fn ring(&self) -> Result<(), anyhow::Error> {
87        unsafe {
88            let src_ptr = self.src_ptr as *mut std::ffi::c_void;
89            let dst_ptr = self.dst_ptr as *mut std::ffi::c_void;
90            rdmaxcel_sys::db_ring(dst_ptr, src_ptr);
91            Ok(())
92        }
93    }
94}
95
96#[derive(Debug, Serialize, Deserialize, Named, Clone)]
97pub struct RdmaBuffer {
98    pub owner: ActorRef<RdmaManagerActor>,
99    pub mr_id: usize,
100    pub lkey: u32,
101    pub rkey: u32,
102    pub addr: usize,
103    pub size: usize,
104}
105
106impl RdmaBuffer {
107    /// Read from the RdmaBuffer into the provided memory.
108    ///
109    /// This method transfers data from the buffer into the local memory region provided over RDMA.
110    /// This involves calling a `Put` operation on the RdmaBuffer's actor side.
111    ///
112    /// # Arguments
113    /// * `client` - The actor who is reading.
114    /// * `remote` - RdmaBuffer representing the remote memory region
115    /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
116    ///
117    /// # Returns
118    /// `Ok(bool)` indicating if the operation completed successfully.
119    pub async fn read_into(
120        &self,
121        client: &impl context::Actor,
122        remote: RdmaBuffer,
123        timeout: u64,
124    ) -> Result<bool, anyhow::Error> {
125        tracing::debug!(
126            "[buffer] reading from {:?} into remote ({:?}) at {:?}",
127            self,
128            remote.owner.actor_id(),
129            remote,
130        );
131        let remote_owner = remote.owner.clone(); // Clone before the move!
132        let mut qp = self
133            .owner
134            .request_queue_pair_deprecated(client, remote_owner.clone())
135            .await?;
136
137        qp.put(self.clone(), remote)?;
138        let result = self
139            .wait_for_completion(&mut qp, PollTarget::Send, timeout)
140            .await;
141
142        // Release the queue pair back to the actor
143        self.owner
144            .release_queue_pair_deprecated(client, remote_owner, qp)
145            .await?;
146
147        result
148    }
149
150    /// Write from the provided memory into the RdmaBuffer.
151    ///
152    /// This method performs an RDMA write operation, transferring data from the caller's
153    /// memory region to this buffer.
154    /// This involves calling a `Fetch` operation on the RdmaBuffer's actor side.
155    ///
156    /// # Arguments
157    /// * `client` - The actor who is writing.
158    /// * `remote` - RdmaBuffer representing the remote memory region
159    /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
160    ///
161    /// # Returns
162    /// `Ok(bool)` indicating if the operation completed successfully.
163    pub async fn write_from(
164        &self,
165        client: &impl context::Actor,
166        remote: RdmaBuffer,
167        timeout: u64,
168    ) -> Result<bool, anyhow::Error> {
169        tracing::debug!(
170            "[buffer] writing into {:?} from remote ({:?}) at {:?}",
171            self,
172            remote.owner.actor_id(),
173            remote,
174        );
175        let remote_owner = remote.owner.clone(); // Clone before the move!
176        let mut qp = self
177            .owner
178            .request_queue_pair_deprecated(client, remote_owner.clone())
179            .await?;
180        qp.get(self.clone(), remote)?;
181        let result = self
182            .wait_for_completion(&mut qp, PollTarget::Send, timeout)
183            .await;
184
185        // Release the queue pair back to the actor
186        self.owner
187            .release_queue_pair_deprecated(client, remote_owner, qp)
188            .await?;
189
190        result?;
191        Ok(true)
192    }
193    /// Waits for the completion of an RDMA operation.
194    ///
195    /// This method polls the completion queue until the specified work request completes
196    /// or until the timeout is reached.
197    ///
198    /// # Arguments
199    /// * `qp` - The RDMA Queue Pair to poll for completion
200    /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
201    ///
202    /// # Returns
203    /// `Ok(true)` if the operation completes successfully within the timeout,
204    /// or an error if the timeout is reached
205    async fn wait_for_completion(
206        &self,
207        qp: &mut RdmaQueuePair,
208        poll_target: PollTarget,
209        timeout: u64,
210    ) -> Result<bool, anyhow::Error> {
211        let timeout = Duration::from_secs(timeout);
212        let start_time = std::time::Instant::now();
213
214        while start_time.elapsed() < timeout {
215            match qp.poll_completion_target(poll_target) {
216                Ok(Some(_wc)) => {
217                    tracing::debug!("work completed");
218                    return Ok(true);
219                }
220                Ok(None) => {
221                    RealClock.sleep(Duration::from_millis(1)).await;
222                }
223                Err(e) => {
224                    return Err(anyhow::anyhow!(
225                        "RDMA polling completion failed: {} [lkey={}, rkey={}, addr=0x{:x}, size={}]",
226                        e,
227                        self.lkey,
228                        self.rkey,
229                        self.addr,
230                        self.size
231                    ));
232                }
233            }
234        }
235        tracing::error!("timed out while waiting on request completion");
236        Err(anyhow::anyhow!(
237            "[buffer({:?})] rdma operation did not complete in time",
238            self
239        ))
240    }
241
242    /// Drop the buffer and release remote handles.
243    ///
244    /// This method calls the owning RdmaManagerActor to release the buffer and clean up
245    /// associated memory regions. This is typically called when the buffer is no longer
246    /// needed and resources should be freed.
247    ///
248    /// # Arguments
249    /// * `client` - Mailbox used for communication
250    ///
251    /// # Returns
252    /// `Ok(())` if the operation completed successfully.
253    pub async fn drop_buffer(&self, client: &impl context::Actor) -> Result<(), anyhow::Error> {
254        tracing::debug!("[buffer] dropping buffer {:?}", self);
255        self.owner
256            .release_buffer_deprecated(client, self.clone())
257            .await?;
258        Ok(())
259    }
260}
261
262/// Represents a domain for RDMA operations, encapsulating the necessary resources
263/// for establishing and managing RDMA connections.
264///
265/// An `RdmaDomain` manages the context, protection domain (PD), and memory region (MR)
266/// required for RDMA operations. It provides the foundation for creating queue pairs
267/// and establishing connections between RDMA devices.
268///
269/// # Fields
270///
271/// * `context`: A pointer to the RDMA device context, representing the connection to the RDMA device.
272/// * `pd`: A pointer to the protection domain, which provides isolation between different connections.
273pub struct RdmaDomain {
274    pub context: *mut rdmaxcel_sys::ibv_context,
275    pub pd: *mut rdmaxcel_sys::ibv_pd,
276}
277
278impl std::fmt::Debug for RdmaDomain {
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        f.debug_struct("RdmaDomain")
281            .field("context", &format!("{:p}", self.context))
282            .field("pd", &format!("{:p}", self.pd))
283            .finish()
284    }
285}
286
287// SAFETY:
288// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls.
289// RdmaDomain is `Send` because the raw pointers to ibverbs structs can be
290// accessed from any thread, and it is safe to drop `RdmaDomain` (and run the
291// ibverbs destructors) from any thread.
292unsafe impl Send for RdmaDomain {}
293
294// SAFETY:
295// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls.
296// RdmaDomain is `Sync` because the underlying ibverbs APIs are thread-safe.
297unsafe impl Sync for RdmaDomain {}
298
299impl Drop for RdmaDomain {
300    fn drop(&mut self) {
301        unsafe {
302            rdmaxcel_sys::ibv_dealloc_pd(self.pd);
303        }
304    }
305}
306
307impl RdmaDomain {
308    /// Creates a new RdmaDomain.
309    ///
310    /// This function initializes the RDMA device context, creates a protection domain,
311    /// and registers a memory region with appropriate access permissions.
312    ///
313    /// SAFETY:
314    /// Our memory region (MR) registration uses implicit ODP for RDMA access, which maps large virtual
315    /// address ranges without explicit pinning. This is convenient, but it broadens the memory footprint
316    /// exposed to the NIC and introduces a security liability.
317    ///
318    /// We currently assume a trusted, single-environment and are not enforcing finer-grained memory isolation
319    /// at this layer. We plan to investigate mitigations - such as memory windows or tighter registration
320    /// boundaries in future follow-ups.
321    ///
322    /// # Arguments
323    ///
324    /// * `config` - Configuration settings for the RDMA operations
325    ///
326    /// # Errors
327    ///
328    /// This function may return errors if:
329    /// * No RDMA devices are found
330    /// * The specified device cannot be found
331    /// * Device context creation fails
332    /// * Protection domain allocation fails
333    /// * Memory region registration fails
334    pub fn new(device: RdmaDevice) -> Result<Self, anyhow::Error> {
335        tracing::debug!("creating RdmaDomain for device {}", device.name());
336        // SAFETY:
337        // This code uses unsafe rdmaxcel_sys calls to interact with the RDMA device, but is safe because:
338        // - All pointers are properly initialized and checked for null before use
339        // - Memory registration follows the ibverbs API contract with proper access flags
340        // - Resources are properly cleaned up in error cases to prevent leaks
341        // - The operations follow the documented RDMA protocol for device initialization
342        unsafe {
343            // Get the device based on the provided RdmaDevice
344            let device_name = device.name();
345            let mut num_devices = 0i32;
346            let devices = rdmaxcel_sys::ibv_get_device_list(&mut num_devices as *mut _);
347
348            if devices.is_null() || num_devices == 0 {
349                return Err(anyhow::anyhow!("no RDMA devices found"));
350            }
351
352            // Find the device with the matching name
353            let mut device_ptr = std::ptr::null_mut();
354            for i in 0..num_devices {
355                let dev = *devices.offset(i as isize);
356                let dev_name =
357                    CStr::from_ptr(rdmaxcel_sys::ibv_get_device_name(dev)).to_string_lossy();
358
359                if dev_name == *device_name {
360                    device_ptr = dev;
361                    break;
362                }
363            }
364
365            // If we didn't find the device, return an error
366            if device_ptr.is_null() {
367                rdmaxcel_sys::ibv_free_device_list(devices);
368                return Err(anyhow::anyhow!("device '{}' not found", device_name));
369            }
370            tracing::info!("using RDMA device: {}", device_name);
371
372            // Open device
373            let context = rdmaxcel_sys::ibv_open_device(device_ptr);
374            if context.is_null() {
375                rdmaxcel_sys::ibv_free_device_list(devices);
376                let os_error = Error::last_os_error();
377                return Err(anyhow::anyhow!("failed to create context: {}", os_error));
378            }
379
380            // Create protection domain
381            let pd = rdmaxcel_sys::ibv_alloc_pd(context);
382            if pd.is_null() {
383                rdmaxcel_sys::ibv_close_device(context);
384                rdmaxcel_sys::ibv_free_device_list(devices);
385                let os_error = Error::last_os_error();
386                return Err(anyhow::anyhow!(
387                    "failed to create protection domain (PD): {}",
388                    os_error
389                ));
390            }
391
392            // Avoids memory leaks
393            rdmaxcel_sys::ibv_free_device_list(devices);
394
395            let domain = RdmaDomain { context, pd };
396
397            Ok(domain)
398        }
399    }
400}
401/// Enum to specify which completion queue to poll
402#[derive(Debug, Clone, Copy, PartialEq)]
403pub enum PollTarget {
404    Send,
405    Recv,
406}
407
408/// Represents an RDMA Queue Pair (QP) that enables communication between two endpoints.
409///
410/// An `RdmaQueuePair` encapsulates the send and receive queues, completion queue,
411/// and other resources needed for RDMA communication. It provides methods for
412/// establishing connections and performing RDMA operations like read and write.
413///
414/// # Fields
415///
416/// * `send_cq` - Send Completion Queue pointer for tracking send operation completions
417/// * `recv_cq` - Receive Completion Queue pointer for tracking receive operation completions
418/// * `qp` - Queue Pair pointer that manages send and receive operations
419/// * `dv_qp` - Pointer to the mlx5 device-specific queue pair structure
420/// * `dv_send_cq` - Pointer to the mlx5 device-specific send completion queue structure
421/// * `dv_recv_cq` - Pointer to the mlx5 device-specific receive completion queue structure
422/// * `context` - RDMA device context pointer
423/// * `config` - Configuration settings for the queue pair
424///
425/// # Connection Lifecycle
426///
427/// 1. Create with `new()` from an `RdmaDomain`
428/// 2. Get connection info with `get_qp_info()`
429/// 3. Exchange connection info with remote peer (application must handle this)
430/// 4. Connect to remote endpoint with `connect()`
431/// 5. Perform RDMA operations with `put()` or `get()`
432/// 6. Poll for completions with `poll_send_completion()` or `poll_recv_completion()`
433#[derive(Debug, Serialize, Deserialize, Named, Clone)]
434pub struct RdmaQueuePair {
435    pub send_cq: usize,    // *mut rdmaxcel_sys::ibv_cq,
436    pub recv_cq: usize,    // *mut rdmaxcel_sys::ibv_cq,
437    pub qp: usize,         // *mut rdmaxcel_sys::ibv_qp,
438    pub dv_qp: usize,      // *mut rdmaxcel_sys::mlx5dv_qp,
439    pub dv_send_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
440    pub dv_recv_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
441    context: usize,        // *mut rdmaxcel_sys::ibv_context,
442    config: IbverbsConfig,
443    pub send_wqe_idx: u64,
444    pub send_db_idx: u64,
445    pub send_cq_idx: u64,
446    pub recv_wqe_idx: u64,
447    pub recv_db_idx: u64,
448    pub recv_cq_idx: u64,
449    rts_timestamp: u64,
450}
451
452impl RdmaQueuePair {
453    /// Applies hardware initialization delay if this is the first operation since RTS.
454    ///
455    /// This ensures the hardware has sufficient time to settle after reaching
456    /// Ready-to-Send state before the first actual operation.
457    fn apply_first_op_delay(&self, wr_id: u64) {
458        if wr_id == 0 {
459            assert!(
460                self.rts_timestamp != u64::MAX,
461                "First operation attempted before queue pair reached RTS state! Call connect() first."
462            );
463            let current_nanos = RealClock
464                .system_time_now()
465                .duration_since(std::time::UNIX_EPOCH)
466                .unwrap()
467                .as_nanos() as u64;
468            let elapsed_nanos = current_nanos - self.rts_timestamp;
469            let elapsed = Duration::from_nanos(elapsed_nanos);
470            let init_delay = Duration::from_millis(self.config.hw_init_delay_ms);
471            if elapsed < init_delay {
472                let remaining_delay = init_delay - elapsed;
473                sleep(remaining_delay);
474            }
475        }
476    }
477
478    /// Creates a new RdmaQueuePair from a given RdmaDomain.
479    ///
480    /// This function initializes a new Queue Pair (QP) and associated Completion Queue (CQ)
481    /// using the resources from the provided RdmaDomain. The QP is created in the RESET state
482    /// and must be transitioned to other states via the `connect()` method before use.
483    ///
484    /// # Arguments
485    ///
486    /// * `domain` - Reference to an RdmaDomain that provides the context, protection domain,
487    ///   and memory region for this queue pair
488    ///
489    /// # Returns
490    ///
491    /// * `Result<Self>` - A new RdmaQueuePair instance or an error if creation fails
492    ///
493    /// # Errors
494    ///
495    /// This function may return errors if:
496    /// * Completion queue (CQ) creation fails
497    /// * Queue pair (QP) creation fails
498    pub fn new(
499        context: *mut rdmaxcel_sys::ibv_context,
500        pd: *mut rdmaxcel_sys::ibv_pd,
501        config: IbverbsConfig,
502    ) -> Result<Self, anyhow::Error> {
503        tracing::debug!("creating an RdmaQueuePair from config {}", config);
504        unsafe {
505            // standard ibverbs QP
506            let qp = rdmaxcel_sys::create_qp(
507                context,
508                pd,
509                config.cq_entries,
510                config.max_send_wr.try_into().unwrap(),
511                config.max_recv_wr.try_into().unwrap(),
512                config.max_send_sge.try_into().unwrap(),
513                config.max_recv_sge.try_into().unwrap(),
514            );
515
516            if qp.is_null() {
517                let os_error = Error::last_os_error();
518                return Err(anyhow::anyhow!(
519                    "failed to create queue pair (QP): {}",
520                    os_error
521                ));
522            }
523
524            let send_cq = (*qp).send_cq;
525            let recv_cq = (*qp).recv_cq;
526
527            // mlx5dv provider APIs
528            let dv_qp = rdmaxcel_sys::create_mlx5dv_qp(qp);
529            let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq(qp);
530            let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq(qp);
531
532            if dv_qp.is_null() || dv_send_cq.is_null() || dv_recv_cq.is_null() {
533                rdmaxcel_sys::ibv_destroy_cq((*qp).recv_cq);
534                rdmaxcel_sys::ibv_destroy_cq((*qp).send_cq);
535                rdmaxcel_sys::ibv_destroy_qp(qp);
536                return Err(anyhow::anyhow!(
537                    "failed to init mlx5dv_qp or completion queues"
538                ));
539            }
540
541            // GPU Direct RDMA specific registrations
542            if config.use_gpu_direct {
543                let ret = rdmaxcel_sys::register_cuda_memory(dv_qp, dv_recv_cq, dv_send_cq);
544                if ret != 0 {
545                    rdmaxcel_sys::ibv_destroy_cq((*qp).recv_cq);
546                    rdmaxcel_sys::ibv_destroy_cq((*qp).send_cq);
547                    rdmaxcel_sys::ibv_destroy_qp(qp);
548                    return Err(anyhow::anyhow!(
549                        "failed to register GPU Direct RDMA memory: {:?}",
550                        ret
551                    ));
552                }
553            }
554            Ok(RdmaQueuePair {
555                send_cq: send_cq as usize,
556                recv_cq: recv_cq as usize,
557                qp: qp as usize,
558                dv_qp: dv_qp as usize,
559                dv_send_cq: dv_send_cq as usize,
560                dv_recv_cq: dv_recv_cq as usize,
561                context: context as usize,
562                config,
563                recv_db_idx: 0,
564                recv_wqe_idx: 0,
565                recv_cq_idx: 0,
566                send_db_idx: 0,
567                send_wqe_idx: 0,
568                send_cq_idx: 0,
569                rts_timestamp: u64::MAX,
570            })
571        }
572    }
573
574    /// Returns the information required for a remote peer to connect to this queue pair.
575    ///
576    /// This method retrieves the local queue pair attributes and port information needed by
577    /// a remote peer to establish an RDMA connection. The returned `RdmaQpInfo` contains
578    /// the queue pair number, LID, GID, and other necessary connection parameters.
579    ///
580    /// # Returns
581    ///
582    /// * `Result<RdmaQpInfo>` - Connection information for the remote peer or an error
583    ///
584    /// # Errors
585    ///
586    /// This function may return errors if:
587    /// * Port attribute query fails
588    /// * GID query fails
589    pub fn get_qp_info(&mut self) -> Result<RdmaQpInfo, anyhow::Error> {
590        // SAFETY:
591        // This code uses unsafe rdmaxcel_sys calls to query RDMA device information, but is safe because:
592        // - All pointers are properly initialized before use
593        // - Port and GID queries follow the documented ibverbs API contract
594        // - Error handling properly checks return codes from ibverbs functions
595        // - The memory address provided is only stored, not dereferenced in this function
596        unsafe {
597            let context = self.context as *mut rdmaxcel_sys::ibv_context;
598            let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
599            let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
600            let errno = rdmaxcel_sys::ibv_query_port(
601                context,
602                self.config.port_num,
603                &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
604            );
605            if errno != 0 {
606                let os_error = Error::last_os_error();
607                return Err(anyhow::anyhow!(
608                    "Failed to query port attributes: {}",
609                    os_error
610                ));
611            }
612
613            let mut gid = Gid::default();
614            let ret = rdmaxcel_sys::ibv_query_gid(
615                context,
616                self.config.port_num,
617                i32::from(self.config.gid_index),
618                gid.as_mut(),
619            );
620            if ret != 0 {
621                return Err(anyhow::anyhow!("Failed to query GID"));
622            }
623
624            Ok(RdmaQpInfo {
625                qp_num: (*qp).qp_num,
626                lid: port_attr.lid,
627                gid: Some(gid),
628                psn: self.config.psn,
629            })
630        }
631    }
632
633    pub fn state(&mut self) -> Result<u32, anyhow::Error> {
634        // SAFETY: This block interacts with the RDMA device through rdmaxcel_sys calls.
635        unsafe {
636            let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
637            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
638                ..Default::default()
639            };
640            let mut qp_init_attr = rdmaxcel_sys::ibv_qp_init_attr {
641                ..Default::default()
642            };
643            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE;
644            let errno =
645                rdmaxcel_sys::ibv_query_qp(qp, &mut qp_attr, mask.0 as i32, &mut qp_init_attr);
646            if errno != 0 {
647                let os_error = Error::last_os_error();
648                return Err(anyhow::anyhow!("failed to query QP state: {}", os_error));
649            }
650            Ok(qp_attr.qp_state)
651        }
652    }
653    /// Connect to a remote Rdma connection point.
654    ///
655    /// This performs the necessary QP state transitions (INIT->RTR->RTS) to establish a connection.
656    ///
657    /// # Arguments
658    ///
659    /// * `connection_info` - The remote connection info to connect to
660    pub fn connect(&mut self, connection_info: &RdmaQpInfo) -> Result<(), anyhow::Error> {
661        // SAFETY:
662        // This unsafe block is necessary because we're interacting with the RDMA device through rdmaxcel_sys calls.
663        // The operations are safe because:
664        // 1. We're following the documented ibverbs API contract
665        // 2. All pointers used are properly initialized and owned by this struct
666        // 3. The QP state transitions (INIT->RTR->RTS) follow the required RDMA connection protocol
667        // 4. Memory access is properly bounded by the registered memory regions
668        unsafe {
669            // Transition to INIT
670            let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
671
672            let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
673                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
674                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
675                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
676
677            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
678                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_INIT,
679                qp_access_flags: qp_access_flags.0,
680                pkey_index: self.config.pkey_index,
681                port_num: self.config.port_num,
682                ..Default::default()
683            };
684
685            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
686                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX
687                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT
688                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS;
689
690            let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32);
691            if errno != 0 {
692                let os_error = Error::last_os_error();
693                return Err(anyhow::anyhow!(
694                    "failed to transition QP to INIT: {}",
695                    os_error
696                ));
697            }
698
699            // Transition to RTR (Ready to Receive)
700            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
701                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTR,
702                path_mtu: self.config.path_mtu,
703                dest_qp_num: connection_info.qp_num,
704                rq_psn: connection_info.psn,
705                max_dest_rd_atomic: self.config.max_dest_rd_atomic,
706                min_rnr_timer: self.config.min_rnr_timer,
707                ah_attr: rdmaxcel_sys::ibv_ah_attr {
708                    dlid: connection_info.lid,
709                    sl: 0,
710                    src_path_bits: 0,
711                    port_num: self.config.port_num,
712                    grh: Default::default(),
713                    ..Default::default()
714                },
715                ..Default::default()
716            };
717
718            // If the remote connection info contains a Gid, the routing will be global.
719            // Otherwise, it will be local, i.e. using LID.
720            if let Some(gid) = connection_info.gid {
721                qp_attr.ah_attr.is_global = 1;
722                qp_attr.ah_attr.grh.dgid = gid.into();
723                qp_attr.ah_attr.grh.hop_limit = 0xff;
724                qp_attr.ah_attr.grh.sgid_index = self.config.gid_index;
725            } else {
726                // Use LID-based routing, e.g. for Infiniband/RoCEv1
727                qp_attr.ah_attr.is_global = 0;
728            }
729
730            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
731                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_AV
732                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PATH_MTU
733                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_DEST_QPN
734                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RQ_PSN
735                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC
736                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER;
737
738            let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32);
739            if errno != 0 {
740                let os_error = Error::last_os_error();
741                return Err(anyhow::anyhow!(
742                    "failed to transition QP to RTR: {}",
743                    os_error
744                ));
745            }
746
747            // Transition to RTS (Ready to Send)
748            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
749                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS,
750                sq_psn: self.config.psn,
751                max_rd_atomic: self.config.max_rd_atomic,
752                retry_cnt: self.config.retry_cnt,
753                rnr_retry: self.config.rnr_retry,
754                timeout: self.config.qp_timeout,
755                ..Default::default()
756            };
757
758            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
759                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_TIMEOUT
760                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RETRY_CNT
761                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_SQ_PSN
762                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY
763                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC;
764
765            let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32);
766            if errno != 0 {
767                let os_error = Error::last_os_error();
768                return Err(anyhow::anyhow!(
769                    "failed to transition QP to RTS: {}",
770                    os_error
771                ));
772            }
773            tracing::debug!(
774                "connection sequence has successfully completed (qp: {:?})",
775                qp
776            );
777
778            // Record RTS time now that the queue pair is ready to send
779            self.rts_timestamp = RealClock
780                .system_time_now()
781                .duration_since(std::time::UNIX_EPOCH)
782                .unwrap()
783                .as_nanos() as u64;
784
785            Ok(())
786        }
787    }
788
789    pub fn recv(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> {
790        let idx = self.recv_wqe_idx;
791        self.recv_wqe_idx += 1;
792        self.send_wqe(
793            0,
794            lhandle.lkey,
795            0,
796            idx,
797            true,
798            RdmaOperation::Recv,
799            0,
800            rhandle.rkey,
801        )
802        .unwrap();
803        Ok(())
804    }
805
806    pub fn put_with_recv(
807        &mut self,
808        lhandle: RdmaBuffer,
809        rhandle: RdmaBuffer,
810    ) -> Result<(), anyhow::Error> {
811        let idx = self.send_wqe_idx;
812        self.send_wqe_idx += 1;
813        self.post_op(
814            lhandle.addr,
815            lhandle.lkey,
816            lhandle.size,
817            idx,
818            true,
819            RdmaOperation::WriteWithImm,
820            rhandle.addr,
821            rhandle.rkey,
822        )
823        .unwrap();
824        self.send_db_idx += 1;
825        Ok(())
826    }
827
828    pub fn put(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> {
829        let total_size = lhandle.size;
830        if rhandle.size < total_size {
831            return Err(anyhow::anyhow!(
832                "Remote buffer size ({}) is smaller than local buffer size ({})",
833                rhandle.size,
834                total_size
835            ));
836        }
837
838        let mut remaining = total_size;
839        let mut offset = 0;
840        while remaining > 0 {
841            let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
842            let idx = self.send_wqe_idx;
843            self.send_wqe_idx += 1;
844            self.post_op(
845                lhandle.addr + offset,
846                lhandle.lkey,
847                chunk_size,
848                idx,
849                true,
850                RdmaOperation::Write,
851                rhandle.addr + offset,
852                rhandle.rkey,
853            )?;
854            self.send_db_idx += 1;
855
856            remaining -= chunk_size;
857            offset += chunk_size;
858        }
859
860        Ok(())
861    }
862
863    /// Get a doorbell for the queue pair.
864    ///
865    /// This method returns a doorbell that can be used to trigger the execution of
866    /// previously enqueued operations.
867    ///
868    /// # Returns
869    ///
870    /// * `Result<DoorBell, anyhow::Error>` - A doorbell for the queue pair
871    pub fn ring_doorbell(&mut self) -> Result<(), anyhow::Error> {
872        unsafe {
873            let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
874            let base_ptr = (*dv_qp).sq.buf as *mut u8;
875            let wqe_cnt = (*dv_qp).sq.wqe_cnt;
876            let stride = (*dv_qp).sq.stride;
877            if (wqe_cnt as u64) < (self.send_wqe_idx - self.send_db_idx) {
878                return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
879            }
880            while self.send_db_idx < self.send_wqe_idx {
881                let offset = (self.send_db_idx % wqe_cnt as u64) * stride as u64;
882                let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
883                rdmaxcel_sys::db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
884                self.send_db_idx += 1;
885            }
886            Ok(())
887        }
888    }
889
890    /// Enqueues a put operation without ringing the doorbell.
891    ///
892    /// This method prepares a put operation but does not execute it.
893    /// Use `get_doorbell().ring()` to execute the operation.
894    ///
895    /// # Arguments
896    ///
897    /// * `lhandle` - Local buffer handle
898    /// * `rhandle` - Remote buffer handle
899    ///
900    /// # Returns
901    ///
902    /// * `Result<(), anyhow::Error>` - Success or error
903    pub fn enqueue_put(
904        &mut self,
905        lhandle: RdmaBuffer,
906        rhandle: RdmaBuffer,
907    ) -> Result<(), anyhow::Error> {
908        let idx = self.send_wqe_idx;
909        self.send_wqe_idx += 1;
910        self.send_wqe(
911            lhandle.addr,
912            lhandle.lkey,
913            lhandle.size,
914            idx,
915            true,
916            RdmaOperation::Write,
917            rhandle.addr,
918            rhandle.rkey,
919        )?;
920        Ok(())
921    }
922
923    /// Enqueues a put with receive operation without ringing the doorbell.
924    ///
925    /// This method prepares a put with receive operation but does not execute it.
926    /// Use `get_doorbell().ring()` to execute the operation.
927    ///
928    /// # Arguments
929    ///
930    /// * `lhandle` - Local buffer handle
931    /// * `rhandle` - Remote buffer handle
932    ///
933    /// # Returns
934    ///
935    /// * `Result<(), anyhow::Error>` - Success or error
936    pub fn enqueue_put_with_recv(
937        &mut self,
938        lhandle: RdmaBuffer,
939        rhandle: RdmaBuffer,
940    ) -> Result<(), anyhow::Error> {
941        let idx = self.send_wqe_idx;
942        self.send_wqe_idx += 1;
943        self.send_wqe(
944            lhandle.addr,
945            lhandle.lkey,
946            lhandle.size,
947            idx,
948            true,
949            RdmaOperation::WriteWithImm,
950            rhandle.addr,
951            rhandle.rkey,
952        )?;
953        Ok(())
954    }
955
956    /// Enqueues a get operation without ringing the doorbell.
957    ///
958    /// This method prepares a get operation but does not execute it.
959    /// Use `get_doorbell().ring()` to execute the operation.
960    ///
961    /// # Arguments
962    ///
963    /// * `lhandle` - Local buffer handle
964    /// * `rhandle` - Remote buffer handle
965    ///
966    /// # Returns
967    ///
968    /// * `Result<(), anyhow::Error>` - Success or error
969    pub fn enqueue_get(
970        &mut self,
971        lhandle: RdmaBuffer,
972        rhandle: RdmaBuffer,
973    ) -> Result<(), anyhow::Error> {
974        let idx = self.send_wqe_idx;
975        self.send_wqe_idx += 1;
976        self.send_wqe(
977            lhandle.addr,
978            lhandle.lkey,
979            lhandle.size,
980            idx,
981            true,
982            RdmaOperation::Read,
983            rhandle.addr,
984            rhandle.rkey,
985        )?;
986        Ok(())
987    }
988
989    pub fn get(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> {
990        let total_size = lhandle.size;
991        if rhandle.size < total_size {
992            return Err(anyhow::anyhow!(
993                "Remote buffer size ({}) is smaller than local buffer size ({})",
994                rhandle.size,
995                total_size
996            ));
997        }
998
999        let mut remaining = total_size;
1000        let mut offset = 0;
1001
1002        while remaining > 0 {
1003            let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
1004            let idx = self.send_wqe_idx;
1005            self.send_wqe_idx += 1;
1006            self.post_op(
1007                lhandle.addr + offset,
1008                lhandle.lkey,
1009                chunk_size,
1010                idx,
1011                true,
1012                RdmaOperation::Read,
1013                rhandle.addr + offset,
1014                rhandle.rkey,
1015            )?;
1016            self.send_db_idx += 1;
1017
1018            remaining -= chunk_size;
1019            offset += chunk_size;
1020        }
1021
1022        Ok(())
1023    }
1024
1025    /// Posts a request to the queue pair.
1026    ///
1027    /// # Arguments
1028    ///
1029    /// * `local_addr` - The local address containing data to send
1030    /// * `length` - Length of the data to send
1031    /// * `wr_id` - Work request ID for completion identification
1032    /// * `signaled` - Whether to generate a completion event
1033    /// * `op_type` - Optional operation type
1034    /// * `raddr` - the remote address, representing the memory location on the remote peer
1035    /// * `rkey` - the remote key, representing the key required to access the remote memory region
1036    fn post_op(
1037        &mut self,
1038        laddr: usize,
1039        lkey: u32,
1040        length: usize,
1041        wr_id: u64,
1042        signaled: bool,
1043        op_type: RdmaOperation,
1044        raddr: usize,
1045        rkey: u32,
1046    ) -> Result<(), anyhow::Error> {
1047        // SAFETY:
1048        // This code uses unsafe rdmaxcel_sys calls to post work requests to the RDMA device, but is safe because:
1049        // - All pointers (send_sge, send_wr) are properly initialized on the stack before use
1050        // - The memory address in `local_addr` is not dereferenced, only passed to the device
1051        // - The remote connection info is verified to exist before accessing
1052        // - The ibverbs post_send operation follows the documented API contract
1053        // - Error codes from the device are properly checked and propagated
1054        unsafe {
1055            let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
1056            let context = self.context as *mut rdmaxcel_sys::ibv_context;
1057            let ops = &mut (*context).ops;
1058            let errno;
1059            if op_type == RdmaOperation::Recv {
1060                let mut sge = rdmaxcel_sys::ibv_sge {
1061                    addr: laddr as u64,
1062                    length: length as u32,
1063                    lkey,
1064                };
1065                let mut wr = rdmaxcel_sys::ibv_recv_wr {
1066                    wr_id: wr_id.into(),
1067                    sg_list: &mut sge as *mut _,
1068                    num_sge: 1,
1069                    ..Default::default()
1070                };
1071                let mut bad_wr: *mut rdmaxcel_sys::ibv_recv_wr = std::ptr::null_mut();
1072                errno = ops.post_recv.as_mut().unwrap()(qp, &mut wr as *mut _, &mut bad_wr);
1073            } else if op_type == RdmaOperation::Write
1074                || op_type == RdmaOperation::Read
1075                || op_type == RdmaOperation::WriteWithImm
1076            {
1077                // Apply hardware initialization delay if this is the first operation
1078                self.apply_first_op_delay(wr_id);
1079                let send_flags = if signaled {
1080                    rdmaxcel_sys::ibv_send_flags::IBV_SEND_SIGNALED.0
1081                } else {
1082                    0
1083                };
1084                let mut sge = rdmaxcel_sys::ibv_sge {
1085                    addr: laddr as u64,
1086                    length: length as u32,
1087                    lkey,
1088                };
1089                let mut wr = rdmaxcel_sys::ibv_send_wr {
1090                    wr_id: wr_id.into(),
1091                    next: std::ptr::null_mut(),
1092                    sg_list: &mut sge as *mut _,
1093                    num_sge: 1,
1094                    opcode: op_type.into(),
1095                    send_flags,
1096                    wr: Default::default(),
1097                    qp_type: Default::default(),
1098                    __bindgen_anon_1: Default::default(),
1099                    __bindgen_anon_2: Default::default(),
1100                };
1101
1102                wr.wr.rdma.remote_addr = raddr as u64;
1103                wr.wr.rdma.rkey = rkey;
1104                let mut bad_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut();
1105
1106                errno = ops.post_send.as_mut().unwrap()(qp, &mut wr as *mut _, &mut bad_wr);
1107            } else {
1108                panic!("Not Implemented");
1109            }
1110
1111            if errno != 0 {
1112                let os_error = Error::last_os_error();
1113                return Err(anyhow::anyhow!("Failed to post send request: {}", os_error));
1114            }
1115            tracing::debug!(
1116                "completed sending {:?} request (lkey: {}, addr: 0x{:x}, length {}) to (raddr 0x{:x}, rkey {})",
1117                op_type,
1118                lkey,
1119                laddr,
1120                length,
1121                raddr,
1122                rkey,
1123            );
1124
1125            Ok(())
1126        }
1127    }
1128
1129    fn send_wqe(
1130        &mut self,
1131        laddr: usize,
1132        lkey: u32,
1133        length: usize,
1134        wr_id: u64,
1135        signaled: bool,
1136        op_type: RdmaOperation,
1137        raddr: usize,
1138        rkey: u32,
1139    ) -> Result<DoorBell, anyhow::Error> {
1140        unsafe {
1141            let op_type_val = match op_type {
1142                RdmaOperation::Write => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE,
1143                RdmaOperation::WriteWithImm => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE_IMM,
1144                RdmaOperation::Read => rdmaxcel_sys::MLX5_OPCODE_RDMA_READ,
1145                RdmaOperation::Recv => 0,
1146            };
1147
1148            let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
1149            let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
1150            let _dv_cq = if op_type == RdmaOperation::Recv {
1151                self.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq
1152            } else {
1153                self.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq
1154            };
1155
1156            // Create the WQE parameters struct
1157
1158            let buf = if op_type == RdmaOperation::Recv {
1159                (*dv_qp).rq.buf as *mut u8
1160            } else {
1161                (*dv_qp).sq.buf as *mut u8
1162            };
1163
1164            let params = rdmaxcel_sys::wqe_params_t {
1165                laddr,
1166                lkey,
1167                length,
1168                wr_id: wr_id.into(),
1169                signaled,
1170                op_type: op_type_val,
1171                raddr,
1172                rkey,
1173                qp_num: (*qp).qp_num,
1174                buf,
1175                dbrec: (*dv_qp).dbrec,
1176                wqe_cnt: (*dv_qp).sq.wqe_cnt,
1177            };
1178
1179            // Call the C function to post the WQE
1180            if op_type == RdmaOperation::Recv {
1181                rdmaxcel_sys::recv_wqe(params);
1182                std::ptr::write_volatile((*dv_qp).dbrec, 1_u32.to_be());
1183            } else {
1184                rdmaxcel_sys::send_wqe(params);
1185            };
1186
1187            // Create and return a DoorBell struct
1188            Ok(DoorBell {
1189                dst_ptr: (*dv_qp).bf.reg as usize,
1190                src_ptr: (*dv_qp).sq.buf as usize,
1191                size: 8,
1192            })
1193        }
1194    }
1195
1196    /// Poll for completions on the specified completion queue(s)
1197    ///
1198    /// # Arguments
1199    ///
1200    /// * `target` - Which completion queue(s) to poll (Send, Receive)
1201    ///
1202    /// # Returns
1203    ///
1204    /// * `Ok(Some(wc))` - A completion was found
1205    /// * `Ok(None)` - No completion was found
1206    /// * `Err(e)` - An error occurred
1207    pub fn poll_completion_target(
1208        &mut self,
1209        target: PollTarget,
1210    ) -> Result<Option<IbvWc>, anyhow::Error> {
1211        unsafe {
1212            let context = self.context as *mut rdmaxcel_sys::ibv_context;
1213            let _outstanding_wqe =
1214                self.send_db_idx + self.recv_db_idx - self.send_cq_idx - self.recv_cq_idx;
1215
1216            // Check for send completions if requested
1217            if (target == PollTarget::Send) && self.send_db_idx > self.send_cq_idx {
1218                let send_cq = self.send_cq as *mut rdmaxcel_sys::ibv_cq;
1219                let ops = &mut (*context).ops;
1220                let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
1221                let ret = ops.poll_cq.as_mut().unwrap()(send_cq, 1, &mut wc);
1222
1223                if ret < 0 {
1224                    return Err(anyhow::anyhow!(
1225                        "Failed to poll send CQ: {}",
1226                        Error::last_os_error()
1227                    ));
1228                }
1229
1230                if ret > 0 {
1231                    if !wc.is_valid() {
1232                        if let Some((status, vendor_err)) = wc.error() {
1233                            return Err(anyhow::anyhow!(
1234                                "Send work completion failed with status: {:?}, vendor error: {}, wr_id: {}, send_cq_idx: {}",
1235                                status,
1236                                vendor_err,
1237                                wc.wr_id(),
1238                                self.send_cq_idx,
1239                            ));
1240                        }
1241                    }
1242
1243                    // This should be a send completion - verify it's the one we're waiting for
1244                    if wc.wr_id() == self.send_cq_idx {
1245                        self.send_cq_idx += 1;
1246                    }
1247                    // finished polling, return the last completion
1248                    if self.send_cq_idx == self.send_db_idx {
1249                        return Ok(Some(IbvWc::from(wc)));
1250                    }
1251                }
1252            }
1253
1254            // Check for receive completions if requested
1255            if (target == PollTarget::Recv) && self.recv_db_idx > self.recv_cq_idx {
1256                let recv_cq = self.recv_cq as *mut rdmaxcel_sys::ibv_cq;
1257                let ops = &mut (*context).ops;
1258                let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
1259                let ret = ops.poll_cq.as_mut().unwrap()(recv_cq, 1, &mut wc);
1260
1261                if ret < 0 {
1262                    return Err(anyhow::anyhow!(
1263                        "Failed to poll receive CQ: {}",
1264                        Error::last_os_error()
1265                    ));
1266                }
1267
1268                if ret > 0 {
1269                    if !wc.is_valid() {
1270                        if let Some((status, vendor_err)) = wc.error() {
1271                            return Err(anyhow::anyhow!(
1272                                "Recv work completion failed with status: {:?}, vendor error: {}, wr_id: {}, send_cq_idx: {}",
1273                                status,
1274                                vendor_err,
1275                                wc.wr_id(),
1276                                self.recv_cq_idx,
1277                            ));
1278                        }
1279                    }
1280
1281                    // This should be a send completion - verify it's the one we're waiting for
1282                    if wc.wr_id() == self.recv_cq_idx {
1283                        self.recv_cq_idx += 1;
1284                    }
1285                    // finished polling, return the last completion
1286                    if self.recv_cq_idx == self.recv_db_idx {
1287                        return Ok(Some(IbvWc::from(wc)));
1288                    }
1289                }
1290            }
1291
1292            // No completion found
1293            Ok(None)
1294        }
1295    }
1296
1297    pub fn poll_send_completion(&mut self) -> Result<Option<IbvWc>, anyhow::Error> {
1298        self.poll_completion_target(PollTarget::Send)
1299    }
1300
1301    pub fn poll_recv_completion(&mut self) -> Result<Option<IbvWc>, anyhow::Error> {
1302        self.poll_completion_target(PollTarget::Recv)
1303    }
1304}
1305
1306/// Utility to validate execution context.
1307///
1308/// Remote Execution environments do not always have access to the nvidia_peermem module
1309/// and/or set the PeerMappingOverride parameter due to security. This function can be
1310/// used to validate that the execution context when running operations that need this
1311/// functionality (ie. cudaHostRegisterIoMemory).
1312///
1313/// # Returns
1314///
1315/// * `Ok(())` if the execution context is valid
1316/// * `Err(anyhow::Error)` if the execution context is invalid
1317pub async fn validate_execution_context() -> Result<(), anyhow::Error> {
1318    // Check for nvidia peermem
1319    match fs::read_to_string("/proc/modules") {
1320        Ok(contents) => {
1321            if !contents.contains("nvidia_peermem") {
1322                return Err(anyhow::anyhow!(
1323                    "nvidia_peermem module not found in /proc/modules"
1324                ));
1325            }
1326        }
1327        Err(e) => {
1328            return Err(anyhow::anyhow!(e));
1329        }
1330    }
1331
1332    // Test file access to nvidia params
1333    match fs::read_to_string("/proc/driver/nvidia/params") {
1334        Ok(contents) => {
1335            if !contents.contains("PeerMappingOverride=1") {
1336                return Err(anyhow::anyhow!(
1337                    "PeerMappingOverride=1 not found in /proc/driver/nvidia/params"
1338                ));
1339            }
1340        }
1341        Err(e) => {
1342            return Err(anyhow::anyhow!(e));
1343        }
1344    }
1345    Ok(())
1346}
1347
1348/// Get all segments that have been registered with MRs
1349///
1350/// # Returns
1351/// * `Vec<SegmentInfo>` - Vector containing all registered segment information
1352pub fn get_registered_cuda_segments() -> Vec<rdmaxcel_sys::rdma_segment_info_t> {
1353    unsafe {
1354        let segment_count = rdmaxcel_sys::rdma_get_active_segment_count();
1355        if segment_count <= 0 {
1356            return Vec::new();
1357        }
1358
1359        let mut segments = vec![
1360            std::mem::MaybeUninit::<rdmaxcel_sys::rdma_segment_info_t>::zeroed()
1361                .assume_init();
1362            segment_count as usize
1363        ];
1364        let actual_count =
1365            rdmaxcel_sys::rdma_get_all_segment_info(segments.as_mut_ptr(), segment_count);
1366
1367        if actual_count > 0 {
1368            segments.truncate(actual_count as usize);
1369            segments
1370        } else {
1371            Vec::new()
1372        }
1373    }
1374}
1375
1376/// Check if PyTorch CUDA caching allocator has expandable segments enabled.
1377///
1378/// This function calls the C++ implementation that directly accesses the
1379/// PyTorch C10 CUDA allocator configuration to check if expandable segments
1380/// are enabled, which is required for RDMA operations with CUDA tensors.
1381///
1382/// # Returns
1383///
1384/// `true` if both CUDA caching allocator is enabled AND expandable segments are enabled,
1385/// `false` otherwise.
1386pub fn pt_cuda_allocator_compatibility() -> bool {
1387    // SAFETY: We are calling a C++ function from rdmaxcel that accesses PyTorch C10 APIs.
1388    unsafe { rdmaxcel_sys::pt_cuda_allocator_compatibility() }
1389}
1390
1391#[cfg(test)]
1392mod tests {
1393    use super::*;
1394
1395    #[test]
1396    fn test_create_connection() {
1397        // Skip test if RDMA devices are not available
1398        if crate::ibverbs_primitives::get_all_devices().is_empty() {
1399            println!("Skipping test: RDMA devices not available");
1400            return;
1401        }
1402
1403        let config = IbverbsConfig {
1404            use_gpu_direct: false,
1405            ..Default::default()
1406        };
1407        let domain = RdmaDomain::new(config.device.clone());
1408        assert!(domain.is_ok());
1409
1410        let domain = domain.unwrap();
1411        let queue_pair = RdmaQueuePair::new(domain.context, domain.pd, config.clone());
1412        assert!(queue_pair.is_ok());
1413    }
1414
1415    #[test]
1416    fn test_loopback_connection() {
1417        // Skip test if RDMA devices are not available
1418        if crate::ibverbs_primitives::get_all_devices().is_empty() {
1419            println!("Skipping test: RDMA devices not available");
1420            return;
1421        }
1422
1423        let server_config = IbverbsConfig {
1424            use_gpu_direct: false,
1425            ..Default::default()
1426        };
1427        let client_config = IbverbsConfig {
1428            use_gpu_direct: false,
1429            ..Default::default()
1430        };
1431
1432        let server_domain = RdmaDomain::new(server_config.device.clone()).unwrap();
1433        let client_domain = RdmaDomain::new(client_config.device.clone()).unwrap();
1434
1435        let mut server_qp = RdmaQueuePair::new(
1436            server_domain.context,
1437            server_domain.pd,
1438            server_config.clone(),
1439        )
1440        .unwrap();
1441        let mut client_qp = RdmaQueuePair::new(
1442            client_domain.context,
1443            client_domain.pd,
1444            client_config.clone(),
1445        )
1446        .unwrap();
1447
1448        let server_connection_info = server_qp.get_qp_info().unwrap();
1449        let client_connection_info = client_qp.get_qp_info().unwrap();
1450
1451        assert!(server_qp.connect(&client_connection_info).is_ok());
1452        assert!(client_qp.connect(&server_connection_info).is_ok());
1453    }
1454}