monarch_rdma/backend/ibverbs/
device_selection.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! ibverbs-specific device selection logic that pairs compute devices
10//! with the best available RDMA NICs based on PCI topology distance.
11
12use std::sync::OnceLock;
13
14use super::primitives::IbvDevice;
15use super::primitives::get_all_devices;
16use crate::device_selection::PCIDevice;
17use crate::device_selection::get_all_rdma_devices;
18use crate::device_selection::get_cuda_pci_address;
19use crate::device_selection::get_numa_pci_address;
20use crate::device_selection::parse_device_string;
21use crate::device_selection::parse_pci_topology;
22
23/// Step 1: Parse device string into prefix and postfix
24/// Step 2: Get PCI address from compute device
25/// Step 3: Get PCI address for all RDMA NIC devices
26/// Step 4: Calculate PCI distances and return closest RDMA NIC device
27pub fn select_optimal_ibv_device(device_hint: Option<&str>) -> Option<IbvDevice> {
28    let device_hint = device_hint?;
29
30    let (prefix, postfix) = parse_device_string(device_hint)?;
31
32    match prefix.as_str() {
33        "nic" => {
34            let all_rdma_devices = get_all_devices();
35            all_rdma_devices
36                .into_iter()
37                .find(|dev| dev.name() == &postfix)
38        }
39        "cuda" | "cpu" => {
40            let source_pci_addr = match prefix.as_str() {
41                "cuda" => get_cuda_pci_address(&postfix)?,
42                "cpu" => get_numa_pci_address(&postfix)?,
43                _ => unreachable!(),
44            };
45            let rdma_devices = get_all_rdma_devices();
46            if rdma_devices.is_empty() {
47                return IbvDevice::first_available();
48            }
49            let pci_devices = parse_pci_topology().ok()?;
50            let source_device = pci_devices.get(&source_pci_addr)?;
51
52            let rdma_names: Vec<String> =
53                rdma_devices.iter().map(|(name, _)| name.clone()).collect();
54            let rdma_pci_devices: Vec<PCIDevice> = rdma_devices
55                .iter()
56                .filter_map(|(_, addr)| pci_devices.get(addr).cloned())
57                .collect();
58
59            if let Some(closest_idx) = source_device.find_closest(&rdma_pci_devices) {
60                if let Some(optimal_name) = rdma_names.get(closest_idx) {
61                    let all_rdma_devices = get_all_devices();
62                    for device in all_rdma_devices {
63                        if *device.name() == *optimal_name {
64                            return Some(device);
65                        }
66                    }
67                }
68            }
69
70            // Fallback
71            IbvDevice::first_available()
72        }
73        _ => {
74            // Direct device name lookup for backward compatibility
75            let rdma_devices = get_all_devices();
76            rdma_devices
77                .into_iter()
78                .find(|dev| dev.name() == device_hint)
79        }
80    }
81}
82
83/// Returns a reference to the process-wide lazily-initialized Vec mapping
84/// CUDA device ordinal → optimal RDMA NIC (`None` if no NIC is mapped).
85///
86/// Computed at most once per process on the first RDMA operation involving
87/// CUDA memory. CPU-only workloads pay no initialization cost.
88pub fn get_cuda_device_to_ibv_device() -> &'static Vec<Option<IbvDevice>> {
89    static CUDA_DEVICE_TO_IBV: OnceLock<Vec<Option<IbvDevice>>> = OnceLock::new();
90    CUDA_DEVICE_TO_IBV.get_or_init(|| {
91        let count = unsafe {
92            let mut c: i32 = 0;
93            rdmaxcel_sys::rdmaxcel_cuDeviceGetCount(&mut c);
94            c.max(0) as usize
95        };
96        (0..count)
97            .map(|ordinal| select_optimal_ibv_device(Some(&format!("cuda:{}", ordinal))))
98            .collect()
99    })
100}
101
102/// Resolves RDMA device using auto-detection logic when needed.
103///
104/// Applies auto-detection for default devices, but otherwise
105/// returns the device as-is.
106pub fn resolve_ibv_device(device: &IbvDevice) -> Option<IbvDevice> {
107    let device_name = device.name();
108
109    if device_name.starts_with("mlx") {
110        return Some(device.clone());
111    }
112
113    let all_devices = get_all_devices();
114    let is_likely_default = if let Some(first_device) = all_devices.first() {
115        device_name == first_device.name()
116    } else {
117        false
118    };
119
120    if is_likely_default {
121        select_optimal_ibv_device(Some("cpu:0"))
122    } else {
123        Some(device.clone())
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    /// Detect if we're running on GT20 hardware by checking for expected RDMA device configuration
132    fn is_gt20_hardware() -> bool {
133        let rdma_devices = get_all_rdma_devices();
134        let device_names: std::collections::HashSet<String> =
135            rdma_devices.iter().map(|(name, _)| name.clone()).collect();
136
137        // GT20 hardware should have these specific RDMA devices
138        let expected_gt20_devices = [
139            "mlx5_0", "mlx5_3", "mlx5_4", "mlx5_5", "mlx5_6", "mlx5_9", "mlx5_10", "mlx5_11",
140        ];
141
142        // Check if we have at least 8 GPUs (GT20 characteristic)
143        let gpu_count = (0..8)
144            .filter(|&i| get_cuda_pci_address(&i.to_string()).is_some())
145            .count();
146
147        // Must have expected RDMA devices AND 8 GPUs
148        let has_expected_rdma = expected_gt20_devices
149            .iter()
150            .all(|&device| device_names.contains(device));
151
152        has_expected_rdma && gpu_count == 8
153    }
154
155    /// Test each function step by step using the new simplified API - GT20 hardware only
156    #[test]
157    fn test_gt20_hardware() {
158        // Early exit if not on GT20 hardware
159        if !is_gt20_hardware() {
160            println!("⚠️  Skipping test_gt20_hardware: Not running on GT20 hardware");
161            return;
162        }
163
164        println!("✓ Detected GT20 hardware - running full validation test");
165        // Step 1: Test PCI topology parsing
166        println!("\n1. PCI TOPOLOGY PARSING");
167        let pci_devices = match parse_pci_topology() {
168            Ok(devices) => {
169                println!("✓ Found {} PCI devices", devices.len());
170                devices
171            }
172            Err(e) => {
173                println!("✗ Error: {}", e);
174                return;
175            }
176        };
177
178        // Step 2: Test unified RDMA device discovery
179        println!("\n2. RDMA DEVICE DISCOVERY");
180        let rdma_devices = get_all_rdma_devices();
181        println!("✓ Found {} RDMA devices", rdma_devices.len());
182        for (name, pci_addr) in &rdma_devices {
183            println!("  RDMA {}: {}", name, pci_addr);
184        }
185
186        // Step 3: Test device string parsing
187        println!("\n3. DEVICE STRING PARSING");
188        let test_strings = ["cuda:0", "cuda:1", "cpu:0", "cpu:1"];
189        for device_str in &test_strings {
190            if let Some((prefix, postfix)) = parse_device_string(device_str) {
191                println!(
192                    "  '{}' -> prefix: '{}', postfix: '{}'",
193                    device_str, prefix, postfix
194                );
195            } else {
196                println!("  '{}' -> PARSE FAILED", device_str);
197            }
198        }
199
200        // Step 4: Test CUDA PCI address resolution
201        println!("\n4. CUDA PCI ADDRESS RESOLUTION");
202        for gpu_idx in 0..8 {
203            let gpu_idx_str = gpu_idx.to_string();
204            match get_cuda_pci_address(&gpu_idx_str) {
205                Some(pci_addr) => {
206                    println!("  GPU {} -> PCI: {}", gpu_idx, pci_addr);
207                }
208                None => {
209                    println!("  GPU {} -> PCI: NOT FOUND", gpu_idx);
210                }
211            }
212        }
213
214        // Step 5: Test CPU/NUMA PCI address resolution
215        println!("\n5. CPU/NUMA PCI ADDRESS RESOLUTION");
216        for numa_node in 0..4 {
217            let numa_str = numa_node.to_string();
218            match get_numa_pci_address(&numa_str) {
219                Some(pci_addr) => {
220                    println!("  NUMA {} -> PCI: {}", numa_node, pci_addr);
221                }
222                None => {
223                    println!("  NUMA {} -> PCI: NOT FOUND", numa_node);
224                }
225            }
226        }
227
228        // Step 6: Test distance calculation for GPU 0
229        println!("\n6. DISTANCE CALCULATION TEST (GPU 0)");
230        if let Some(gpu0_pci_addr) = get_cuda_pci_address("0") {
231            if let Some(gpu0_device) = pci_devices.get(&gpu0_pci_addr) {
232                println!("GPU 0 PCI: {}", gpu0_pci_addr);
233                println!("GPU 0 path to root: {:?}", gpu0_device.get_path_to_root());
234
235                let mut all_distances = Vec::new();
236                for (nic_name, nic_pci_addr) in &rdma_devices {
237                    if let Some(nic_device) = pci_devices.get(nic_pci_addr) {
238                        let distance = gpu0_device.distance_to(nic_device);
239                        all_distances.push((distance, nic_name.clone(), nic_pci_addr.clone()));
240                        println!("  {} ({}): distance = {}", nic_name, nic_pci_addr, distance);
241                        println!("    NIC path to root: {:?}", nic_device.get_path_to_root());
242                    }
243                }
244
245                // Find the minimum distance
246                all_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
247                if let Some((min_dist, min_nic, min_addr)) = all_distances.first() {
248                    println!(
249                        "  → CLOSEST: {} ({}) with distance {}",
250                        min_nic, min_addr, min_dist
251                    );
252                }
253            }
254        }
255
256        // Step 7: Test unified device selection interface
257        println!("\n7. UNIFIED DEVICE SELECTION TEST");
258        let test_cases = [
259            ("cuda:0", "CUDA device 0"),
260            ("cuda:1", "CUDA device 1"),
261            ("cpu:0", "CPU/NUMA node 0"),
262            ("cpu:1", "CPU/NUMA node 1"),
263        ];
264
265        for (device_hint, description) in &test_cases {
266            let selected_device = select_optimal_ibv_device(Some(device_hint));
267            match selected_device {
268                Some(device) => {
269                    println!("  {} ({}) -> {}", device_hint, description, device.name());
270                }
271                None => {
272                    println!("  {} ({}) -> NOT FOUND", device_hint, description);
273                }
274            }
275        }
276
277        // Step 8: Test all 8 GPU mappings against expected GT20 hardware results
278        println!("\n8. GPU-TO-RDMA MAPPING VALIDATION (ALL 8 GPUs)");
279
280        // Expected results from original Python implementation on GT20 hardware
281        let python_expected = [
282            (0, "mlx5_0"),
283            (1, "mlx5_3"),
284            (2, "mlx5_4"),
285            (3, "mlx5_5"),
286            (4, "mlx5_6"),
287            (5, "mlx5_9"),
288            (6, "mlx5_10"),
289            (7, "mlx5_11"),
290        ];
291
292        let mut rust_results = std::collections::HashMap::new();
293        let mut all_match = true;
294
295        // Test all 8 GPU mappings using new unified API
296        for gpu_idx in 0..8 {
297            let cuda_hint = format!("cuda:{}", gpu_idx);
298            let selected_device = select_optimal_ibv_device(Some(&cuda_hint));
299
300            match selected_device {
301                Some(device) => {
302                    let device_name = device.name().to_string();
303                    rust_results.insert(gpu_idx, device_name.clone());
304                    println!("  GPU {} -> {}", gpu_idx, device_name);
305                }
306                None => {
307                    println!("  GPU {} -> NOT FOUND", gpu_idx);
308                    rust_results.insert(gpu_idx, "NOT_FOUND".to_string());
309                }
310            }
311        }
312
313        // Compare against expected results
314        println!("\n=== VALIDATION AGAINST EXPECTED RESULTS ===");
315        for (gpu_idx, expected_nic) in python_expected {
316            if let Some(actual_nic) = rust_results.get(&gpu_idx) {
317                let matches = actual_nic == expected_nic;
318                println!(
319                    "  GPU {} -> {} {} (expected {})",
320                    gpu_idx,
321                    actual_nic,
322                    if matches { "✓" } else { "✗" },
323                    expected_nic
324                );
325                all_match = all_match && matches;
326            } else {
327                println!(
328                    "  GPU {} -> NOT FOUND ✗ (expected {})",
329                    gpu_idx, expected_nic
330                );
331                all_match = false;
332            }
333        }
334
335        if all_match {
336            println!("\n🎉 SUCCESS: All GPU-NIC pairings match expected GT20 hardware results!");
337            println!("✓ New unified API produces identical results to proven algorithm");
338        } else {
339            println!("\n⚠️  WARNING: Some GPU-NIC pairings differ from expected results");
340            println!("   This could indicate:");
341            println!("   - Hardware configuration differences");
342            println!("   - Algorithm implementation differences");
343            println!("   - Environment setup differences");
344        }
345
346        // Step 9: Detailed CPU device selection analysis
347        println!("\n9. DETAILED CPU DEVICE SELECTION ANALYSIS");
348
349        // Check what representative PCI addresses we found for each NUMA node
350        if let Some(numa0_addr) = get_numa_pci_address("0") {
351            println!("  NUMA 0 representative PCI: {}", numa0_addr);
352        } else {
353            println!("  NUMA 0 representative PCI: NOT FOUND");
354        }
355
356        if let Some(numa1_addr) = get_numa_pci_address("1") {
357            println!("  NUMA 1 representative PCI: {}", numa1_addr);
358        } else {
359            println!("  NUMA 1 representative PCI: NOT FOUND");
360        }
361
362        // Now test the actual selections
363        let cpu0_device = select_optimal_ibv_device(Some("cpu:0"));
364        let cpu1_device = select_optimal_ibv_device(Some("cpu:1"));
365
366        match (
367            cpu0_device.as_ref().map(|d| d.name()),
368            cpu1_device.as_ref().map(|d| d.name()),
369        ) {
370            (Some(cpu0_name), Some(cpu1_name)) => {
371                println!("\n  FINAL SELECTIONS:");
372                println!("    CPU:0 -> {}", cpu0_name);
373                println!("    CPU:1 -> {}", cpu1_name);
374                if cpu0_name != cpu1_name {
375                    println!("    ✓ Different NUMA nodes select different RDMA devices");
376                } else {
377                    println!("    ⚠️  Same RDMA device selected for both NUMA nodes");
378                    println!("       This could indicate:");
379                    println!(
380                        "       - {} is genuinely closest to both NUMA nodes",
381                        cpu0_name
382                    );
383                    println!("       - NUMA topology detection issue");
384                    println!("       - Cross-NUMA penalty algorithm working correctly");
385                }
386            }
387            _ => {
388                println!("    ○ CPU device selection not available");
389            }
390        }
391
392        println!("\n✓ GT20 hardware test completed");
393
394        // we can't gaurantee that the test will always match given test infra but is good for diagnostic purposes / tracking.
395    }
396}