monarch_rdma/
ibverbs_primitives.rs

1/*
2 * Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9/*
10 * Sections of code adapted from
11 * Copyright (c) 2016 Jon Gjengset under MIT License (MIT)
12*/
13
14//! This file contains primitive data structures for interacting with ibverbs.
15//!
16//! Primitives:
17//! - `IbverbsConfig`: Represents ibverbs specific configurations, holding parameters required to establish and
18//!   manage an RDMA connection, including settings for the RDMA device, queue pair attributes, and other
19//!   connection-specific parameters.
20//! - `RdmaDevice`: Represents an RDMA device, i.e. 'mlx5_0'. Contains information about the device, such as:
21//!   its name, vendor ID, vendor part ID, hardware version, firmware version, node GUID, and capabilities.
22//! - `RdmaPort`: Represents information about the port of an RDMA device, including state, physical state,
23//!   LID (Local Identifier), and GID (Global Identifier) information.
24//! - `RdmaMemoryRegionView`: Represents a memory region that can be registered with an RDMA device for direct
25//!   memory access operations.
26//! - `RdmaOperation`: Represents the type of RDMA operation to perform (Read or Write).
27//! - `RdmaQpInfo`: Contains connection information needed to establish an RDMA connection with a remote endpoint.
28//! - `IbvWc`: Wrapper around ibverbs work completion structure, used to track the status of RDMA operations.
29use std::ffi::CStr;
30use std::fmt;
31
32use hyperactor::Named;
33use serde::Deserialize;
34use serde::Serialize;
35
36#[derive(
37    Default,
38    Copy,
39    Clone,
40    Debug,
41    Eq,
42    PartialEq,
43    Hash,
44    serde::Serialize,
45    serde::Deserialize
46)]
47#[repr(transparent)]
48pub struct Gid {
49    raw: [u8; 16],
50}
51
52impl Gid {
53    #[allow(dead_code)]
54    fn subnet_prefix(&self) -> u64 {
55        u64::from_be_bytes(self.raw[..8].try_into().unwrap())
56    }
57
58    #[allow(dead_code)]
59    fn interface_id(&self) -> u64 {
60        u64::from_be_bytes(self.raw[8..].try_into().unwrap())
61    }
62}
63impl From<rdmaxcel_sys::ibv_gid> for Gid {
64    fn from(gid: rdmaxcel_sys::ibv_gid) -> Self {
65        Self {
66            raw: unsafe { gid.raw },
67        }
68    }
69}
70
71impl From<Gid> for rdmaxcel_sys::ibv_gid {
72    fn from(mut gid: Gid) -> Self {
73        *gid.as_mut()
74    }
75}
76
77impl AsRef<rdmaxcel_sys::ibv_gid> for Gid {
78    fn as_ref(&self) -> &rdmaxcel_sys::ibv_gid {
79        unsafe { &*self.raw.as_ptr().cast::<rdmaxcel_sys::ibv_gid>() }
80    }
81}
82
83impl AsMut<rdmaxcel_sys::ibv_gid> for Gid {
84    fn as_mut(&mut self) -> &mut rdmaxcel_sys::ibv_gid {
85        unsafe { &mut *self.raw.as_mut_ptr().cast::<rdmaxcel_sys::ibv_gid>() }
86    }
87}
88
89/// Represents ibverbs specific configurations.
90///
91/// This struct holds various parameters required to establish and manage an RDMA connection.
92/// It includes settings for the RDMA device, queue pair attributes, and other connection-specific
93/// parameters.
94#[derive(Debug, Named, Clone, Serialize, Deserialize)]
95pub struct IbverbsConfig {
96    /// `device` - The RDMA device to use for the connection.
97    pub device: RdmaDevice,
98    /// `cq_entries` - The number of completion queue entries.
99    pub cq_entries: i32,
100    /// `port_num` - The physical port number on the device.
101    pub port_num: u8,
102    /// `gid_index` - The GID index for the RDMA device.
103    pub gid_index: u8,
104    /// `max_send_wr` - The maximum number of outstanding send work requests.
105    pub max_send_wr: u32,
106    /// `max_recv_wr` - The maximum number of outstanding receive work requests.
107    pub max_recv_wr: u32,
108    /// `max_send_sge` - Te maximum number of scatter/gather elements in a send work request.
109    pub max_send_sge: u32,
110    /// `max_recv_sge` - The maximum number of scatter/gather elements in a receive work request.
111    pub max_recv_sge: u32,
112    /// `path_mtu` - The path MTU (Maximum Transmission Unit) for the connection.
113    pub path_mtu: u32,
114    /// `retry_cnt` - The number of retry attempts for a connection request.
115    pub retry_cnt: u8,
116    /// `rnr_retry` - The number of retry attempts for a receiver not ready (RNR) condition.
117    pub rnr_retry: u8,
118    /// `qp_timeout` - The timeout for a queue pair operation.
119    pub qp_timeout: u8,
120    /// `min_rnr_timer` - The minimum RNR timer value.
121    pub min_rnr_timer: u8,
122    /// `max_dest_rd_atomic` - The maximum number of outstanding RDMA read operations at the destination.
123    pub max_dest_rd_atomic: u8,
124    /// `max_rd_atomic` - The maximum number of outstanding RDMA read operations at the initiator.
125    pub max_rd_atomic: u8,
126    /// `pkey_index` - The partition key index.
127    pub pkey_index: u16,
128    /// `psn` - The packet sequence number.
129    pub psn: u32,
130    /// `use_gpu_direct` - Whether to enable GPU Direct RDMA support on init.
131    pub use_gpu_direct: bool,
132    /// `hw_init_delay_ms` - The delay in milliseconds before initializing the hardware.
133    /// This is used to allow the hardware to settle before starting the first transmission.
134    pub hw_init_delay_ms: u64,
135}
136
137/// Default RDMA parameters below are based on common values from rdma-core examples
138/// (e.g. rc_pingpong). For high-performance or production use, consider tuning
139/// based on ibv_query_device() results and workload characteristics.
140impl Default for IbverbsConfig {
141    fn default() -> Self {
142        Self {
143            device: RdmaDevice::default(),
144            cq_entries: 256,
145            port_num: 1,
146            gid_index: 3,
147            max_send_wr: 128,
148            max_recv_wr: 128,
149            max_send_sge: 32,
150            max_recv_sge: 32,
151            path_mtu: rdmaxcel_sys::IBV_MTU_4096,
152            retry_cnt: 7,
153            rnr_retry: 7,
154            qp_timeout: 14, // 4.096 μs * 2^14 = ~67 ms
155            min_rnr_timer: 12,
156            max_dest_rd_atomic: 16,
157            max_rd_atomic: 16,
158            pkey_index: 0,
159            psn: rand::random::<u32>() & 0xffffff,
160            use_gpu_direct: false, // nv_peermem enabled for cuda
161            hw_init_delay_ms: 2,
162        }
163    }
164}
165
166impl std::fmt::Display for IbverbsConfig {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        write!(
169            f,
170            "IbverbsConfig {{ device: {}, port_num: {}, gid_index: {}, max_send_wr: {}, max_recv_wr: {}, max_send_sge: {}, max_recv_sge: {}, path_mtu: {:?}, retry_cnt: {}, rnr_retry: {}, qp_timeout: {}, min_rnr_timer: {}, max_dest_rd_atomic: {}, max_rd_atomic: {}, pkey_index: {}, psn: 0x{:x} }}",
171            self.device.name(),
172            self.port_num,
173            self.gid_index,
174            self.max_send_wr,
175            self.max_recv_wr,
176            self.max_send_sge,
177            self.max_recv_sge,
178            self.path_mtu,
179            self.retry_cnt,
180            self.rnr_retry,
181            self.qp_timeout,
182            self.min_rnr_timer,
183            self.max_dest_rd_atomic,
184            self.max_rd_atomic,
185            self.pkey_index,
186            self.psn,
187        )
188    }
189}
190
191/// Represents an RDMA device in the system.
192///
193/// This struct encapsulates information about an RDMA device, including its hardware
194/// characteristics, capabilities, and port information. It provides access to device
195/// attributes such as vendor information, firmware version, and supported features.
196///
197/// # Examples
198///
199/// ```
200/// use monarch_rdma::get_all_devices;
201///
202/// let devices = get_all_devices();
203/// if let Some(device) = devices.first() {
204///     // Access device name and firmware version
205///     let device_name = device.name();
206///     let firmware_version = device.fw_ver();
207/// }
208/// ```
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct RdmaDevice {
211    /// `name` - The name of the RDMA device (e.g., "mlx5_0").
212    pub name: String,
213    /// `vendor_id` - The vendor ID of the device.
214    vendor_id: u32,
215    /// `vendor_part_id` - The vendor part ID of the device.
216    vendor_part_id: u32,
217    /// `hw_ver` - Hardware version of the device.
218    hw_ver: u32,
219    /// `fw_ver` - Firmware version of the device.
220    fw_ver: String,
221    /// `node_guid` - Node GUID (Globally Unique Identifier) of the device.
222    node_guid: u64,
223    /// `ports` - Vector of ports available on this device.
224    ports: Vec<RdmaPort>,
225    /// `max_qp` - Maximum number of queue pairs supported.
226    max_qp: i32,
227    /// `max_cq` - Maximum number of completion queues supported.
228    max_cq: i32,
229    /// `max_mr` - Maximum number of memory regions supported.
230    max_mr: i32,
231    /// `max_pd` - Maximum number of protection domains supported.
232    max_pd: i32,
233    /// `max_qp_wr` - Maximum number of work requests per queue pair.
234    max_qp_wr: i32,
235    /// `max_sge` - Maximum number of scatter/gather elements per work request.
236    max_sge: i32,
237}
238
239impl RdmaDevice {
240    /// Returns the name of the RDMA device.
241    pub fn name(&self) -> &String {
242        &self.name
243    }
244
245    /// Returns the first available RDMA device, if any.
246    pub fn first_available() -> Option<RdmaDevice> {
247        let devices = get_all_devices();
248        if devices.is_empty() {
249            None
250        } else {
251            Some(devices.into_iter().next().unwrap())
252        }
253    }
254
255    /// Returns the vendor ID of the RDMA device.
256    pub fn vendor_id(&self) -> u32 {
257        self.vendor_id
258    }
259
260    /// Returns the vendor part ID of the RDMA device.
261    pub fn vendor_part_id(&self) -> u32 {
262        self.vendor_part_id
263    }
264
265    /// Returns the hardware version of the RDMA device.
266    pub fn hw_ver(&self) -> u32 {
267        self.hw_ver
268    }
269
270    /// Returns the firmware version of the RDMA device.
271    pub fn fw_ver(&self) -> &String {
272        &self.fw_ver
273    }
274
275    /// Returns the node GUID of the RDMA device.
276    pub fn node_guid(&self) -> u64 {
277        self.node_guid
278    }
279
280    /// Returns a reference to the vector of ports available on the RDMA device.
281    pub fn ports(&self) -> &Vec<RdmaPort> {
282        &self.ports
283    }
284
285    /// Returns the maximum number of queue pairs supported by the RDMA device.
286    pub fn max_qp(&self) -> i32 {
287        self.max_qp
288    }
289
290    /// Returns the maximum number of completion queues supported by the RDMA device.
291    pub fn max_cq(&self) -> i32 {
292        self.max_cq
293    }
294
295    /// Returns the maximum number of memory regions supported by the RDMA device.
296    pub fn max_mr(&self) -> i32 {
297        self.max_mr
298    }
299
300    /// Returns the maximum number of protection domains supported by the RDMA device.
301    pub fn max_pd(&self) -> i32 {
302        self.max_pd
303    }
304
305    /// Returns the maximum number of work requests per queue pair supported by the RDMA device.
306    pub fn max_qp_wr(&self) -> i32 {
307        self.max_qp_wr
308    }
309
310    /// Returns the maximum number of scatter/gather elements per work request supported by the RDMA device.
311    pub fn max_sge(&self) -> i32 {
312        self.max_sge
313    }
314}
315
316impl Default for RdmaDevice {
317    fn default() -> Self {
318        get_all_devices()
319            .into_iter()
320            .next()
321            .unwrap_or_else(|| panic!("No RDMA devices found"))
322    }
323}
324
325#[derive(Debug, Clone, Serialize, Deserialize)]
326pub struct RdmaPort {
327    /// `port_num` - The physical port number on the device.
328    port_num: u8,
329    /// `state` - The current state of the port.
330    state: String,
331    /// `physical_state` - The physical state of the port.
332    physical_state: String,
333    /// `base_lid` - Base Local Identifier for the port.
334    base_lid: u16,
335    /// `lmc` - LID Mask Control.
336    lmc: u8,
337    /// `sm_lid` - Subnet Manager Local Identifier.
338    sm_lid: u16,
339    /// `capability_mask` - Capability mask of the port.
340    capability_mask: u32,
341    /// `link_layer` - The link layer type (e.g., InfiniBand, Ethernet).
342    link_layer: String,
343    /// `gid` - Global Identifier for the port.
344    gid: String,
345    /// `gid_tbl_len` - Length of the GID table.
346    gid_tbl_len: i32,
347}
348
349impl fmt::Display for RdmaDevice {
350    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351        writeln!(f, "{}", self.name)?;
352        writeln!(f, "\tNumber of ports: {}", self.ports.len())?;
353        writeln!(f, "\tFirmware version: {}", self.fw_ver)?;
354        writeln!(f, "\tHardware version: {}", self.hw_ver)?;
355        writeln!(f, "\tNode GUID: 0x{:016x}", self.node_guid)?;
356        writeln!(f, "\tVendor ID: 0x{:x}", self.vendor_id)?;
357        writeln!(f, "\tVendor part ID: {}", self.vendor_part_id)?;
358        writeln!(f, "\tMax QPs: {}", self.max_qp)?;
359        writeln!(f, "\tMax CQs: {}", self.max_cq)?;
360        writeln!(f, "\tMax MRs: {}", self.max_mr)?;
361        writeln!(f, "\tMax PDs: {}", self.max_pd)?;
362        writeln!(f, "\tMax QP WRs: {}", self.max_qp_wr)?;
363        writeln!(f, "\tMax SGE: {}", self.max_sge)?;
364
365        for port in &self.ports {
366            write!(f, "{}", port)?;
367        }
368
369        Ok(())
370    }
371}
372
373impl fmt::Display for RdmaPort {
374    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
375        writeln!(f, "\tPort {}:", self.port_num)?;
376        writeln!(f, "\t\tState: {}", self.state)?;
377        writeln!(f, "\t\tPhysical state: {}", self.physical_state)?;
378        writeln!(f, "\t\tBase lid: {}", self.base_lid)?;
379        writeln!(f, "\t\tLMC: {}", self.lmc)?;
380        writeln!(f, "\t\tSM lid: {}", self.sm_lid)?;
381        writeln!(f, "\t\tCapability mask: 0x{:08x}", self.capability_mask)?;
382        writeln!(f, "\t\tLink layer: {}", self.link_layer)?;
383        writeln!(f, "\t\tGID: {}", self.gid)?;
384        writeln!(f, "\t\tGID table length: {}", self.gid_tbl_len)?;
385        Ok(())
386    }
387}
388
389/// Converts the given port state to a human-readable string.
390///
391/// # Arguments
392///
393/// * `state` - The port state as defined by `ffi::ibv_port_state::Type`.
394///
395/// # Returns
396///
397/// A string representation of the port state.
398pub fn get_port_state_str(state: rdmaxcel_sys::ibv_port_state::Type) -> String {
399    // SAFETY: We are calling a C function that returns a C string.
400    unsafe {
401        let c_str = rdmaxcel_sys::ibv_port_state_str(state);
402        if c_str.is_null() {
403            return "Unknown".to_string();
404        }
405        CStr::from_ptr(c_str).to_string_lossy().into_owned()
406    }
407}
408
409/// Converts the given physical state to a human-readable string.
410///
411/// # Arguments
412///
413/// * `phys_state` - The physical state as a `u8`.
414///
415/// # Returns
416///
417/// A string representation of the physical state.
418pub fn get_port_phy_state_str(phys_state: u8) -> String {
419    match phys_state {
420        1 => "Sleep".to_string(),
421        2 => "Polling".to_string(),
422        3 => "Disabled".to_string(),
423        4 => "PortConfigurationTraining".to_string(),
424        5 => "LinkUp".to_string(),
425        6 => "LinkErrorRecovery".to_string(),
426        7 => "PhyTest".to_string(),
427        _ => "No state change".to_string(),
428    }
429}
430
431/// Converts the given link layer type to a human-readable string.
432///
433/// # Arguments
434///
435/// * `link_layer` - The link layer type as a `u8`.
436///
437/// # Returns
438///
439/// A string representation of the link layer type.
440pub fn get_link_layer_str(link_layer: u8) -> String {
441    match link_layer {
442        1 => "InfiniBand".to_string(),
443        2 => "Ethernet".to_string(),
444        _ => "Unknown".to_string(),
445    }
446}
447
448/// Formats a GID (Global Identifier) into a human-readable string.
449///
450/// # Arguments
451///
452/// * `gid` - A reference to a 16-byte array representing the GID.
453///
454/// # Returns
455///
456/// A formatted string representation of the GID.
457pub fn format_gid(gid: &[u8; 16]) -> String {
458    format!(
459        "{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}",
460        gid[0],
461        gid[1],
462        gid[2],
463        gid[3],
464        gid[4],
465        gid[5],
466        gid[6],
467        gid[7],
468        gid[8],
469        gid[9],
470        gid[10],
471        gid[11],
472        gid[12],
473        gid[13],
474        gid[14],
475        gid[15]
476    )
477}
478
479/// Retrieves information about all available RDMA devices in the system.
480///
481/// This function queries the system for all available RDMA devices and returns
482/// detailed information about each device, including its capabilities, ports,
483/// and attributes.
484///
485/// # Returns
486///
487/// A vector of `RdmaDevice` structures, each representing an RDMA device in the system.
488/// Returns an empty vector if no devices are found or if there was an error querying
489/// the devices.
490pub fn get_all_devices() -> Vec<RdmaDevice> {
491    let mut devices = Vec::new();
492
493    // SAFETY: We are calling several C functions from libibverbs.
494    unsafe {
495        let mut num_devices = 0;
496        let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices);
497        if device_list.is_null() || num_devices == 0 {
498            return devices;
499        }
500
501        for i in 0..num_devices {
502            let device = *device_list.add(i as usize);
503            if device.is_null() {
504                continue;
505            }
506
507            let context = rdmaxcel_sys::ibv_open_device(device);
508            if context.is_null() {
509                continue;
510            }
511
512            let device_name = CStr::from_ptr(rdmaxcel_sys::ibv_get_device_name(device))
513                .to_string_lossy()
514                .into_owned();
515
516            let mut device_attr = rdmaxcel_sys::ibv_device_attr::default();
517            if rdmaxcel_sys::ibv_query_device(context, &mut device_attr) != 0 {
518                rdmaxcel_sys::ibv_close_device(context);
519                continue;
520            }
521
522            let fw_ver = CStr::from_ptr(device_attr.fw_ver.as_ptr())
523                .to_string_lossy()
524                .into_owned();
525
526            let mut rdma_device = RdmaDevice {
527                name: device_name,
528                vendor_id: device_attr.vendor_id,
529                vendor_part_id: device_attr.vendor_part_id,
530                hw_ver: device_attr.hw_ver,
531                fw_ver,
532                node_guid: device_attr.node_guid,
533                ports: Vec::new(),
534                max_qp: device_attr.max_qp,
535                max_cq: device_attr.max_cq,
536                max_mr: device_attr.max_mr,
537                max_pd: device_attr.max_pd,
538                max_qp_wr: device_attr.max_qp_wr,
539                max_sge: device_attr.max_sge,
540            };
541
542            for port_num in 1..=device_attr.phys_port_cnt {
543                let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
544                if rdmaxcel_sys::ibv_query_port(
545                    context,
546                    port_num,
547                    &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
548                ) != 0
549                {
550                    continue;
551                }
552                let state = get_port_state_str(port_attr.state);
553                let physical_state = get_port_phy_state_str(port_attr.phys_state);
554
555                let link_layer = get_link_layer_str(port_attr.link_layer);
556
557                let mut gid = rdmaxcel_sys::ibv_gid::default();
558                let gid_str = if rdmaxcel_sys::ibv_query_gid(context, port_num, 0, &mut gid) == 0 {
559                    format_gid(&gid.raw)
560                } else {
561                    "N/A".to_string()
562                };
563
564                let rdma_port = RdmaPort {
565                    port_num,
566                    state,
567                    physical_state,
568                    base_lid: port_attr.lid,
569                    lmc: port_attr.lmc,
570                    sm_lid: port_attr.sm_lid,
571                    capability_mask: port_attr.port_cap_flags,
572                    link_layer,
573                    gid: gid_str,
574                    gid_tbl_len: port_attr.gid_tbl_len,
575                };
576
577                rdma_device.ports.push(rdma_port);
578            }
579
580            devices.push(rdma_device);
581            rdmaxcel_sys::ibv_close_device(context);
582        }
583
584        rdmaxcel_sys::ibv_free_device_list(device_list);
585    }
586
587    devices
588}
589
590/// Checks if ibverbs devices can be retrieved successfully.
591///
592/// This function attempts to retrieve the list of RDMA devices using the
593/// `ibv_get_device_list` function from the ibverbs library. It returns `true`
594/// if devices are found, and `false` otherwise.
595///
596/// # Returns
597///
598/// `true` if devices are successfully retrieved, `false` otherwise.
599pub fn ibverbs_supported() -> bool {
600    // SAFETY: We are calling a C function from libibverbs.
601    unsafe {
602        let mut num_devices = 0;
603        let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices);
604        if !device_list.is_null() {
605            rdmaxcel_sys::ibv_free_device_list(device_list);
606        }
607        num_devices > 0
608    }
609}
610
611/// Represents a view of a memory region that can be registered with an RDMA device.
612///
613/// This is a 'view' of a registered Memory Region, allowing multiple views into a single
614/// large MR registration. This is commonly used with PyTorch's caching allocator, which
615/// reserves large memory blocks and provides different data pointers into that space.
616///
617/// # Example
618/// PyTorch Caching Allocator creates a 16GB segment at virtual address `0x01000000`.
619/// The underlying Memory Region registers 16GB but at RDMA address `0x0`.
620/// To access virtual address `0x01100000`, we return a view at RDMA address `0x100000`.
621///
622/// # Safety
623/// The caller must ensure the memory remains valid and is not freed, moved, or
624/// overwritten while RDMA operations are in progress.
625
626#[derive(
627    Debug,
628    PartialEq,
629    Eq,
630    std::hash::Hash,
631    Serialize,
632    Deserialize,
633    Clone,
634    Copy
635)]
636pub struct RdmaMemoryRegionView {
637    // id should be unique with a given rdmam manager
638    pub id: usize,
639    /// Virtual address in the process address space.
640    /// This is the pointer/address as seen by the local process.
641    pub virtual_addr: usize,
642    /// Memory address assigned after Memory Region (MR) registration.
643    /// This is the address may be offset a base MR addr.
644    pub rdma_addr: usize,
645    pub size: usize,
646    pub lkey: u32,
647    pub rkey: u32,
648}
649
650// SAFETY: RdmaMemoryRegionView can be safely sent between threads because it only
651// contains address and size information without any thread-local state. However,
652// this DOES NOT provide any protection against data races in the underlying memory.
653// If one thread initiates an RDMA operation while another thread modifies the same
654// memory region, undefined behavior will occur. The caller is responsible for proper
655// synchronization of access to the underlying memory.
656unsafe impl Send for RdmaMemoryRegionView {}
657
658// SAFETY: RdmaMemoryRegionView is safe for concurrent access by multiple threads
659// as it only provides a view into memory without modifying its own state. However,
660// it provides NO PROTECTION against concurrent access to the underlying memory region.
661// The caller must ensure proper synchronization when:
662// 1. Initiating RDMA operations while local code reads/writes the same memory
663// 2. Performing multiple overlapping RDMA operations on the same memory region
664// 3. Freeing or reallocating memory that has in-flight RDMA operations
665unsafe impl Sync for RdmaMemoryRegionView {}
666
667impl RdmaMemoryRegionView {
668    /// Creates a new `RdmaMemoryRegionView` with the given address and size.
669    pub fn new(
670        id: usize,
671        virtual_addr: usize,
672        rdma_addr: usize,
673        size: usize,
674        lkey: u32,
675        rkey: u32,
676    ) -> Self {
677        Self {
678            id,
679            virtual_addr,
680            rdma_addr,
681            size,
682            lkey,
683            rkey,
684        }
685    }
686}
687
688/// Enum representing the common RDMA operations.
689///
690/// This provides a more ergonomic interface to the underlying ibv_wr_opcode types.
691/// RDMA operations allow for direct memory access between two machines without
692/// involving the CPU of the target machine.
693///
694/// # Variants
695///
696/// * `Write` - Represents an RDMA write operation where data is written from the local
697///   memory to a remote memory region.
698/// * `Read` - Represents an RDMA read operation where data is read from a remote memory
699///   region into the local memory.
700#[derive(Debug, Clone, Copy, PartialEq, Eq)]
701pub enum RdmaOperation {
702    /// RDMA write operations
703    Write,
704    WriteWithImm,
705    /// RDMA read operation
706    Read,
707    /// RDMA recv operation
708    Recv,
709}
710
711impl From<RdmaOperation> for rdmaxcel_sys::ibv_wr_opcode::Type {
712    fn from(op: RdmaOperation) -> Self {
713        match op {
714            RdmaOperation::Write => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE,
715            RdmaOperation::WriteWithImm => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE_WITH_IMM,
716            RdmaOperation::Read => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_READ,
717            RdmaOperation::Recv => panic!("Invalid wr opcode"),
718        }
719    }
720}
721
722impl From<rdmaxcel_sys::ibv_wc_opcode::Type> for RdmaOperation {
723    fn from(op: rdmaxcel_sys::ibv_wc_opcode::Type) -> Self {
724        match op {
725            rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE => RdmaOperation::Write,
726            rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_READ => RdmaOperation::Read,
727            _ => panic!("Unsupported operation type"),
728        }
729    }
730}
731
732/// Contains information needed to establish an RDMA queue pair with a remote endpoint.
733///
734/// `RdmaQpInfo` encapsulates all the necessary information to establish a queue pair
735/// with a remote RDMA device. This includes queue pair number, LID (Local Identifier),
736/// GID (Global Identifier), remote memory address, remote key, and packet sequence number.
737#[derive(Default, Named, Clone, serde::Serialize, serde::Deserialize)]
738pub struct RdmaQpInfo {
739    /// `qp_num` - Queue Pair Number, uniquely identifies a queue pair on the remote device
740    pub qp_num: u32,
741    /// `lid` - Local Identifier, used for addressing in InfiniBand subnet
742    pub lid: u16,
743    /// `gid` - Global Identifier, used for routing across subnets (similar to IPv6 address)
744    pub gid: Option<Gid>,
745    /// `psn` - Packet Sequence Number, used for ordering packets
746    pub psn: u32,
747}
748
749impl std::fmt::Debug for RdmaQpInfo {
750    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
751        write!(
752            f,
753            "RdmaQpInfo {{ qp_num: {}, lid: {}, gid: {:?}, psn: 0x{:x} }}",
754            self.qp_num, self.lid, self.gid, self.psn
755        )
756    }
757}
758
759/// Wrapper around ibv_wc (ibverbs work completion).
760///
761/// This exposes only the public fields of rdmaxcel_sys::ibv_wc, allowing us to more easily
762/// interact with it from Rust. Work completions are used to track the status of
763/// RDMA operations and are generated when an operation completes.
764#[derive(Debug, Named, Clone, serde::Serialize, serde::Deserialize)]
765pub struct IbvWc {
766    /// `wr_id` - Work Request ID, used to identify the completed operation
767    wr_id: u64,
768    /// `len` - Length of the data transferred
769    len: usize,
770    /// `valid` - Whether the work completion is valid
771    valid: bool,
772    /// `error` - Error information if the operation failed
773    error: Option<(rdmaxcel_sys::ibv_wc_status::Type, u32)>,
774    /// `opcode` - Type of operation that completed (read, write, etc.)
775    opcode: rdmaxcel_sys::ibv_wc_opcode::Type,
776    /// `bytes` - Immediate data (if any)
777    bytes: Option<u32>,
778    /// `qp_num` - Queue Pair Number
779    qp_num: u32,
780    /// `src_qp` - Source Queue Pair Number
781    src_qp: u32,
782    /// `pkey_index` - Partition Key Index
783    pkey_index: u16,
784    /// `slid` - Source LID
785    slid: u16,
786    /// `sl` - Service Level
787    sl: u8,
788    /// `dlid_path_bits` - Destination LID Path Bits
789    dlid_path_bits: u8,
790}
791
792impl From<rdmaxcel_sys::ibv_wc> for IbvWc {
793    fn from(wc: rdmaxcel_sys::ibv_wc) -> Self {
794        IbvWc {
795            wr_id: wc.wr_id(),
796            len: wc.len(),
797            valid: wc.is_valid(),
798            error: wc.error(),
799            opcode: wc.opcode(),
800            bytes: wc.imm_data(),
801            qp_num: wc.qp_num,
802            src_qp: wc.src_qp,
803            pkey_index: wc.pkey_index,
804            slid: wc.slid,
805            sl: wc.sl,
806            dlid_path_bits: wc.dlid_path_bits,
807        }
808    }
809}
810
811impl IbvWc {
812    /// Returns the Work Request ID associated with this work completion.
813    ///
814    /// The Work Request ID is used to identify the specific operation that completed.
815    /// It is set by the application when posting the work request and is returned
816    /// unchanged in the work completion.
817    pub fn wr_id(&self) -> u64 {
818        self.wr_id
819    }
820
821    /// Returns whether this work completion is valid.
822    ///
823    /// A valid work completion indicates that the operation completed successfully.
824    /// If false, the `error` field may contain additional information about the failure.
825    pub fn is_valid(&self) -> bool {
826        self.valid
827    }
828}
829
830#[cfg(test)]
831mod tests {
832    use super::*;
833
834    #[test]
835    fn test_get_all_devices() {
836        // Skip test if RDMA devices are not available
837        let devices = get_all_devices();
838        if devices.is_empty() {
839            println!("Skipping test: RDMA devices not available");
840            return;
841        }
842        // Basic validation of first device
843        let device = &devices[0];
844        assert!(!device.name().is_empty(), "device name should not be empty");
845        assert!(
846            !device.ports().is_empty(),
847            "device should have at least one port"
848        );
849    }
850
851    #[test]
852    fn test_first_available() {
853        // Skip test if RDMA is not available
854        let devices = get_all_devices();
855        if devices.is_empty() {
856            println!("Skipping test: RDMA devices not available");
857            return;
858        }
859        // Basic validation of first device
860        let device = &devices[0];
861
862        let dev = device;
863        // Verify getters return expected values
864        assert_eq!(dev.vendor_id(), dev.vendor_id);
865        assert_eq!(dev.vendor_part_id(), dev.vendor_part_id);
866        assert_eq!(dev.hw_ver(), dev.hw_ver);
867        assert_eq!(dev.fw_ver(), &dev.fw_ver);
868        assert_eq!(dev.node_guid(), dev.node_guid);
869        assert_eq!(dev.max_qp(), dev.max_qp);
870        assert_eq!(dev.max_cq(), dev.max_cq);
871        assert_eq!(dev.max_mr(), dev.max_mr);
872        assert_eq!(dev.max_pd(), dev.max_pd);
873        assert_eq!(dev.max_qp_wr(), dev.max_qp_wr);
874        assert_eq!(dev.max_sge(), dev.max_sge);
875    }
876
877    #[test]
878    fn test_device_display() {
879        if let Some(device) = RdmaDevice::first_available() {
880            let display_output = format!("{}", device);
881            assert!(
882                display_output.contains(&device.name),
883                "display should include device name"
884            );
885            assert!(
886                display_output.contains(&device.fw_ver),
887                "display should include firmware version"
888            );
889        }
890    }
891
892    #[test]
893    fn test_port_display() {
894        if let Some(device) = RdmaDevice::first_available() {
895            if !device.ports().is_empty() {
896                let port = &device.ports()[0];
897                let display_output = format!("{}", port);
898                assert!(
899                    display_output.contains(&port.state),
900                    "display should include port state"
901                );
902                assert!(
903                    display_output.contains(&port.link_layer),
904                    "display should include link layer"
905                );
906            }
907        }
908    }
909
910    #[test]
911    fn test_rdma_operation_conversion() {
912        assert_eq!(
913            rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE,
914            rdmaxcel_sys::ibv_wr_opcode::Type::from(RdmaOperation::Write)
915        );
916        assert_eq!(
917            rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_READ,
918            rdmaxcel_sys::ibv_wr_opcode::Type::from(RdmaOperation::Read)
919        );
920
921        assert_eq!(
922            RdmaOperation::Write,
923            RdmaOperation::from(rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE)
924        );
925        assert_eq!(
926            RdmaOperation::Read,
927            RdmaOperation::from(rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_READ)
928        );
929    }
930
931    #[test]
932    fn test_rdma_endpoint() {
933        let endpoint = RdmaQpInfo {
934            qp_num: 42,
935            lid: 123,
936            gid: None,
937            psn: 0x5678,
938        };
939
940        let debug_str = format!("{:?}", endpoint);
941        assert!(debug_str.contains("qp_num: 42"));
942        assert!(debug_str.contains("lid: 123"));
943        assert!(debug_str.contains("psn: 0x5678"));
944    }
945
946    #[test]
947    fn test_ibv_wc() {
948        let mut wc = rdmaxcel_sys::ibv_wc::default();
949
950        // SAFETY: modifies private fields through pointer manipulation
951        unsafe {
952            // Cast to pointer and modify the fields directly
953            let wc_ptr = &mut wc as *mut rdmaxcel_sys::ibv_wc as *mut u8;
954
955            // Set wr_id (at offset 0, u64)
956            *(wc_ptr as *mut u64) = 42;
957
958            // Set status to SUCCESS (at offset 8, u32)
959            *(wc_ptr.add(8) as *mut i32) = rdmaxcel_sys::ibv_wc_status::IBV_WC_SUCCESS as i32;
960        }
961        let ibv_wc = IbvWc::from(wc);
962        assert_eq!(ibv_wc.wr_id(), 42);
963        assert!(ibv_wc.is_valid());
964    }
965
966    #[test]
967    fn test_format_gid() {
968        let gid = [
969            0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66,
970            0x77, 0x88,
971        ];
972
973        let formatted = format_gid(&gid);
974        assert_eq!(formatted, "1234:5678:9abc:def0:1122:3344:5566:7788");
975    }
976}