monarch_rdma/ibverbs_primitives.rs
1/*
2 * Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9/*
10 * Sections of code adapted from
11 * Copyright (c) 2016 Jon Gjengset under MIT License (MIT)
12*/
13
14//! This file contains primitive data structures for interacting with ibverbs.
15//!
16//! Primitives:
17//! - `IbverbsConfig`: Represents ibverbs specific configurations, holding parameters required to establish and
18//! manage an RDMA connection, including settings for the RDMA device, queue pair attributes, and other
19//! connection-specific parameters.
20//! - `RdmaDevice`: Represents an RDMA device, i.e. 'mlx5_0'. Contains information about the device, such as:
21//! its name, vendor ID, vendor part ID, hardware version, firmware version, node GUID, and capabilities.
22//! - `RdmaPort`: Represents information about the port of an RDMA device, including state, physical state,
23//! LID (Local Identifier), and GID (Global Identifier) information.
24//! - `RdmaMemoryRegionView`: Represents a memory region that can be registered with an RDMA device for direct
25//! memory access operations.
26//! - `RdmaOperation`: Represents the type of RDMA operation to perform (Read or Write).
27//! - `RdmaQpInfo`: Contains connection information needed to establish an RDMA connection with a remote endpoint.
28//! - `IbvWc`: Wrapper around ibverbs work completion structure, used to track the status of RDMA operations.
29use std::ffi::CStr;
30use std::fmt;
31use std::sync::OnceLock;
32
33use hyperactor::Named;
34use serde::Deserialize;
35use serde::Serialize;
36
37#[derive(
38 Default,
39 Copy,
40 Clone,
41 Debug,
42 Eq,
43 PartialEq,
44 Hash,
45 serde::Serialize,
46 serde::Deserialize
47)]
48#[repr(transparent)]
49pub struct Gid {
50 raw: [u8; 16],
51}
52
53impl Gid {
54 #[allow(dead_code)]
55 fn subnet_prefix(&self) -> u64 {
56 u64::from_be_bytes(self.raw[..8].try_into().unwrap())
57 }
58
59 #[allow(dead_code)]
60 fn interface_id(&self) -> u64 {
61 u64::from_be_bytes(self.raw[8..].try_into().unwrap())
62 }
63}
64impl From<rdmaxcel_sys::ibv_gid> for Gid {
65 fn from(gid: rdmaxcel_sys::ibv_gid) -> Self {
66 Self {
67 raw: unsafe { gid.raw },
68 }
69 }
70}
71
72impl From<Gid> for rdmaxcel_sys::ibv_gid {
73 fn from(mut gid: Gid) -> Self {
74 *gid.as_mut()
75 }
76}
77
78impl AsRef<rdmaxcel_sys::ibv_gid> for Gid {
79 fn as_ref(&self) -> &rdmaxcel_sys::ibv_gid {
80 unsafe { &*self.raw.as_ptr().cast::<rdmaxcel_sys::ibv_gid>() }
81 }
82}
83
84impl AsMut<rdmaxcel_sys::ibv_gid> for Gid {
85 fn as_mut(&mut self) -> &mut rdmaxcel_sys::ibv_gid {
86 unsafe { &mut *self.raw.as_mut_ptr().cast::<rdmaxcel_sys::ibv_gid>() }
87 }
88}
89
90/// Queue pair type for RDMA operations.
91///
92/// Controls whether to use standard ibverbs queue pairs or mlx5dv extended queue pairs.
93/// Auto mode automatically selects based on device capabilities.
94#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
95pub enum RdmaQpType {
96 /// Auto-detect based on device capabilities
97 Auto,
98 /// Force standard ibverbs queue pair
99 Standard,
100 /// Force mlx5dv extended queue pair
101 Mlx5dv,
102}
103
104/// Converts `RdmaQpType` to the corresponding integer enum value in rdmaxcel_sys.
105pub fn resolve_qp_type(qp_type: RdmaQpType) -> u32 {
106 match qp_type {
107 RdmaQpType::Auto => {
108 if mlx5dv_supported() {
109 rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV
110 } else {
111 rdmaxcel_sys::RDMA_QP_TYPE_STANDARD
112 }
113 }
114 RdmaQpType::Standard => rdmaxcel_sys::RDMA_QP_TYPE_STANDARD,
115 RdmaQpType::Mlx5dv => rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV,
116 }
117}
118
119/// Represents ibverbs specific configurations.
120///
121/// This struct holds various parameters required to establish and manage an RDMA connection.
122/// It includes settings for the RDMA device, queue pair attributes, and other connection-specific
123/// parameters.
124#[derive(Debug, Named, Clone, Serialize, Deserialize)]
125pub struct IbverbsConfig {
126 /// `device` - The RDMA device to use for the connection.
127 pub device: RdmaDevice,
128 /// `cq_entries` - The number of completion queue entries.
129 pub cq_entries: i32,
130 /// `port_num` - The physical port number on the device.
131 pub port_num: u8,
132 /// `gid_index` - The GID index for the RDMA device.
133 pub gid_index: u8,
134 /// `max_send_wr` - The maximum number of outstanding send work requests.
135 pub max_send_wr: u32,
136 /// `max_recv_wr` - The maximum number of outstanding receive work requests.
137 pub max_recv_wr: u32,
138 /// `max_send_sge` - Te maximum number of scatter/gather elements in a send work request.
139 pub max_send_sge: u32,
140 /// `max_recv_sge` - The maximum number of scatter/gather elements in a receive work request.
141 pub max_recv_sge: u32,
142 /// `path_mtu` - The path MTU (Maximum Transmission Unit) for the connection.
143 pub path_mtu: u32,
144 /// `retry_cnt` - The number of retry attempts for a connection request.
145 pub retry_cnt: u8,
146 /// `rnr_retry` - The number of retry attempts for a receiver not ready (RNR) condition.
147 pub rnr_retry: u8,
148 /// `qp_timeout` - The timeout for a queue pair operation.
149 pub qp_timeout: u8,
150 /// `min_rnr_timer` - The minimum RNR timer value.
151 pub min_rnr_timer: u8,
152 /// `max_dest_rd_atomic` - The maximum number of outstanding RDMA read operations at the destination.
153 pub max_dest_rd_atomic: u8,
154 /// `max_rd_atomic` - The maximum number of outstanding RDMA read operations at the initiator.
155 pub max_rd_atomic: u8,
156 /// `pkey_index` - The partition key index.
157 pub pkey_index: u16,
158 /// `psn` - The packet sequence number.
159 pub psn: u32,
160 /// `use_gpu_direct` - Whether to enable GPU Direct RDMA support on init.
161 pub use_gpu_direct: bool,
162 /// `hw_init_delay_ms` - The delay in milliseconds before initializing the hardware.
163 /// This is used to allow the hardware to settle before starting the first transmission.
164 pub hw_init_delay_ms: u64,
165 /// `qp_type` - The type of queue pair to create (Auto, Standard, or Mlx5dv).
166 pub qp_type: RdmaQpType,
167}
168
169/// Default RDMA parameters below are based on common values from rdma-core examples
170/// For high-performance or production use, consider tuning
171/// based on ibv_query_device() results and workload characteristics
172impl Default for IbverbsConfig {
173 fn default() -> Self {
174 Self {
175 device: RdmaDevice::default(),
176 cq_entries: 1024,
177 port_num: 1,
178 gid_index: 3,
179 max_send_wr: 512,
180 max_recv_wr: 512,
181 max_send_sge: 30,
182 max_recv_sge: 30,
183 path_mtu: rdmaxcel_sys::IBV_MTU_4096,
184 retry_cnt: 7,
185 rnr_retry: 7,
186 qp_timeout: 14, // 4.096 μs * 2^14 = ~67 ms
187 min_rnr_timer: 12,
188 max_dest_rd_atomic: 16,
189 max_rd_atomic: 16,
190 pkey_index: 0,
191 psn: rand::random::<u32>() & 0xffffff,
192 use_gpu_direct: false, // nv_peermem enabled for cuda
193 hw_init_delay_ms: 2,
194 qp_type: RdmaQpType::Auto,
195 }
196 }
197}
198
199impl IbverbsConfig {
200 /// Create a new IbverbsConfig targeting a specific device
201 ///
202 /// Device targets use a unified "type:id" format:
203 /// - "cpu:N" -> finds RDMA device closest to NUMA node N
204 /// - "cuda:N" -> finds RDMA device closest to CUDA device N
205 /// - "nic:mlx5_N" -> returns the specified NIC directly
206 ///
207 /// Shortcuts:
208 /// - "cpu" -> defaults to "cpu:0"
209 /// - "cuda" -> defaults to "cuda:0"
210 ///
211 /// # Arguments
212 ///
213 /// * `target` - Target device specification
214 ///
215 /// # Returns
216 ///
217 /// * `IbverbsConfig` with resolved device, or default device if resolution fails
218 pub fn targeting(target: &str) -> Self {
219 // Normalize shortcuts
220 let normalized_target = match target {
221 "cpu" => "cpu:0",
222 "cuda" => "cuda:0",
223 _ => target,
224 };
225
226 let device = crate::device_selection::select_optimal_rdma_device(Some(normalized_target))
227 .unwrap_or_else(RdmaDevice::default);
228
229 Self {
230 device,
231 ..Default::default()
232 }
233 }
234}
235
236impl std::fmt::Display for IbverbsConfig {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 write!(
239 f,
240 "IbverbsConfig {{ device: {}, port_num: {}, gid_index: {}, max_send_wr: {}, max_recv_wr: {}, max_send_sge: {}, max_recv_sge: {}, path_mtu: {:?}, retry_cnt: {}, rnr_retry: {}, qp_timeout: {}, min_rnr_timer: {}, max_dest_rd_atomic: {}, max_rd_atomic: {}, pkey_index: {}, psn: 0x{:x} }}",
241 self.device.name(),
242 self.port_num,
243 self.gid_index,
244 self.max_send_wr,
245 self.max_recv_wr,
246 self.max_send_sge,
247 self.max_recv_sge,
248 self.path_mtu,
249 self.retry_cnt,
250 self.rnr_retry,
251 self.qp_timeout,
252 self.min_rnr_timer,
253 self.max_dest_rd_atomic,
254 self.max_rd_atomic,
255 self.pkey_index,
256 self.psn,
257 )
258 }
259}
260
261/// Represents an RDMA device in the system.
262///
263/// This struct encapsulates information about an RDMA device, including its hardware
264/// characteristics, capabilities, and port information. It provides access to device
265/// attributes such as vendor information, firmware version, and supported features.
266///
267/// # Examples
268///
269/// ```
270/// use monarch_rdma::get_all_devices;
271///
272/// let devices = get_all_devices();
273/// if let Some(device) = devices.first() {
274/// // Access device name and firmware version
275/// let device_name = device.name();
276/// let firmware_version = device.fw_ver();
277/// }
278/// ```
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct RdmaDevice {
281 /// `name` - The name of the RDMA device (e.g., "mlx5_0").
282 pub name: String,
283 /// `vendor_id` - The vendor ID of the device.
284 vendor_id: u32,
285 /// `vendor_part_id` - The vendor part ID of the device.
286 vendor_part_id: u32,
287 /// `hw_ver` - Hardware version of the device.
288 hw_ver: u32,
289 /// `fw_ver` - Firmware version of the device.
290 fw_ver: String,
291 /// `node_guid` - Node GUID (Globally Unique Identifier) of the device.
292 node_guid: u64,
293 /// `ports` - Vector of ports available on this device.
294 ports: Vec<RdmaPort>,
295 /// `max_qp` - Maximum number of queue pairs supported.
296 max_qp: i32,
297 /// `max_cq` - Maximum number of completion queues supported.
298 max_cq: i32,
299 /// `max_mr` - Maximum number of memory regions supported.
300 max_mr: i32,
301 /// `max_pd` - Maximum number of protection domains supported.
302 max_pd: i32,
303 /// `max_qp_wr` - Maximum number of work requests per queue pair.
304 max_qp_wr: i32,
305 /// `max_sge` - Maximum number of scatter/gather elements per work request.
306 max_sge: i32,
307}
308
309impl RdmaDevice {
310 /// Returns the name of the RDMA device.
311 pub fn name(&self) -> &String {
312 &self.name
313 }
314
315 /// Returns the first available RDMA device, if any.
316 pub fn first_available() -> Option<RdmaDevice> {
317 let devices = get_all_devices();
318 if devices.is_empty() {
319 None
320 } else {
321 Some(devices.into_iter().next().unwrap())
322 }
323 }
324
325 /// Returns the vendor ID of the RDMA device.
326 pub fn vendor_id(&self) -> u32 {
327 self.vendor_id
328 }
329
330 /// Returns the vendor part ID of the RDMA device.
331 pub fn vendor_part_id(&self) -> u32 {
332 self.vendor_part_id
333 }
334
335 /// Returns the hardware version of the RDMA device.
336 pub fn hw_ver(&self) -> u32 {
337 self.hw_ver
338 }
339
340 /// Returns the firmware version of the RDMA device.
341 pub fn fw_ver(&self) -> &String {
342 &self.fw_ver
343 }
344
345 /// Returns the node GUID of the RDMA device.
346 pub fn node_guid(&self) -> u64 {
347 self.node_guid
348 }
349
350 /// Returns a reference to the vector of ports available on the RDMA device.
351 pub fn ports(&self) -> &Vec<RdmaPort> {
352 &self.ports
353 }
354
355 /// Returns the maximum number of queue pairs supported by the RDMA device.
356 pub fn max_qp(&self) -> i32 {
357 self.max_qp
358 }
359
360 /// Returns the maximum number of completion queues supported by the RDMA device.
361 pub fn max_cq(&self) -> i32 {
362 self.max_cq
363 }
364
365 /// Returns the maximum number of memory regions supported by the RDMA device.
366 pub fn max_mr(&self) -> i32 {
367 self.max_mr
368 }
369
370 /// Returns the maximum number of protection domains supported by the RDMA device.
371 pub fn max_pd(&self) -> i32 {
372 self.max_pd
373 }
374
375 /// Returns the maximum number of work requests per queue pair supported by the RDMA device.
376 pub fn max_qp_wr(&self) -> i32 {
377 self.max_qp_wr
378 }
379
380 /// Returns the maximum number of scatter/gather elements per work request supported by the RDMA device.
381 pub fn max_sge(&self) -> i32 {
382 self.max_sge
383 }
384}
385
386impl Default for RdmaDevice {
387 fn default() -> Self {
388 // Try to get a smart default using device selection logic (defaults to cpu:0)
389 if let Some(device) = crate::device_selection::select_optimal_rdma_device(Some("cpu:0")) {
390 device
391 } else {
392 // Fallback to first available device
393 get_all_devices()
394 .into_iter()
395 .next()
396 .unwrap_or_else(|| panic!("No RDMA devices found"))
397 }
398 }
399}
400
401#[derive(Debug, Clone, Serialize, Deserialize)]
402pub struct RdmaPort {
403 /// `port_num` - The physical port number on the device.
404 port_num: u8,
405 /// `state` - The current state of the port.
406 state: String,
407 /// `physical_state` - The physical state of the port.
408 physical_state: String,
409 /// `base_lid` - Base Local Identifier for the port.
410 base_lid: u16,
411 /// `lmc` - LID Mask Control.
412 lmc: u8,
413 /// `sm_lid` - Subnet Manager Local Identifier.
414 sm_lid: u16,
415 /// `capability_mask` - Capability mask of the port.
416 capability_mask: u32,
417 /// `link_layer` - The link layer type (e.g., InfiniBand, Ethernet).
418 link_layer: String,
419 /// `gid` - Global Identifier for the port.
420 gid: String,
421 /// `gid_tbl_len` - Length of the GID table.
422 gid_tbl_len: i32,
423}
424
425impl fmt::Display for RdmaDevice {
426 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
427 writeln!(f, "{}", self.name)?;
428 writeln!(f, "\tNumber of ports: {}", self.ports.len())?;
429 writeln!(f, "\tFirmware version: {}", self.fw_ver)?;
430 writeln!(f, "\tHardware version: {}", self.hw_ver)?;
431 writeln!(f, "\tNode GUID: 0x{:016x}", self.node_guid)?;
432 writeln!(f, "\tVendor ID: 0x{:x}", self.vendor_id)?;
433 writeln!(f, "\tVendor part ID: {}", self.vendor_part_id)?;
434 writeln!(f, "\tMax QPs: {}", self.max_qp)?;
435 writeln!(f, "\tMax CQs: {}", self.max_cq)?;
436 writeln!(f, "\tMax MRs: {}", self.max_mr)?;
437 writeln!(f, "\tMax PDs: {}", self.max_pd)?;
438 writeln!(f, "\tMax QP WRs: {}", self.max_qp_wr)?;
439 writeln!(f, "\tMax SGE: {}", self.max_sge)?;
440
441 for port in &self.ports {
442 write!(f, "{}", port)?;
443 }
444
445 Ok(())
446 }
447}
448
449impl fmt::Display for RdmaPort {
450 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
451 writeln!(f, "\tPort {}:", self.port_num)?;
452 writeln!(f, "\t\tState: {}", self.state)?;
453 writeln!(f, "\t\tPhysical state: {}", self.physical_state)?;
454 writeln!(f, "\t\tBase lid: {}", self.base_lid)?;
455 writeln!(f, "\t\tLMC: {}", self.lmc)?;
456 writeln!(f, "\t\tSM lid: {}", self.sm_lid)?;
457 writeln!(f, "\t\tCapability mask: 0x{:08x}", self.capability_mask)?;
458 writeln!(f, "\t\tLink layer: {}", self.link_layer)?;
459 writeln!(f, "\t\tGID: {}", self.gid)?;
460 writeln!(f, "\t\tGID table length: {}", self.gid_tbl_len)?;
461 Ok(())
462 }
463}
464
465/// Converts the given port state to a human-readable string.
466///
467/// # Arguments
468///
469/// * `state` - The port state as defined by `ffi::ibv_port_state::Type`.
470///
471/// # Returns
472///
473/// A string representation of the port state.
474pub fn get_port_state_str(state: rdmaxcel_sys::ibv_port_state::Type) -> String {
475 // SAFETY: We are calling a C function that returns a C string.
476 unsafe {
477 let c_str = rdmaxcel_sys::ibv_port_state_str(state);
478 if c_str.is_null() {
479 return "Unknown".to_string();
480 }
481 CStr::from_ptr(c_str).to_string_lossy().into_owned()
482 }
483}
484
485/// Converts the given physical state to a human-readable string.
486///
487/// # Arguments
488///
489/// * `phys_state` - The physical state as a `u8`.
490///
491/// # Returns
492///
493/// A string representation of the physical state.
494pub fn get_port_phy_state_str(phys_state: u8) -> String {
495 match phys_state {
496 1 => "Sleep".to_string(),
497 2 => "Polling".to_string(),
498 3 => "Disabled".to_string(),
499 4 => "PortConfigurationTraining".to_string(),
500 5 => "LinkUp".to_string(),
501 6 => "LinkErrorRecovery".to_string(),
502 7 => "PhyTest".to_string(),
503 _ => "No state change".to_string(),
504 }
505}
506
507/// Converts the given link layer type to a human-readable string.
508///
509/// # Arguments
510///
511/// * `link_layer` - The link layer type as a `u8`.
512///
513/// # Returns
514///
515/// A string representation of the link layer type.
516pub fn get_link_layer_str(link_layer: u8) -> String {
517 match link_layer {
518 1 => "InfiniBand".to_string(),
519 2 => "Ethernet".to_string(),
520 _ => "Unknown".to_string(),
521 }
522}
523
524/// Formats a GID (Global Identifier) into a human-readable string.
525///
526/// # Arguments
527///
528/// * `gid` - A reference to a 16-byte array representing the GID.
529///
530/// # Returns
531///
532/// A formatted string representation of the GID.
533pub fn format_gid(gid: &[u8; 16]) -> String {
534 format!(
535 "{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}",
536 gid[0],
537 gid[1],
538 gid[2],
539 gid[3],
540 gid[4],
541 gid[5],
542 gid[6],
543 gid[7],
544 gid[8],
545 gid[9],
546 gid[10],
547 gid[11],
548 gid[12],
549 gid[13],
550 gid[14],
551 gid[15]
552 )
553}
554
555/// Retrieves information about all available RDMA devices in the system.
556///
557/// This function queries the system for all available RDMA devices and returns
558/// detailed information about each device, including its capabilities, ports,
559/// and attributes.
560///
561/// # Returns
562///
563/// A vector of `RdmaDevice` structures, each representing an RDMA device in the system.
564/// Returns an empty vector if no devices are found or if there was an error querying
565/// the devices.
566pub fn get_all_devices() -> Vec<RdmaDevice> {
567 let mut devices = Vec::new();
568
569 // SAFETY: We are calling several C functions from libibverbs.
570 unsafe {
571 let mut num_devices = 0;
572 let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices);
573 if device_list.is_null() || num_devices == 0 {
574 return devices;
575 }
576
577 for i in 0..num_devices {
578 let device = *device_list.add(i as usize);
579 if device.is_null() {
580 continue;
581 }
582
583 let context = rdmaxcel_sys::ibv_open_device(device);
584 if context.is_null() {
585 continue;
586 }
587
588 let device_name = CStr::from_ptr(rdmaxcel_sys::ibv_get_device_name(device))
589 .to_string_lossy()
590 .into_owned();
591
592 let mut device_attr = rdmaxcel_sys::ibv_device_attr::default();
593 if rdmaxcel_sys::ibv_query_device(context, &mut device_attr) != 0 {
594 rdmaxcel_sys::ibv_close_device(context);
595 continue;
596 }
597
598 let fw_ver = CStr::from_ptr(device_attr.fw_ver.as_ptr())
599 .to_string_lossy()
600 .into_owned();
601
602 let mut rdma_device = RdmaDevice {
603 name: device_name,
604 vendor_id: device_attr.vendor_id,
605 vendor_part_id: device_attr.vendor_part_id,
606 hw_ver: device_attr.hw_ver,
607 fw_ver,
608 node_guid: device_attr.node_guid,
609 ports: Vec::new(),
610 max_qp: device_attr.max_qp,
611 max_cq: device_attr.max_cq,
612 max_mr: device_attr.max_mr,
613 max_pd: device_attr.max_pd,
614 max_qp_wr: device_attr.max_qp_wr,
615 max_sge: device_attr.max_sge,
616 };
617
618 for port_num in 1..=device_attr.phys_port_cnt {
619 let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
620 if rdmaxcel_sys::ibv_query_port(
621 context,
622 port_num,
623 &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
624 ) != 0
625 {
626 continue;
627 }
628 let state = get_port_state_str(port_attr.state);
629 let physical_state = get_port_phy_state_str(port_attr.phys_state);
630
631 let link_layer = get_link_layer_str(port_attr.link_layer);
632
633 let mut gid = rdmaxcel_sys::ibv_gid::default();
634 let gid_str = if rdmaxcel_sys::ibv_query_gid(context, port_num, 0, &mut gid) == 0 {
635 format_gid(&gid.raw)
636 } else {
637 "N/A".to_string()
638 };
639
640 let rdma_port = RdmaPort {
641 port_num,
642 state,
643 physical_state,
644 base_lid: port_attr.lid,
645 lmc: port_attr.lmc,
646 sm_lid: port_attr.sm_lid,
647 capability_mask: port_attr.port_cap_flags,
648 link_layer,
649 gid: gid_str,
650 gid_tbl_len: port_attr.gid_tbl_len,
651 };
652
653 rdma_device.ports.push(rdma_port);
654 }
655
656 devices.push(rdma_device);
657 rdmaxcel_sys::ibv_close_device(context);
658 }
659
660 rdmaxcel_sys::ibv_free_device_list(device_list);
661 }
662
663 devices
664}
665
666/// Cached result of mlx5dv support check.
667static MLX5DV_SUPPORTED_CACHE: OnceLock<bool> = OnceLock::new();
668
669/// Checks if mlx5dv (Mellanox device-specific verbs extension) is supported.
670///
671/// This function attempts to open the first available RDMA device and check if
672/// mlx5dv extensions can be initialized. The mlx5dv extensions are required for
673/// advanced features like GPU Direct RDMA and direct queue pair manipulation.
674///
675/// The result is cached after the first call, making subsequent calls essentially free.
676///
677/// # Returns
678///
679/// `true` if mlx5dv extensions are supported, `false` otherwise.
680pub fn mlx5dv_supported() -> bool {
681 *MLX5DV_SUPPORTED_CACHE.get_or_init(mlx5dv_supported_impl)
682}
683
684fn mlx5dv_supported_impl() -> bool {
685 // SAFETY: We are calling C functions from libibverbs and libmlx5.
686 unsafe {
687 let mut mlx5dv_supported = false;
688 let mut num_devices = 0;
689 let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices);
690 if !device_list.is_null() && num_devices > 0 {
691 let device = *device_list;
692 if !device.is_null() {
693 mlx5dv_supported = rdmaxcel_sys::mlx5dv_is_supported(device);
694 }
695 rdmaxcel_sys::ibv_free_device_list(device_list);
696 }
697 mlx5dv_supported
698 }
699}
700
701/// Cached result of ibverbs support check.
702static IBVERBS_SUPPORTED_CACHE: OnceLock<bool> = OnceLock::new();
703
704/// Checks if ibverbs devices can be retrieved successfully.
705///
706/// This function attempts to retrieve the list of RDMA devices using the
707/// `ibv_get_device_list` function from the ibverbs library. It returns `true`
708/// if devices are found, and `false` otherwise.
709///
710/// The result is cached after the first call, making subsequent calls essentially free.
711///
712/// # Returns
713///
714/// `true` if devices are successfully retrieved, `false` otherwise.
715pub fn ibverbs_supported() -> bool {
716 *IBVERBS_SUPPORTED_CACHE.get_or_init(ibverbs_supported_impl)
717}
718
719fn ibverbs_supported_impl() -> bool {
720 // SAFETY: We are calling a C function from libibverbs.
721 unsafe {
722 let mut num_devices = 0;
723 let device_list = rdmaxcel_sys::ibv_get_device_list(&mut num_devices);
724 if !device_list.is_null() {
725 rdmaxcel_sys::ibv_free_device_list(device_list);
726 }
727 num_devices > 0
728 }
729}
730
731/// Checks if RDMA is fully supported on this system.
732///
733/// This is the canonical function to check if RDMA can be used.
734pub fn rdma_supported() -> bool {
735 ibverbs_supported()
736}
737
738/// Represents a view of a memory region that can be registered with an RDMA device.
739///
740/// This is a 'view' of a registered Memory Region, allowing multiple views into a single
741/// large MR registration. This is commonly used with PyTorch's caching allocator, which
742/// reserves large memory blocks and provides different data pointers into that space.
743///
744/// # Example
745/// PyTorch Caching Allocator creates a 16GB segment at virtual address `0x01000000`.
746/// The underlying Memory Region registers 16GB but at RDMA address `0x0`.
747/// To access virtual address `0x01100000`, we return a view at RDMA address `0x100000`.
748///
749/// # Safety
750/// The caller must ensure the memory remains valid and is not freed, moved, or
751/// overwritten while RDMA operations are in progress.
752
753#[derive(
754 Debug,
755 PartialEq,
756 Eq,
757 std::hash::Hash,
758 Serialize,
759 Deserialize,
760 Clone,
761 Copy
762)]
763pub struct RdmaMemoryRegionView {
764 // id should be unique with a given rdmam manager
765 pub id: usize,
766 /// Virtual address in the process address space.
767 /// This is the pointer/address as seen by the local process.
768 pub virtual_addr: usize,
769 /// Memory address assigned after Memory Region (MR) registration.
770 /// This is the address may be offset a base MR addr.
771 pub rdma_addr: usize,
772 pub size: usize,
773 pub lkey: u32,
774 pub rkey: u32,
775}
776
777// SAFETY: RdmaMemoryRegionView can be safely sent between threads because it only
778// contains address and size information without any thread-local state. However,
779// this DOES NOT provide any protection against data races in the underlying memory.
780// If one thread initiates an RDMA operation while another thread modifies the same
781// memory region, undefined behavior will occur. The caller is responsible for proper
782// synchronization of access to the underlying memory.
783unsafe impl Send for RdmaMemoryRegionView {}
784
785// SAFETY: RdmaMemoryRegionView is safe for concurrent access by multiple threads
786// as it only provides a view into memory without modifying its own state. However,
787// it provides NO PROTECTION against concurrent access to the underlying memory region.
788// The caller must ensure proper synchronization when:
789// 1. Initiating RDMA operations while local code reads/writes the same memory
790// 2. Performing multiple overlapping RDMA operations on the same memory region
791// 3. Freeing or reallocating memory that has in-flight RDMA operations
792unsafe impl Sync for RdmaMemoryRegionView {}
793
794impl RdmaMemoryRegionView {
795 /// Creates a new `RdmaMemoryRegionView` with the given address and size.
796 pub fn new(
797 id: usize,
798 virtual_addr: usize,
799 rdma_addr: usize,
800 size: usize,
801 lkey: u32,
802 rkey: u32,
803 ) -> Self {
804 Self {
805 id,
806 virtual_addr,
807 rdma_addr,
808 size,
809 lkey,
810 rkey,
811 }
812 }
813}
814
815/// Enum representing the common RDMA operations.
816///
817/// This provides a more ergonomic interface to the underlying ibv_wr_opcode types.
818/// RDMA operations allow for direct memory access between two machines without
819/// involving the CPU of the target machine.
820///
821/// # Variants
822///
823/// * `Write` - Represents an RDMA write operation where data is written from the local
824/// memory to a remote memory region.
825/// * `Read` - Represents an RDMA read operation where data is read from a remote memory
826/// region into the local memory.
827#[derive(Debug, Clone, Copy, PartialEq, Eq)]
828pub enum RdmaOperation {
829 /// RDMA write operations
830 Write,
831 WriteWithImm,
832 /// RDMA read operation
833 Read,
834 /// RDMA recv operation
835 Recv,
836}
837
838impl From<RdmaOperation> for rdmaxcel_sys::ibv_wr_opcode::Type {
839 fn from(op: RdmaOperation) -> Self {
840 match op {
841 RdmaOperation::Write => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE,
842 RdmaOperation::WriteWithImm => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE_WITH_IMM,
843 RdmaOperation::Read => rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_READ,
844 RdmaOperation::Recv => panic!("Invalid wr opcode"),
845 }
846 }
847}
848
849impl From<rdmaxcel_sys::ibv_wc_opcode::Type> for RdmaOperation {
850 fn from(op: rdmaxcel_sys::ibv_wc_opcode::Type) -> Self {
851 match op {
852 rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE => RdmaOperation::Write,
853 rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_READ => RdmaOperation::Read,
854 _ => panic!("Unsupported operation type"),
855 }
856 }
857}
858
859/// Contains information needed to establish an RDMA queue pair with a remote endpoint.
860///
861/// `RdmaQpInfo` encapsulates all the necessary information to establish a queue pair
862/// with a remote RDMA device. This includes queue pair number, LID (Local Identifier),
863/// GID (Global Identifier), remote memory address, remote key, and packet sequence number.
864#[derive(Default, Named, Clone, serde::Serialize, serde::Deserialize)]
865pub struct RdmaQpInfo {
866 /// `qp_num` - Queue Pair Number, uniquely identifies a queue pair on the remote device
867 pub qp_num: u32,
868 /// `lid` - Local Identifier, used for addressing in InfiniBand subnet
869 pub lid: u16,
870 /// `gid` - Global Identifier, used for routing across subnets (similar to IPv6 address)
871 pub gid: Option<Gid>,
872 /// `psn` - Packet Sequence Number, used for ordering packets
873 pub psn: u32,
874}
875
876impl std::fmt::Debug for RdmaQpInfo {
877 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
878 write!(
879 f,
880 "RdmaQpInfo {{ qp_num: {}, lid: {}, gid: {:?}, psn: 0x{:x} }}",
881 self.qp_num, self.lid, self.gid, self.psn
882 )
883 }
884}
885
886/// Wrapper around ibv_wc (ibverbs work completion).
887///
888/// This exposes only the public fields of rdmaxcel_sys::ibv_wc, allowing us to more easily
889/// interact with it from Rust. Work completions are used to track the status of
890/// RDMA operations and are generated when an operation completes.
891#[derive(Debug, Named, Clone, serde::Serialize, serde::Deserialize)]
892pub struct IbvWc {
893 /// `wr_id` - Work Request ID, used to identify the completed operation
894 wr_id: u64,
895 /// `len` - Length of the data transferred
896 len: usize,
897 /// `valid` - Whether the work completion is valid
898 valid: bool,
899 /// `error` - Error information if the operation failed
900 error: Option<(rdmaxcel_sys::ibv_wc_status::Type, u32)>,
901 /// `opcode` - Type of operation that completed (read, write, etc.)
902 opcode: rdmaxcel_sys::ibv_wc_opcode::Type,
903 /// `bytes` - Immediate data (if any)
904 bytes: Option<u32>,
905 /// `qp_num` - Queue Pair Number
906 qp_num: u32,
907 /// `src_qp` - Source Queue Pair Number
908 src_qp: u32,
909 /// `pkey_index` - Partition Key Index
910 pkey_index: u16,
911 /// `slid` - Source LID
912 slid: u16,
913 /// `sl` - Service Level
914 sl: u8,
915 /// `dlid_path_bits` - Destination LID Path Bits
916 dlid_path_bits: u8,
917}
918
919impl From<rdmaxcel_sys::ibv_wc> for IbvWc {
920 fn from(wc: rdmaxcel_sys::ibv_wc) -> Self {
921 IbvWc {
922 wr_id: wc.wr_id(),
923 len: wc.len(),
924 valid: wc.is_valid(),
925 error: wc.error(),
926 opcode: wc.opcode(),
927 bytes: wc.imm_data(),
928 qp_num: wc.qp_num,
929 src_qp: wc.src_qp,
930 pkey_index: wc.pkey_index,
931 slid: wc.slid,
932 sl: wc.sl,
933 dlid_path_bits: wc.dlid_path_bits,
934 }
935 }
936}
937
938impl IbvWc {
939 /// Returns the Work Request ID associated with this work completion.
940 ///
941 /// The Work Request ID is used to identify the specific operation that completed.
942 /// It is set by the application when posting the work request and is returned
943 /// unchanged in the work completion.
944 pub fn wr_id(&self) -> u64 {
945 self.wr_id
946 }
947
948 /// Returns whether this work completion is valid.
949 ///
950 /// A valid work completion indicates that the operation completed successfully.
951 /// If false, the `error` field may contain additional information about the failure.
952 pub fn is_valid(&self) -> bool {
953 self.valid
954 }
955}
956
957#[cfg(test)]
958mod tests {
959 use super::*;
960
961 #[test]
962 fn test_get_all_devices() {
963 // Skip test if RDMA devices are not available
964 let devices = get_all_devices();
965 if devices.is_empty() {
966 println!("Skipping test: RDMA devices not available");
967 return;
968 }
969 // Basic validation of first device
970 let device = &devices[0];
971 assert!(!device.name().is_empty(), "device name should not be empty");
972 assert!(
973 !device.ports().is_empty(),
974 "device should have at least one port"
975 );
976 }
977
978 #[test]
979 fn test_first_available() {
980 // Skip test if RDMA is not available
981 let devices = get_all_devices();
982 if devices.is_empty() {
983 println!("Skipping test: RDMA devices not available");
984 return;
985 }
986 // Basic validation of first device
987 let device = &devices[0];
988
989 let dev = device;
990 // Verify getters return expected values
991 assert_eq!(dev.vendor_id(), dev.vendor_id);
992 assert_eq!(dev.vendor_part_id(), dev.vendor_part_id);
993 assert_eq!(dev.hw_ver(), dev.hw_ver);
994 assert_eq!(dev.fw_ver(), &dev.fw_ver);
995 assert_eq!(dev.node_guid(), dev.node_guid);
996 assert_eq!(dev.max_qp(), dev.max_qp);
997 assert_eq!(dev.max_cq(), dev.max_cq);
998 assert_eq!(dev.max_mr(), dev.max_mr);
999 assert_eq!(dev.max_pd(), dev.max_pd);
1000 assert_eq!(dev.max_qp_wr(), dev.max_qp_wr);
1001 assert_eq!(dev.max_sge(), dev.max_sge);
1002 }
1003
1004 #[test]
1005 fn test_device_display() {
1006 if let Some(device) = RdmaDevice::first_available() {
1007 let display_output = format!("{}", device);
1008 assert!(
1009 display_output.contains(&device.name),
1010 "display should include device name"
1011 );
1012 assert!(
1013 display_output.contains(&device.fw_ver),
1014 "display should include firmware version"
1015 );
1016 }
1017 }
1018
1019 #[test]
1020 fn test_port_display() {
1021 if let Some(device) = RdmaDevice::first_available() {
1022 if !device.ports().is_empty() {
1023 let port = &device.ports()[0];
1024 let display_output = format!("{}", port);
1025 assert!(
1026 display_output.contains(&port.state),
1027 "display should include port state"
1028 );
1029 assert!(
1030 display_output.contains(&port.link_layer),
1031 "display should include link layer"
1032 );
1033 }
1034 }
1035 }
1036
1037 #[test]
1038 fn test_rdma_operation_conversion() {
1039 assert_eq!(
1040 rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_WRITE,
1041 rdmaxcel_sys::ibv_wr_opcode::Type::from(RdmaOperation::Write)
1042 );
1043 assert_eq!(
1044 rdmaxcel_sys::ibv_wr_opcode::IBV_WR_RDMA_READ,
1045 rdmaxcel_sys::ibv_wr_opcode::Type::from(RdmaOperation::Read)
1046 );
1047
1048 assert_eq!(
1049 RdmaOperation::Write,
1050 RdmaOperation::from(rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_WRITE)
1051 );
1052 assert_eq!(
1053 RdmaOperation::Read,
1054 RdmaOperation::from(rdmaxcel_sys::ibv_wc_opcode::IBV_WC_RDMA_READ)
1055 );
1056 }
1057
1058 #[test]
1059 fn test_rdma_endpoint() {
1060 let endpoint = RdmaQpInfo {
1061 qp_num: 42,
1062 lid: 123,
1063 gid: None,
1064 psn: 0x5678,
1065 };
1066
1067 let debug_str = format!("{:?}", endpoint);
1068 assert!(debug_str.contains("qp_num: 42"));
1069 assert!(debug_str.contains("lid: 123"));
1070 assert!(debug_str.contains("psn: 0x5678"));
1071 }
1072
1073 #[test]
1074 fn test_ibv_wc() {
1075 let mut wc = rdmaxcel_sys::ibv_wc::default();
1076
1077 // SAFETY: modifies private fields through pointer manipulation
1078 unsafe {
1079 // Cast to pointer and modify the fields directly
1080 let wc_ptr = &mut wc as *mut rdmaxcel_sys::ibv_wc as *mut u8;
1081
1082 // Set wr_id (at offset 0, u64)
1083 *(wc_ptr as *mut u64) = 42;
1084
1085 // Set status to SUCCESS (at offset 8, u32)
1086 *(wc_ptr.add(8) as *mut i32) = rdmaxcel_sys::ibv_wc_status::IBV_WC_SUCCESS as i32;
1087 }
1088 let ibv_wc = IbvWc::from(wc);
1089 assert_eq!(ibv_wc.wr_id(), 42);
1090 assert!(ibv_wc.is_valid());
1091 }
1092
1093 #[test]
1094 fn test_format_gid() {
1095 let gid = [
1096 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66,
1097 0x77, 0x88,
1098 ];
1099
1100 let formatted = format_gid(&gid);
1101 assert_eq!(formatted, "1234:5678:9abc:def0:1122:3344:5566:7788");
1102 }
1103
1104 #[test]
1105 fn test_mlx5dv_supported_basic() {
1106 // The test just verifies the function doesn't panic
1107 let mlx5dv_support = mlx5dv_supported();
1108 println!("mlx5dv_supported: {}", mlx5dv_support);
1109 }
1110}