monarch_rdma/backend/ibverbs/
primitives.rs

1/*
2 * Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9/*
10 * Sections of code adapted from
11 * Copyright (c) 2016 Jon Gjengset under MIT License (MIT)
12*/
13
14//! This file contains primitive data structures for interacting with ibverbs.
15//!
16//! Primitives:
17//! - `IbvConfig`: Represents ibverbs specific configurations, holding parameters required to establish and
18//!   manage an RDMA connection, including settings for the RDMA device, queue pair attributes, and other
19//!   connection-specific parameters.
20//! - `IbvDevice`: Represents an RDMA device, i.e. 'mlx5_0'. Contains information about the device, such as:
21//!   its name, vendor ID, vendor part ID, hardware version, firmware version, node GUID, and capabilities.
22//! - `IbvPort`: Represents information about the port of an RDMA device, including state, physical state,
23//!   LID (Local Identifier), and GID (Global Identifier) information.
24//! - `IbvMemoryRegionView`: Represents a memory region that can be registered with an RDMA device for direct
25//!   memory access operations.
26//! - `IbvOperation`: Represents the type of RDMA operation to perform (Read or Write).
27//! - `IbvQpInfo`: Contains connection information needed to establish an RDMA connection with a remote endpoint.
28//! - `IbvWc`: Wrapper around ibverbs work completion structure, used to track the status of RDMA operations.
29use std::ffi::CStr;
30use std::fmt;
31use std::sync::OnceLock;
32
33use serde::Deserialize;
34use serde::Serialize;
35use typeuri::Named;
36
37#[derive(
38    Default,
39    Copy,
40    Clone,
41    Debug,
42    Eq,
43    PartialEq,
44    Hash,
45    serde::Serialize,
46    serde::Deserialize
47)]
48#[repr(transparent)]
49pub struct Gid {
50    raw: [u8; 16],
51}
52
53impl Gid {
54    #[allow(dead_code)]
55    fn subnet_prefix(&self) -> u64 {
56        u64::from_be_bytes(self.raw[..8].try_into().unwrap())
57    }
58
59    #[allow(dead_code)]
60    fn interface_id(&self) -> u64 {
61        u64::from_be_bytes(self.raw[8..].try_into().unwrap())
62    }
63}
64impl From<rdmaxcel_sys::ibv_gid> for Gid {
65    fn from(gid: rdmaxcel_sys::ibv_gid) -> Self {
66        Self {
67            raw: unsafe { gid.raw },
68        }
69    }
70}
71
72impl From<Gid> for rdmaxcel_sys::ibv_gid {
73    fn from(mut gid: Gid) -> Self {
74        *gid.as_mut()
75    }
76}
77
78impl AsRef<rdmaxcel_sys::ibv_gid> for Gid {
79    fn as_ref(&self) -> &rdmaxcel_sys::ibv_gid {
80        unsafe { &*self.raw.as_ptr().cast::<rdmaxcel_sys::ibv_gid>() }
81    }
82}
83
84impl AsMut<rdmaxcel_sys::ibv_gid> for Gid {
85    fn as_mut(&mut self) -> &mut rdmaxcel_sys::ibv_gid {
86        unsafe { &mut *self.raw.as_mut_ptr().cast::<rdmaxcel_sys::ibv_gid>() }
87    }
88}
89
90/// Queue pair type for RDMA operations.
91///
92/// Controls whether to use standard ibverbs queue pairs, mlx5dv extended queue pairs,
93/// or EFA SRD queue pairs. Auto mode automatically selects based on device capabilities.
94#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
95pub enum IbvQpType {
96    /// Auto-detect based on device capabilities
97    Auto,
98    /// Force standard ibverbs queue pair
99    Standard,
100    /// Force mlx5dv extended queue pair
101    Mlx5dv,
102    /// Force EFA SRD queue pair
103    Efa,
104}
105
106/// Converts `IbvQpType` to the corresponding integer enum value in rdmaxcel_sys.
107pub fn resolve_qp_type(qp_type: IbvQpType) -> u32 {
108    match qp_type {
109        IbvQpType::Auto => {
110            if crate::efa::is_efa_device() {
111                rdmaxcel_sys::RDMA_QP_TYPE_EFA
112            } else if mlx5dv_supported() {
113                rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV
114            } else {
115                rdmaxcel_sys::RDMA_QP_TYPE_STANDARD
116            }
117        }
118        IbvQpType::Standard => rdmaxcel_sys::RDMA_QP_TYPE_STANDARD,
119        IbvQpType::Mlx5dv => rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV,
120        IbvQpType::Efa => rdmaxcel_sys::RDMA_QP_TYPE_EFA,
121    }
122}
123
124/// Represents ibverbs specific configurations.
125///
126/// This struct holds various parameters required to establish and manage an RDMA connection.
127/// It includes settings for the RDMA device, queue pair attributes, and other connection-specific
128/// parameters.
129#[derive(Debug, Named, Clone, Serialize, Deserialize)]
130pub struct IbvConfig {
131    /// `device` - The RDMA device to use for the connection.
132    pub device: IbvDevice,
133    /// `cq_entries` - The number of completion queue entries.
134    pub cq_entries: i32,
135    /// `port_num` - The physical port number on the device.
136    pub port_num: u8,
137    /// `gid_index` - The GID index for the RDMA device.
138    pub gid_index: u8,
139    /// `max_send_wr` - The maximum number of outstanding send work requests.
140    pub max_send_wr: u32,
141    /// `max_recv_wr` - The maximum number of outstanding receive work requests.
142    pub max_recv_wr: u32,
143    /// `max_send_sge` - Te maximum number of scatter/gather elements in a send work request.
144    pub max_send_sge: u32,
145    /// `max_recv_sge` - The maximum number of scatter/gather elements in a receive work request.
146    pub max_recv_sge: u32,
147    /// `path_mtu` - The path MTU (Maximum Transmission Unit) for the connection.
148    pub path_mtu: u32,
149    /// `retry_cnt` - The number of retry attempts for a connection request.
150    pub retry_cnt: u8,
151    /// `rnr_retry` - The number of retry attempts for a receiver not ready (RNR) condition.
152    pub rnr_retry: u8,
153    /// `qp_timeout` - The timeout for a queue pair operation.
154    pub qp_timeout: u8,
155    /// `min_rnr_timer` - The minimum RNR timer value.
156    pub min_rnr_timer: u8,
157    /// `max_dest_rd_atomic` - The maximum number of outstanding RDMA read operations at the destination.
158    pub max_dest_rd_atomic: u8,
159    /// `max_rd_atomic` - The maximum number of outstanding RDMA read operations at the initiator.
160    pub max_rd_atomic: u8,
161    /// `pkey_index` - The partition key index.
162    pub pkey_index: u16,
163    /// `psn` - The packet sequence number.
164    pub psn: u32,
165    /// `use_gpu_direct` - Whether to enable GPU Direct RDMA support on init.
166    pub use_gpu_direct: bool,
167    /// `hw_init_delay_ms` - The delay in milliseconds before initializing the hardware.
168    /// This is used to allow the hardware to settle before starting the first transmission.
169    pub hw_init_delay_ms: u64,
170    /// `qp_type` - The type of queue pair to create (Auto, Standard, or Mlx5dv).
171    pub qp_type: IbvQpType,
172}
173wirevalue::register_type!(IbvConfig);
174
175/// Default RDMA parameters below are based on common values from rdma-core examples
176/// For high-performance or production use, consider tuning
177/// based on ibv_query_device() results and workload characteristics
178impl Default for IbvConfig {
179    fn default() -> Self {
180        let mut config = Self {
181            device: IbvDevice::default(),
182            cq_entries: 1024,
183            port_num: 1,
184            gid_index: 3,
185            max_send_wr: 512,
186            max_recv_wr: 512,
187            max_send_sge: 30,
188            max_recv_sge: 30,
189            path_mtu: rdmaxcel_sys::IBV_MTU_4096,
190            retry_cnt: 7,
191            rnr_retry: 7,
192            qp_timeout: 14, // 4.096 μs * 2^14 = ~67 ms
193            min_rnr_timer: 12,
194            max_dest_rd_atomic: 16,
195            max_rd_atomic: 16,
196            pkey_index: 0,
197            psn: rand::random::<u32>() & 0xffffff,
198            use_gpu_direct: false, // nv_peermem enabled for cuda
199            hw_init_delay_ms: 2,
200            qp_type: IbvQpType::Auto,
201        };
202        if crate::efa::is_efa_device() {
203            crate::efa::apply_efa_defaults(&mut config);
204        }
205        config
206    }
207}
208
209impl IbvConfig {
210    /// Create a new IbvConfig targeting a specific device
211    ///
212    /// Device targets use a unified "type:id" format:
213    /// - "cpu:N" -> finds RDMA device closest to NUMA node N
214    /// - "cuda:N" -> finds RDMA device closest to CUDA device N
215    /// - "nic:mlx5_N" -> returns the specified NIC directly
216    ///
217    /// Shortcuts:
218    /// - "cpu" -> defaults to "cpu:0"
219    /// - "cuda" -> defaults to "cuda:0"
220    ///
221    /// # Arguments
222    ///
223    /// * `target` - Target device specification
224    ///
225    /// # Returns
226    ///
227    /// * `IbvConfig` with resolved device, or default device if resolution fails
228    pub fn targeting(target: &str) -> Self {
229        // Normalize shortcuts
230        let normalized_target = match target {
231            "cpu" => "cpu:0",
232            "cuda" => "cuda:0",
233            _ => target,
234        };
235
236        let device = super::device_selection::select_optimal_ibv_device(Some(normalized_target))
237            .unwrap_or_else(IbvDevice::default);
238
239        Self {
240            device,
241            ..Default::default()
242        }
243    }
244}
245
246impl std::fmt::Display for IbvConfig {
247    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248        write!(
249            f,
250            "IbvConfig {{ device: {}, port_num: {}, gid_index: {}, max_send_wr: {}, max_recv_wr: {}, max_send_sge: {}, max_recv_sge: {}, path_mtu: {:?}, retry_cnt: {}, rnr_retry: {}, qp_timeout: {}, min_rnr_timer: {}, max_dest_rd_atomic: {}, max_rd_atomic: {}, pkey_index: {}, psn: 0x{:x} }}",
251            self.device.name(),
252            self.port_num,
253            self.gid_index,
254            self.max_send_wr,
255            self.max_recv_wr,
256            self.max_send_sge,
257            self.max_recv_sge,
258            self.path_mtu,
259            self.retry_cnt,
260            self.rnr_retry,
261            self.qp_timeout,
262            self.min_rnr_timer,
263            self.max_dest_rd_atomic,
264            self.max_rd_atomic,
265            self.pkey_index,
266            self.psn,
267        )
268    }
269}
270
271/// Represents an RDMA device in the system.
272///
273/// This struct encapsulates information about an RDMA device, including its hardware
274/// characteristics, capabilities, and port information. It provides access to device
275/// attributes such as vendor information, firmware version, and supported features.
276///
277/// # Examples
278///
279/// ```
280/// use monarch_rdma::get_all_devices;
281///
282/// let devices = get_all_devices();
283/// if let Some(device) = devices.first() {
284///     // Access device name and firmware version
285///     let device_name = device.name();
286///     let firmware_version = device.fw_ver();
287/// }
288/// ```
289#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct IbvDevice {
291    /// `name` - The name of the RDMA device (e.g., "mlx5_0").
292    pub name: String,
293    /// `vendor_id` - The vendor ID of the device.
294    vendor_id: u32,
295    /// `vendor_part_id` - The vendor part ID of the device.
296    vendor_part_id: u32,
297    /// `hw_ver` - Hardware version of the device.
298    hw_ver: u32,
299    /// `fw_ver` - Firmware version of the device.
300    fw_ver: String,
301    /// `node_guid` - Node GUID (Globally Unique Identifier) of the device.
302    node_guid: u64,
303    /// `ports` - Vector of ports available on this device.
304    ports: Vec<IbvPort>,
305    /// `max_qp` - Maximum number of queue pairs supported.
306    max_qp: i32,
307    /// `max_cq` - Maximum number of completion queues supported.
308    max_cq: i32,
309    /// `max_mr` - Maximum number of memory regions supported.
310    max_mr: i32,
311    /// `max_pd` - Maximum number of protection domains supported.
312    max_pd: i32,
313    /// `max_qp_wr` - Maximum number of work requests per queue pair.
314    max_qp_wr: i32,
315    /// `max_sge` - Maximum number of scatter/gather elements per work request.
316    max_sge: i32,
317}
318
319impl IbvDevice {
320    /// Returns the name of the RDMA device.
321    pub fn name(&self) -> &String {
322        &self.name
323    }
324
325    /// Returns the first available RDMA device, if any.
326    pub fn first_available() -> Option<IbvDevice> {
327        let devices = get_all_devices();
328        if devices.is_empty() {
329            None
330        } else {
331            Some(devices.into_iter().next().unwrap())
332        }
333    }
334
335    /// Returns the vendor ID of the RDMA device.
336    pub fn vendor_id(&self) -> u32 {
337        self.vendor_id
338    }
339
340    /// Returns the vendor part ID of the RDMA device.
341    pub fn vendor_part_id(&self) -> u32 {
342        self.vendor_part_id
343    }
344
345    /// Returns the hardware version of the RDMA device.
346    pub fn hw_ver(&self) -> u32 {
347        self.hw_ver
348    }
349
350    /// Returns the firmware version of the RDMA device.
351    pub fn fw_ver(&self) -> &String {
352        &self.fw_ver
353    }
354
355    /// Returns the node GUID of the RDMA device.
356    pub fn node_guid(&self) -> u64 {
357        self.node_guid
358    }
359
360    /// Returns a reference to the vector of ports available on the RDMA device.
361    pub fn ports(&self) -> &Vec<IbvPort> {
362        &self.ports
363    }
364
365    /// Returns the maximum number of queue pairs supported by the RDMA device.
366    pub fn max_qp(&self) -> i32 {
367        self.max_qp
368    }
369
370    /// Returns the maximum number of completion queues supported by the RDMA device.
371    pub fn max_cq(&self) -> i32 {
372        self.max_cq
373    }
374
375    /// Returns the maximum number of memory regions supported by the RDMA device.
376    pub fn max_mr(&self) -> i32 {
377        self.max_mr
378    }
379
380    /// Returns the maximum number of protection domains supported by the RDMA device.
381    pub fn max_pd(&self) -> i32 {
382        self.max_pd
383    }
384
385    /// Returns the maximum number of work requests per queue pair supported by the RDMA device.
386    pub fn max_qp_wr(&self) -> i32 {
387        self.max_qp_wr
388    }
389
390    /// Returns the maximum number of scatter/gather elements per work request supported by the RDMA device.
391    pub fn max_sge(&self) -> i32 {
392        self.max_sge
393    }
394}
395
396impl Default for IbvDevice {
397    fn default() -> Self {
398        // Try to get a smart default using device selection logic (defaults to cpu:0)
399        if let Some(device) = super::device_selection::select_optimal_ibv_device(Some("cpu:0")) {
400            device
401        } else {
402            // Fallback to first available device
403            get_all_devices()
404                .into_iter()
405                .next()
406                .unwrap_or_else(|| panic!("No RDMA devices found"))
407        }
408    }
409}
410
411#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct IbvPort {
413    /// `port_num` - The physical port number on the device.
414    port_num: u8,
415    /// `state` - The current state of the port.
416    state: String,
417    /// `physical_state` - The physical state of the port.
418    physical_state: String,
419    /// `base_lid` - Base Local Identifier for the port.
420    base_lid: u16,
421    /// `lmc` - LID Mask Control.
422    lmc: u8,
423    /// `sm_lid` - Subnet Manager Local Identifier.
424    sm_lid: u16,
425    /// `capability_mask` - Capability mask of the port.
426    capability_mask: u32,
427    /// `link_layer` - The link layer type (e.g., InfiniBand, Ethernet).
428    link_layer: String,
429    /// `gid` - Global Identifier for the port.
430    gid: String,
431    /// `gid_tbl_len` - Length of the GID table.
432    gid_tbl_len: i32,
433}
434
435impl fmt::Display for IbvDevice {
436    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
437        writeln!(f, "{}", self.name)?;
438        writeln!(f, "\tNumber of ports: {}", self.ports.len())?;
439        writeln!(f, "\tFirmware version: {}", self.fw_ver)?;
440        writeln!(f, "\tHardware version: {}", self.hw_ver)?;
441        writeln!(f, "\tNode GUID: 0x{:016x}", self.node_guid)?;
442        writeln!(f, "\tVendor ID: 0x{:x}", self.vendor_id)?;
443        writeln!(f, "\tVendor part ID: {}", self.vendor_part_id)?;
444        writeln!(f, "\tMax QPs: {}", self.max_qp)?;
445        writeln!(f, "\tMax CQs: {}", self.max_cq)?;
446        writeln!(f, "\tMax MRs: {}", self.max_mr)?;
447        writeln!(f, "\tMax PDs: {}", self.max_pd)?;
448        writeln!(f, "\tMax QP WRs: {}", self.max_qp_wr)?;
449        writeln!(f, "\tMax SGE: {}", self.max_sge)?;
450
451        for port in &self.ports {
452            write!(f, "{}", port)?;
453        }
454
455        Ok(())
456    }
457}
458
459impl fmt::Display for IbvPort {
460    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
461        writeln!(f, "\tPort {}:", self.port_num)?;
462        writeln!(f, "\t\tState: {}", self.state)?;
463        writeln!(f, "\t\tPhysical state: {}", self.physical_state)?;
464        writeln!(f, "\t\tBase lid: {}", self.base_lid)?;
465        writeln!(f, "\t\tLMC: {}", self.lmc)?;
466        writeln!(f, "\t\tSM lid: {}", self.sm_lid)?;
467        writeln!(f, "\t\tCapability mask: 0x{:08x}", self.capability_mask)?;
468        writeln!(f, "\t\tLink layer: {}", self.link_layer)?;
469        writeln!(f, "\t\tGID: {}", self.gid)?;
470        writeln!(f, "\t\tGID table length: {}", self.gid_tbl_len)?;
471        Ok(())
472    }
473}
474
475/// Converts the given port state to a human-readable string.
476///
477/// # Arguments
478///
479/// * `state` - The port state as defined by `ffi::ibv_port_state::Type`.
480///
481/// # Returns
482///
483/// A string representation of the port state.
484pub fn get_port_state_str(state: rdmaxcel_sys::ibv_port_state::Type) -> String {
485    // SAFETY: We are calling a C function that returns a C string.
486    unsafe {
487        let c_str = rdmaxcel_sys::ibv_port_state_str(state);
488        if c_str.is_null() {
489            return "Unknown".to_string();
490        }
491        CStr::from_ptr(c_str).to_string_lossy().into_owned()
492    }
493}
494
495/// Converts the given physical state to a human-readable string.
496///
497/// # Arguments
498///
499/// * `phys_state` - The physical state as a `u8`.
500///
501/// # Returns
502///
503/// A string representation of the physical state.
504pub fn get_port_phy_state_str(phys_state: u8) -> String {
505    match phys_state {
506        1 => "Sleep".to_string(),
507        2 => "Polling".to_string(),
508        3 => "Disabled".to_string(),
509        4 => "PortConfigurationTraining".to_string(),
510        5 => "LinkUp".to_string(),
511        6 => "LinkErrorRecovery".to_string(),
512        7 => "PhyTest".to_string(),
513        _ => "No state change".to_string(),
514    }
515}
516
517/// Converts the given link layer type to a human-readable string.
518///
519/// # Arguments
520///
521/// * `link_layer` - The link layer type as a `u8`.
522///
523/// # Returns
524///
525/// A string representation of the link layer type.
526pub fn get_link_layer_str(link_layer: u8) -> String {
527    match link_layer {
528        1 => "InfiniBand".to_string(),
529        2 => "Ethernet".to_string(),
530        _ => "Unknown".to_string(),
531    }
532}
533
534/// Formats a GID (Global Identifier) into a human-readable string.
535///
536/// # Arguments
537///
538/// * `gid` - A reference to a 16-byte array representing the GID.
539///
540/// # Returns
541///
542/// A formatted string representation of the GID.
543pub fn format_gid(gid: &[u8; 16]) -> String {
544    format!(
545        "{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}",
546        gid[0],
547        gid[1],
548        gid[2],
549        gid[3],
550        gid[4],
551        gid[5],
552        gid[6],
553        gid[7],
554        gid[8],
555        gid[9],
556        gid[10],
557        gid[11],
558        gid[12],
559        gid[13],
560        gid[14],
561        gid[15]
562    )
563}
564
565/// Retrieves information about all available RDMA devices in the system.
566///
567/// This function queries the system for all available RDMA devices and returns
568/// detailed information about each device, including its capabilities, ports,
569/// and attributes.
570///
571/// # Returns
572///
573/// A vector of `IbvDevice` structures, each representing an RDMA device in the system.
574/// Returns an empty vector if no devices are found or if there was an error querying
575/// the devices.
576pub fn get_all_devices() -> Vec<IbvDevice> {
577    let mut devices = Vec::new();
578
579    // SAFETY: We are calling several C functions from libibverbs.
580    unsafe {
581        let mut num_devices = 0;
582        let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices);
583        if device_list.is_null() || num_devices == 0 {
584            return devices;
585        }
586
587        for i in 0..num_devices {
588            let device = *device_list.add(i as usize);
589            if device.is_null() {
590                continue;
591            }
592
593            let context = rdmaxcel_sys::ibv_open_device(device);
594            if context.is_null() {
595                continue;
596            }
597
598            let device_name = CStr::from_ptr(rdmaxcel_sys::ibv_get_device_name(device))
599                .to_string_lossy()
600                .into_owned();
601
602            let mut device_attr = rdmaxcel_sys::ibv_device_attr::default();
603            if rdmaxcel_sys::ibv_query_device(context, &mut device_attr) != 0 {
604                rdmaxcel_sys::ibv_close_device(context);
605                continue;
606            }
607
608            let fw_ver = CStr::from_ptr(device_attr.fw_ver.as_ptr())
609                .to_string_lossy()
610                .into_owned();
611
612            let mut rdma_device = IbvDevice {
613                name: device_name,
614                vendor_id: device_attr.vendor_id,
615                vendor_part_id: device_attr.vendor_part_id,
616                hw_ver: device_attr.hw_ver,
617                fw_ver,
618                node_guid: device_attr.node_guid,
619                ports: Vec::new(),
620                max_qp: device_attr.max_qp,
621                max_cq: device_attr.max_cq,
622                max_mr: device_attr.max_mr,
623                max_pd: device_attr.max_pd,
624                max_qp_wr: device_attr.max_qp_wr,
625                max_sge: device_attr.max_sge,
626            };
627
628            for port_num in 1..=device_attr.phys_port_cnt {
629                let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
630                if rdmaxcel_sys::ibv_query_port(
631                    context,
632                    port_num,
633                    &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
634                ) != 0
635                {
636                    continue;
637                }
638                let state = get_port_state_str(port_attr.state);
639                let physical_state = get_port_phy_state_str(port_attr.phys_state);
640
641                let link_layer = get_link_layer_str(port_attr.link_layer);
642
643                let mut gid = rdmaxcel_sys::ibv_gid::default();
644                let gid_str = if rdmaxcel_sys::ibv_query_gid(context, port_num, 0, &mut gid) == 0 {
645                    format_gid(&gid.raw)
646                } else {
647                    "N/A".to_string()
648                };
649
650                let rdma_port = IbvPort {
651                    port_num,
652                    state,
653                    physical_state,
654                    base_lid: port_attr.lid,
655                    lmc: port_attr.lmc,
656                    sm_lid: port_attr.sm_lid,
657                    capability_mask: port_attr.port_cap_flags,
658                    link_layer,
659                    gid: gid_str,
660                    gid_tbl_len: port_attr.gid_tbl_len,
661                };
662
663                rdma_device.ports.push(rdma_port);
664            }
665
666            devices.push(rdma_device);
667            rdmaxcel_sys::ibv_close_device(context);
668        }
669
670        rdmaxcel_sys::ibv_free_device_list(device_list);
671    }
672
673    devices
674}
675
676/// Cached result of mlx5dv support check.
677static MLX5DV_SUPPORTED_CACHE: OnceLock<bool> = OnceLock::new();
678
679/// Checks if mlx5dv (Mellanox device-specific verbs extension) is supported.
680///
681/// This function attempts to open the first available RDMA device and check if
682/// mlx5dv extensions can be initialized. The mlx5dv extensions are required for
683/// advanced features like GPU Direct RDMA and direct queue pair manipulation.
684///
685/// The result is cached after the first call, making subsequent calls essentially free.
686///
687/// # Returns
688///
689/// `true` if mlx5dv extensions are supported, `false` otherwise.
690pub fn mlx5dv_supported() -> bool {
691    *MLX5DV_SUPPORTED_CACHE.get_or_init(mlx5dv_supported_impl)
692}
693
694fn mlx5dv_supported_impl() -> bool {
695    // SAFETY: We are calling C functions from libibverbs and libmlx5.
696    unsafe {
697        let mut mlx5dv_supported = false;
698        let mut num_devices = 0;
699        let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices);
700        if !device_list.is_null() && num_devices > 0 {
701            let device = *device_list;
702            if !device.is_null() {
703                mlx5dv_supported = rdmaxcel_sys::mlx5dv_is_supported(device);
704            }
705            rdmaxcel_sys::ibv_free_device_list(device_list);
706        }
707        mlx5dv_supported
708    }
709}
710
711/// Cached result of ibverbs support check.
712static IBVERBS_SUPPORTED_CACHE: OnceLock<bool> = OnceLock::new();
713
714/// Checks if ibverbs devices can be retrieved successfully.
715///
716/// This function attempts to retrieve the list of RDMA devices using the
717/// `ibv_get_device_list` function from the ibverbs library. It returns `true`
718/// if devices are found, and `false` otherwise.
719///
720/// The result is cached after the first call, making subsequent calls essentially free.
721///
722/// # Returns
723///
724/// `true` if devices are successfully retrieved, `false` otherwise.
725pub fn ibverbs_supported() -> bool {
726    *IBVERBS_SUPPORTED_CACHE.get_or_init(ibverbs_supported_impl)
727}
728
729fn ibverbs_supported_impl() -> bool {
730    // SAFETY: We are calling a C function from libibverbs.
731    unsafe {
732        let mut num_devices = 0;
733        let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices);
734        if !device_list.is_null() {
735            rdmaxcel_sys::ibv_free_device_list(device_list);
736        }
737        num_devices > 0
738    }
739}
740
741/// Represents a view of a memory region that can be registered with an RDMA device.
742///
743/// This is a 'view' of a registered Memory Region, allowing multiple views into a single
744/// large MR registration. This is commonly used with PyTorch's caching allocator, which
745/// reserves large memory blocks and provides different data pointers into that space.
746///
747/// # Example
748/// PyTorch Caching Allocator creates a 16GB segment at virtual address `0x01000000`.
749/// The underlying Memory Region registers 16GB but at RDMA address `0x0`.
750/// To access virtual address `0x01100000`, we return a view at RDMA address `0x100000`.
751///
752/// # Safety
753/// The caller must ensure the memory remains valid and is not freed, moved, or
754/// overwritten while RDMA operations are in progress.
755
756#[derive(
757    Debug,
758    PartialEq,
759    Eq,
760    std::hash::Hash,
761    Serialize,
762    Deserialize,
763    Clone,
764    Copy
765)]
766pub struct IbvMemoryRegionView {
767    // id should be unique with a given rdmam manager
768    pub id: usize,
769    /// Virtual address in the process address space.
770    /// This is the pointer/address as seen by the local process.
771    pub virtual_addr: usize,
772    /// Memory address assigned after Memory Region (MR) registration.
773    /// This is the address may be offset a base MR addr.
774    pub rdma_addr: usize,
775    pub size: usize,
776    pub lkey: u32,
777    pub rkey: u32,
778}
779
780// SAFETY: IbvMemoryRegionView can be safely sent between threads because it only
781// contains address and size information without any thread-local state. However,
782// this DOES NOT provide any protection against data races in the underlying memory.
783// If one thread initiates an RDMA operation while another thread modifies the same
784// memory region, undefined behavior will occur. The caller is responsible for proper
785// synchronization of access to the underlying memory.
786unsafe impl Send for IbvMemoryRegionView {}
787
788// SAFETY: IbvMemoryRegionView is safe for concurrent access by multiple threads
789// as it only provides a view into memory without modifying its own state. However,
790// it provides NO PROTECTION against concurrent access to the underlying memory region.
791// The caller must ensure proper synchronization when:
792// 1. Initiating RDMA operations while local code reads/writes the same memory
793// 2. Performing multiple overlapping RDMA operations on the same memory region
794// 3. Freeing or reallocating memory that has in-flight RDMA operations
795unsafe impl Sync for IbvMemoryRegionView {}
796
797impl IbvMemoryRegionView {
798    /// Creates a new `IbvMemoryRegionView` with the given address and size.
799    pub fn new(
800        id: usize,
801        virtual_addr: usize,
802        rdma_addr: usize,
803        size: usize,
804        lkey: u32,
805        rkey: u32,
806    ) -> Self {
807        Self {
808            id,
809            virtual_addr,
810            rdma_addr,
811            size,
812            lkey,
813            rkey,
814        }
815    }
816}
817
818/// Enum representing the common RDMA operations.
819///
820/// This provides a more ergonomic interface to the underlying ibv_wr_opcode types.
821/// RDMA operations allow for direct memory access between two machines without
822/// involving the CPU of the target machine.
823///
824/// # Variants
825///
826/// * `Write` - Represents an RDMA write operation where data is written from the local
827///   memory to a remote memory region.
828/// * `Read` - Represents an RDMA read operation where data is read from a remote memory
829///   region into the local memory.
830#[derive(Debug, Clone, Copy, PartialEq, Eq)]
831pub enum IbvOperation {
832    /// RDMA write operations
833    Write,
834    WriteWithImm,
835    /// RDMA read operation
836    Read,
837    /// RDMA recv operation
838    Recv,
839}
840
841impl From<IbvOperation> for rdmaxcel_sys::ibv_wr_opcode::Type {
842    fn from(op: IbvOperation) -> Self {
843        match op {
844            IbvOperation::Write => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE,
845            IbvOperation::WriteWithImm => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE_WITH_IMM,
846            IbvOperation::Read => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_READ,
847            IbvOperation::Recv => panic!("Invalid wr opcode"),
848        }
849    }
850}
851
852impl From<rdmaxcel_sys::ibv_wc_opcode::Type> for IbvOperation {
853    fn from(op: rdmaxcel_sys::ibv_wc_opcode::Type) -> Self {
854        match op {
855            rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE => IbvOperation::Write,
856            rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_READ => IbvOperation::Read,
857            _ => panic!("Unsupported operation type"),
858        }
859    }
860}
861
862/// Contains information needed to establish an RDMA queue pair with a remote endpoint.
863///
864/// `IbvQpInfo` encapsulates all the necessary information to establish a queue pair
865/// with a remote RDMA device. This includes queue pair number, LID (Local Identifier),
866/// GID (Global Identifier), remote memory address, remote key, and packet sequence number.
867#[derive(Default, Named, Clone, serde::Serialize, serde::Deserialize)]
868pub struct IbvQpInfo {
869    /// `qp_num` - Queue Pair Number, uniquely identifies a queue pair on the remote device
870    pub qp_num: u32,
871    /// `lid` - Local Identifier, used for addressing in InfiniBand subnet
872    pub lid: u16,
873    /// `gid` - Global Identifier, used for routing across subnets (similar to IPv6 address)
874    pub gid: Option<Gid>,
875    /// `psn` - Packet Sequence Number, used for ordering packets
876    pub psn: u32,
877}
878wirevalue::register_type!(IbvQpInfo);
879
880impl std::fmt::Debug for IbvQpInfo {
881    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
882        write!(
883            f,
884            "IbvQpInfo {{ qp_num: {}, lid: {}, gid: {:?}, psn: 0x{:x} }}",
885            self.qp_num, self.lid, self.gid, self.psn
886        )
887    }
888}
889
890/// Wrapper around ibv_wc (ibverbs work completion).
891///
892/// This exposes only the public fields of rdmaxcel_sys::ibv_wc, allowing us to more easily
893/// interact with it from Rust. Work completions are used to track the status of
894/// RDMA operations and are generated when an operation completes.
895#[derive(Debug, Named, Clone, serde::Serialize, serde::Deserialize)]
896pub struct IbvWc {
897    /// `wr_id` - Work Request ID, used to identify the completed operation
898    wr_id: u64,
899    /// `len` - Length of the data transferred
900    len: usize,
901    /// `valid` - Whether the work completion is valid
902    valid: bool,
903    /// `error` - Error information if the operation failed
904    error: Option<(rdmaxcel_sys::ibv_wc_status::Type, u32)>,
905    /// `opcode` - Type of operation that completed (read, write, etc.)
906    opcode: rdmaxcel_sys::ibv_wc_opcode::Type,
907    /// `bytes` - Immediate data (if any)
908    bytes: Option<u32>,
909    /// `qp_num` - Queue Pair Number
910    qp_num: u32,
911    /// `src_qp` - Source Queue Pair Number
912    src_qp: u32,
913    /// `pkey_index` - Partition Key Index
914    pkey_index: u16,
915    /// `slid` - Source LID
916    slid: u16,
917    /// `sl` - Service Level
918    sl: u8,
919    /// `dlid_path_bits` - Destination LID Path Bits
920    dlid_path_bits: u8,
921}
922wirevalue::register_type!(IbvWc);
923
924impl From<rdmaxcel_sys::ibv_wc> for IbvWc {
925    fn from(wc: rdmaxcel_sys::ibv_wc) -> Self {
926        IbvWc {
927            wr_id: wc.wr_id(),
928            len: wc.len(),
929            valid: wc.is_valid(),
930            error: wc.error(),
931            opcode: wc.opcode(),
932            bytes: wc.imm_data(),
933            qp_num: wc.qp_num,
934            src_qp: wc.src_qp,
935            pkey_index: wc.pkey_index,
936            slid: wc.slid,
937            sl: wc.sl,
938            dlid_path_bits: wc.dlid_path_bits,
939        }
940    }
941}
942
943impl IbvWc {
944    /// Returns the Work Request ID associated with this work completion.
945    ///
946    /// The Work Request ID is used to identify the specific operation that completed.
947    /// It is set by the application when posting the work request and is returned
948    /// unchanged in the work completion.
949    pub fn wr_id(&self) -> u64 {
950        self.wr_id
951    }
952
953    /// Returns whether this work completion is valid.
954    ///
955    /// A valid work completion indicates that the operation completed successfully.
956    /// If false, the `error` field may contain additional information about the failure.
957    pub fn is_valid(&self) -> bool {
958        self.valid
959    }
960}
961
962#[cfg(test)]
963mod tests {
964    use super::*;
965
966    #[test]
967    fn test_get_all_devices() {
968        // Skip test if RDMA devices are not available
969        let devices = get_all_devices();
970        if devices.is_empty() {
971            println!("Skipping test: RDMA devices not available");
972            return;
973        }
974        // Basic validation of first device
975        let device = &devices[0];
976        assert!(!device.name().is_empty(), "device name should not be empty");
977        assert!(
978            !device.ports().is_empty(),
979            "device should have at least one port"
980        );
981    }
982
983    #[test]
984    fn test_first_available() {
985        // Skip test if RDMA is not available
986        let devices = get_all_devices();
987        if devices.is_empty() {
988            println!("Skipping test: RDMA devices not available");
989            return;
990        }
991        // Basic validation of first device
992        let device = &devices[0];
993
994        let dev = device;
995        // Verify getters return expected values
996        assert_eq!(dev.vendor_id(), dev.vendor_id);
997        assert_eq!(dev.vendor_part_id(), dev.vendor_part_id);
998        assert_eq!(dev.hw_ver(), dev.hw_ver);
999        assert_eq!(dev.fw_ver(), &dev.fw_ver);
1000        assert_eq!(dev.node_guid(), dev.node_guid);
1001        assert_eq!(dev.max_qp(), dev.max_qp);
1002        assert_eq!(dev.max_cq(), dev.max_cq);
1003        assert_eq!(dev.max_mr(), dev.max_mr);
1004        assert_eq!(dev.max_pd(), dev.max_pd);
1005        assert_eq!(dev.max_qp_wr(), dev.max_qp_wr);
1006        assert_eq!(dev.max_sge(), dev.max_sge);
1007    }
1008
1009    #[test]
1010    fn test_device_display() {
1011        if let Some(device) = IbvDevice::first_available() {
1012            let display_output = format!("{}", device);
1013            assert!(
1014                display_output.contains(&device.name),
1015                "display should include device name"
1016            );
1017            assert!(
1018                display_output.contains(&device.fw_ver),
1019                "display should include firmware version"
1020            );
1021        }
1022    }
1023
1024    #[test]
1025    fn test_port_display() {
1026        if let Some(device) = IbvDevice::first_available() {
1027            if !device.ports().is_empty() {
1028                let port = &device.ports()[0];
1029                let display_output = format!("{}", port);
1030                assert!(
1031                    display_output.contains(&port.state),
1032                    "display should include port state"
1033                );
1034                assert!(
1035                    display_output.contains(&port.link_layer),
1036                    "display should include link layer"
1037                );
1038            }
1039        }
1040    }
1041
1042    #[test]
1043    fn test_rdma_operation_conversion() {
1044        assert_eq!(
1045            rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE,
1046            rdmaxcel_sys::ibv_wr_opcode::Type::from(IbvOperation::Write)
1047        );
1048        assert_eq!(
1049            rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_READ,
1050            rdmaxcel_sys::ibv_wr_opcode::Type::from(IbvOperation::Read)
1051        );
1052
1053        assert_eq!(
1054            IbvOperation::Write,
1055            IbvOperation::from(rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE)
1056        );
1057        assert_eq!(
1058            IbvOperation::Read,
1059            IbvOperation::from(rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_READ)
1060        );
1061    }
1062
1063    #[test]
1064    fn test_rdma_endpoint() {
1065        let endpoint = IbvQpInfo {
1066            qp_num: 42,
1067            lid: 123,
1068            gid: None,
1069            psn: 0x5678,
1070        };
1071
1072        let debug_str = format!("{:?}", endpoint);
1073        assert!(debug_str.contains("qp_num: 42"));
1074        assert!(debug_str.contains("lid: 123"));
1075        assert!(debug_str.contains("psn: 0x5678"));
1076    }
1077
1078    #[test]
1079    fn test_ibv_wc() {
1080        let mut wc = rdmaxcel_sys::ibv_wc::default();
1081
1082        // SAFETY: modifies private fields through pointer manipulation
1083        unsafe {
1084            // Cast to pointer and modify the fields directly
1085            let wc_ptr = &mut wc as *mut rdmaxcel_sys::ibv_wc as *mut u8;
1086
1087            // Set wr_id (at offset 0, u64)
1088            *(wc_ptr as *mut u64) = 42;
1089
1090            // Set status to SUCCESS (at offset 8, u32)
1091            *(wc_ptr.add(8) as *mut i32) = rdmaxcel_sys::ibv_wc_status::IBV_WC_SUCCESS as i32;
1092        }
1093        let ibv_wc = IbvWc::from(wc);
1094        assert_eq!(ibv_wc.wr_id(), 42);
1095        assert!(ibv_wc.is_valid());
1096    }
1097
1098    #[test]
1099    fn test_format_gid() {
1100        let gid = [
1101            0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66,
1102            0x77, 0x88,
1103        ];
1104
1105        let formatted = format_gid(&gid);
1106        assert_eq!(formatted, "1234:5678:9abc:def0:1122:3344:5566:7788");
1107    }
1108
1109    #[test]
1110    fn test_mlx5dv_supported_basic() {
1111        // The test just verifies the function doesn't panic
1112        let mlx5dv_support = mlx5dv_supported();
1113        println!("mlx5dv_supported: {}", mlx5dv_support);
1114    }
1115}