Skip to main content

monarch_rdma/backend/ibverbs/
primitives.rs

1/*
2 * Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9/*
10 * Sections of code adapted from
11 * Copyright (c) 2016 Jon Gjengset under MIT License (MIT)
12*/
13
14//! This file contains primitive data structures for interacting with ibverbs.
15//!
16//! Primitives:
17//! - `IbvConfig`: Represents ibverbs specific configurations, holding parameters required to establish and
18//!   manage an RDMA connection, including settings for the RDMA device, queue pair attributes, and other
19//!   connection-specific parameters.
20//! - `IbvDevice`: Represents an RDMA device, i.e. 'mlx5_0'. Contains information about the device, such as:
21//!   its name, vendor ID, vendor part ID, hardware version, firmware version, node GUID, and capabilities.
22//! - `IbvPort`: Represents information about the port of an RDMA device, including state, physical state,
23//!   LID (Local Identifier), and GID (Global Identifier) information.
24//! - `IbvMemoryRegionView`: Represents a memory region that can be registered with an RDMA device for direct
25//!   memory access operations.
26//! - `IbvOperation`: Represents the type of RDMA operation to perform (Read or Write).
27//! - `IbvQpInfo`: Contains connection information needed to establish an RDMA connection with a remote endpoint.
28//! - `IbvWc`: Wrapper around ibverbs work completion structure, used to track the status of RDMA operations.
29use std::ffi::CStr;
30use std::fmt;
31use std::sync::Arc;
32use std::sync::OnceLock;
33
34use serde::Deserialize;
35use serde::Serialize;
36use typeuri::Named;
37
38use super::domain::IbvDomain;
39
40#[derive(
41    Default,
42    Copy,
43    Clone,
44    Debug,
45    Eq,
46    PartialEq,
47    Hash,
48    serde::Serialize,
49    serde::Deserialize
50)]
51#[repr(transparent)]
52pub struct Gid {
53    raw: [u8; 16],
54}
55
56impl Gid {
57    #[allow(dead_code)]
58    fn subnet_prefix(&self) -> u64 {
59        u64::from_be_bytes(self.raw[..8].try_into().unwrap())
60    }
61
62    #[allow(dead_code)]
63    fn interface_id(&self) -> u64 {
64        u64::from_be_bytes(self.raw[8..].try_into().unwrap())
65    }
66}
67impl From<rdmaxcel_sys::ibv_gid> for Gid {
68    fn from(gid: rdmaxcel_sys::ibv_gid) -> Self {
69        Self {
70            raw: unsafe { gid.raw },
71        }
72    }
73}
74
75impl From<Gid> for rdmaxcel_sys::ibv_gid {
76    fn from(mut gid: Gid) -> Self {
77        *gid.as_mut()
78    }
79}
80
81impl AsRef<rdmaxcel_sys::ibv_gid> for Gid {
82    fn as_ref(&self) -> &rdmaxcel_sys::ibv_gid {
83        unsafe { &*self.raw.as_ptr().cast::<rdmaxcel_sys::ibv_gid>() }
84    }
85}
86
87impl AsMut<rdmaxcel_sys::ibv_gid> for Gid {
88    fn as_mut(&mut self) -> &mut rdmaxcel_sys::ibv_gid {
89        unsafe { &mut *self.raw.as_mut_ptr().cast::<rdmaxcel_sys::ibv_gid>() }
90    }
91}
92
93/// Queue pair type for RDMA operations.
94///
95/// Controls whether to use standard ibverbs queue pairs, mlx5dv extended queue pairs,
96/// or EFA SRD queue pairs. Auto mode automatically selects based on device capabilities.
97#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
98pub enum IbvQpType {
99    /// Auto-detect based on device capabilities
100    Auto,
101    /// Force standard ibverbs queue pair
102    Standard,
103    /// Force mlx5dv extended queue pair
104    Mlx5dv,
105    /// Force EFA SRD queue pair
106    Efa,
107}
108
109/// Converts `IbvQpType` to the corresponding integer enum value in rdmaxcel_sys.
110pub fn resolve_qp_type(qp_type: IbvQpType) -> u32 {
111    match qp_type {
112        IbvQpType::Auto => {
113            if crate::efa::is_efa_device() {
114                rdmaxcel_sys::RDMA_QP_TYPE_EFA
115            } else if mlx5dv_supported() {
116                rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV
117            } else {
118                rdmaxcel_sys::RDMA_QP_TYPE_STANDARD
119            }
120        }
121        IbvQpType::Standard => rdmaxcel_sys::RDMA_QP_TYPE_STANDARD,
122        IbvQpType::Mlx5dv => rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV,
123        IbvQpType::Efa => rdmaxcel_sys::RDMA_QP_TYPE_EFA,
124    }
125}
126
127/// Represents ibverbs specific configurations.
128///
129/// This struct holds various parameters required to establish and manage an RDMA connection.
130/// It includes settings for the RDMA device, queue pair attributes, and other connection-specific
131/// parameters.
132#[derive(Debug, Named, Clone, Serialize, Deserialize)]
133pub struct IbvConfig {
134    /// `device` - The RDMA device to use for the connection.
135    pub device: IbvDevice,
136    /// `cq_entries` - The number of completion queue entries.
137    pub cq_entries: i32,
138    /// `port_num` - The physical port number on the device.
139    pub port_num: u8,
140    /// `gid_index` - The GID index for the RDMA device.
141    pub gid_index: u8,
142    /// `max_send_wr` - The maximum number of outstanding send work requests.
143    pub max_send_wr: u32,
144    /// `max_recv_wr` - The maximum number of outstanding receive work requests.
145    pub max_recv_wr: u32,
146    /// `max_send_sge` - Te maximum number of scatter/gather elements in a send work request.
147    pub max_send_sge: u32,
148    /// `max_recv_sge` - The maximum number of scatter/gather elements in a receive work request.
149    pub max_recv_sge: u32,
150    /// `path_mtu` - The path MTU (Maximum Transmission Unit) for the connection.
151    pub path_mtu: u32,
152    /// `retry_cnt` - The number of retry attempts for a connection request.
153    pub retry_cnt: u8,
154    /// `rnr_retry` - The number of retry attempts for a receiver not ready (RNR) condition.
155    pub rnr_retry: u8,
156    /// `qp_timeout` - The timeout for a queue pair operation.
157    pub qp_timeout: u8,
158    /// `min_rnr_timer` - The minimum RNR timer value.
159    pub min_rnr_timer: u8,
160    /// `max_dest_rd_atomic` - The maximum number of outstanding RDMA read operations at the destination.
161    pub max_dest_rd_atomic: u8,
162    /// `max_rd_atomic` - The maximum number of outstanding RDMA read operations at the initiator.
163    pub max_rd_atomic: u8,
164    /// `pkey_index` - The partition key index.
165    pub pkey_index: u16,
166    /// `psn` - The packet sequence number.
167    pub psn: u32,
168    /// `use_gpu_direct` - Whether to enable GPU Direct RDMA support on init.
169    pub use_gpu_direct: bool,
170    /// `hw_init_delay_ms` - The delay in milliseconds before initializing the hardware.
171    /// This is used to allow the hardware to settle before starting the first transmission.
172    pub hw_init_delay_ms: u64,
173    /// `qp_type` - The type of queue pair to create (Auto, Standard, or Mlx5dv).
174    pub qp_type: IbvQpType,
175    /// Test-only override for `register_segments`'s `max_sge`. `<= 0`
176    /// (default) uses `ibv_query_device`; small positive values force
177    /// `RDMAXCEL_MKEY_REG_LIMIT` to exercise the dmabuf fallback.
178    pub max_sge_override: i32,
179}
180wirevalue::register_type!(IbvConfig);
181
182/// Default RDMA parameters below are based on common values from rdma-core examples
183/// For high-performance or production use, consider tuning
184/// based on ibv_query_device() results and workload characteristics
185impl Default for IbvConfig {
186    fn default() -> Self {
187        let mut config = Self {
188            device: IbvDevice::default(),
189            cq_entries: 1024,
190            port_num: 1,
191            gid_index: 3,
192            max_send_wr: 512,
193            max_recv_wr: 512,
194            max_send_sge: 30,
195            max_recv_sge: 30,
196            path_mtu: rdmaxcel_sys::IBV_MTU_4096,
197            retry_cnt: 7,
198            rnr_retry: 7,
199            qp_timeout: 14, // 4.096 μs * 2^14 = ~67 ms
200            min_rnr_timer: 12,
201            max_dest_rd_atomic: 16,
202            max_rd_atomic: 16,
203            pkey_index: 0,
204            psn: rand::random::<u32>() & 0xffffff,
205            use_gpu_direct: false, // nv_peermem enabled for cuda
206            hw_init_delay_ms: 2,
207            qp_type: IbvQpType::Auto,
208            max_sge_override: 0,
209        };
210        if crate::efa::is_efa_device() {
211            crate::efa::apply_efa_defaults(&mut config);
212        }
213        config
214    }
215}
216
217impl IbvConfig {
218    /// Create a new IbvConfig targeting a specific device
219    ///
220    /// Device targets use a unified "type:id" format:
221    /// - "cpu:N" -> finds RDMA device closest to NUMA node N
222    /// - "cuda:N" -> finds RDMA device closest to CUDA device N
223    /// - "nic:mlx5_N" -> returns the specified NIC directly
224    ///
225    /// Shortcuts:
226    /// - "cpu" -> defaults to "cpu:0"
227    /// - "cuda" -> defaults to "cuda:0"
228    ///
229    /// # Arguments
230    ///
231    /// * `target` - Target device specification
232    ///
233    /// # Returns
234    ///
235    /// * `IbvConfig` with resolved device, or default device if resolution fails
236    pub fn targeting(target: &str) -> Self {
237        // Normalize shortcuts
238        let normalized_target = match target {
239            "cpu" => "cpu:0",
240            "cuda" => "cuda:0",
241            _ => target,
242        };
243
244        let device = super::device_selection::select_optimal_ibv_device(Some(normalized_target))
245            .unwrap_or_else(IbvDevice::default);
246
247        Self {
248            device,
249            ..Default::default()
250        }
251    }
252}
253
254impl std::fmt::Display for IbvConfig {
255    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256        write!(
257            f,
258            "IbvConfig {{ device: {}, port_num: {}, gid_index: {}, max_send_wr: {}, max_recv_wr: {}, max_send_sge: {}, max_recv_sge: {}, path_mtu: {:?}, retry_cnt: {}, rnr_retry: {}, qp_timeout: {}, min_rnr_timer: {}, max_dest_rd_atomic: {}, max_rd_atomic: {}, pkey_index: {}, psn: 0x{:x} }}",
259            self.device.name(),
260            self.port_num,
261            self.gid_index,
262            self.max_send_wr,
263            self.max_recv_wr,
264            self.max_send_sge,
265            self.max_recv_sge,
266            self.path_mtu,
267            self.retry_cnt,
268            self.rnr_retry,
269            self.qp_timeout,
270            self.min_rnr_timer,
271            self.max_dest_rd_atomic,
272            self.max_rd_atomic,
273            self.pkey_index,
274            self.psn,
275        )
276    }
277}
278
279/// Represents an RDMA device in the system.
280///
281/// This struct encapsulates information about an RDMA device, including its hardware
282/// characteristics, capabilities, and port information. It provides access to device
283/// attributes such as vendor information, firmware version, and supported features.
284///
285/// # Examples
286///
287/// ```
288/// use monarch_rdma::get_all_devices;
289///
290/// let devices = get_all_devices();
291/// if let Some(device) = devices.first() {
292///     // Access device name and firmware version
293///     let device_name = device.name();
294///     let firmware_version = device.fw_ver();
295/// }
296/// ```
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct IbvDevice {
299    /// `name` - The name of the RDMA device (e.g., "mlx5_0").
300    pub name: String,
301    /// `vendor_id` - The vendor ID of the device.
302    vendor_id: u32,
303    /// `vendor_part_id` - The vendor part ID of the device.
304    vendor_part_id: u32,
305    /// `hw_ver` - Hardware version of the device.
306    hw_ver: u32,
307    /// `fw_ver` - Firmware version of the device.
308    fw_ver: String,
309    /// `node_guid` - Node GUID (Globally Unique Identifier) of the device.
310    node_guid: u64,
311    /// `ports` - Vector of ports available on this device.
312    ports: Vec<IbvPort>,
313    /// `max_qp` - Maximum number of queue pairs supported.
314    max_qp: i32,
315    /// `max_cq` - Maximum number of completion queues supported.
316    max_cq: i32,
317    /// `max_mr` - Maximum number of memory regions supported.
318    max_mr: i32,
319    /// `max_pd` - Maximum number of protection domains supported.
320    max_pd: i32,
321    /// `max_qp_wr` - Maximum number of work requests per queue pair.
322    max_qp_wr: i32,
323    /// `max_sge` - Maximum number of scatter/gather elements per work request.
324    max_sge: i32,
325}
326
327impl IbvDevice {
328    /// Returns the name of the RDMA device.
329    pub fn name(&self) -> &String {
330        &self.name
331    }
332
333    /// Returns the first available RDMA device, if any.
334    pub fn first_available() -> Option<IbvDevice> {
335        let devices = get_all_devices();
336        if devices.is_empty() {
337            None
338        } else {
339            Some(devices.into_iter().next().unwrap())
340        }
341    }
342
343    /// Returns the vendor ID of the RDMA device.
344    pub fn vendor_id(&self) -> u32 {
345        self.vendor_id
346    }
347
348    /// Returns the vendor part ID of the RDMA device.
349    pub fn vendor_part_id(&self) -> u32 {
350        self.vendor_part_id
351    }
352
353    /// Returns the hardware version of the RDMA device.
354    pub fn hw_ver(&self) -> u32 {
355        self.hw_ver
356    }
357
358    /// Returns the firmware version of the RDMA device.
359    pub fn fw_ver(&self) -> &String {
360        &self.fw_ver
361    }
362
363    /// Returns the node GUID of the RDMA device.
364    pub fn node_guid(&self) -> u64 {
365        self.node_guid
366    }
367
368    /// Returns a reference to the vector of ports available on the RDMA device.
369    pub fn ports(&self) -> &Vec<IbvPort> {
370        &self.ports
371    }
372
373    /// Returns the maximum number of queue pairs supported by the RDMA device.
374    pub fn max_qp(&self) -> i32 {
375        self.max_qp
376    }
377
378    /// Returns the maximum number of completion queues supported by the RDMA device.
379    pub fn max_cq(&self) -> i32 {
380        self.max_cq
381    }
382
383    /// Returns the maximum number of memory regions supported by the RDMA device.
384    pub fn max_mr(&self) -> i32 {
385        self.max_mr
386    }
387
388    /// Returns the maximum number of protection domains supported by the RDMA device.
389    pub fn max_pd(&self) -> i32 {
390        self.max_pd
391    }
392
393    /// Returns the maximum number of work requests per queue pair supported by the RDMA device.
394    pub fn max_qp_wr(&self) -> i32 {
395        self.max_qp_wr
396    }
397
398    /// Returns the maximum number of scatter/gather elements per work request supported by the RDMA device.
399    pub fn max_sge(&self) -> i32 {
400        self.max_sge
401    }
402}
403
404impl Default for IbvDevice {
405    fn default() -> Self {
406        // Try to get a smart default using device selection logic (defaults to cpu:0)
407        if let Some(device) = super::device_selection::select_optimal_ibv_device(Some("cpu:0")) {
408            device
409        } else {
410            // Fallback to first available device
411            get_all_devices()
412                .into_iter()
413                .next()
414                .unwrap_or_else(|| panic!("No RDMA devices found"))
415        }
416    }
417}
418
419#[derive(Debug, Clone, Serialize, Deserialize)]
420pub struct IbvPort {
421    /// `port_num` - The physical port number on the device.
422    port_num: u8,
423    /// `state` - The current state of the port.
424    state: String,
425    /// `physical_state` - The physical state of the port.
426    physical_state: String,
427    /// `base_lid` - Base Local Identifier for the port.
428    base_lid: u16,
429    /// `lmc` - LID Mask Control.
430    lmc: u8,
431    /// `sm_lid` - Subnet Manager Local Identifier.
432    sm_lid: u16,
433    /// `capability_mask` - Capability mask of the port.
434    capability_mask: u32,
435    /// `link_layer` - The link layer type (e.g., InfiniBand, Ethernet).
436    link_layer: String,
437    /// `gid` - Global Identifier for the port.
438    gid: String,
439    /// `gid_tbl_len` - Length of the GID table.
440    gid_tbl_len: i32,
441}
442
443impl fmt::Display for IbvDevice {
444    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
445        writeln!(f, "{}", self.name)?;
446        writeln!(f, "\tNumber of ports: {}", self.ports.len())?;
447        writeln!(f, "\tFirmware version: {}", self.fw_ver)?;
448        writeln!(f, "\tHardware version: {}", self.hw_ver)?;
449        writeln!(f, "\tNode GUID: 0x{:016x}", self.node_guid)?;
450        writeln!(f, "\tVendor ID: 0x{:x}", self.vendor_id)?;
451        writeln!(f, "\tVendor part ID: {}", self.vendor_part_id)?;
452        writeln!(f, "\tMax QPs: {}", self.max_qp)?;
453        writeln!(f, "\tMax CQs: {}", self.max_cq)?;
454        writeln!(f, "\tMax MRs: {}", self.max_mr)?;
455        writeln!(f, "\tMax PDs: {}", self.max_pd)?;
456        writeln!(f, "\tMax QP WRs: {}", self.max_qp_wr)?;
457        writeln!(f, "\tMax SGE: {}", self.max_sge)?;
458
459        for port in &self.ports {
460            write!(f, "{}", port)?;
461        }
462
463        Ok(())
464    }
465}
466
467impl fmt::Display for IbvPort {
468    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
469        writeln!(f, "\tPort {}:", self.port_num)?;
470        writeln!(f, "\t\tState: {}", self.state)?;
471        writeln!(f, "\t\tPhysical state: {}", self.physical_state)?;
472        writeln!(f, "\t\tBase lid: {}", self.base_lid)?;
473        writeln!(f, "\t\tLMC: {}", self.lmc)?;
474        writeln!(f, "\t\tSM lid: {}", self.sm_lid)?;
475        writeln!(f, "\t\tCapability mask: 0x{:08x}", self.capability_mask)?;
476        writeln!(f, "\t\tLink layer: {}", self.link_layer)?;
477        writeln!(f, "\t\tGID: {}", self.gid)?;
478        writeln!(f, "\t\tGID table length: {}", self.gid_tbl_len)?;
479        Ok(())
480    }
481}
482
483/// Converts the given port state to a human-readable string.
484///
485/// # Arguments
486///
487/// * `state` - The port state as defined by `ffi::ibv_port_state::Type`.
488///
489/// # Returns
490///
491/// A string representation of the port state.
492pub fn get_port_state_str(state: rdmaxcel_sys::ibv_port_state::Type) -> String {
493    // SAFETY: We are calling a C function that returns a C string.
494    unsafe {
495        let c_str = rdmaxcel_sys::ibv_port_state_str(state);
496        if c_str.is_null() {
497            return "Unknown".to_string();
498        }
499        CStr::from_ptr(c_str).to_string_lossy().into_owned()
500    }
501}
502
503/// Converts the given physical state to a human-readable string.
504///
505/// # Arguments
506///
507/// * `phys_state` - The physical state as a `u8`.
508///
509/// # Returns
510///
511/// A string representation of the physical state.
512pub fn get_port_phy_state_str(phys_state: u8) -> String {
513    match phys_state {
514        1 => "Sleep".to_string(),
515        2 => "Polling".to_string(),
516        3 => "Disabled".to_string(),
517        4 => "PortConfigurationTraining".to_string(),
518        5 => "LinkUp".to_string(),
519        6 => "LinkErrorRecovery".to_string(),
520        7 => "PhyTest".to_string(),
521        _ => "No state change".to_string(),
522    }
523}
524
525/// Converts the given link layer type to a human-readable string.
526///
527/// # Arguments
528///
529/// * `link_layer` - The link layer type as a `u8`.
530///
531/// # Returns
532///
533/// A string representation of the link layer type.
534pub fn get_link_layer_str(link_layer: u8) -> String {
535    match link_layer {
536        1 => "InfiniBand".to_string(),
537        2 => "Ethernet".to_string(),
538        _ => "Unknown".to_string(),
539    }
540}
541
542/// Formats a GID (Global Identifier) into a human-readable string.
543///
544/// # Arguments
545///
546/// * `gid` - A reference to a 16-byte array representing the GID.
547///
548/// # Returns
549///
550/// A formatted string representation of the GID.
551pub fn format_gid(gid: &[u8; 16]) -> String {
552    format!(
553        "{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}",
554        gid[0],
555        gid[1],
556        gid[2],
557        gid[3],
558        gid[4],
559        gid[5],
560        gid[6],
561        gid[7],
562        gid[8],
563        gid[9],
564        gid[10],
565        gid[11],
566        gid[12],
567        gid[13],
568        gid[14],
569        gid[15]
570    )
571}
572
573/// Retrieves information about all available RDMA devices in the system.
574///
575/// This function queries the system for all available RDMA devices and returns
576/// detailed information about each device, including its capabilities, ports,
577/// and attributes.
578///
579/// # Returns
580///
581/// A vector of `IbvDevice` structures, each representing an RDMA device in the system.
582/// Returns an empty vector if no devices are found or if there was an error querying
583/// the devices.
584pub fn get_all_devices() -> Vec<IbvDevice> {
585    let mut devices = Vec::new();
586
587    // SAFETY: We are calling several C functions from libibverbs.
588    unsafe {
589        let mut num_devices = 0;
590        let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices);
591        if device_list.is_null() || num_devices == 0 {
592            return devices;
593        }
594
595        for i in 0..num_devices {
596            let device = *device_list.add(i as usize);
597            if device.is_null() {
598                continue;
599            }
600
601            let context = rdmaxcel_sys::ibv_open_device(device);
602            if context.is_null() {
603                continue;
604            }
605
606            let device_name = CStr::from_ptr(rdmaxcel_sys::ibv_get_device_name(device))
607                .to_string_lossy()
608                .into_owned();
609
610            let mut device_attr = rdmaxcel_sys::ibv_device_attr::default();
611            if rdmaxcel_sys::ibv_query_device(context, &mut device_attr) != 0 {
612                rdmaxcel_sys::ibv_close_device(context);
613                continue;
614            }
615
616            let fw_ver = CStr::from_ptr(device_attr.fw_ver.as_ptr())
617                .to_string_lossy()
618                .into_owned();
619
620            let mut rdma_device = IbvDevice {
621                name: device_name,
622                vendor_id: device_attr.vendor_id,
623                vendor_part_id: device_attr.vendor_part_id,
624                hw_ver: device_attr.hw_ver,
625                fw_ver,
626                node_guid: device_attr.node_guid,
627                ports: Vec::new(),
628                max_qp: device_attr.max_qp,
629                max_cq: device_attr.max_cq,
630                max_mr: device_attr.max_mr,
631                max_pd: device_attr.max_pd,
632                max_qp_wr: device_attr.max_qp_wr,
633                max_sge: device_attr.max_sge,
634            };
635
636            for port_num in 1..=device_attr.phys_port_cnt {
637                let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
638                if rdmaxcel_sys::ibv_query_port(
639                    context,
640                    port_num,
641                    &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
642                ) != 0
643                {
644                    continue;
645                }
646                let state = get_port_state_str(port_attr.state);
647                let physical_state = get_port_phy_state_str(port_attr.phys_state);
648
649                let link_layer = get_link_layer_str(port_attr.link_layer);
650
651                let mut gid = rdmaxcel_sys::ibv_gid::default();
652                let gid_str = if rdmaxcel_sys::ibv_query_gid(context, port_num, 0, &mut gid) == 0 {
653                    format_gid(&gid.raw)
654                } else {
655                    "N/A".to_string()
656                };
657
658                let rdma_port = IbvPort {
659                    port_num,
660                    state,
661                    physical_state,
662                    base_lid: port_attr.lid,
663                    lmc: port_attr.lmc,
664                    sm_lid: port_attr.sm_lid,
665                    capability_mask: port_attr.port_cap_flags,
666                    link_layer,
667                    gid: gid_str,
668                    gid_tbl_len: port_attr.gid_tbl_len,
669                };
670
671                rdma_device.ports.push(rdma_port);
672            }
673
674            devices.push(rdma_device);
675            rdmaxcel_sys::ibv_close_device(context);
676        }
677
678        rdmaxcel_sys::ibv_free_device_list(device_list);
679    }
680
681    devices
682}
683
684/// Cached result of mlx5dv support check.
685static MLX5DV_SUPPORTED_CACHE: OnceLock<bool> = OnceLock::new();
686
687/// Checks if mlx5dv (Mellanox device-specific verbs extension) is supported.
688///
689/// This function attempts to open the first available RDMA device and check if
690/// mlx5dv extensions can be initialized. The mlx5dv extensions are required for
691/// advanced features like GPU Direct RDMA and direct queue pair manipulation.
692///
693/// The result is cached after the first call, making subsequent calls essentially free.
694///
695/// # Returns
696///
697/// `true` if mlx5dv extensions are supported, `false` otherwise.
698pub fn mlx5dv_supported() -> bool {
699    *MLX5DV_SUPPORTED_CACHE.get_or_init(mlx5dv_supported_impl)
700}
701
702fn mlx5dv_supported_impl() -> bool {
703    // SAFETY: We are calling C functions from libibverbs and libmlx5.
704    unsafe {
705        let mut mlx5dv_supported = false;
706        let mut num_devices = 0;
707        let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices);
708        if !device_list.is_null() && num_devices > 0 {
709            let device = *device_list;
710            if !device.is_null() {
711                mlx5dv_supported = rdmaxcel_sys::mlx5dv_is_supported(device);
712            }
713            rdmaxcel_sys::ibv_free_device_list(device_list);
714        }
715        mlx5dv_supported
716    }
717}
718
719/// Cached result of ibverbs support check.
720static IBVERBS_SUPPORTED_CACHE: OnceLock<bool> = OnceLock::new();
721
722/// Checks if ibverbs devices can be retrieved successfully.
723///
724/// This function attempts to retrieve the list of RDMA devices using the
725/// `ibv_get_device_list` function from the ibverbs library. It returns `true`
726/// if devices are found, and `false` otherwise.
727///
728/// The result is cached after the first call, making subsequent calls essentially free.
729///
730/// # Returns
731///
732/// `true` if devices are successfully retrieved, `false` otherwise.
733pub fn ibverbs_supported() -> bool {
734    *IBVERBS_SUPPORTED_CACHE.get_or_init(ibverbs_supported_impl)
735}
736
737fn ibverbs_supported_impl() -> bool {
738    // SAFETY: We are calling a C function from libibverbs.
739    unsafe {
740        let mut num_devices = 0;
741        let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices);
742        if !device_list.is_null() {
743            rdmaxcel_sys::ibv_free_device_list(device_list);
744        }
745        num_devices > 0
746    }
747}
748
749/// Refcounted owner of an ibverbs memory region. The intended
750/// sharing pattern is `Arc<IbvMemoryRegion>` cloned into every
751/// [`IbvMemoryRegionView`] backed by the region; the FFI resource
752/// is released by [`Drop`] when the last clone goes away.
753///
754/// `Direct` carries an `Arc<IbvDomain>` so the PD the MR was
755/// registered against outlives the `ibv_dereg_mr` call. The struct
756/// field order is significant: the `mr` pointer is dereg'd in
757/// `Drop` before the `_domain` field is released, so the PD is
758/// still alive when libibverbs walks back to it. The eventual
759/// `ibv_dealloc_pd` only fires when every other `Arc<IbvDomain>`
760/// (e.g. the manager's `device_domains` entry) is also gone.
761#[derive(Debug)]
762pub(super) enum IbvMemoryRegion {
763    /// A standalone `ibv_mr*` registered via `ibv_reg_mr` /
764    /// `ibv_reg_dmabuf_mr`.
765    Direct {
766        mr: *mut rdmaxcel_sys::ibv_mr,
767        /// PD the MR was registered against, kept alive past
768        /// `ibv_dereg_mr` via this clone. Never read directly.
769        _domain: Arc<IbvDomain>,
770    },
771    /// Singleton owner shared by every segment-backed view from the
772    /// mlx5dv segment scanner. `Drop` calls
773    /// `rdmaxcel_sys::deregister_segments`.
774    Segments,
775}
776
777// SAFETY: `IbvMemoryRegion::Direct`'s `*mut ibv_mr` is treated as
778// thread-safe by libibverbs for deregistration; `Segments` holds no
779// data. The intended sharing pattern is `Arc<IbvMemoryRegion>`.
780unsafe impl Send for IbvMemoryRegion {}
781unsafe impl Sync for IbvMemoryRegion {}
782
783impl Drop for IbvMemoryRegion {
784    fn drop(&mut self) {
785        match self {
786            IbvMemoryRegion::Direct { mr, .. } => {
787                if mr.is_null() {
788                    return;
789                }
790                // SAFETY: `*mr` was returned by `ibv_reg_mr` /
791                // `ibv_reg_dmabuf_mr` at construction (see
792                // `register_mr_impl`) and is non-null per the
793                // null-check above. The intended sharing pattern
794                // is `Arc<IbvMemoryRegion>`, so `Drop` runs exactly
795                // once when the last clone goes away, meaning
796                // `ibv_dereg_mr` is called exactly once on this
797                // pointer. The PD the MR was registered against is
798                // kept alive by the sibling `_domain` `Arc<IbvDomain>`
799                // until after this dereg returns.
800                let result = unsafe { rdmaxcel_sys::ibv_dereg_mr(*mr) };
801                if result != 0 {
802                    tracing::error!(
803                        "failed to deregister MR at {:p}: error code {}",
804                        *mr,
805                        result
806                    );
807                }
808            }
809            IbvMemoryRegion::Segments => {
810                // SAFETY: `IbvMemoryRegion::Segments` is the
811                // singleton owner of the mlx5dv scanner's globally
812                // registered segments. `register_mr_impl` lazily
813                // wraps it in `Arc::new(...)` at most once per
814                // manager (stored in `IbvManagerActor::segments_mr`)
815                // and clones it into every segment-backed view, so
816                // `deregister_segments` runs exactly once when the
817                // last reference is dropped.
818                let result = unsafe { rdmaxcel_sys::deregister_segments() };
819                if result != 0 {
820                    tracing::error!("failed to deregister CUDA segments: error code {}", result);
821                }
822            }
823        }
824    }
825}
826
827/// Represents a view of a memory region that can be registered with an RDMA device.
828///
829/// This is a 'view' of a registered Memory Region, allowing multiple views into a single
830/// large MR registration. This is commonly used with PyTorch's caching allocator, which
831/// reserves large memory blocks and provides different data pointers into that space.
832///
833/// # Example
834/// PyTorch Caching Allocator creates a 16GB segment at virtual address `0x01000000`.
835/// The underlying Memory Region registers 16GB but at RDMA address `0x0`.
836/// To access virtual address `0x01100000`, we return a view at RDMA address `0x100000`.
837///
838/// # Safety
839/// The caller must ensure the memory remains valid and is not freed, moved, or
840/// overwritten while RDMA operations are in progress.
841#[derive(Debug, Clone)]
842pub struct IbvMemoryRegionView {
843    // id should be unique with a given rdmam manager
844    pub id: usize,
845    /// Virtual address in the process address space.
846    /// This is the pointer/address as seen by the local process.
847    pub virtual_addr: usize,
848    /// Memory address assigned after Memory Region (MR) registration.
849    /// This is the address may be offset a base MR addr.
850    pub rdma_addr: usize,
851    pub size: usize,
852    pub lkey: u32,
853    pub rkey: u32,
854    /// Name of the RDMA device the view's protection domain is on.
855    pub device_name: String,
856    /// Keeps the underlying FFI MR (and, transitively, the PD it
857    /// was registered against) alive for the lifetime of any clone
858    /// of this view; the last drop triggers deregistration. Never
859    /// read directly.
860    pub(super) _mr: Arc<IbvMemoryRegion>,
861}
862
863impl IbvMemoryRegionView {
864    pub(super) fn new(
865        id: usize,
866        virtual_addr: usize,
867        rdma_addr: usize,
868        size: usize,
869        lkey: u32,
870        rkey: u32,
871        device_name: String,
872        mr: Arc<IbvMemoryRegion>,
873    ) -> Self {
874        Self {
875            id,
876            virtual_addr,
877            rdma_addr,
878            size,
879            lkey,
880            rkey,
881            device_name,
882            _mr: mr,
883        }
884    }
885}
886
887/// Compact identifier for log messages and per-op errors. Carries
888/// the MR-identifying fields all in one place so callers that
889/// embed an `IbvMemoryRegionView` in their error format get a
890/// consistent shape.
891impl fmt::Display for IbvMemoryRegionView {
892    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
893        write!(
894            f,
895            "lkey={}, rkey={}, virtual_addr=0x{:x}, rdma_addr=0x{:x}, size={}",
896            self.lkey, self.rkey, self.virtual_addr, self.rdma_addr, self.size,
897        )
898    }
899}
900
901/// Enum representing the common RDMA operations.
902///
903/// This provides a more ergonomic interface to the underlying ibv_wr_opcode types.
904/// RDMA operations allow for direct memory access between two machines without
905/// involving the CPU of the target machine.
906///
907/// # Variants
908///
909/// * `Write` - Represents an RDMA write operation where data is written from the local
910///   memory to a remote memory region.
911/// * `Read` - Represents an RDMA read operation where data is read from a remote memory
912///   region into the local memory.
913#[derive(Debug, Clone, Copy, PartialEq, Eq)]
914pub enum IbvOperation {
915    /// RDMA write operations
916    Write,
917    WriteWithImm,
918    /// RDMA read operation
919    Read,
920    /// RDMA recv operation
921    Recv,
922}
923
924impl From<IbvOperation> for rdmaxcel_sys::ibv_wr_opcode::Type {
925    fn from(op: IbvOperation) -> Self {
926        match op {
927            IbvOperation::Write => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE,
928            IbvOperation::WriteWithImm => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE_WITH_IMM,
929            IbvOperation::Read => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_READ,
930            IbvOperation::Recv => panic!("Invalid wr opcode"),
931        }
932    }
933}
934
935impl From<rdmaxcel_sys::ibv_wc_opcode::Type> for IbvOperation {
936    fn from(op: rdmaxcel_sys::ibv_wc_opcode::Type) -> Self {
937        match op {
938            rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE => IbvOperation::Write,
939            rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_READ => IbvOperation::Read,
940            _ => panic!("Unsupported operation type"),
941        }
942    }
943}
944
945/// Contains information needed to establish an RDMA queue pair with a remote endpoint.
946///
947/// `IbvQpInfo` encapsulates all the necessary information to establish a queue pair
948/// with a remote RDMA device. This includes queue pair number, LID (Local Identifier),
949/// GID (Global Identifier), remote memory address, remote key, and packet sequence number.
950#[derive(Default, Named, Clone, serde::Serialize, serde::Deserialize)]
951pub struct IbvQpInfo {
952    /// `qp_num` - Queue Pair Number, uniquely identifies a queue pair on the remote device
953    pub qp_num: u32,
954    /// `lid` - Local Identifier, used for addressing in InfiniBand subnet
955    pub lid: u16,
956    /// `gid` - Global Identifier, used for routing across subnets (similar to IPv6 address)
957    pub gid: Option<Gid>,
958    /// `psn` - Packet Sequence Number, used for ordering packets
959    pub psn: u32,
960}
961wirevalue::register_type!(IbvQpInfo);
962
963impl std::fmt::Debug for IbvQpInfo {
964    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
965        write!(
966            f,
967            "IbvQpInfo {{ qp_num: {}, lid: {}, gid: {:?}, psn: 0x{:x} }}",
968            self.qp_num, self.lid, self.gid, self.psn
969        )
970    }
971}
972
973/// Wrapper around ibv_wc (ibverbs work completion).
974///
975/// This exposes only the public fields of rdmaxcel_sys::ibv_wc, allowing us to more easily
976/// interact with it from Rust. Work completions are used to track the status of
977/// RDMA operations and are generated when an operation completes.
978#[derive(Debug, Named, Clone, serde::Serialize, serde::Deserialize)]
979pub struct IbvWc {
980    /// `wr_id` - Work Request ID, used to identify the completed operation
981    wr_id: u64,
982    /// `len` - Length of the data transferred
983    len: usize,
984    /// `valid` - Whether the work completion is valid
985    valid: bool,
986    /// `error` - Error information if the operation failed
987    error: Option<(rdmaxcel_sys::ibv_wc_status::Type, u32)>,
988    /// `opcode` - Type of operation that completed (read, write, etc.)
989    opcode: rdmaxcel_sys::ibv_wc_opcode::Type,
990    /// `bytes` - Immediate data (if any)
991    bytes: Option<u32>,
992    /// `qp_num` - Queue Pair Number
993    qp_num: u32,
994    /// `src_qp` - Source Queue Pair Number
995    src_qp: u32,
996    /// `pkey_index` - Partition Key Index
997    pkey_index: u16,
998    /// `slid` - Source LID
999    slid: u16,
1000    /// `sl` - Service Level
1001    sl: u8,
1002    /// `dlid_path_bits` - Destination LID Path Bits
1003    dlid_path_bits: u8,
1004}
1005wirevalue::register_type!(IbvWc);
1006
1007impl From<rdmaxcel_sys::ibv_wc> for IbvWc {
1008    fn from(wc: rdmaxcel_sys::ibv_wc) -> Self {
1009        IbvWc {
1010            wr_id: wc.wr_id(),
1011            len: wc.len(),
1012            valid: wc.is_valid(),
1013            error: wc.error(),
1014            opcode: wc.opcode(),
1015            bytes: wc.imm_data(),
1016            qp_num: wc.qp_num,
1017            src_qp: wc.src_qp,
1018            pkey_index: wc.pkey_index,
1019            slid: wc.slid,
1020            sl: wc.sl,
1021            dlid_path_bits: wc.dlid_path_bits,
1022        }
1023    }
1024}
1025
1026impl IbvWc {
1027    /// Returns the Work Request ID associated with this work completion.
1028    ///
1029    /// The Work Request ID is used to identify the specific operation that completed.
1030    /// It is set by the application when posting the work request and is returned
1031    /// unchanged in the work completion.
1032    pub fn wr_id(&self) -> u64 {
1033        self.wr_id
1034    }
1035
1036    /// Returns whether this work completion is valid.
1037    ///
1038    /// A valid work completion indicates that the operation completed successfully.
1039    /// If false, the `error` field may contain additional information about the failure.
1040    pub fn is_valid(&self) -> bool {
1041        self.valid
1042    }
1043}
1044
1045#[cfg(test)]
1046mod tests {
1047    use super::*;
1048
1049    #[test]
1050    fn test_get_all_devices() {
1051        // Skip test if RDMA devices are not available
1052        let devices = get_all_devices();
1053        if devices.is_empty() {
1054            println!("Skipping test: RDMA devices not available");
1055            return;
1056        }
1057        // Basic validation of first device
1058        let device = &devices[0];
1059        assert!(!device.name().is_empty(), "device name should not be empty");
1060        assert!(
1061            !device.ports().is_empty(),
1062            "device should have at least one port"
1063        );
1064    }
1065
1066    #[test]
1067    fn test_first_available() {
1068        // Skip test if RDMA is not available
1069        let devices = get_all_devices();
1070        if devices.is_empty() {
1071            println!("Skipping test: RDMA devices not available");
1072            return;
1073        }
1074        // Basic validation of first device
1075        let device = &devices[0];
1076
1077        let dev = device;
1078        // Verify getters return expected values
1079        assert_eq!(dev.vendor_id(), dev.vendor_id);
1080        assert_eq!(dev.vendor_part_id(), dev.vendor_part_id);
1081        assert_eq!(dev.hw_ver(), dev.hw_ver);
1082        assert_eq!(dev.fw_ver(), &dev.fw_ver);
1083        assert_eq!(dev.node_guid(), dev.node_guid);
1084        assert_eq!(dev.max_qp(), dev.max_qp);
1085        assert_eq!(dev.max_cq(), dev.max_cq);
1086        assert_eq!(dev.max_mr(), dev.max_mr);
1087        assert_eq!(dev.max_pd(), dev.max_pd);
1088        assert_eq!(dev.max_qp_wr(), dev.max_qp_wr);
1089        assert_eq!(dev.max_sge(), dev.max_sge);
1090    }
1091
1092    #[test]
1093    fn test_device_display() {
1094        if let Some(device) = IbvDevice::first_available() {
1095            let display_output = format!("{}", device);
1096            assert!(
1097                display_output.contains(&device.name),
1098                "display should include device name"
1099            );
1100            assert!(
1101                display_output.contains(&device.fw_ver),
1102                "display should include firmware version"
1103            );
1104        }
1105    }
1106
1107    #[test]
1108    fn test_port_display() {
1109        if let Some(device) = IbvDevice::first_available() {
1110            if !device.ports().is_empty() {
1111                let port = &device.ports()[0];
1112                let display_output = format!("{}", port);
1113                assert!(
1114                    display_output.contains(&port.state),
1115                    "display should include port state"
1116                );
1117                assert!(
1118                    display_output.contains(&port.link_layer),
1119                    "display should include link layer"
1120                );
1121            }
1122        }
1123    }
1124
1125    #[test]
1126    fn test_rdma_operation_conversion() {
1127        assert_eq!(
1128            rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE,
1129            rdmaxcel_sys::ibv_wr_opcode::Type::from(IbvOperation::Write)
1130        );
1131        assert_eq!(
1132            rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_READ,
1133            rdmaxcel_sys::ibv_wr_opcode::Type::from(IbvOperation::Read)
1134        );
1135
1136        assert_eq!(
1137            IbvOperation::Write,
1138            IbvOperation::from(rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE)
1139        );
1140        assert_eq!(
1141            IbvOperation::Read,
1142            IbvOperation::from(rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_READ)
1143        );
1144    }
1145
1146    #[test]
1147    fn test_rdma_endpoint() {
1148        let endpoint = IbvQpInfo {
1149            qp_num: 42,
1150            lid: 123,
1151            gid: None,
1152            psn: 0x5678,
1153        };
1154
1155        let debug_str = format!("{:?}", endpoint);
1156        assert!(debug_str.contains("qp_num: 42"));
1157        assert!(debug_str.contains("lid: 123"));
1158        assert!(debug_str.contains("psn: 0x5678"));
1159    }
1160
1161    #[test]
1162    fn test_ibv_wc() {
1163        let mut wc = rdmaxcel_sys::ibv_wc::default();
1164
1165        // SAFETY: modifies private fields through pointer manipulation
1166        unsafe {
1167            // Cast to pointer and modify the fields directly
1168            let wc_ptr = &mut wc as *mut rdmaxcel_sys::ibv_wc as *mut u8;
1169
1170            // Set wr_id (at offset 0, u64)
1171            *(wc_ptr as *mut u64) = 42;
1172
1173            // Set status to SUCCESS (at offset 8, u32)
1174            *(wc_ptr.add(8) as *mut i32) = rdmaxcel_sys::ibv_wc_status::IBV_WC_SUCCESS as i32;
1175        }
1176        let ibv_wc = IbvWc::from(wc);
1177        assert_eq!(ibv_wc.wr_id(), 42);
1178        assert!(ibv_wc.is_valid());
1179    }
1180
1181    #[test]
1182    fn test_format_gid() {
1183        let gid = [
1184            0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66,
1185            0x77, 0x88,
1186        ];
1187
1188        let formatted = format_gid(&gid);
1189        assert_eq!(formatted, "1234:5678:9abc:def0:1122:3344:5566:7788");
1190    }
1191
1192    #[test]
1193    fn test_mlx5dv_supported_basic() {
1194        // The test just verifies the function doesn't panic
1195        let mlx5dv_support = mlx5dv_supported();
1196        println!("mlx5dv_supported: {}", mlx5dv_support);
1197    }
1198}