monarch_rdma/backend/ibverbs/
device_selection.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! ibverbs-specific device selection logic that pairs compute devices
10//! with the best available RDMA NICs based on PCI topology distance.
11
12use std::collections::HashMap;
13
14use super::primitives::IbvDevice;
15use super::primitives::get_all_devices;
16use crate::device_selection::PCIDevice;
17use crate::device_selection::get_all_rdma_devices;
18use crate::device_selection::get_cuda_pci_address;
19use crate::device_selection::get_numa_pci_address;
20use crate::device_selection::parse_device_string;
21use crate::device_selection::parse_pci_topology;
22
23/// Step 1: Parse device string into prefix and postfix
24/// Step 2: Get PCI address from compute device
25/// Step 3: Get PCI address for all RDMA NIC devices
26/// Step 4: Calculate PCI distances and return closest RDMA NIC device
27pub fn select_optimal_ibv_device(device_hint: Option<&str>) -> Option<IbvDevice> {
28    let device_hint = device_hint?;
29
30    let (prefix, postfix) = parse_device_string(device_hint)?;
31
32    match prefix.as_str() {
33        "nic" => {
34            let all_rdma_devices = get_all_devices();
35            all_rdma_devices
36                .into_iter()
37                .find(|dev| dev.name() == &postfix)
38        }
39        "cuda" | "cpu" => {
40            let source_pci_addr = match prefix.as_str() {
41                "cuda" => get_cuda_pci_address(&postfix)?,
42                "cpu" => get_numa_pci_address(&postfix)?,
43                _ => unreachable!(),
44            };
45            let rdma_devices = get_all_rdma_devices();
46            if rdma_devices.is_empty() {
47                return IbvDevice::first_available();
48            }
49            let pci_devices = parse_pci_topology().ok()?;
50            let source_device = pci_devices.get(&source_pci_addr)?;
51
52            let rdma_names: Vec<String> =
53                rdma_devices.iter().map(|(name, _)| name.clone()).collect();
54            let rdma_pci_devices: Vec<PCIDevice> = rdma_devices
55                .iter()
56                .filter_map(|(_, addr)| pci_devices.get(addr).cloned())
57                .collect();
58
59            if let Some(closest_idx) = source_device.find_closest(&rdma_pci_devices) {
60                if let Some(optimal_name) = rdma_names.get(closest_idx) {
61                    let all_rdma_devices = get_all_devices();
62                    for device in all_rdma_devices {
63                        if *device.name() == *optimal_name {
64                            return Some(device);
65                        }
66                    }
67                }
68            }
69
70            // Fallback
71            IbvDevice::first_available()
72        }
73        _ => {
74            // Direct device name lookup for backward compatibility
75            let rdma_devices = get_all_devices();
76            rdma_devices
77                .into_iter()
78                .find(|dev| dev.name() == device_hint)
79        }
80    }
81}
82
83/// Creates a mapping from CUDA PCI addresses to optimal RDMA devices.
84///
85/// Discovers all available CUDA devices and determines the best
86/// RDMA device for each one using the device selection algorithm.
87pub fn create_cuda_to_ibv_mapping() -> HashMap<String, IbvDevice> {
88    let mut mapping = HashMap::new();
89
90    // Try to discover CUDA devices (GPU 0-8 should be sufficient for most systems)
91    for gpu_idx in 0..8 {
92        let gpu_idx_str = gpu_idx.to_string();
93        if let Some(cuda_pci_addr) = get_cuda_pci_address(&gpu_idx_str) {
94            let cuda_hint = format!("cuda:{}", gpu_idx);
95            if let Some(rdma_device) = select_optimal_ibv_device(Some(&cuda_hint)) {
96                mapping.insert(cuda_pci_addr, rdma_device);
97            }
98        }
99    }
100
101    mapping
102}
103
104/// Resolves RDMA device using auto-detection logic when needed.
105///
106/// Applies auto-detection for default devices, but otherwise
107/// returns the device as-is.
108pub fn resolve_ibv_device(device: &IbvDevice) -> Option<IbvDevice> {
109    let device_name = device.name();
110
111    if device_name.starts_with("mlx") {
112        return Some(device.clone());
113    }
114
115    let all_devices = get_all_devices();
116    let is_likely_default = if let Some(first_device) = all_devices.first() {
117        device_name == first_device.name()
118    } else {
119        false
120    };
121
122    if is_likely_default {
123        select_optimal_ibv_device(Some("cpu:0"))
124    } else {
125        Some(device.clone())
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    /// Detect if we're running on GT20 hardware by checking for expected RDMA device configuration
134    fn is_gt20_hardware() -> bool {
135        let rdma_devices = get_all_rdma_devices();
136        let device_names: std::collections::HashSet<String> =
137            rdma_devices.iter().map(|(name, _)| name.clone()).collect();
138
139        // GT20 hardware should have these specific RDMA devices
140        let expected_gt20_devices = [
141            "mlx5_0", "mlx5_3", "mlx5_4", "mlx5_5", "mlx5_6", "mlx5_9", "mlx5_10", "mlx5_11",
142        ];
143
144        // Check if we have at least 8 GPUs (GT20 characteristic)
145        let gpu_count = (0..8)
146            .filter(|&i| get_cuda_pci_address(&i.to_string()).is_some())
147            .count();
148
149        // Must have expected RDMA devices AND 8 GPUs
150        let has_expected_rdma = expected_gt20_devices
151            .iter()
152            .all(|&device| device_names.contains(device));
153
154        has_expected_rdma && gpu_count == 8
155    }
156
157    /// Test each function step by step using the new simplified API - GT20 hardware only
158    #[test]
159    fn test_gt20_hardware() {
160        // Early exit if not on GT20 hardware
161        if !is_gt20_hardware() {
162            println!("⚠️  Skipping test_gt20_hardware: Not running on GT20 hardware");
163            return;
164        }
165
166        println!("✓ Detected GT20 hardware - running full validation test");
167        // Step 1: Test PCI topology parsing
168        println!("\n1. PCI TOPOLOGY PARSING");
169        let pci_devices = match parse_pci_topology() {
170            Ok(devices) => {
171                println!("✓ Found {} PCI devices", devices.len());
172                devices
173            }
174            Err(e) => {
175                println!("✗ Error: {}", e);
176                return;
177            }
178        };
179
180        // Step 2: Test unified RDMA device discovery
181        println!("\n2. RDMA DEVICE DISCOVERY");
182        let rdma_devices = get_all_rdma_devices();
183        println!("✓ Found {} RDMA devices", rdma_devices.len());
184        for (name, pci_addr) in &rdma_devices {
185            println!("  RDMA {}: {}", name, pci_addr);
186        }
187
188        // Step 3: Test device string parsing
189        println!("\n3. DEVICE STRING PARSING");
190        let test_strings = ["cuda:0", "cuda:1", "cpu:0", "cpu:1"];
191        for device_str in &test_strings {
192            if let Some((prefix, postfix)) = parse_device_string(device_str) {
193                println!(
194                    "  '{}' -> prefix: '{}', postfix: '{}'",
195                    device_str, prefix, postfix
196                );
197            } else {
198                println!("  '{}' -> PARSE FAILED", device_str);
199            }
200        }
201
202        // Step 4: Test CUDA PCI address resolution
203        println!("\n4. CUDA PCI ADDRESS RESOLUTION");
204        for gpu_idx in 0..8 {
205            let gpu_idx_str = gpu_idx.to_string();
206            match get_cuda_pci_address(&gpu_idx_str) {
207                Some(pci_addr) => {
208                    println!("  GPU {} -> PCI: {}", gpu_idx, pci_addr);
209                }
210                None => {
211                    println!("  GPU {} -> PCI: NOT FOUND", gpu_idx);
212                }
213            }
214        }
215
216        // Step 5: Test CPU/NUMA PCI address resolution
217        println!("\n5. CPU/NUMA PCI ADDRESS RESOLUTION");
218        for numa_node in 0..4 {
219            let numa_str = numa_node.to_string();
220            match get_numa_pci_address(&numa_str) {
221                Some(pci_addr) => {
222                    println!("  NUMA {} -> PCI: {}", numa_node, pci_addr);
223                }
224                None => {
225                    println!("  NUMA {} -> PCI: NOT FOUND", numa_node);
226                }
227            }
228        }
229
230        // Step 6: Test distance calculation for GPU 0
231        println!("\n6. DISTANCE CALCULATION TEST (GPU 0)");
232        if let Some(gpu0_pci_addr) = get_cuda_pci_address("0") {
233            if let Some(gpu0_device) = pci_devices.get(&gpu0_pci_addr) {
234                println!("GPU 0 PCI: {}", gpu0_pci_addr);
235                println!("GPU 0 path to root: {:?}", gpu0_device.get_path_to_root());
236
237                let mut all_distances = Vec::new();
238                for (nic_name, nic_pci_addr) in &rdma_devices {
239                    if let Some(nic_device) = pci_devices.get(nic_pci_addr) {
240                        let distance = gpu0_device.distance_to(nic_device);
241                        all_distances.push((distance, nic_name.clone(), nic_pci_addr.clone()));
242                        println!("  {} ({}): distance = {}", nic_name, nic_pci_addr, distance);
243                        println!("    NIC path to root: {:?}", nic_device.get_path_to_root());
244                    }
245                }
246
247                // Find the minimum distance
248                all_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
249                if let Some((min_dist, min_nic, min_addr)) = all_distances.first() {
250                    println!(
251                        "  → CLOSEST: {} ({}) with distance {}",
252                        min_nic, min_addr, min_dist
253                    );
254                }
255            }
256        }
257
258        // Step 7: Test unified device selection interface
259        println!("\n7. UNIFIED DEVICE SELECTION TEST");
260        let test_cases = [
261            ("cuda:0", "CUDA device 0"),
262            ("cuda:1", "CUDA device 1"),
263            ("cpu:0", "CPU/NUMA node 0"),
264            ("cpu:1", "CPU/NUMA node 1"),
265        ];
266
267        for (device_hint, description) in &test_cases {
268            let selected_device = select_optimal_ibv_device(Some(device_hint));
269            match selected_device {
270                Some(device) => {
271                    println!("  {} ({}) -> {}", device_hint, description, device.name());
272                }
273                None => {
274                    println!("  {} ({}) -> NOT FOUND", device_hint, description);
275                }
276            }
277        }
278
279        // Step 8: Test all 8 GPU mappings against expected GT20 hardware results
280        println!("\n8. GPU-TO-RDMA MAPPING VALIDATION (ALL 8 GPUs)");
281
282        // Expected results from original Python implementation on GT20 hardware
283        let python_expected = [
284            (0, "mlx5_0"),
285            (1, "mlx5_3"),
286            (2, "mlx5_4"),
287            (3, "mlx5_5"),
288            (4, "mlx5_6"),
289            (5, "mlx5_9"),
290            (6, "mlx5_10"),
291            (7, "mlx5_11"),
292        ];
293
294        let mut rust_results = std::collections::HashMap::new();
295        let mut all_match = true;
296
297        // Test all 8 GPU mappings using new unified API
298        for gpu_idx in 0..8 {
299            let cuda_hint = format!("cuda:{}", gpu_idx);
300            let selected_device = select_optimal_ibv_device(Some(&cuda_hint));
301
302            match selected_device {
303                Some(device) => {
304                    let device_name = device.name().to_string();
305                    rust_results.insert(gpu_idx, device_name.clone());
306                    println!("  GPU {} -> {}", gpu_idx, device_name);
307                }
308                None => {
309                    println!("  GPU {} -> NOT FOUND", gpu_idx);
310                    rust_results.insert(gpu_idx, "NOT_FOUND".to_string());
311                }
312            }
313        }
314
315        // Compare against expected results
316        println!("\n=== VALIDATION AGAINST EXPECTED RESULTS ===");
317        for (gpu_idx, expected_nic) in python_expected {
318            if let Some(actual_nic) = rust_results.get(&gpu_idx) {
319                let matches = actual_nic == expected_nic;
320                println!(
321                    "  GPU {} -> {} {} (expected {})",
322                    gpu_idx,
323                    actual_nic,
324                    if matches { "✓" } else { "✗" },
325                    expected_nic
326                );
327                all_match = all_match && matches;
328            } else {
329                println!(
330                    "  GPU {} -> NOT FOUND ✗ (expected {})",
331                    gpu_idx, expected_nic
332                );
333                all_match = false;
334            }
335        }
336
337        if all_match {
338            println!("\n🎉 SUCCESS: All GPU-NIC pairings match expected GT20 hardware results!");
339            println!("✓ New unified API produces identical results to proven algorithm");
340        } else {
341            println!("\n⚠️  WARNING: Some GPU-NIC pairings differ from expected results");
342            println!("   This could indicate:");
343            println!("   - Hardware configuration differences");
344            println!("   - Algorithm implementation differences");
345            println!("   - Environment setup differences");
346        }
347
348        // Step 9: Detailed CPU device selection analysis
349        println!("\n9. DETAILED CPU DEVICE SELECTION ANALYSIS");
350
351        // Check what representative PCI addresses we found for each NUMA node
352        if let Some(numa0_addr) = get_numa_pci_address("0") {
353            println!("  NUMA 0 representative PCI: {}", numa0_addr);
354        } else {
355            println!("  NUMA 0 representative PCI: NOT FOUND");
356        }
357
358        if let Some(numa1_addr) = get_numa_pci_address("1") {
359            println!("  NUMA 1 representative PCI: {}", numa1_addr);
360        } else {
361            println!("  NUMA 1 representative PCI: NOT FOUND");
362        }
363
364        // Now test the actual selections
365        let cpu0_device = select_optimal_ibv_device(Some("cpu:0"));
366        let cpu1_device = select_optimal_ibv_device(Some("cpu:1"));
367
368        match (
369            cpu0_device.as_ref().map(|d| d.name()),
370            cpu1_device.as_ref().map(|d| d.name()),
371        ) {
372            (Some(cpu0_name), Some(cpu1_name)) => {
373                println!("\n  FINAL SELECTIONS:");
374                println!("    CPU:0 -> {}", cpu0_name);
375                println!("    CPU:1 -> {}", cpu1_name);
376                if cpu0_name != cpu1_name {
377                    println!("    ✓ Different NUMA nodes select different RDMA devices");
378                } else {
379                    println!("    ⚠️  Same RDMA device selected for both NUMA nodes");
380                    println!("       This could indicate:");
381                    println!(
382                        "       - {} is genuinely closest to both NUMA nodes",
383                        cpu0_name
384                    );
385                    println!("       - NUMA topology detection issue");
386                    println!("       - Cross-NUMA penalty algorithm working correctly");
387                }
388            }
389            _ => {
390                println!("    ○ CPU device selection not available");
391            }
392        }
393
394        println!("\n✓ GT20 hardware test completed");
395
396        // we can't gaurantee that the test will always match given test infra but is good for diagnostic purposes / tracking.
397    }
398}