monarch_rdma/
device_selection.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module provides functionality to automatically pair compute devices with
10//! the best available RDMA NICs based on PCI topology distance.
11
12use std::collections::HashMap;
13use std::fs;
14use std::path::Path;
15
16use regex::Regex;
17
18use crate::ibverbs_primitives::RdmaDevice;
19
20// ==== PCI TOPOLOGY DISTANCE CONSTANTS ====
21//
22// These constants define penalty values for cross-NUMA communication in PCI topology:
23//
24// - CROSS_NUMA_BASE_PENALTY (20.0): Base penalty for cross-NUMA communication.
25//   This value is higher than typical intra-NUMA distances (usually 0-8 hops)
26//   to ensure same-NUMA devices are always preferred over cross-NUMA devices.
27//
28// - ADDRESS_PARSE_FAILURE_PENALTY (Inf): Penalty when PCI address parsing fails.
29//   Used as fallback when we can't determine bus relationships between devices.
30//
31// - CROSS_DOMAIN_PENALTY (1000.0): Very high penalty for different PCI domains.
32//   Different domains typically indicate completely separate I/O subsystems.
33//
34// - BUS_DISTANCE_SCALE (0.1): Scaling factor for bus distance in cross-NUMA penalty.
35//   Small factor to provide tie-breaking between devices at different bus numbers.
36
37const CROSS_NUMA_BASE_PENALTY: f64 = 20.0;
38const ADDRESS_PARSE_FAILURE_PENALTY: f64 = f64::INFINITY;
39const CROSS_DOMAIN_PENALTY: f64 = 1000.0;
40const BUS_DISTANCE_SCALE: f64 = 0.1;
41
42#[derive(Debug, Clone)]
43pub struct PCIDevice {
44    pub address: String,
45    pub parent: Option<Box<PCIDevice>>,
46}
47
48impl PCIDevice {
49    pub fn new(address: String) -> Self {
50        Self {
51            address,
52            parent: None,
53        }
54    }
55
56    pub fn get_path_to_root(&self) -> Vec<String> {
57        let mut path = vec![self.address.clone()];
58        let mut current = self;
59
60        while let Some(ref parent) = current.parent {
61            path.push(parent.address.clone());
62            current = parent;
63        }
64
65        path
66    }
67    pub fn get_numa_node(&self) -> Option<i32> {
68        let numa_file = format!("/sys/bus/pci/devices/{}/numa_node", self.address);
69        std::fs::read_to_string(numa_file).ok()?.trim().parse().ok()
70    }
71
72    pub fn distance_to(&self, other: &PCIDevice) -> f64 {
73        if self.address == other.address {
74            return 0.0;
75        }
76
77        // Get paths to root for both devices
78        let path1 = self.get_path_to_root();
79        let path2 = other.get_path_to_root();
80
81        // Find lowest common ancestor (first common element from the end)
82        let mut common_ancestor = None;
83        let min_len = path1.len().min(path2.len());
84
85        // Check from the root down to find the deepest common ancestor
86        for i in 1..=min_len {
87            if path1[path1.len() - i] == path2[path2.len() - i] {
88                common_ancestor = Some(&path1[path1.len() - i]);
89            } else {
90                break;
91            }
92        }
93
94        if let Some(ancestor) = common_ancestor {
95            let hops1 = path1.iter().position(|addr| addr == ancestor).unwrap_or(0);
96            let hops2 = path2.iter().position(|addr| addr == ancestor).unwrap_or(0);
97            (hops1 + hops2) as f64
98        } else {
99            self.calculate_cross_numa_distance(other)
100        }
101    }
102
103    /// Calculate distance between devices on different NUMA domains/root complexes
104    /// This handles cases where devices don't share a common PCI ancestor
105    fn calculate_cross_numa_distance(&self, other: &PCIDevice) -> f64 {
106        let self_parts = self.parse_pci_address();
107        let other_parts = other.parse_pci_address();
108
109        match (self_parts, other_parts) {
110            (Some((self_domain, self_bus, _, _)), Some((other_domain, other_bus, _, _))) => {
111                if self_domain != other_domain {
112                    return CROSS_DOMAIN_PENALTY;
113                }
114
115                let bus_distance = (self_bus as i32 - other_bus as i32).abs() as f64;
116                CROSS_NUMA_BASE_PENALTY + bus_distance * BUS_DISTANCE_SCALE
117            }
118            _ => ADDRESS_PARSE_FAILURE_PENALTY,
119        }
120    }
121
122    /// Parse PCI address into components (domain, bus, device, function)
123    fn parse_pci_address(&self) -> Option<(u16, u8, u8, u8)> {
124        let parts: Vec<&str> = self.address.split(':').collect();
125        if parts.len() != 3 {
126            return None;
127        }
128
129        let domain = u16::from_str_radix(parts[0], 16).ok()?;
130        let bus = u8::from_str_radix(parts[1], 16).ok()?;
131
132        let dev_func: Vec<&str> = parts[2].split('.').collect();
133        if dev_func.len() != 2 {
134            return None;
135        }
136
137        let device = u8::from_str_radix(dev_func[0], 16).ok()?;
138        let function = u8::from_str_radix(dev_func[1], 16).ok()?;
139
140        Some((domain, bus, device, function))
141    }
142
143    /// Find the index of the closest device from a list of candidates
144    pub fn find_closest(&self, candidate_devices: &[PCIDevice]) -> Option<usize> {
145        if candidate_devices.is_empty() {
146            return None;
147        }
148
149        let mut closest_idx = 0;
150        let mut min_distance = self.distance_to(&candidate_devices[0]);
151
152        for (idx, device) in candidate_devices.iter().enumerate().skip(1) {
153            let distance = self.distance_to(device);
154            if distance < min_distance {
155                min_distance = distance;
156                closest_idx = idx;
157            }
158        }
159
160        Some(closest_idx)
161    }
162}
163
164/// Resolve all symlinks in a path (equivalent to Python's os.path.realpath)
165fn realpath(path: &Path) -> Result<std::path::PathBuf, std::io::Error> {
166    let mut current = path.to_path_buf();
167    let mut seen = std::collections::HashSet::new();
168
169    loop {
170        if seen.contains(&current) {
171            return Err(std::io::Error::new(
172                std::io::ErrorKind::InvalidInput,
173                "Circular symlink detected",
174            ));
175        }
176        seen.insert(current.clone());
177
178        match fs::read_link(&current) {
179            Ok(target) => {
180                current = if target.is_absolute() {
181                    target
182                } else {
183                    current.parent().unwrap_or(Path::new("/")).join(target)
184                };
185            }
186            Err(_) => break, // Not a symlink or error reading
187        }
188    }
189
190    Ok(current)
191}
192
193pub fn parse_pci_topology() -> Result<HashMap<String, PCIDevice>, std::io::Error> {
194    let mut devices = HashMap::new();
195    let mut parent_addresses = HashMap::new();
196    let pci_devices_dir = "/sys/bus/pci/devices";
197
198    if !Path::new(pci_devices_dir).exists() {
199        return Ok(devices);
200    }
201
202    let pci_addr_regex = Regex::new(r"([0-9a-f]{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9])$").unwrap();
203
204    // First pass: create all devices without parent references
205    for entry in fs::read_dir(pci_devices_dir)? {
206        let entry = entry?;
207        let pci_addr = entry.file_name().to_string_lossy().to_string();
208        let device_path = entry.path();
209
210        // Find parent device by following the device symlink and extracting PCI address from the path
211        let parent_addr = match realpath(&device_path) {
212            Ok(real_path) => {
213                if let Some(parent_path) = real_path.parent() {
214                    let parent_path_str = parent_path.to_string_lossy();
215                    pci_addr_regex
216                        .captures(&parent_path_str)
217                        .map(|captures| captures.get(1).unwrap().as_str().to_string())
218                } else {
219                    None
220                }
221            }
222            Err(_) => None,
223        };
224
225        devices.insert(pci_addr.clone(), PCIDevice::new(pci_addr.clone()));
226        if let Some(ref parent) = parent_addr {
227            if !devices.contains_key(parent) {
228                devices.insert(parent.clone(), PCIDevice::new(parent.clone()));
229            }
230        }
231        parent_addresses.insert(pci_addr, parent_addr);
232    }
233
234    // Second pass: set up parent references recursively
235    fn build_parent_chain(
236        devices: &mut HashMap<String, PCIDevice>,
237        parent_addresses: &HashMap<String, Option<String>>,
238        pci_addr: &str,
239        visited: &mut std::collections::HashSet<String>,
240    ) {
241        if visited.contains(pci_addr) {
242            return;
243        }
244        visited.insert(pci_addr.to_string());
245
246        if let Some(Some(parent_addr)) = parent_addresses.get(pci_addr) {
247            build_parent_chain(devices, parent_addresses, parent_addr, visited);
248
249            if let Some(parent_device) = devices.get(parent_addr).cloned() {
250                if let Some(device) = devices.get_mut(pci_addr) {
251                    device.parent = Some(Box::new(parent_device));
252                }
253            }
254        }
255    }
256
257    let mut visited = std::collections::HashSet::new();
258    for pci_addr in devices.keys().cloned().collect::<Vec<_>>() {
259        visited.clear();
260        build_parent_chain(&mut devices, &parent_addresses, &pci_addr, &mut visited);
261    }
262
263    Ok(devices)
264}
265
266pub fn parse_device_string(device_str: &str) -> Option<(String, String)> {
267    let parts: Vec<&str> = device_str.split(':').collect();
268    if parts.len() == 2 {
269        Some((parts[0].to_string(), parts[1].to_string()))
270    } else {
271        None
272    }
273}
274
275pub fn get_cuda_pci_address(device_idx: &str) -> Option<String> {
276    let idx: i32 = device_idx.parse().ok()?;
277    let gpu_proc_dir = "/proc/driver/nvidia/gpus";
278
279    if !Path::new(gpu_proc_dir).exists() {
280        return None;
281    }
282
283    for entry in fs::read_dir(gpu_proc_dir).ok()? {
284        let entry = entry.ok()?;
285        let pci_addr = entry.file_name().to_string_lossy().to_lowercase();
286        let info_file = entry.path().join("information");
287
288        if let Ok(content) = fs::read_to_string(&info_file) {
289            let minor_regex = Regex::new(r"Device Minor:\s*(\d+)").unwrap();
290            if let Some(captures) = minor_regex.captures(&content) {
291                if let Ok(device_minor) = captures.get(1).unwrap().as_str().parse::<i32>() {
292                    if device_minor == idx {
293                        return Some(pci_addr);
294                    }
295                }
296            }
297        }
298    }
299    None
300}
301
302pub fn get_numa_pci_address(numa_node: &str) -> Option<String> {
303    let node: i32 = numa_node.parse().ok()?;
304    let pci_devices = parse_pci_topology().ok()?;
305
306    let mut candidates = Vec::new();
307    for (pci_addr, device) in &pci_devices {
308        if let Some(device_numa) = device.get_numa_node() {
309            if device_numa == node {
310                candidates.push(pci_addr.clone());
311            }
312        }
313    }
314
315    if candidates.is_empty() {
316        return None;
317    }
318
319    let mut best_candidate = candidates[0].clone();
320    let mut shortest_path = usize::MAX;
321
322    for pci_addr in &candidates {
323        if let Some(device) = pci_devices.get(pci_addr) {
324            let path_length = device.get_path_to_root().len();
325            if path_length < shortest_path
326                || (path_length == shortest_path && pci_addr < &best_candidate)
327            {
328                shortest_path = path_length;
329                best_candidate = pci_addr.clone();
330            }
331        }
332    }
333
334    Some(best_candidate)
335}
336
337pub fn get_all_rdma_devices() -> Vec<(String, String)> {
338    let mut rdma_devices = Vec::new();
339    let ib_class_dir = "/sys/class/infiniband";
340
341    if !Path::new(ib_class_dir).exists() {
342        return rdma_devices;
343    }
344
345    let pci_regex = Regex::new(r"([0-9a-f]{4}:[0-9a-f]{2}:[0-9a-f]{2}\.[0-9])").unwrap();
346
347    if let Ok(entries) = fs::read_dir(ib_class_dir) {
348        let mut sorted_entries: Vec<_> = entries.collect::<Result<Vec<_>, _>>().unwrap_or_default();
349        sorted_entries.sort_by_key(|entry| entry.file_name());
350
351        for entry in sorted_entries {
352            let ib_dev = entry.file_name().to_string_lossy().to_string();
353            let device_path = entry.path().join("device");
354
355            if let Ok(real_path) = fs::read_link(&device_path) {
356                let real_path_str = real_path.to_string_lossy();
357                let pci_matches: Vec<&str> = pci_regex
358                    .find_iter(&real_path_str)
359                    .map(|m| m.as_str())
360                    .collect();
361
362                if let Some(&last_pci_addr) = pci_matches.last() {
363                    rdma_devices.push((ib_dev, last_pci_addr.to_string()));
364                }
365            }
366        }
367    }
368
369    rdma_devices
370}
371
372pub fn get_nic_pci_address(nic_name: &str) -> Option<String> {
373    let rdma_devices = get_all_rdma_devices();
374    for (name, pci_addr) in rdma_devices {
375        if name == nic_name {
376            return Some(pci_addr);
377        }
378    }
379    None
380}
381
382/// Step 1: Parse device string into prefix and postfix
383/// Step 2: Get PCI address from compute device
384/// Step 3: Get PCI address for all RDMA NIC devices
385/// Step 4: Calculate PCI distances and return closest RDMA NIC device
386pub fn select_optimal_rdma_device(device_hint: Option<&str>) -> Option<RdmaDevice> {
387    let device_hint = device_hint?;
388
389    let (prefix, postfix) = parse_device_string(device_hint)?;
390
391    match prefix.as_str() {
392        "nic" => {
393            let all_rdma_devices = crate::ibverbs_primitives::get_all_devices();
394            all_rdma_devices
395                .into_iter()
396                .find(|dev| dev.name() == &postfix)
397        }
398        "cuda" | "cpu" => {
399            let source_pci_addr = match prefix.as_str() {
400                "cuda" => get_cuda_pci_address(&postfix)?,
401                "cpu" => get_numa_pci_address(&postfix)?,
402                _ => unreachable!(),
403            };
404            let rdma_devices = get_all_rdma_devices();
405            if rdma_devices.is_empty() {
406                return RdmaDevice::first_available();
407            }
408            let pci_devices = parse_pci_topology().ok()?;
409            let source_device = pci_devices.get(&source_pci_addr)?;
410
411            let rdma_names: Vec<String> =
412                rdma_devices.iter().map(|(name, _)| name.clone()).collect();
413            let rdma_pci_devices: Vec<PCIDevice> = rdma_devices
414                .iter()
415                .filter_map(|(_, addr)| pci_devices.get(addr).cloned())
416                .collect();
417
418            if let Some(closest_idx) = source_device.find_closest(&rdma_pci_devices) {
419                if let Some(optimal_name) = rdma_names.get(closest_idx) {
420                    let all_rdma_devices = crate::ibverbs_primitives::get_all_devices();
421                    for device in all_rdma_devices {
422                        if *device.name() == *optimal_name {
423                            return Some(device);
424                        }
425                    }
426                }
427            }
428
429            // Fallback
430            RdmaDevice::first_available()
431        }
432        _ => {
433            // Direct device name lookup for backward compatibility
434            let rdma_devices = crate::ibverbs_primitives::get_all_devices();
435            rdma_devices
436                .into_iter()
437                .find(|dev| dev.name() == device_hint)
438        }
439    }
440}
441
442/// Creates a mapping from CUDA PCI addresses to optimal RDMA devices
443///
444/// This function discovers all available CUDA devices and determines the best
445/// RDMA device for each one using the device selection algorithm.
446///
447/// # Returns
448///
449/// * `HashMap<String, RdmaDevice>` - Map from CUDA PCI address to optimal RDMA device
450pub fn create_cuda_to_rdma_mapping() -> HashMap<String, RdmaDevice> {
451    let mut mapping = HashMap::new();
452
453    // Try to discover CUDA devices (GPU 0-8 should be sufficient for most systems)
454    for gpu_idx in 0..8 {
455        let gpu_idx_str = gpu_idx.to_string();
456        if let Some(cuda_pci_addr) = get_cuda_pci_address(&gpu_idx_str) {
457            let cuda_hint = format!("cuda:{}", gpu_idx);
458            if let Some(rdma_device) = select_optimal_rdma_device(Some(&cuda_hint)) {
459                mapping.insert(cuda_pci_addr, rdma_device);
460            }
461        }
462    }
463
464    mapping
465}
466
467/// Resolves RDMA device using auto-detection logic when needed
468///
469/// This function applies auto-detection for default devices, but otherwise  
470/// returns the device as-is. The main device selection logic happens in
471/// `select_optimal_rdma_device` and `IbverbsConfig::with_device_hint`.
472///
473/// # Arguments
474///
475/// * `device` - The RdmaDevice to potentially resolve
476///
477/// # Returns
478///
479/// * `Option<RdmaDevice>` - The resolved device, or None if resolution fails
480pub fn resolve_rdma_device(device: &RdmaDevice) -> Option<RdmaDevice> {
481    let device_name = device.name();
482
483    if device_name.starts_with("mlx") {
484        return Some(device.clone());
485    }
486
487    let all_devices = crate::ibverbs_primitives::get_all_devices();
488    let is_likely_default = if let Some(first_device) = all_devices.first() {
489        device_name == first_device.name()
490    } else {
491        false
492    };
493
494    if is_likely_default {
495        select_optimal_rdma_device(Some("cpu:0"))
496    } else {
497        Some(device.clone())
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504
505    #[test]
506    fn test_parse_device_string() {
507        assert_eq!(
508            parse_device_string("cuda:0"),
509            Some(("cuda".to_string(), "0".to_string()))
510        );
511        assert_eq!(
512            parse_device_string("cpu:1"),
513            Some(("cpu".to_string(), "1".to_string()))
514        );
515        assert_eq!(parse_device_string("invalid"), None);
516        assert_eq!(
517            parse_device_string("cuda:"),
518            Some(("cuda".to_string(), "".to_string()))
519        );
520    }
521
522    /// Detect if we're running on GT20 hardware by checking for expected RDMA device configuration
523    fn is_gt20_hardware() -> bool {
524        let rdma_devices = get_all_rdma_devices();
525        let device_names: std::collections::HashSet<String> =
526            rdma_devices.iter().map(|(name, _)| name.clone()).collect();
527
528        // GT20 hardware should have these specific RDMA devices
529        let expected_gt20_devices = [
530            "mlx5_0", "mlx5_3", "mlx5_4", "mlx5_5", "mlx5_6", "mlx5_9", "mlx5_10", "mlx5_11",
531        ];
532
533        // Check if we have at least 8 GPUs (GT20 characteristic)
534        let gpu_count = (0..8)
535            .filter(|&i| get_cuda_pci_address(&i.to_string()).is_some())
536            .count();
537
538        // Must have expected RDMA devices AND 8 GPUs
539        let has_expected_rdma = expected_gt20_devices
540            .iter()
541            .all(|&device| device_names.contains(device));
542
543        has_expected_rdma && gpu_count == 8
544    }
545
546    /// Test each function step by step using the new simplified API - GT20 hardware only
547    #[test]
548    fn test_gt20_hardware() {
549        // Early exit if not on GT20 hardware
550        if !is_gt20_hardware() {
551            println!("⚠️  Skipping test_gt20_hardware: Not running on GT20 hardware");
552            return;
553        }
554
555        println!("✓ Detected GT20 hardware - running full validation test");
556        // Step 1: Test PCI topology parsing
557        println!("\n1. PCI TOPOLOGY PARSING");
558        let pci_devices = match parse_pci_topology() {
559            Ok(devices) => {
560                println!("✓ Found {} PCI devices", devices.len());
561                devices
562            }
563            Err(e) => {
564                println!("✗ Error: {}", e);
565                return;
566            }
567        };
568
569        // Step 2: Test unified RDMA device discovery
570        println!("\n2. RDMA DEVICE DISCOVERY");
571        let rdma_devices = get_all_rdma_devices();
572        println!("✓ Found {} RDMA devices", rdma_devices.len());
573        for (name, pci_addr) in &rdma_devices {
574            println!("  RDMA {}: {}", name, pci_addr);
575        }
576
577        // Step 3: Test device string parsing
578        println!("\n3. DEVICE STRING PARSING");
579        let test_strings = ["cuda:0", "cuda:1", "cpu:0", "cpu:1"];
580        for device_str in &test_strings {
581            if let Some((prefix, postfix)) = parse_device_string(device_str) {
582                println!(
583                    "  '{}' -> prefix: '{}', postfix: '{}'",
584                    device_str, prefix, postfix
585                );
586            } else {
587                println!("  '{}' -> PARSE FAILED", device_str);
588            }
589        }
590
591        // Step 4: Test CUDA PCI address resolution
592        println!("\n4. CUDA PCI ADDRESS RESOLUTION");
593        for gpu_idx in 0..8 {
594            let gpu_idx_str = gpu_idx.to_string();
595            match get_cuda_pci_address(&gpu_idx_str) {
596                Some(pci_addr) => {
597                    println!("  GPU {} -> PCI: {}", gpu_idx, pci_addr);
598                }
599                None => {
600                    println!("  GPU {} -> PCI: NOT FOUND", gpu_idx);
601                }
602            }
603        }
604
605        // Step 5: Test CPU/NUMA PCI address resolution
606        println!("\n5. CPU/NUMA PCI ADDRESS RESOLUTION");
607        for numa_node in 0..4 {
608            let numa_str = numa_node.to_string();
609            match get_numa_pci_address(&numa_str) {
610                Some(pci_addr) => {
611                    println!("  NUMA {} -> PCI: {}", numa_node, pci_addr);
612                }
613                None => {
614                    println!("  NUMA {} -> PCI: NOT FOUND", numa_node);
615                }
616            }
617        }
618
619        // Step 6: Test distance calculation for GPU 0
620        println!("\n6. DISTANCE CALCULATION TEST (GPU 0)");
621        if let Some(gpu0_pci_addr) = get_cuda_pci_address("0") {
622            if let Some(gpu0_device) = pci_devices.get(&gpu0_pci_addr) {
623                println!("GPU 0 PCI: {}", gpu0_pci_addr);
624                println!("GPU 0 path to root: {:?}", gpu0_device.get_path_to_root());
625
626                let mut all_distances = Vec::new();
627                for (nic_name, nic_pci_addr) in &rdma_devices {
628                    if let Some(nic_device) = pci_devices.get(nic_pci_addr) {
629                        let distance = gpu0_device.distance_to(nic_device);
630                        all_distances.push((distance, nic_name.clone(), nic_pci_addr.clone()));
631                        println!("  {} ({}): distance = {}", nic_name, nic_pci_addr, distance);
632                        println!("    NIC path to root: {:?}", nic_device.get_path_to_root());
633                    }
634                }
635
636                // Find the minimum distance
637                all_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
638                if let Some((min_dist, min_nic, min_addr)) = all_distances.first() {
639                    println!(
640                        "  → CLOSEST: {} ({}) with distance {}",
641                        min_nic, min_addr, min_dist
642                    );
643                }
644            }
645        }
646
647        // Step 7: Test unified device selection interface
648        println!("\n7. UNIFIED DEVICE SELECTION TEST");
649        let test_cases = [
650            ("cuda:0", "CUDA device 0"),
651            ("cuda:1", "CUDA device 1"),
652            ("cpu:0", "CPU/NUMA node 0"),
653            ("cpu:1", "CPU/NUMA node 1"),
654        ];
655
656        for (device_hint, description) in &test_cases {
657            let selected_device = select_optimal_rdma_device(Some(device_hint));
658            match selected_device {
659                Some(device) => {
660                    println!("  {} ({}) -> {}", device_hint, description, device.name());
661                }
662                None => {
663                    println!("  {} ({}) -> NOT FOUND", device_hint, description);
664                }
665            }
666        }
667
668        // Step 8: Test all 8 GPU mappings against expected GT20 hardware results
669        println!("\n8. GPU-TO-RDMA MAPPING VALIDATION (ALL 8 GPUs)");
670
671        // Expected results from original Python implementation on GT20 hardware
672        let python_expected = [
673            (0, "mlx5_0"),
674            (1, "mlx5_3"),
675            (2, "mlx5_4"),
676            (3, "mlx5_5"),
677            (4, "mlx5_6"),
678            (5, "mlx5_9"),
679            (6, "mlx5_10"),
680            (7, "mlx5_11"),
681        ];
682
683        let mut rust_results = std::collections::HashMap::new();
684        let mut all_match = true;
685
686        // Test all 8 GPU mappings using new unified API
687        for gpu_idx in 0..8 {
688            let cuda_hint = format!("cuda:{}", gpu_idx);
689            let selected_device = select_optimal_rdma_device(Some(&cuda_hint));
690
691            match selected_device {
692                Some(device) => {
693                    let device_name = device.name().to_string();
694                    rust_results.insert(gpu_idx, device_name.clone());
695                    println!("  GPU {} -> {}", gpu_idx, device_name);
696                }
697                None => {
698                    println!("  GPU {} -> NOT FOUND", gpu_idx);
699                    rust_results.insert(gpu_idx, "NOT_FOUND".to_string());
700                }
701            }
702        }
703
704        // Compare against expected results
705        println!("\n=== VALIDATION AGAINST EXPECTED RESULTS ===");
706        for (gpu_idx, expected_nic) in python_expected {
707            if let Some(actual_nic) = rust_results.get(&gpu_idx) {
708                let matches = actual_nic == expected_nic;
709                println!(
710                    "  GPU {} -> {} {} (expected {})",
711                    gpu_idx,
712                    actual_nic,
713                    if matches { "✓" } else { "✗" },
714                    expected_nic
715                );
716                all_match = all_match && matches;
717            } else {
718                println!(
719                    "  GPU {} -> NOT FOUND ✗ (expected {})",
720                    gpu_idx, expected_nic
721                );
722                all_match = false;
723            }
724        }
725
726        if all_match {
727            println!("\n🎉 SUCCESS: All GPU-NIC pairings match expected GT20 hardware results!");
728            println!("✓ New unified API produces identical results to proven algorithm");
729        } else {
730            println!("\n⚠️  WARNING: Some GPU-NIC pairings differ from expected results");
731            println!("   This could indicate:");
732            println!("   - Hardware configuration differences");
733            println!("   - Algorithm implementation differences");
734            println!("   - Environment setup differences");
735        }
736
737        // Step 9: Detailed CPU device selection analysis
738        println!("\n9. DETAILED CPU DEVICE SELECTION ANALYSIS");
739
740        // Check what representative PCI addresses we found for each NUMA node
741        if let Some(numa0_addr) = get_numa_pci_address("0") {
742            println!("  NUMA 0 representative PCI: {}", numa0_addr);
743        } else {
744            println!("  NUMA 0 representative PCI: NOT FOUND");
745        }
746
747        if let Some(numa1_addr) = get_numa_pci_address("1") {
748            println!("  NUMA 1 representative PCI: {}", numa1_addr);
749        } else {
750            println!("  NUMA 1 representative PCI: NOT FOUND");
751        }
752
753        // Now test the actual selections
754        let cpu0_device = select_optimal_rdma_device(Some("cpu:0"));
755        let cpu1_device = select_optimal_rdma_device(Some("cpu:1"));
756
757        match (
758            cpu0_device.as_ref().map(|d| d.name()),
759            cpu1_device.as_ref().map(|d| d.name()),
760        ) {
761            (Some(cpu0_name), Some(cpu1_name)) => {
762                println!("\n  FINAL SELECTIONS:");
763                println!("    CPU:0 -> {}", cpu0_name);
764                println!("    CPU:1 -> {}", cpu1_name);
765                if cpu0_name != cpu1_name {
766                    println!("    ✓ Different NUMA nodes select different RDMA devices");
767                } else {
768                    println!("    ⚠️  Same RDMA device selected for both NUMA nodes");
769                    println!("       This could indicate:");
770                    println!(
771                        "       - {} is genuinely closest to both NUMA nodes",
772                        cpu0_name
773                    );
774                    println!("       - NUMA topology detection issue");
775                    println!("       - Cross-NUMA penalty algorithm working correctly");
776                }
777            }
778            _ => {
779                println!("    ○ CPU device selection not available");
780            }
781        }
782
783        println!("\n✓ GT20 hardware test completed");
784
785        // we can't gaurantee that the test will always match given test infra but is good for diagnostic purposes / tracking.
786    }
787}