1use std::collections::HashMap;
13use std::fs;
14use std::path::Path;
15
16use regex::Regex;
17
18use crate::ibverbs_primitives::RdmaDevice;
19
20const CROSS_NUMA_BASE_PENALTY: f64 = 20.0;
38const ADDRESS_PARSE_FAILURE_PENALTY: f64 = f64::INFINITY;
39const CROSS_DOMAIN_PENALTY: f64 = 1000.0;
40const BUS_DISTANCE_SCALE: f64 = 0.1;
41
42#[derive(Debug, Clone)]
43pub struct PCIDevice {
44 pub address: String,
45 pub parent: Option<Box<PCIDevice>>,
46}
47
48impl PCIDevice {
49 pub fn new(address: String) -> Self {
50 Self {
51 address,
52 parent: None,
53 }
54 }
55
56 pub fn get_path_to_root(&self) -> Vec<String> {
57 let mut path = vec![self.address.clone()];
58 let mut current = self;
59
60 while let Some(ref parent) = current.parent {
61 path.push(parent.address.clone());
62 current = parent;
63 }
64
65 path
66 }
67 pub fn get_numa_node(&self) -> Option<i32> {
68 let numa_file = format!("/sys/bus/pci/devices/{}/numa_node", self.address);
69 std::fs::read_to_string(numa_file).ok()?.trim().parse().ok()
70 }
71
72 pub fn distance_to(&self, other: &PCIDevice) -> f64 {
73 if self.address == other.address {
74 return 0.0;
75 }
76
77 let path1 = self.get_path_to_root();
79 let path2 = other.get_path_to_root();
80
81 let mut common_ancestor = None;
83 let min_len = path1.len().min(path2.len());
84
85 for i in 1..=min_len {
87 if path1[path1.len() - i] == path2[path2.len() - i] {
88 common_ancestor = Some(&path1[path1.len() - i]);
89 } else {
90 break;
91 }
92 }
93
94 if let Some(ancestor) = common_ancestor {
95 let hops1 = path1.iter().position(|addr| addr == ancestor).unwrap_or(0);
96 let hops2 = path2.iter().position(|addr| addr == ancestor).unwrap_or(0);
97 (hops1 + hops2) as f64
98 } else {
99 self.calculate_cross_numa_distance(other)
100 }
101 }
102
103 fn calculate_cross_numa_distance(&self, other: &PCIDevice) -> f64 {
106 let self_parts = self.parse_pci_address();
107 let other_parts = other.parse_pci_address();
108
109 match (self_parts, other_parts) {
110 (Some((self_domain, self_bus, _, _)), Some((other_domain, other_bus, _, _))) => {
111 if self_domain != other_domain {
112 return CROSS_DOMAIN_PENALTY;
113 }
114
115 let bus_distance = (self_bus as i32 - other_bus as i32).abs() as f64;
116 CROSS_NUMA_BASE_PENALTY + bus_distance * BUS_DISTANCE_SCALE
117 }
118 _ => ADDRESS_PARSE_FAILURE_PENALTY,
119 }
120 }
121
122 fn parse_pci_address(&self) -> Option<(u16, u8, u8, u8)> {
124 let parts: Vec<&str> = self.address.split(':').collect();
125 if parts.len() != 3 {
126 return None;
127 }
128
129 let domain = u16::from_str_radix(parts[0], 16).ok()?;
130 let bus = u8::from_str_radix(parts[1], 16).ok()?;
131
132 let dev_func: Vec<&str> = parts[2].split('.').collect();
133 if dev_func.len() != 2 {
134 return None;
135 }
136
137 let device = u8::from_str_radix(dev_func[0], 16).ok()?;
138 let function = u8::from_str_radix(dev_func[1], 16).ok()?;
139
140 Some((domain, bus, device, function))
141 }
142
143 pub fn find_closest(&self, candidate_devices: &[PCIDevice]) -> Option<usize> {
145 if candidate_devices.is_empty() {
146 return None;
147 }
148
149 let mut closest_idx = 0;
150 let mut min_distance = self.distance_to(&candidate_devices[0]);
151
152 for (idx, device) in candidate_devices.iter().enumerate().skip(1) {
153 let distance = self.distance_to(device);
154 if distance < min_distance {
155 min_distance = distance;
156 closest_idx = idx;
157 }
158 }
159
160 Some(closest_idx)
161 }
162}
163
164fn realpath(path: &Path) -> Result<std::path::PathBuf, std::io::Error> {
166 let mut current = path.to_path_buf();
167 let mut seen = std::collections::HashSet::new();
168
169 loop {
170 if seen.contains(¤t) {
171 return Err(std::io::Error::new(
172 std::io::ErrorKind::InvalidInput,
173 "Circular symlink detected",
174 ));
175 }
176 seen.insert(current.clone());
177
178 match fs::read_link(¤t) {
179 Ok(target) => {
180 current = if target.is_absolute() {
181 target
182 } else {
183 current.parent().unwrap_or(Path::new("/")).join(target)
184 };
185 }
186 Err(_) => break, }
188 }
189
190 Ok(current)
191}
192
193pub fn parse_pci_topology() -> Result<HashMap<String, PCIDevice>, std::io::Error> {
194 let mut devices = HashMap::new();
195 let mut parent_addresses = HashMap::new();
196 let pci_devices_dir = "/sys/bus/pci/devices";
197
198 if !Path::new(pci_devices_dir).exists() {
199 return Ok(devices);
200 }
201
202 let pci_addr_regex = Regex::new(r"([0-9a-f]{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9])$").unwrap();
203
204 for entry in fs::read_dir(pci_devices_dir)? {
206 let entry = entry?;
207 let pci_addr = entry.file_name().to_string_lossy().to_string();
208 let device_path = entry.path();
209
210 let parent_addr = match realpath(&device_path) {
212 Ok(real_path) => {
213 if let Some(parent_path) = real_path.parent() {
214 let parent_path_str = parent_path.to_string_lossy();
215 pci_addr_regex
216 .captures(&parent_path_str)
217 .map(|captures| captures.get(1).unwrap().as_str().to_string())
218 } else {
219 None
220 }
221 }
222 Err(_) => None,
223 };
224
225 devices.insert(pci_addr.clone(), PCIDevice::new(pci_addr.clone()));
226 if let Some(ref parent) = parent_addr {
227 if !devices.contains_key(parent) {
228 devices.insert(parent.clone(), PCIDevice::new(parent.clone()));
229 }
230 }
231 parent_addresses.insert(pci_addr, parent_addr);
232 }
233
234 fn build_parent_chain(
236 devices: &mut HashMap<String, PCIDevice>,
237 parent_addresses: &HashMap<String, Option<String>>,
238 pci_addr: &str,
239 visited: &mut std::collections::HashSet<String>,
240 ) {
241 if visited.contains(pci_addr) {
242 return;
243 }
244 visited.insert(pci_addr.to_string());
245
246 if let Some(Some(parent_addr)) = parent_addresses.get(pci_addr) {
247 build_parent_chain(devices, parent_addresses, parent_addr, visited);
248
249 if let Some(parent_device) = devices.get(parent_addr).cloned() {
250 if let Some(device) = devices.get_mut(pci_addr) {
251 device.parent = Some(Box::new(parent_device));
252 }
253 }
254 }
255 }
256
257 let mut visited = std::collections::HashSet::new();
258 for pci_addr in devices.keys().cloned().collect::<Vec<_>>() {
259 visited.clear();
260 build_parent_chain(&mut devices, &parent_addresses, &pci_addr, &mut visited);
261 }
262
263 Ok(devices)
264}
265
266pub fn parse_device_string(device_str: &str) -> Option<(String, String)> {
267 let parts: Vec<&str> = device_str.split(':').collect();
268 if parts.len() == 2 {
269 Some((parts[0].to_string(), parts[1].to_string()))
270 } else {
271 None
272 }
273}
274
275pub fn get_cuda_pci_address(device_idx: &str) -> Option<String> {
276 let idx: i32 = device_idx.parse().ok()?;
277 let gpu_proc_dir = "/proc/driver/nvidia/gpus";
278
279 if !Path::new(gpu_proc_dir).exists() {
280 return None;
281 }
282
283 for entry in fs::read_dir(gpu_proc_dir).ok()? {
284 let entry = entry.ok()?;
285 let pci_addr = entry.file_name().to_string_lossy().to_lowercase();
286 let info_file = entry.path().join("information");
287
288 if let Ok(content) = fs::read_to_string(&info_file) {
289 let minor_regex = Regex::new(r"Device Minor:\s*(\d+)").unwrap();
290 if let Some(captures) = minor_regex.captures(&content) {
291 if let Ok(device_minor) = captures.get(1).unwrap().as_str().parse::<i32>() {
292 if device_minor == idx {
293 return Some(pci_addr);
294 }
295 }
296 }
297 }
298 }
299 None
300}
301
302pub fn get_numa_pci_address(numa_node: &str) -> Option<String> {
303 let node: i32 = numa_node.parse().ok()?;
304 let pci_devices = parse_pci_topology().ok()?;
305
306 let mut candidates = Vec::new();
307 for (pci_addr, device) in &pci_devices {
308 if let Some(device_numa) = device.get_numa_node() {
309 if device_numa == node {
310 candidates.push(pci_addr.clone());
311 }
312 }
313 }
314
315 if candidates.is_empty() {
316 return None;
317 }
318
319 let mut best_candidate = candidates[0].clone();
320 let mut shortest_path = usize::MAX;
321
322 for pci_addr in &candidates {
323 if let Some(device) = pci_devices.get(pci_addr) {
324 let path_length = device.get_path_to_root().len();
325 if path_length < shortest_path
326 || (path_length == shortest_path && pci_addr < &best_candidate)
327 {
328 shortest_path = path_length;
329 best_candidate = pci_addr.clone();
330 }
331 }
332 }
333
334 Some(best_candidate)
335}
336
337pub fn get_all_rdma_devices() -> Vec<(String, String)> {
338 let mut rdma_devices = Vec::new();
339 let ib_class_dir = "/sys/class/infiniband";
340
341 if !Path::new(ib_class_dir).exists() {
342 return rdma_devices;
343 }
344
345 let pci_regex = Regex::new(r"([0-9a-f]{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9])").unwrap();
346
347 if let Ok(entries) = fs::read_dir(ib_class_dir) {
348 let mut sorted_entries: Vec<_> = entries.collect::<Result<Vec<_>, _>>().unwrap_or_default();
349 sorted_entries.sort_by_key(|entry| entry.file_name());
350
351 for entry in sorted_entries {
352 let ib_dev = entry.file_name().to_string_lossy().to_string();
353 let device_path = entry.path().join("device");
354
355 if let Ok(real_path) = fs::read_link(&device_path) {
356 let real_path_str = real_path.to_string_lossy();
357 let pci_matches: Vec<&str> = pci_regex
358 .find_iter(&real_path_str)
359 .map(|m| m.as_str())
360 .collect();
361
362 if let Some(&last_pci_addr) = pci_matches.last() {
363 rdma_devices.push((ib_dev, last_pci_addr.to_string()));
364 }
365 }
366 }
367 }
368
369 rdma_devices
370}
371
372pub fn get_nic_pci_address(nic_name: &str) -> Option<String> {
373 let rdma_devices = get_all_rdma_devices();
374 for (name, pci_addr) in rdma_devices {
375 if name == nic_name {
376 return Some(pci_addr);
377 }
378 }
379 None
380}
381
382pub fn select_optimal_rdma_device(device_hint: Option<&str>) -> Option<RdmaDevice> {
387 let device_hint = device_hint?;
388
389 let (prefix, postfix) = parse_device_string(device_hint)?;
390
391 match prefix.as_str() {
392 "nic" => {
393 let all_rdma_devices = crate::ibverbs_primitives::get_all_devices();
394 all_rdma_devices
395 .into_iter()
396 .find(|dev| dev.name() == &postfix)
397 }
398 "cuda" | "cpu" => {
399 let source_pci_addr = match prefix.as_str() {
400 "cuda" => get_cuda_pci_address(&postfix)?,
401 "cpu" => get_numa_pci_address(&postfix)?,
402 _ => unreachable!(),
403 };
404 let rdma_devices = get_all_rdma_devices();
405 if rdma_devices.is_empty() {
406 return RdmaDevice::first_available();
407 }
408 let pci_devices = parse_pci_topology().ok()?;
409 let source_device = pci_devices.get(&source_pci_addr)?;
410
411 let rdma_names: Vec<String> =
412 rdma_devices.iter().map(|(name, _)| name.clone()).collect();
413 let rdma_pci_devices: Vec<PCIDevice> = rdma_devices
414 .iter()
415 .filter_map(|(_, addr)| pci_devices.get(addr).cloned())
416 .collect();
417
418 if let Some(closest_idx) = source_device.find_closest(&rdma_pci_devices) {
419 if let Some(optimal_name) = rdma_names.get(closest_idx) {
420 let all_rdma_devices = crate::ibverbs_primitives::get_all_devices();
421 for device in all_rdma_devices {
422 if *device.name() == *optimal_name {
423 return Some(device);
424 }
425 }
426 }
427 }
428
429 RdmaDevice::first_available()
431 }
432 _ => {
433 let rdma_devices = crate::ibverbs_primitives::get_all_devices();
435 rdma_devices
436 .into_iter()
437 .find(|dev| dev.name() == device_hint)
438 }
439 }
440}
441
442pub fn create_cuda_to_rdma_mapping() -> HashMap<String, RdmaDevice> {
451 let mut mapping = HashMap::new();
452
453 for gpu_idx in 0..8 {
455 let gpu_idx_str = gpu_idx.to_string();
456 if let Some(cuda_pci_addr) = get_cuda_pci_address(&gpu_idx_str) {
457 let cuda_hint = format!("cuda:{}", gpu_idx);
458 if let Some(rdma_device) = select_optimal_rdma_device(Some(&cuda_hint)) {
459 mapping.insert(cuda_pci_addr, rdma_device);
460 }
461 }
462 }
463
464 mapping
465}
466
467pub fn resolve_rdma_device(device: &RdmaDevice) -> Option<RdmaDevice> {
481 let device_name = device.name();
482
483 if device_name.starts_with("mlx") {
484 return Some(device.clone());
485 }
486
487 let all_devices = crate::ibverbs_primitives::get_all_devices();
488 let is_likely_default = if let Some(first_device) = all_devices.first() {
489 device_name == first_device.name()
490 } else {
491 false
492 };
493
494 if is_likely_default {
495 select_optimal_rdma_device(Some("cpu:0"))
496 } else {
497 Some(device.clone())
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504
505 #[test]
506 fn test_parse_device_string() {
507 assert_eq!(
508 parse_device_string("cuda:0"),
509 Some(("cuda".to_string(), "0".to_string()))
510 );
511 assert_eq!(
512 parse_device_string("cpu:1"),
513 Some(("cpu".to_string(), "1".to_string()))
514 );
515 assert_eq!(parse_device_string("invalid"), None);
516 assert_eq!(
517 parse_device_string("cuda:"),
518 Some(("cuda".to_string(), "".to_string()))
519 );
520 }
521
522 fn is_gt20_hardware() -> bool {
524 let rdma_devices = get_all_rdma_devices();
525 let device_names: std::collections::HashSet<String> =
526 rdma_devices.iter().map(|(name, _)| name.clone()).collect();
527
528 let expected_gt20_devices = [
530 "mlx5_0", "mlx5_3", "mlx5_4", "mlx5_5", "mlx5_6", "mlx5_9", "mlx5_10", "mlx5_11",
531 ];
532
533 let gpu_count = (0..8)
535 .filter(|&i| get_cuda_pci_address(&i.to_string()).is_some())
536 .count();
537
538 let has_expected_rdma = expected_gt20_devices
540 .iter()
541 .all(|&device| device_names.contains(device));
542
543 has_expected_rdma && gpu_count == 8
544 }
545
546 #[test]
548 fn test_gt20_hardware() {
549 if !is_gt20_hardware() {
551 println!("⚠️ Skipping test_gt20_hardware: Not running on GT20 hardware");
552 return;
553 }
554
555 println!("✓ Detected GT20 hardware - running full validation test");
556 println!("\n1. PCI TOPOLOGY PARSING");
558 let pci_devices = match parse_pci_topology() {
559 Ok(devices) => {
560 println!("✓ Found {} PCI devices", devices.len());
561 devices
562 }
563 Err(e) => {
564 println!("✗ Error: {}", e);
565 return;
566 }
567 };
568
569 println!("\n2. RDMA DEVICE DISCOVERY");
571 let rdma_devices = get_all_rdma_devices();
572 println!("✓ Found {} RDMA devices", rdma_devices.len());
573 for (name, pci_addr) in &rdma_devices {
574 println!(" RDMA {}: {}", name, pci_addr);
575 }
576
577 println!("\n3. DEVICE STRING PARSING");
579 let test_strings = ["cuda:0", "cuda:1", "cpu:0", "cpu:1"];
580 for device_str in &test_strings {
581 if let Some((prefix, postfix)) = parse_device_string(device_str) {
582 println!(
583 " '{}' -> prefix: '{}', postfix: '{}'",
584 device_str, prefix, postfix
585 );
586 } else {
587 println!(" '{}' -> PARSE FAILED", device_str);
588 }
589 }
590
591 println!("\n4. CUDA PCI ADDRESS RESOLUTION");
593 for gpu_idx in 0..8 {
594 let gpu_idx_str = gpu_idx.to_string();
595 match get_cuda_pci_address(&gpu_idx_str) {
596 Some(pci_addr) => {
597 println!(" GPU {} -> PCI: {}", gpu_idx, pci_addr);
598 }
599 None => {
600 println!(" GPU {} -> PCI: NOT FOUND", gpu_idx);
601 }
602 }
603 }
604
605 println!("\n5. CPU/NUMA PCI ADDRESS RESOLUTION");
607 for numa_node in 0..4 {
608 let numa_str = numa_node.to_string();
609 match get_numa_pci_address(&numa_str) {
610 Some(pci_addr) => {
611 println!(" NUMA {} -> PCI: {}", numa_node, pci_addr);
612 }
613 None => {
614 println!(" NUMA {} -> PCI: NOT FOUND", numa_node);
615 }
616 }
617 }
618
619 println!("\n6. DISTANCE CALCULATION TEST (GPU 0)");
621 if let Some(gpu0_pci_addr) = get_cuda_pci_address("0") {
622 if let Some(gpu0_device) = pci_devices.get(&gpu0_pci_addr) {
623 println!("GPU 0 PCI: {}", gpu0_pci_addr);
624 println!("GPU 0 path to root: {:?}", gpu0_device.get_path_to_root());
625
626 let mut all_distances = Vec::new();
627 for (nic_name, nic_pci_addr) in &rdma_devices {
628 if let Some(nic_device) = pci_devices.get(nic_pci_addr) {
629 let distance = gpu0_device.distance_to(nic_device);
630 all_distances.push((distance, nic_name.clone(), nic_pci_addr.clone()));
631 println!(" {} ({}): distance = {}", nic_name, nic_pci_addr, distance);
632 println!(" NIC path to root: {:?}", nic_device.get_path_to_root());
633 }
634 }
635
636 all_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
638 if let Some((min_dist, min_nic, min_addr)) = all_distances.first() {
639 println!(
640 " → CLOSEST: {} ({}) with distance {}",
641 min_nic, min_addr, min_dist
642 );
643 }
644 }
645 }
646
647 println!("\n7. UNIFIED DEVICE SELECTION TEST");
649 let test_cases = [
650 ("cuda:0", "CUDA device 0"),
651 ("cuda:1", "CUDA device 1"),
652 ("cpu:0", "CPU/NUMA node 0"),
653 ("cpu:1", "CPU/NUMA node 1"),
654 ];
655
656 for (device_hint, description) in &test_cases {
657 let selected_device = select_optimal_rdma_device(Some(device_hint));
658 match selected_device {
659 Some(device) => {
660 println!(" {} ({}) -> {}", device_hint, description, device.name());
661 }
662 None => {
663 println!(" {} ({}) -> NOT FOUND", device_hint, description);
664 }
665 }
666 }
667
668 println!("\n8. GPU-TO-RDMA MAPPING VALIDATION (ALL 8 GPUs)");
670
671 let python_expected = [
673 (0, "mlx5_0"),
674 (1, "mlx5_3"),
675 (2, "mlx5_4"),
676 (3, "mlx5_5"),
677 (4, "mlx5_6"),
678 (5, "mlx5_9"),
679 (6, "mlx5_10"),
680 (7, "mlx5_11"),
681 ];
682
683 let mut rust_results = std::collections::HashMap::new();
684 let mut all_match = true;
685
686 for gpu_idx in 0..8 {
688 let cuda_hint = format!("cuda:{}", gpu_idx);
689 let selected_device = select_optimal_rdma_device(Some(&cuda_hint));
690
691 match selected_device {
692 Some(device) => {
693 let device_name = device.name().to_string();
694 rust_results.insert(gpu_idx, device_name.clone());
695 println!(" GPU {} -> {}", gpu_idx, device_name);
696 }
697 None => {
698 println!(" GPU {} -> NOT FOUND", gpu_idx);
699 rust_results.insert(gpu_idx, "NOT_FOUND".to_string());
700 }
701 }
702 }
703
704 println!("\n=== VALIDATION AGAINST EXPECTED RESULTS ===");
706 for (gpu_idx, expected_nic) in python_expected {
707 if let Some(actual_nic) = rust_results.get(&gpu_idx) {
708 let matches = actual_nic == expected_nic;
709 println!(
710 " GPU {} -> {} {} (expected {})",
711 gpu_idx,
712 actual_nic,
713 if matches { "✓" } else { "✗" },
714 expected_nic
715 );
716 all_match = all_match && matches;
717 } else {
718 println!(
719 " GPU {} -> NOT FOUND ✗ (expected {})",
720 gpu_idx, expected_nic
721 );
722 all_match = false;
723 }
724 }
725
726 if all_match {
727 println!("\n🎉 SUCCESS: All GPU-NIC pairings match expected GT20 hardware results!");
728 println!("✓ New unified API produces identical results to proven algorithm");
729 } else {
730 println!("\n⚠️ WARNING: Some GPU-NIC pairings differ from expected results");
731 println!(" This could indicate:");
732 println!(" - Hardware configuration differences");
733 println!(" - Algorithm implementation differences");
734 println!(" - Environment setup differences");
735 }
736
737 println!("\n9. DETAILED CPU DEVICE SELECTION ANALYSIS");
739
740 if let Some(numa0_addr) = get_numa_pci_address("0") {
742 println!(" NUMA 0 representative PCI: {}", numa0_addr);
743 } else {
744 println!(" NUMA 0 representative PCI: NOT FOUND");
745 }
746
747 if let Some(numa1_addr) = get_numa_pci_address("1") {
748 println!(" NUMA 1 representative PCI: {}", numa1_addr);
749 } else {
750 println!(" NUMA 1 representative PCI: NOT FOUND");
751 }
752
753 let cpu0_device = select_optimal_rdma_device(Some("cpu:0"));
755 let cpu1_device = select_optimal_rdma_device(Some("cpu:1"));
756
757 match (
758 cpu0_device.as_ref().map(|d| d.name()),
759 cpu1_device.as_ref().map(|d| d.name()),
760 ) {
761 (Some(cpu0_name), Some(cpu1_name)) => {
762 println!("\n FINAL SELECTIONS:");
763 println!(" CPU:0 -> {}", cpu0_name);
764 println!(" CPU:1 -> {}", cpu1_name);
765 if cpu0_name != cpu1_name {
766 println!(" ✓ Different NUMA nodes select different RDMA devices");
767 } else {
768 println!(" ⚠️ Same RDMA device selected for both NUMA nodes");
769 println!(" This could indicate:");
770 println!(
771 " - {} is genuinely closest to both NUMA nodes",
772 cpu0_name
773 );
774 println!(" - NUMA topology detection issue");
775 println!(" - Cross-NUMA penalty algorithm working correctly");
776 }
777 }
778 _ => {
779 println!(" ○ CPU device selection not available");
780 }
781 }
782
783 println!("\n✓ GT20 hardware test completed");
784
785 }
787}