monarch_rdma/
device_selection.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! PCI topology parsing and device discovery utilities for RDMA device selection.
10//!
11//! ibverbs-specific selection logic lives in [`crate::backend::ibverbs::device_selection`].
12
13use std::collections::HashMap;
14use std::fs;
15use std::path::Path;
16
17use regex::Regex;
18
19// ==== PCI TOPOLOGY DISTANCE CONSTANTS ====
20//
21// These constants define penalty values for cross-NUMA communication in PCI topology:
22//
23// - CROSS_NUMA_BASE_PENALTY (20.0): Base penalty for cross-NUMA communication.
24//   This value is higher than typical intra-NUMA distances (usually 0-8 hops)
25//   to ensure same-NUMA devices are always preferred over cross-NUMA devices.
26//
27// - ADDRESS_PARSE_FAILURE_PENALTY (Inf): Penalty when PCI address parsing fails.
28//   Used as fallback when we can't determine bus relationships between devices.
29//
30// - CROSS_DOMAIN_PENALTY (1000.0): Very high penalty for different PCI domains.
31//   Different domains typically indicate completely separate I/O subsystems.
32//
33// - BUS_DISTANCE_SCALE (0.1): Scaling factor for bus distance in cross-NUMA penalty.
34//   Small factor to provide tie-breaking between devices at different bus numbers.
35
36const CROSS_NUMA_BASE_PENALTY: f64 = 20.0;
37const ADDRESS_PARSE_FAILURE_PENALTY: f64 = f64::INFINITY;
38const CROSS_DOMAIN_PENALTY: f64 = 1000.0;
39const BUS_DISTANCE_SCALE: f64 = 0.1;
40
41#[derive(Debug, Clone)]
42pub struct PCIDevice {
43    pub address: String,
44    pub parent: Option<Box<PCIDevice>>,
45}
46
47impl PCIDevice {
48    pub fn new(address: String) -> Self {
49        Self {
50            address,
51            parent: None,
52        }
53    }
54
55    pub fn get_path_to_root(&self) -> Vec<String> {
56        let mut path = vec![self.address.clone()];
57        let mut current = self;
58
59        while let Some(ref parent) = current.parent {
60            path.push(parent.address.clone());
61            current = parent;
62        }
63
64        path
65    }
66    pub fn get_numa_node(&self) -> Option<i32> {
67        let numa_file = format!("/sys/bus/pci/devices/{}/numa_node", self.address);
68        std::fs::read_to_string(numa_file).ok()?.trim().parse().ok()
69    }
70
71    pub fn distance_to(&self, other: &PCIDevice) -> f64 {
72        if self.address == other.address {
73            return 0.0;
74        }
75
76        // Get paths to root for both devices
77        let path1 = self.get_path_to_root();
78        let path2 = other.get_path_to_root();
79
80        // Find lowest common ancestor (first common element from the end)
81        let mut common_ancestor = None;
82        let min_len = path1.len().min(path2.len());
83
84        // Check from the root down to find the deepest common ancestor
85        for i in 1..=min_len {
86            if path1[path1.len() - i] == path2[path2.len() - i] {
87                common_ancestor = Some(&path1[path1.len() - i]);
88            } else {
89                break;
90            }
91        }
92
93        if let Some(ancestor) = common_ancestor {
94            let hops1 = path1.iter().position(|addr| addr == ancestor).unwrap_or(0);
95            let hops2 = path2.iter().position(|addr| addr == ancestor).unwrap_or(0);
96            (hops1 + hops2) as f64
97        } else {
98            self.calculate_cross_numa_distance(other)
99        }
100    }
101
102    /// Calculate distance between devices on different NUMA domains/root complexes
103    /// This handles cases where devices don't share a common PCI ancestor
104    fn calculate_cross_numa_distance(&self, other: &PCIDevice) -> f64 {
105        let self_parts = self.parse_pci_address();
106        let other_parts = other.parse_pci_address();
107
108        match (self_parts, other_parts) {
109            (Some((self_domain, self_bus, _, _)), Some((other_domain, other_bus, _, _))) => {
110                if self_domain != other_domain {
111                    return CROSS_DOMAIN_PENALTY;
112                }
113
114                let bus_distance = (self_bus as i32 - other_bus as i32).abs() as f64;
115                CROSS_NUMA_BASE_PENALTY + bus_distance * BUS_DISTANCE_SCALE
116            }
117            _ => ADDRESS_PARSE_FAILURE_PENALTY,
118        }
119    }
120
121    /// Parse PCI address into components (domain, bus, device, function)
122    fn parse_pci_address(&self) -> Option<(u16, u8, u8, u8)> {
123        let parts: Vec<&str> = self.address.split(':').collect();
124        if parts.len() != 3 {
125            return None;
126        }
127
128        let domain = u16::from_str_radix(parts[0], 16).ok()?;
129        let bus = u8::from_str_radix(parts[1], 16).ok()?;
130
131        let dev_func: Vec<&str> = parts[2].split('.').collect();
132        if dev_func.len() != 2 {
133            return None;
134        }
135
136        let device = u8::from_str_radix(dev_func[0], 16).ok()?;
137        let function = u8::from_str_radix(dev_func[1], 16).ok()?;
138
139        Some((domain, bus, device, function))
140    }
141
142    /// Find the index of the closest device from a list of candidates
143    pub fn find_closest(&self, candidate_devices: &[PCIDevice]) -> Option<usize> {
144        if candidate_devices.is_empty() {
145            return None;
146        }
147
148        let mut closest_idx = 0;
149        let mut min_distance = self.distance_to(&candidate_devices[0]);
150
151        for (idx, device) in candidate_devices.iter().enumerate().skip(1) {
152            let distance = self.distance_to(device);
153            if distance < min_distance {
154                min_distance = distance;
155                closest_idx = idx;
156            }
157        }
158
159        Some(closest_idx)
160    }
161}
162
163/// Resolve all symlinks in a path (equivalent to Python's os.path.realpath)
164fn realpath(path: &Path) -> Result<std::path::PathBuf, std::io::Error> {
165    let mut current = path.to_path_buf();
166    let mut seen = std::collections::HashSet::new();
167
168    loop {
169        if seen.contains(&current) {
170            return Err(std::io::Error::new(
171                std::io::ErrorKind::InvalidInput,
172                "Circular symlink detected",
173            ));
174        }
175        seen.insert(current.clone());
176
177        match fs::read_link(&current) {
178            Ok(target) => {
179                current = if target.is_absolute() {
180                    target
181                } else {
182                    current.parent().unwrap_or(Path::new("/")).join(target)
183                };
184            }
185            Err(_) => break, // Not a symlink or error reading
186        }
187    }
188
189    Ok(current)
190}
191
192pub fn parse_pci_topology() -> Result<HashMap<String, PCIDevice>, std::io::Error> {
193    let mut devices = HashMap::new();
194    let mut parent_addresses = HashMap::new();
195    let pci_devices_dir = "/sys/bus/pci/devices";
196
197    if !Path::new(pci_devices_dir).exists() {
198        return Ok(devices);
199    }
200
201    let pci_addr_regex = Regex::new(r"([0-9a-f]{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9])$").unwrap();
202
203    // First pass: create all devices without parent references
204    for entry in fs::read_dir(pci_devices_dir)? {
205        let entry = entry?;
206        let pci_addr = entry.file_name().to_string_lossy().to_string();
207        let device_path = entry.path();
208
209        // Find parent device by following the device symlink and extracting PCI address from the path
210        let parent_addr = match realpath(&device_path) {
211            Ok(real_path) => {
212                if let Some(parent_path) = real_path.parent() {
213                    let parent_path_str = parent_path.to_string_lossy();
214                    pci_addr_regex
215                        .captures(&parent_path_str)
216                        .map(|captures| captures.get(1).unwrap().as_str().to_string())
217                } else {
218                    None
219                }
220            }
221            Err(_) => None,
222        };
223
224        devices.insert(pci_addr.clone(), PCIDevice::new(pci_addr.clone()));
225        if let Some(ref parent) = parent_addr {
226            if !devices.contains_key(parent) {
227                devices.insert(parent.clone(), PCIDevice::new(parent.clone()));
228            }
229        }
230        parent_addresses.insert(pci_addr, parent_addr);
231    }
232
233    // Second pass: set up parent references recursively
234    fn build_parent_chain(
235        devices: &mut HashMap<String, PCIDevice>,
236        parent_addresses: &HashMap<String, Option<String>>,
237        pci_addr: &str,
238        visited: &mut std::collections::HashSet<String>,
239    ) {
240        if visited.contains(pci_addr) {
241            return;
242        }
243        visited.insert(pci_addr.to_string());
244
245        if let Some(Some(parent_addr)) = parent_addresses.get(pci_addr) {
246            build_parent_chain(devices, parent_addresses, parent_addr, visited);
247
248            if let Some(parent_device) = devices.get(parent_addr).cloned() {
249                if let Some(device) = devices.get_mut(pci_addr) {
250                    device.parent = Some(Box::new(parent_device));
251                }
252            }
253        }
254    }
255
256    let mut visited = std::collections::HashSet::new();
257    for pci_addr in devices.keys().cloned().collect::<Vec<_>>() {
258        visited.clear();
259        build_parent_chain(&mut devices, &parent_addresses, &pci_addr, &mut visited);
260    }
261
262    Ok(devices)
263}
264
265pub fn parse_device_string(device_str: &str) -> Option<(String, String)> {
266    let parts: Vec<&str> = device_str.split(':').collect();
267    if parts.len() == 2 {
268        Some((parts[0].to_string(), parts[1].to_string()))
269    } else {
270        None
271    }
272}
273
274pub fn get_cuda_pci_address(device_idx: &str) -> Option<String> {
275    let idx: i32 = device_idx.parse().ok()?;
276    let gpu_proc_dir = "/proc/driver/nvidia/gpus";
277
278    if !Path::new(gpu_proc_dir).exists() {
279        return None;
280    }
281
282    for entry in fs::read_dir(gpu_proc_dir).ok()? {
283        let entry = entry.ok()?;
284        let pci_addr = entry.file_name().to_string_lossy().to_lowercase();
285        let info_file = entry.path().join("information");
286
287        if let Ok(content) = fs::read_to_string(&info_file) {
288            let minor_regex = Regex::new(r"Device Minor:\s*(\d+)").unwrap();
289            if let Some(captures) = minor_regex.captures(&content) {
290                if let Ok(device_minor) = captures.get(1).unwrap().as_str().parse::<i32>() {
291                    if device_minor == idx {
292                        return Some(pci_addr);
293                    }
294                }
295            }
296        }
297    }
298    None
299}
300
301pub fn get_numa_pci_address(numa_node: &str) -> Option<String> {
302    let node: i32 = numa_node.parse().ok()?;
303    let pci_devices = parse_pci_topology().ok()?;
304
305    let mut candidates = Vec::new();
306    for (pci_addr, device) in &pci_devices {
307        if let Some(device_numa) = device.get_numa_node() {
308            if device_numa == node {
309                candidates.push(pci_addr.clone());
310            }
311        }
312    }
313
314    if candidates.is_empty() {
315        return None;
316    }
317
318    let mut best_candidate = candidates[0].clone();
319    let mut shortest_path = usize::MAX;
320
321    for pci_addr in &candidates {
322        if let Some(device) = pci_devices.get(pci_addr) {
323            let path_length = device.get_path_to_root().len();
324            if path_length < shortest_path
325                || (path_length == shortest_path && pci_addr < &best_candidate)
326            {
327                shortest_path = path_length;
328                best_candidate = pci_addr.clone();
329            }
330        }
331    }
332
333    Some(best_candidate)
334}
335
336pub fn get_all_rdma_devices() -> Vec<(String, String)> {
337    let mut rdma_devices = Vec::new();
338    let ib_class_dir = "/sys/class/infiniband";
339
340    if !Path::new(ib_class_dir).exists() {
341        return rdma_devices;
342    }
343
344    let pci_regex = Regex::new(r"([0-9a-f]{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9])").unwrap();
345
346    if let Ok(entries) = fs::read_dir(ib_class_dir) {
347        let mut sorted_entries: Vec<_> = entries.collect::<Result<Vec<_>, _>>().unwrap_or_default();
348        sorted_entries.sort_by_key(|entry| entry.file_name());
349
350        for entry in sorted_entries {
351            let ib_dev = entry.file_name().to_string_lossy().to_string();
352            let device_path = entry.path().join("device");
353
354            if let Ok(real_path) = fs::read_link(&device_path) {
355                let real_path_str = real_path.to_string_lossy();
356                let pci_matches: Vec<&str> = pci_regex
357                    .find_iter(&real_path_str)
358                    .map(|m| m.as_str())
359                    .collect();
360
361                if let Some(&last_pci_addr) = pci_matches.last() {
362                    rdma_devices.push((ib_dev, last_pci_addr.to_string()));
363                }
364            }
365        }
366    }
367
368    rdma_devices
369}
370
371pub fn get_nic_pci_address(nic_name: &str) -> Option<String> {
372    let rdma_devices = get_all_rdma_devices();
373    for (name, pci_addr) in rdma_devices {
374        if name == nic_name {
375            return Some(pci_addr);
376        }
377    }
378    None
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn test_parse_device_string() {
387        assert_eq!(
388            parse_device_string("cuda:0"),
389            Some(("cuda".to_string(), "0".to_string()))
390        );
391        assert_eq!(
392            parse_device_string("cpu:1"),
393            Some(("cpu".to_string(), "1".to_string()))
394        );
395        assert_eq!(parse_device_string("invalid"), None);
396        assert_eq!(
397            parse_device_string("cuda:"),
398            Some(("cuda".to_string(), "".to_string()))
399        );
400    }
401}