monarch_rdma/
ibverbs_primitives.rs

1/*
2 * Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9/*
10 * Sections of code adapted from
11 * Copyright (c) 2016 Jon Gjengset under MIT License (MIT)
12*/
13
14//! This file contains primitive data structures for interacting with ibverbs.
15//!
16//! Primitives:
17//! - `IbverbsConfig`: Represents ibverbs specific configurations, holding parameters required to establish and
18//!   manage an RDMA connection, including settings for the RDMA device, queue pair attributes, and other
19//!   connection-specific parameters.
20//! - `RdmaDevice`: Represents an RDMA device, i.e. 'mlx5_0'. Contains information about the device, such as:
21//!   its name, vendor ID, vendor part ID, hardware version, firmware version, node GUID, and capabilities.
22//! - `RdmaPort`: Represents information about the port of an RDMA device, including state, physical state,
23//!   LID (Local Identifier), and GID (Global Identifier) information.
24//! - `RdmaMemoryRegionView`: Represents a memory region that can be registered with an RDMA device for direct
25//!   memory access operations.
26//! - `RdmaOperation`: Represents the type of RDMA operation to perform (Read or Write).
27//! - `RdmaQpInfo`: Contains connection information needed to establish an RDMA connection with a remote endpoint.
28//! - `IbvWc`: Wrapper around ibverbs work completion structure, used to track the status of RDMA operations.
29use std::ffi::CStr;
30use std::fmt;
31use std::sync::OnceLock;
32
33use serde::Deserialize;
34use serde::Serialize;
35use typeuri::Named;
36
37#[derive(
38    Default,
39    Copy,
40    Clone,
41    Debug,
42    Eq,
43    PartialEq,
44    Hash,
45    serde::Serialize,
46    serde::Deserialize
47)]
48#[repr(transparent)]
49pub struct Gid {
50    raw: [u8; 16],
51}
52
53impl Gid {
54    #[allow(dead_code)]
55    fn subnet_prefix(&self) -> u64 {
56        u64::from_be_bytes(self.raw[..8].try_into().unwrap())
57    }
58
59    #[allow(dead_code)]
60    fn interface_id(&self) -> u64 {
61        u64::from_be_bytes(self.raw[8..].try_into().unwrap())
62    }
63}
64impl From<rdmaxcel_sys::ibv_gid> for Gid {
65    fn from(gid: rdmaxcel_sys::ibv_gid) -> Self {
66        Self {
67            raw: unsafe { gid.raw },
68        }
69    }
70}
71
72impl From<Gid> for rdmaxcel_sys::ibv_gid {
73    fn from(mut gid: Gid) -> Self {
74        *gid.as_mut()
75    }
76}
77
78impl AsRef<rdmaxcel_sys::ibv_gid> for Gid {
79    fn as_ref(&self) -> &rdmaxcel_sys::ibv_gid {
80        unsafe { &*self.raw.as_ptr().cast::<rdmaxcel_sys::ibv_gid>() }
81    }
82}
83
84impl AsMut<rdmaxcel_sys::ibv_gid> for Gid {
85    fn as_mut(&mut self) -> &mut rdmaxcel_sys::ibv_gid {
86        unsafe { &mut *self.raw.as_mut_ptr().cast::<rdmaxcel_sys::ibv_gid>() }
87    }
88}
89
90/// Queue pair type for RDMA operations.
91///
92/// Controls whether to use standard ibverbs queue pairs or mlx5dv extended queue pairs.
93/// Auto mode automatically selects based on device capabilities.
94#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
95pub enum RdmaQpType {
96    /// Auto-detect based on device capabilities
97    Auto,
98    /// Force standard ibverbs queue pair
99    Standard,
100    /// Force mlx5dv extended queue pair
101    Mlx5dv,
102}
103
104/// Converts `RdmaQpType` to the corresponding integer enum value in rdmaxcel_sys.
105pub fn resolve_qp_type(qp_type: RdmaQpType) -> u32 {
106    match qp_type {
107        RdmaQpType::Auto => {
108            if mlx5dv_supported() {
109                rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV
110            } else {
111                rdmaxcel_sys::RDMA_QP_TYPE_STANDARD
112            }
113        }
114        RdmaQpType::Standard => rdmaxcel_sys::RDMA_QP_TYPE_STANDARD,
115        RdmaQpType::Mlx5dv => rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV,
116    }
117}
118
119/// Represents ibverbs specific configurations.
120///
121/// This struct holds various parameters required to establish and manage an RDMA connection.
122/// It includes settings for the RDMA device, queue pair attributes, and other connection-specific
123/// parameters.
124#[derive(Debug, Named, Clone, Serialize, Deserialize)]
125pub struct IbverbsConfig {
126    /// `device` - The RDMA device to use for the connection.
127    pub device: RdmaDevice,
128    /// `cq_entries` - The number of completion queue entries.
129    pub cq_entries: i32,
130    /// `port_num` - The physical port number on the device.
131    pub port_num: u8,
132    /// `gid_index` - The GID index for the RDMA device.
133    pub gid_index: u8,
134    /// `max_send_wr` - The maximum number of outstanding send work requests.
135    pub max_send_wr: u32,
136    /// `max_recv_wr` - The maximum number of outstanding receive work requests.
137    pub max_recv_wr: u32,
138    /// `max_send_sge` - Te maximum number of scatter/gather elements in a send work request.
139    pub max_send_sge: u32,
140    /// `max_recv_sge` - The maximum number of scatter/gather elements in a receive work request.
141    pub max_recv_sge: u32,
142    /// `path_mtu` - The path MTU (Maximum Transmission Unit) for the connection.
143    pub path_mtu: u32,
144    /// `retry_cnt` - The number of retry attempts for a connection request.
145    pub retry_cnt: u8,
146    /// `rnr_retry` - The number of retry attempts for a receiver not ready (RNR) condition.
147    pub rnr_retry: u8,
148    /// `qp_timeout` - The timeout for a queue pair operation.
149    pub qp_timeout: u8,
150    /// `min_rnr_timer` - The minimum RNR timer value.
151    pub min_rnr_timer: u8,
152    /// `max_dest_rd_atomic` - The maximum number of outstanding RDMA read operations at the destination.
153    pub max_dest_rd_atomic: u8,
154    /// `max_rd_atomic` - The maximum number of outstanding RDMA read operations at the initiator.
155    pub max_rd_atomic: u8,
156    /// `pkey_index` - The partition key index.
157    pub pkey_index: u16,
158    /// `psn` - The packet sequence number.
159    pub psn: u32,
160    /// `use_gpu_direct` - Whether to enable GPU Direct RDMA support on init.
161    pub use_gpu_direct: bool,
162    /// `hw_init_delay_ms` - The delay in milliseconds before initializing the hardware.
163    /// This is used to allow the hardware to settle before starting the first transmission.
164    pub hw_init_delay_ms: u64,
165    /// `qp_type` - The type of queue pair to create (Auto, Standard, or Mlx5dv).
166    pub qp_type: RdmaQpType,
167}
168wirevalue::register_type!(IbverbsConfig);
169
170/// Default RDMA parameters below are based on common values from rdma-core examples
171/// For high-performance or production use, consider tuning
172/// based on ibv_query_device() results and workload characteristics
173impl Default for IbverbsConfig {
174    fn default() -> Self {
175        Self {
176            device: RdmaDevice::default(),
177            cq_entries: 1024,
178            port_num: 1,
179            gid_index: 3,
180            max_send_wr: 512,
181            max_recv_wr: 512,
182            max_send_sge: 30,
183            max_recv_sge: 30,
184            path_mtu: rdmaxcel_sys::IBV_MTU_4096,
185            retry_cnt: 7,
186            rnr_retry: 7,
187            qp_timeout: 14, // 4.096 μs * 2^14 = ~67 ms
188            min_rnr_timer: 12,
189            max_dest_rd_atomic: 16,
190            max_rd_atomic: 16,
191            pkey_index: 0,
192            psn: rand::random::<u32>() & 0xffffff,
193            use_gpu_direct: false, // nv_peermem enabled for cuda
194            hw_init_delay_ms: 2,
195            qp_type: RdmaQpType::Auto,
196        }
197    }
198}
199
200impl IbverbsConfig {
201    /// Create a new IbverbsConfig targeting a specific device
202    ///
203    /// Device targets use a unified "type:id" format:
204    /// - "cpu:N" -> finds RDMA device closest to NUMA node N
205    /// - "cuda:N" -> finds RDMA device closest to CUDA device N
206    /// - "nic:mlx5_N" -> returns the specified NIC directly
207    ///
208    /// Shortcuts:
209    /// - "cpu" -> defaults to "cpu:0"
210    /// - "cuda" -> defaults to "cuda:0"
211    ///
212    /// # Arguments
213    ///
214    /// * `target` - Target device specification
215    ///
216    /// # Returns
217    ///
218    /// * `IbverbsConfig` with resolved device, or default device if resolution fails
219    pub fn targeting(target: &str) -> Self {
220        // Normalize shortcuts
221        let normalized_target = match target {
222            "cpu" => "cpu:0",
223            "cuda" => "cuda:0",
224            _ => target,
225        };
226
227        let device = crate::device_selection::select_optimal_rdma_device(Some(normalized_target))
228            .unwrap_or_else(RdmaDevice::default);
229
230        Self {
231            device,
232            ..Default::default()
233        }
234    }
235}
236
237impl std::fmt::Display for IbverbsConfig {
238    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        write!(
240            f,
241            "IbverbsConfig {{ device: {}, port_num: {}, gid_index: {}, max_send_wr: {}, max_recv_wr: {}, max_send_sge: {}, max_recv_sge: {}, path_mtu: {:?}, retry_cnt: {}, rnr_retry: {}, qp_timeout: {}, min_rnr_timer: {}, max_dest_rd_atomic: {}, max_rd_atomic: {}, pkey_index: {}, psn: 0x{:x} }}",
242            self.device.name(),
243            self.port_num,
244            self.gid_index,
245            self.max_send_wr,
246            self.max_recv_wr,
247            self.max_send_sge,
248            self.max_recv_sge,
249            self.path_mtu,
250            self.retry_cnt,
251            self.rnr_retry,
252            self.qp_timeout,
253            self.min_rnr_timer,
254            self.max_dest_rd_atomic,
255            self.max_rd_atomic,
256            self.pkey_index,
257            self.psn,
258        )
259    }
260}
261
262/// Represents an RDMA device in the system.
263///
264/// This struct encapsulates information about an RDMA device, including its hardware
265/// characteristics, capabilities, and port information. It provides access to device
266/// attributes such as vendor information, firmware version, and supported features.
267///
268/// # Examples
269///
270/// ```
271/// use monarch_rdma::get_all_devices;
272///
273/// let devices = get_all_devices();
274/// if let Some(device) = devices.first() {
275///     // Access device name and firmware version
276///     let device_name = device.name();
277///     let firmware_version = device.fw_ver();
278/// }
279/// ```
280#[derive(Debug, Clone, Serialize, Deserialize)]
281pub struct RdmaDevice {
282    /// `name` - The name of the RDMA device (e.g., "mlx5_0").
283    pub name: String,
284    /// `vendor_id` - The vendor ID of the device.
285    vendor_id: u32,
286    /// `vendor_part_id` - The vendor part ID of the device.
287    vendor_part_id: u32,
288    /// `hw_ver` - Hardware version of the device.
289    hw_ver: u32,
290    /// `fw_ver` - Firmware version of the device.
291    fw_ver: String,
292    /// `node_guid` - Node GUID (Globally Unique Identifier) of the device.
293    node_guid: u64,
294    /// `ports` - Vector of ports available on this device.
295    ports: Vec<RdmaPort>,
296    /// `max_qp` - Maximum number of queue pairs supported.
297    max_qp: i32,
298    /// `max_cq` - Maximum number of completion queues supported.
299    max_cq: i32,
300    /// `max_mr` - Maximum number of memory regions supported.
301    max_mr: i32,
302    /// `max_pd` - Maximum number of protection domains supported.
303    max_pd: i32,
304    /// `max_qp_wr` - Maximum number of work requests per queue pair.
305    max_qp_wr: i32,
306    /// `max_sge` - Maximum number of scatter/gather elements per work request.
307    max_sge: i32,
308}
309
310impl RdmaDevice {
311    /// Returns the name of the RDMA device.
312    pub fn name(&self) -> &String {
313        &self.name
314    }
315
316    /// Returns the first available RDMA device, if any.
317    pub fn first_available() -> Option<RdmaDevice> {
318        let devices = get_all_devices();
319        if devices.is_empty() {
320            None
321        } else {
322            Some(devices.into_iter().next().unwrap())
323        }
324    }
325
326    /// Returns the vendor ID of the RDMA device.
327    pub fn vendor_id(&self) -> u32 {
328        self.vendor_id
329    }
330
331    /// Returns the vendor part ID of the RDMA device.
332    pub fn vendor_part_id(&self) -> u32 {
333        self.vendor_part_id
334    }
335
336    /// Returns the hardware version of the RDMA device.
337    pub fn hw_ver(&self) -> u32 {
338        self.hw_ver
339    }
340
341    /// Returns the firmware version of the RDMA device.
342    pub fn fw_ver(&self) -> &String {
343        &self.fw_ver
344    }
345
346    /// Returns the node GUID of the RDMA device.
347    pub fn node_guid(&self) -> u64 {
348        self.node_guid
349    }
350
351    /// Returns a reference to the vector of ports available on the RDMA device.
352    pub fn ports(&self) -> &Vec<RdmaPort> {
353        &self.ports
354    }
355
356    /// Returns the maximum number of queue pairs supported by the RDMA device.
357    pub fn max_qp(&self) -> i32 {
358        self.max_qp
359    }
360
361    /// Returns the maximum number of completion queues supported by the RDMA device.
362    pub fn max_cq(&self) -> i32 {
363        self.max_cq
364    }
365
366    /// Returns the maximum number of memory regions supported by the RDMA device.
367    pub fn max_mr(&self) -> i32 {
368        self.max_mr
369    }
370
371    /// Returns the maximum number of protection domains supported by the RDMA device.
372    pub fn max_pd(&self) -> i32 {
373        self.max_pd
374    }
375
376    /// Returns the maximum number of work requests per queue pair supported by the RDMA device.
377    pub fn max_qp_wr(&self) -> i32 {
378        self.max_qp_wr
379    }
380
381    /// Returns the maximum number of scatter/gather elements per work request supported by the RDMA device.
382    pub fn max_sge(&self) -> i32 {
383        self.max_sge
384    }
385}
386
387impl Default for RdmaDevice {
388    fn default() -> Self {
389        // Try to get a smart default using device selection logic (defaults to cpu:0)
390        if let Some(device) = crate::device_selection::select_optimal_rdma_device(Some("cpu:0")) {
391            device
392        } else {
393            // Fallback to first available device
394            get_all_devices()
395                .into_iter()
396                .next()
397                .unwrap_or_else(|| panic!("No RDMA devices found"))
398        }
399    }
400}
401
402#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct RdmaPort {
404    /// `port_num` - The physical port number on the device.
405    port_num: u8,
406    /// `state` - The current state of the port.
407    state: String,
408    /// `physical_state` - The physical state of the port.
409    physical_state: String,
410    /// `base_lid` - Base Local Identifier for the port.
411    base_lid: u16,
412    /// `lmc` - LID Mask Control.
413    lmc: u8,
414    /// `sm_lid` - Subnet Manager Local Identifier.
415    sm_lid: u16,
416    /// `capability_mask` - Capability mask of the port.
417    capability_mask: u32,
418    /// `link_layer` - The link layer type (e.g., InfiniBand, Ethernet).
419    link_layer: String,
420    /// `gid` - Global Identifier for the port.
421    gid: String,
422    /// `gid_tbl_len` - Length of the GID table.
423    gid_tbl_len: i32,
424}
425
426impl fmt::Display for RdmaDevice {
427    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
428        writeln!(f, "{}", self.name)?;
429        writeln!(f, "\tNumber of ports: {}", self.ports.len())?;
430        writeln!(f, "\tFirmware version: {}", self.fw_ver)?;
431        writeln!(f, "\tHardware version: {}", self.hw_ver)?;
432        writeln!(f, "\tNode GUID: 0x{:016x}", self.node_guid)?;
433        writeln!(f, "\tVendor ID: 0x{:x}", self.vendor_id)?;
434        writeln!(f, "\tVendor part ID: {}", self.vendor_part_id)?;
435        writeln!(f, "\tMax QPs: {}", self.max_qp)?;
436        writeln!(f, "\tMax CQs: {}", self.max_cq)?;
437        writeln!(f, "\tMax MRs: {}", self.max_mr)?;
438        writeln!(f, "\tMax PDs: {}", self.max_pd)?;
439        writeln!(f, "\tMax QP WRs: {}", self.max_qp_wr)?;
440        writeln!(f, "\tMax SGE: {}", self.max_sge)?;
441
442        for port in &self.ports {
443            write!(f, "{}", port)?;
444        }
445
446        Ok(())
447    }
448}
449
450impl fmt::Display for RdmaPort {
451    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452        writeln!(f, "\tPort {}:", self.port_num)?;
453        writeln!(f, "\t\tState: {}", self.state)?;
454        writeln!(f, "\t\tPhysical state: {}", self.physical_state)?;
455        writeln!(f, "\t\tBase lid: {}", self.base_lid)?;
456        writeln!(f, "\t\tLMC: {}", self.lmc)?;
457        writeln!(f, "\t\tSM lid: {}", self.sm_lid)?;
458        writeln!(f, "\t\tCapability mask: 0x{:08x}", self.capability_mask)?;
459        writeln!(f, "\t\tLink layer: {}", self.link_layer)?;
460        writeln!(f, "\t\tGID: {}", self.gid)?;
461        writeln!(f, "\t\tGID table length: {}", self.gid_tbl_len)?;
462        Ok(())
463    }
464}
465
466/// Converts the given port state to a human-readable string.
467///
468/// # Arguments
469///
470/// * `state` - The port state as defined by `ffi::ibv_port_state::Type`.
471///
472/// # Returns
473///
474/// A string representation of the port state.
475pub fn get_port_state_str(state: rdmaxcel_sys::ibv_port_state::Type) -> String {
476    // SAFETY: We are calling a C function that returns a C string.
477    unsafe {
478        let c_str = rdmaxcel_sys::ibv_port_state_str(state);
479        if c_str.is_null() {
480            return "Unknown".to_string();
481        }
482        CStr::from_ptr(c_str).to_string_lossy().into_owned()
483    }
484}
485
486/// Converts the given physical state to a human-readable string.
487///
488/// # Arguments
489///
490/// * `phys_state` - The physical state as a `u8`.
491///
492/// # Returns
493///
494/// A string representation of the physical state.
495pub fn get_port_phy_state_str(phys_state: u8) -> String {
496    match phys_state {
497        1 => "Sleep".to_string(),
498        2 => "Polling".to_string(),
499        3 => "Disabled".to_string(),
500        4 => "PortConfigurationTraining".to_string(),
501        5 => "LinkUp".to_string(),
502        6 => "LinkErrorRecovery".to_string(),
503        7 => "PhyTest".to_string(),
504        _ => "No state change".to_string(),
505    }
506}
507
508/// Converts the given link layer type to a human-readable string.
509///
510/// # Arguments
511///
512/// * `link_layer` - The link layer type as a `u8`.
513///
514/// # Returns
515///
516/// A string representation of the link layer type.
517pub fn get_link_layer_str(link_layer: u8) -> String {
518    match link_layer {
519        1 => "InfiniBand".to_string(),
520        2 => "Ethernet".to_string(),
521        _ => "Unknown".to_string(),
522    }
523}
524
525/// Formats a GID (Global Identifier) into a human-readable string.
526///
527/// # Arguments
528///
529/// * `gid` - A reference to a 16-byte array representing the GID.
530///
531/// # Returns
532///
533/// A formatted string representation of the GID.
534pub fn format_gid(gid: &[u8; 16]) -> String {
535    format!(
536        "{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}",
537        gid[0],
538        gid[1],
539        gid[2],
540        gid[3],
541        gid[4],
542        gid[5],
543        gid[6],
544        gid[7],
545        gid[8],
546        gid[9],
547        gid[10],
548        gid[11],
549        gid[12],
550        gid[13],
551        gid[14],
552        gid[15]
553    )
554}
555
556/// Retrieves information about all available RDMA devices in the system.
557///
558/// This function queries the system for all available RDMA devices and returns
559/// detailed information about each device, including its capabilities, ports,
560/// and attributes.
561///
562/// # Returns
563///
564/// A vector of `RdmaDevice` structures, each representing an RDMA device in the system.
565/// Returns an empty vector if no devices are found or if there was an error querying
566/// the devices.
567pub fn get_all_devices() -> Vec<RdmaDevice> {
568    let mut devices = Vec::new();
569
570    // SAFETY: We are calling several C functions from libibverbs.
571    unsafe {
572        let mut num_devices = 0;
573        let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices);
574        if device_list.is_null() || num_devices == 0 {
575            return devices;
576        }
577
578        for i in 0..num_devices {
579            let device = *device_list.add(i as usize);
580            if device.is_null() {
581                continue;
582            }
583
584            let context = rdmaxcel_sys::ibv_open_device(device);
585            if context.is_null() {
586                continue;
587            }
588
589            let device_name = CStr::from_ptr(rdmaxcel_sys::ibv_get_device_name(device))
590                .to_string_lossy()
591                .into_owned();
592
593            let mut device_attr = rdmaxcel_sys::ibv_device_attr::default();
594            if rdmaxcel_sys::ibv_query_device(context, &mut device_attr) != 0 {
595                rdmaxcel_sys::ibv_close_device(context);
596                continue;
597            }
598
599            let fw_ver = CStr::from_ptr(device_attr.fw_ver.as_ptr())
600                .to_string_lossy()
601                .into_owned();
602
603            let mut rdma_device = RdmaDevice {
604                name: device_name,
605                vendor_id: device_attr.vendor_id,
606                vendor_part_id: device_attr.vendor_part_id,
607                hw_ver: device_attr.hw_ver,
608                fw_ver,
609                node_guid: device_attr.node_guid,
610                ports: Vec::new(),
611                max_qp: device_attr.max_qp,
612                max_cq: device_attr.max_cq,
613                max_mr: device_attr.max_mr,
614                max_pd: device_attr.max_pd,
615                max_qp_wr: device_attr.max_qp_wr,
616                max_sge: device_attr.max_sge,
617            };
618
619            for port_num in 1..=device_attr.phys_port_cnt {
620                let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
621                if rdmaxcel_sys::ibv_query_port(
622                    context,
623                    port_num,
624                    &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
625                ) != 0
626                {
627                    continue;
628                }
629                let state = get_port_state_str(port_attr.state);
630                let physical_state = get_port_phy_state_str(port_attr.phys_state);
631
632                let link_layer = get_link_layer_str(port_attr.link_layer);
633
634                let mut gid = rdmaxcel_sys::ibv_gid::default();
635                let gid_str = if rdmaxcel_sys::ibv_query_gid(context, port_num, 0, &mut gid) == 0 {
636                    format_gid(&gid.raw)
637                } else {
638                    "N/A".to_string()
639                };
640
641                let rdma_port = RdmaPort {
642                    port_num,
643                    state,
644                    physical_state,
645                    base_lid: port_attr.lid,
646                    lmc: port_attr.lmc,
647                    sm_lid: port_attr.sm_lid,
648                    capability_mask: port_attr.port_cap_flags,
649                    link_layer,
650                    gid: gid_str,
651                    gid_tbl_len: port_attr.gid_tbl_len,
652                };
653
654                rdma_device.ports.push(rdma_port);
655            }
656
657            devices.push(rdma_device);
658            rdmaxcel_sys::ibv_close_device(context);
659        }
660
661        rdmaxcel_sys::ibv_free_device_list(device_list);
662    }
663
664    devices
665}
666
667/// Cached result of mlx5dv support check.
668static MLX5DV_SUPPORTED_CACHE: OnceLock<bool> = OnceLock::new();
669
670/// Checks if mlx5dv (Mellanox device-specific verbs extension) is supported.
671///
672/// This function attempts to open the first available RDMA device and check if
673/// mlx5dv extensions can be initialized. The mlx5dv extensions are required for
674/// advanced features like GPU Direct RDMA and direct queue pair manipulation.
675///
676/// The result is cached after the first call, making subsequent calls essentially free.
677///
678/// # Returns
679///
680/// `true` if mlx5dv extensions are supported, `false` otherwise.
681pub fn mlx5dv_supported() -> bool {
682    *MLX5DV_SUPPORTED_CACHE.get_or_init(mlx5dv_supported_impl)
683}
684
685fn mlx5dv_supported_impl() -> bool {
686    // SAFETY: We are calling C functions from libibverbs and libmlx5.
687    unsafe {
688        let mut mlx5dv_supported = false;
689        let mut num_devices = 0;
690        let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices);
691        if !device_list.is_null() && num_devices > 0 {
692            let device = *device_list;
693            if !device.is_null() {
694                mlx5dv_supported = rdmaxcel_sys::mlx5dv_is_supported(device);
695            }
696            rdmaxcel_sys::ibv_free_device_list(device_list);
697        }
698        mlx5dv_supported
699    }
700}
701
702/// Cached result of ibverbs support check.
703static IBVERBS_SUPPORTED_CACHE: OnceLock<bool> = OnceLock::new();
704
705/// Checks if ibverbs devices can be retrieved successfully.
706///
707/// This function attempts to retrieve the list of RDMA devices using the
708/// `ibv_get_device_list` function from the ibverbs library. It returns `true`
709/// if devices are found, and `false` otherwise.
710///
711/// The result is cached after the first call, making subsequent calls essentially free.
712///
713/// # Returns
714///
715/// `true` if devices are successfully retrieved, `false` otherwise.
716pub fn ibverbs_supported() -> bool {
717    *IBVERBS_SUPPORTED_CACHE.get_or_init(ibverbs_supported_impl)
718}
719
720fn ibverbs_supported_impl() -> bool {
721    // SAFETY: We are calling a C function from libibverbs.
722    unsafe {
723        let mut num_devices = 0;
724        let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices);
725        if !device_list.is_null() {
726            rdmaxcel_sys::ibv_free_device_list(device_list);
727        }
728        num_devices > 0
729    }
730}
731
732/// Checks if RDMA is fully supported on this system.
733///
734/// This is the canonical function to check if RDMA can be used.
735pub fn rdma_supported() -> bool {
736    ibverbs_supported()
737}
738
739/// Represents a view of a memory region that can be registered with an RDMA device.
740///
741/// This is a 'view' of a registered Memory Region, allowing multiple views into a single
742/// large MR registration. This is commonly used with PyTorch's caching allocator, which
743/// reserves large memory blocks and provides different data pointers into that space.
744///
745/// # Example
746/// PyTorch Caching Allocator creates a 16GB segment at virtual address `0x01000000`.
747/// The underlying Memory Region registers 16GB but at RDMA address `0x0`.
748/// To access virtual address `0x01100000`, we return a view at RDMA address `0x100000`.
749///
750/// # Safety
751/// The caller must ensure the memory remains valid and is not freed, moved, or
752/// overwritten while RDMA operations are in progress.
753
754#[derive(
755    Debug,
756    PartialEq,
757    Eq,
758    std::hash::Hash,
759    Serialize,
760    Deserialize,
761    Clone,
762    Copy
763)]
764pub struct RdmaMemoryRegionView {
765    // id should be unique with a given rdmam manager
766    pub id: usize,
767    /// Virtual address in the process address space.
768    /// This is the pointer/address as seen by the local process.
769    pub virtual_addr: usize,
770    /// Memory address assigned after Memory Region (MR) registration.
771    /// This is the address may be offset a base MR addr.
772    pub rdma_addr: usize,
773    pub size: usize,
774    pub lkey: u32,
775    pub rkey: u32,
776}
777
778// SAFETY: RdmaMemoryRegionView can be safely sent between threads because it only
779// contains address and size information without any thread-local state. However,
780// this DOES NOT provide any protection against data races in the underlying memory.
781// If one thread initiates an RDMA operation while another thread modifies the same
782// memory region, undefined behavior will occur. The caller is responsible for proper
783// synchronization of access to the underlying memory.
784unsafe impl Send for RdmaMemoryRegionView {}
785
786// SAFETY: RdmaMemoryRegionView is safe for concurrent access by multiple threads
787// as it only provides a view into memory without modifying its own state. However,
788// it provides NO PROTECTION against concurrent access to the underlying memory region.
789// The caller must ensure proper synchronization when:
790// 1. Initiating RDMA operations while local code reads/writes the same memory
791// 2. Performing multiple overlapping RDMA operations on the same memory region
792// 3. Freeing or reallocating memory that has in-flight RDMA operations
793unsafe impl Sync for RdmaMemoryRegionView {}
794
795impl RdmaMemoryRegionView {
796    /// Creates a new `RdmaMemoryRegionView` with the given address and size.
797    pub fn new(
798        id: usize,
799        virtual_addr: usize,
800        rdma_addr: usize,
801        size: usize,
802        lkey: u32,
803        rkey: u32,
804    ) -> Self {
805        Self {
806            id,
807            virtual_addr,
808            rdma_addr,
809            size,
810            lkey,
811            rkey,
812        }
813    }
814}
815
816/// Enum representing the common RDMA operations.
817///
818/// This provides a more ergonomic interface to the underlying ibv_wr_opcode types.
819/// RDMA operations allow for direct memory access between two machines without
820/// involving the CPU of the target machine.
821///
822/// # Variants
823///
824/// * `Write` - Represents an RDMA write operation where data is written from the local
825///   memory to a remote memory region.
826/// * `Read` - Represents an RDMA read operation where data is read from a remote memory
827///   region into the local memory.
828#[derive(Debug, Clone, Copy, PartialEq, Eq)]
829pub enum RdmaOperation {
830    /// RDMA write operations
831    Write,
832    WriteWithImm,
833    /// RDMA read operation
834    Read,
835    /// RDMA recv operation
836    Recv,
837}
838
839impl From<RdmaOperation> for rdmaxcel_sys::ibv_wr_opcode::Type {
840    fn from(op: RdmaOperation) -> Self {
841        match op {
842            RdmaOperation::Write => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE,
843            RdmaOperation::WriteWithImm => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE_WITH_IMM,
844            RdmaOperation::Read => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_READ,
845            RdmaOperation::Recv => panic!("Invalid wr opcode"),
846        }
847    }
848}
849
850impl From<rdmaxcel_sys::ibv_wc_opcode::Type> for RdmaOperation {
851    fn from(op: rdmaxcel_sys::ibv_wc_opcode::Type) -> Self {
852        match op {
853            rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE => RdmaOperation::Write,
854            rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_READ => RdmaOperation::Read,
855            _ => panic!("Unsupported operation type"),
856        }
857    }
858}
859
860/// Contains information needed to establish an RDMA queue pair with a remote endpoint.
861///
862/// `RdmaQpInfo` encapsulates all the necessary information to establish a queue pair
863/// with a remote RDMA device. This includes queue pair number, LID (Local Identifier),
864/// GID (Global Identifier), remote memory address, remote key, and packet sequence number.
865#[derive(Default, Named, Clone, serde::Serialize, serde::Deserialize)]
866pub struct RdmaQpInfo {
867    /// `qp_num` - Queue Pair Number, uniquely identifies a queue pair on the remote device
868    pub qp_num: u32,
869    /// `lid` - Local Identifier, used for addressing in InfiniBand subnet
870    pub lid: u16,
871    /// `gid` - Global Identifier, used for routing across subnets (similar to IPv6 address)
872    pub gid: Option<Gid>,
873    /// `psn` - Packet Sequence Number, used for ordering packets
874    pub psn: u32,
875}
876wirevalue::register_type!(RdmaQpInfo);
877
878impl std::fmt::Debug for RdmaQpInfo {
879    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
880        write!(
881            f,
882            "RdmaQpInfo {{ qp_num: {}, lid: {}, gid: {:?}, psn: 0x{:x} }}",
883            self.qp_num, self.lid, self.gid, self.psn
884        )
885    }
886}
887
888/// Wrapper around ibv_wc (ibverbs work completion).
889///
890/// This exposes only the public fields of rdmaxcel_sys::ibv_wc, allowing us to more easily
891/// interact with it from Rust. Work completions are used to track the status of
892/// RDMA operations and are generated when an operation completes.
893#[derive(Debug, Named, Clone, serde::Serialize, serde::Deserialize)]
894pub struct IbvWc {
895    /// `wr_id` - Work Request ID, used to identify the completed operation
896    wr_id: u64,
897    /// `len` - Length of the data transferred
898    len: usize,
899    /// `valid` - Whether the work completion is valid
900    valid: bool,
901    /// `error` - Error information if the operation failed
902    error: Option<(rdmaxcel_sys::ibv_wc_status::Type, u32)>,
903    /// `opcode` - Type of operation that completed (read, write, etc.)
904    opcode: rdmaxcel_sys::ibv_wc_opcode::Type,
905    /// `bytes` - Immediate data (if any)
906    bytes: Option<u32>,
907    /// `qp_num` - Queue Pair Number
908    qp_num: u32,
909    /// `src_qp` - Source Queue Pair Number
910    src_qp: u32,
911    /// `pkey_index` - Partition Key Index
912    pkey_index: u16,
913    /// `slid` - Source LID
914    slid: u16,
915    /// `sl` - Service Level
916    sl: u8,
917    /// `dlid_path_bits` - Destination LID Path Bits
918    dlid_path_bits: u8,
919}
920wirevalue::register_type!(IbvWc);
921
922impl From<rdmaxcel_sys::ibv_wc> for IbvWc {
923    fn from(wc: rdmaxcel_sys::ibv_wc) -> Self {
924        IbvWc {
925            wr_id: wc.wr_id(),
926            len: wc.len(),
927            valid: wc.is_valid(),
928            error: wc.error(),
929            opcode: wc.opcode(),
930            bytes: wc.imm_data(),
931            qp_num: wc.qp_num,
932            src_qp: wc.src_qp,
933            pkey_index: wc.pkey_index,
934            slid: wc.slid,
935            sl: wc.sl,
936            dlid_path_bits: wc.dlid_path_bits,
937        }
938    }
939}
940
941impl IbvWc {
942    /// Returns the Work Request ID associated with this work completion.
943    ///
944    /// The Work Request ID is used to identify the specific operation that completed.
945    /// It is set by the application when posting the work request and is returned
946    /// unchanged in the work completion.
947    pub fn wr_id(&self) -> u64 {
948        self.wr_id
949    }
950
951    /// Returns whether this work completion is valid.
952    ///
953    /// A valid work completion indicates that the operation completed successfully.
954    /// If false, the `error` field may contain additional information about the failure.
955    pub fn is_valid(&self) -> bool {
956        self.valid
957    }
958}
959
960#[cfg(test)]
961mod tests {
962    use super::*;
963
964    #[test]
965    fn test_get_all_devices() {
966        // Skip test if RDMA devices are not available
967        let devices = get_all_devices();
968        if devices.is_empty() {
969            println!("Skipping test: RDMA devices not available");
970            return;
971        }
972        // Basic validation of first device
973        let device = &devices[0];
974        assert!(!device.name().is_empty(), "device name should not be empty");
975        assert!(
976            !device.ports().is_empty(),
977            "device should have at least one port"
978        );
979    }
980
981    #[test]
982    fn test_first_available() {
983        // Skip test if RDMA is not available
984        let devices = get_all_devices();
985        if devices.is_empty() {
986            println!("Skipping test: RDMA devices not available");
987            return;
988        }
989        // Basic validation of first device
990        let device = &devices[0];
991
992        let dev = device;
993        // Verify getters return expected values
994        assert_eq!(dev.vendor_id(), dev.vendor_id);
995        assert_eq!(dev.vendor_part_id(), dev.vendor_part_id);
996        assert_eq!(dev.hw_ver(), dev.hw_ver);
997        assert_eq!(dev.fw_ver(), &dev.fw_ver);
998        assert_eq!(dev.node_guid(), dev.node_guid);
999        assert_eq!(dev.max_qp(), dev.max_qp);
1000        assert_eq!(dev.max_cq(), dev.max_cq);
1001        assert_eq!(dev.max_mr(), dev.max_mr);
1002        assert_eq!(dev.max_pd(), dev.max_pd);
1003        assert_eq!(dev.max_qp_wr(), dev.max_qp_wr);
1004        assert_eq!(dev.max_sge(), dev.max_sge);
1005    }
1006
1007    #[test]
1008    fn test_device_display() {
1009        if let Some(device) = RdmaDevice::first_available() {
1010            let display_output = format!("{}", device);
1011            assert!(
1012                display_output.contains(&device.name),
1013                "display should include device name"
1014            );
1015            assert!(
1016                display_output.contains(&device.fw_ver),
1017                "display should include firmware version"
1018            );
1019        }
1020    }
1021
1022    #[test]
1023    fn test_port_display() {
1024        if let Some(device) = RdmaDevice::first_available() {
1025            if !device.ports().is_empty() {
1026                let port = &device.ports()[0];
1027                let display_output = format!("{}", port);
1028                assert!(
1029                    display_output.contains(&port.state),
1030                    "display should include port state"
1031                );
1032                assert!(
1033                    display_output.contains(&port.link_layer),
1034                    "display should include link layer"
1035                );
1036            }
1037        }
1038    }
1039
1040    #[test]
1041    fn test_rdma_operation_conversion() {
1042        assert_eq!(
1043            rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE,
1044            rdmaxcel_sys::ibv_wr_opcode::Type::from(RdmaOperation::Write)
1045        );
1046        assert_eq!(
1047            rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_READ,
1048            rdmaxcel_sys::ibv_wr_opcode::Type::from(RdmaOperation::Read)
1049        );
1050
1051        assert_eq!(
1052            RdmaOperation::Write,
1053            RdmaOperation::from(rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE)
1054        );
1055        assert_eq!(
1056            RdmaOperation::Read,
1057            RdmaOperation::from(rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_READ)
1058        );
1059    }
1060
1061    #[test]
1062    fn test_rdma_endpoint() {
1063        let endpoint = RdmaQpInfo {
1064            qp_num: 42,
1065            lid: 123,
1066            gid: None,
1067            psn: 0x5678,
1068        };
1069
1070        let debug_str = format!("{:?}", endpoint);
1071        assert!(debug_str.contains("qp_num: 42"));
1072        assert!(debug_str.contains("lid: 123"));
1073        assert!(debug_str.contains("psn: 0x5678"));
1074    }
1075
1076    #[test]
1077    fn test_ibv_wc() {
1078        let mut wc = rdmaxcel_sys::ibv_wc::default();
1079
1080        // SAFETY: modifies private fields through pointer manipulation
1081        unsafe {
1082            // Cast to pointer and modify the fields directly
1083            let wc_ptr = &mut wc as *mut rdmaxcel_sys::ibv_wc as *mut u8;
1084
1085            // Set wr_id (at offset 0, u64)
1086            *(wc_ptr as *mut u64) = 42;
1087
1088            // Set status to SUCCESS (at offset 8, u32)
1089            *(wc_ptr.add(8) as *mut i32) = rdmaxcel_sys::ibv_wc_status::IBV_WC_SUCCESS as i32;
1090        }
1091        let ibv_wc = IbvWc::from(wc);
1092        assert_eq!(ibv_wc.wr_id(), 42);
1093        assert!(ibv_wc.is_valid());
1094    }
1095
1096    #[test]
1097    fn test_format_gid() {
1098        let gid = [
1099            0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66,
1100            0x77, 0x88,
1101        ];
1102
1103        let formatted = format_gid(&gid);
1104        assert_eq!(formatted, "1234:5678:9abc:def0:1122:3344:5566:7788");
1105    }
1106
1107    #[test]
1108    fn test_mlx5dv_supported_basic() {
1109        // The test just verifies the function doesn't panic
1110        let mlx5dv_support = mlx5dv_supported();
1111        println!("mlx5dv_supported: {}", mlx5dv_support);
1112    }
1113}