monarch_rdma/
device_selection.rs1use std::collections::HashMap;
14use std::fs;
15use std::path::Path;
16
17use regex::Regex;
18
19const CROSS_NUMA_BASE_PENALTY: f64 = 20.0;
37const ADDRESS_PARSE_FAILURE_PENALTY: f64 = f64::INFINITY;
38const CROSS_DOMAIN_PENALTY: f64 = 1000.0;
39const BUS_DISTANCE_SCALE: f64 = 0.1;
40
41#[derive(Debug, Clone)]
42pub struct PCIDevice {
43 pub address: String,
44 pub parent: Option<Box<PCIDevice>>,
45}
46
47impl PCIDevice {
48 pub fn new(address: String) -> Self {
49 Self {
50 address,
51 parent: None,
52 }
53 }
54
55 pub fn get_path_to_root(&self) -> Vec<String> {
56 let mut path = vec![self.address.clone()];
57 let mut current = self;
58
59 while let Some(ref parent) = current.parent {
60 path.push(parent.address.clone());
61 current = parent;
62 }
63
64 path
65 }
66 pub fn get_numa_node(&self) -> Option<i32> {
67 let numa_file = format!("/sys/bus/pci/devices/{}/numa_node", self.address);
68 std::fs::read_to_string(numa_file).ok()?.trim().parse().ok()
69 }
70
71 pub fn distance_to(&self, other: &PCIDevice) -> f64 {
72 if self.address == other.address {
73 return 0.0;
74 }
75
76 let path1 = self.get_path_to_root();
78 let path2 = other.get_path_to_root();
79
80 let mut common_ancestor = None;
82 let min_len = path1.len().min(path2.len());
83
84 for i in 1..=min_len {
86 if path1[path1.len() - i] == path2[path2.len() - i] {
87 common_ancestor = Some(&path1[path1.len() - i]);
88 } else {
89 break;
90 }
91 }
92
93 if let Some(ancestor) = common_ancestor {
94 let hops1 = path1.iter().position(|addr| addr == ancestor).unwrap_or(0);
95 let hops2 = path2.iter().position(|addr| addr == ancestor).unwrap_or(0);
96 (hops1 + hops2) as f64
97 } else {
98 self.calculate_cross_numa_distance(other)
99 }
100 }
101
102 fn calculate_cross_numa_distance(&self, other: &PCIDevice) -> f64 {
105 let self_parts = self.parse_pci_address();
106 let other_parts = other.parse_pci_address();
107
108 match (self_parts, other_parts) {
109 (Some((self_domain, self_bus, _, _)), Some((other_domain, other_bus, _, _))) => {
110 if self_domain != other_domain {
111 return CROSS_DOMAIN_PENALTY;
112 }
113
114 let bus_distance = (self_bus as i32 - other_bus as i32).abs() as f64;
115 CROSS_NUMA_BASE_PENALTY + bus_distance * BUS_DISTANCE_SCALE
116 }
117 _ => ADDRESS_PARSE_FAILURE_PENALTY,
118 }
119 }
120
121 fn parse_pci_address(&self) -> Option<(u16, u8, u8, u8)> {
123 let parts: Vec<&str> = self.address.split(':').collect();
124 if parts.len() != 3 {
125 return None;
126 }
127
128 let domain = u16::from_str_radix(parts[0], 16).ok()?;
129 let bus = u8::from_str_radix(parts[1], 16).ok()?;
130
131 let dev_func: Vec<&str> = parts[2].split('.').collect();
132 if dev_func.len() != 2 {
133 return None;
134 }
135
136 let device = u8::from_str_radix(dev_func[0], 16).ok()?;
137 let function = u8::from_str_radix(dev_func[1], 16).ok()?;
138
139 Some((domain, bus, device, function))
140 }
141
142 pub fn find_closest(&self, candidate_devices: &[PCIDevice]) -> Option<usize> {
144 if candidate_devices.is_empty() {
145 return None;
146 }
147
148 let mut closest_idx = 0;
149 let mut min_distance = self.distance_to(&candidate_devices[0]);
150
151 for (idx, device) in candidate_devices.iter().enumerate().skip(1) {
152 let distance = self.distance_to(device);
153 if distance < min_distance {
154 min_distance = distance;
155 closest_idx = idx;
156 }
157 }
158
159 Some(closest_idx)
160 }
161}
162
163fn realpath(path: &Path) -> Result<std::path::PathBuf, std::io::Error> {
165 let mut current = path.to_path_buf();
166 let mut seen = std::collections::HashSet::new();
167
168 loop {
169 if seen.contains(¤t) {
170 return Err(std::io::Error::new(
171 std::io::ErrorKind::InvalidInput,
172 "Circular symlink detected",
173 ));
174 }
175 seen.insert(current.clone());
176
177 match fs::read_link(¤t) {
178 Ok(target) => {
179 current = if target.is_absolute() {
180 target
181 } else {
182 current.parent().unwrap_or(Path::new("/")).join(target)
183 };
184 }
185 Err(_) => break, }
187 }
188
189 Ok(current)
190}
191
192pub fn parse_pci_topology() -> Result<HashMap<String, PCIDevice>, std::io::Error> {
193 let mut devices = HashMap::new();
194 let mut parent_addresses = HashMap::new();
195 let pci_devices_dir = "/sys/bus/pci/devices";
196
197 if !Path::new(pci_devices_dir).exists() {
198 return Ok(devices);
199 }
200
201 let pci_addr_regex = Regex::new(r"([0-9a-f]{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9])$").unwrap();
202
203 for entry in fs::read_dir(pci_devices_dir)? {
205 let entry = entry?;
206 let pci_addr = entry.file_name().to_string_lossy().to_string();
207 let device_path = entry.path();
208
209 let parent_addr = match realpath(&device_path) {
211 Ok(real_path) => {
212 if let Some(parent_path) = real_path.parent() {
213 let parent_path_str = parent_path.to_string_lossy();
214 pci_addr_regex
215 .captures(&parent_path_str)
216 .map(|captures| captures.get(1).unwrap().as_str().to_string())
217 } else {
218 None
219 }
220 }
221 Err(_) => None,
222 };
223
224 devices.insert(pci_addr.clone(), PCIDevice::new(pci_addr.clone()));
225 if let Some(ref parent) = parent_addr {
226 if !devices.contains_key(parent) {
227 devices.insert(parent.clone(), PCIDevice::new(parent.clone()));
228 }
229 }
230 parent_addresses.insert(pci_addr, parent_addr);
231 }
232
233 fn build_parent_chain(
235 devices: &mut HashMap<String, PCIDevice>,
236 parent_addresses: &HashMap<String, Option<String>>,
237 pci_addr: &str,
238 visited: &mut std::collections::HashSet<String>,
239 ) {
240 if visited.contains(pci_addr) {
241 return;
242 }
243 visited.insert(pci_addr.to_string());
244
245 if let Some(Some(parent_addr)) = parent_addresses.get(pci_addr) {
246 build_parent_chain(devices, parent_addresses, parent_addr, visited);
247
248 if let Some(parent_device) = devices.get(parent_addr).cloned() {
249 if let Some(device) = devices.get_mut(pci_addr) {
250 device.parent = Some(Box::new(parent_device));
251 }
252 }
253 }
254 }
255
256 let mut visited = std::collections::HashSet::new();
257 for pci_addr in devices.keys().cloned().collect::<Vec<_>>() {
258 visited.clear();
259 build_parent_chain(&mut devices, &parent_addresses, &pci_addr, &mut visited);
260 }
261
262 Ok(devices)
263}
264
265pub fn parse_device_string(device_str: &str) -> Option<(String, String)> {
266 let parts: Vec<&str> = device_str.split(':').collect();
267 if parts.len() == 2 {
268 Some((parts[0].to_string(), parts[1].to_string()))
269 } else {
270 None
271 }
272}
273
274pub fn get_cuda_pci_address(device_idx: &str) -> Option<String> {
275 let idx: i32 = device_idx.parse().ok()?;
276 let gpu_proc_dir = "/proc/driver/nvidia/gpus";
277
278 if !Path::new(gpu_proc_dir).exists() {
279 return None;
280 }
281
282 for entry in fs::read_dir(gpu_proc_dir).ok()? {
283 let entry = entry.ok()?;
284 let pci_addr = entry.file_name().to_string_lossy().to_lowercase();
285 let info_file = entry.path().join("information");
286
287 if let Ok(content) = fs::read_to_string(&info_file) {
288 let minor_regex = Regex::new(r"Device Minor:\s*(\d+)").unwrap();
289 if let Some(captures) = minor_regex.captures(&content) {
290 if let Ok(device_minor) = captures.get(1).unwrap().as_str().parse::<i32>() {
291 if device_minor == idx {
292 return Some(pci_addr);
293 }
294 }
295 }
296 }
297 }
298 None
299}
300
301pub fn get_numa_pci_address(numa_node: &str) -> Option<String> {
302 let node: i32 = numa_node.parse().ok()?;
303 let pci_devices = parse_pci_topology().ok()?;
304
305 let mut candidates = Vec::new();
306 for (pci_addr, device) in &pci_devices {
307 if let Some(device_numa) = device.get_numa_node() {
308 if device_numa == node {
309 candidates.push(pci_addr.clone());
310 }
311 }
312 }
313
314 if candidates.is_empty() {
315 return None;
316 }
317
318 let mut best_candidate = candidates[0].clone();
319 let mut shortest_path = usize::MAX;
320
321 for pci_addr in &candidates {
322 if let Some(device) = pci_devices.get(pci_addr) {
323 let path_length = device.get_path_to_root().len();
324 if path_length < shortest_path
325 || (path_length == shortest_path && pci_addr < &best_candidate)
326 {
327 shortest_path = path_length;
328 best_candidate = pci_addr.clone();
329 }
330 }
331 }
332
333 Some(best_candidate)
334}
335
336pub fn get_all_rdma_devices() -> Vec<(String, String)> {
337 let mut rdma_devices = Vec::new();
338 let ib_class_dir = "/sys/class/infiniband";
339
340 if !Path::new(ib_class_dir).exists() {
341 return rdma_devices;
342 }
343
344 let pci_regex = Regex::new(r"([0-9a-f]{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9])").unwrap();
345
346 if let Ok(entries) = fs::read_dir(ib_class_dir) {
347 let mut sorted_entries: Vec<_> = entries.collect::<Result<Vec<_>, _>>().unwrap_or_default();
348 sorted_entries.sort_by_key(|entry| entry.file_name());
349
350 for entry in sorted_entries {
351 let ib_dev = entry.file_name().to_string_lossy().to_string();
352 let device_path = entry.path().join("device");
353
354 if let Ok(real_path) = fs::read_link(&device_path) {
355 let real_path_str = real_path.to_string_lossy();
356 let pci_matches: Vec<&str> = pci_regex
357 .find_iter(&real_path_str)
358 .map(|m| m.as_str())
359 .collect();
360
361 if let Some(&last_pci_addr) = pci_matches.last() {
362 rdma_devices.push((ib_dev, last_pci_addr.to_string()));
363 }
364 }
365 }
366 }
367
368 rdma_devices
369}
370
371pub fn get_nic_pci_address(nic_name: &str) -> Option<String> {
372 let rdma_devices = get_all_rdma_devices();
373 for (name, pci_addr) in rdma_devices {
374 if name == nic_name {
375 return Some(pci_addr);
376 }
377 }
378 None
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn test_parse_device_string() {
387 assert_eq!(
388 parse_device_string("cuda:0"),
389 Some(("cuda".to_string(), "0".to_string()))
390 );
391 assert_eq!(
392 parse_device_string("cpu:1"),
393 Some(("cpu".to_string(), "1".to_string()))
394 );
395 assert_eq!(parse_device_string("invalid"), None);
396 assert_eq!(
397 parse_device_string("cuda:"),
398 Some(("cuda".to_string(), "".to_string()))
399 );
400 }
401}