monarch_rdma/backend/ibverbs/
device_selection.rs1use std::collections::HashMap;
13
14use super::primitives::IbvDevice;
15use super::primitives::get_all_devices;
16use crate::device_selection::PCIDevice;
17use crate::device_selection::get_all_rdma_devices;
18use crate::device_selection::get_cuda_pci_address;
19use crate::device_selection::get_numa_pci_address;
20use crate::device_selection::parse_device_string;
21use crate::device_selection::parse_pci_topology;
22
23pub fn select_optimal_ibv_device(device_hint: Option<&str>) -> Option<IbvDevice> {
28 let device_hint = device_hint?;
29
30 let (prefix, postfix) = parse_device_string(device_hint)?;
31
32 match prefix.as_str() {
33 "nic" => {
34 let all_rdma_devices = get_all_devices();
35 all_rdma_devices
36 .into_iter()
37 .find(|dev| dev.name() == &postfix)
38 }
39 "cuda" | "cpu" => {
40 let source_pci_addr = match prefix.as_str() {
41 "cuda" => get_cuda_pci_address(&postfix)?,
42 "cpu" => get_numa_pci_address(&postfix)?,
43 _ => unreachable!(),
44 };
45 let rdma_devices = get_all_rdma_devices();
46 if rdma_devices.is_empty() {
47 return IbvDevice::first_available();
48 }
49 let pci_devices = parse_pci_topology().ok()?;
50 let source_device = pci_devices.get(&source_pci_addr)?;
51
52 let rdma_names: Vec<String> =
53 rdma_devices.iter().map(|(name, _)| name.clone()).collect();
54 let rdma_pci_devices: Vec<PCIDevice> = rdma_devices
55 .iter()
56 .filter_map(|(_, addr)| pci_devices.get(addr).cloned())
57 .collect();
58
59 if let Some(closest_idx) = source_device.find_closest(&rdma_pci_devices) {
60 if let Some(optimal_name) = rdma_names.get(closest_idx) {
61 let all_rdma_devices = get_all_devices();
62 for device in all_rdma_devices {
63 if *device.name() == *optimal_name {
64 return Some(device);
65 }
66 }
67 }
68 }
69
70 IbvDevice::first_available()
72 }
73 _ => {
74 let rdma_devices = get_all_devices();
76 rdma_devices
77 .into_iter()
78 .find(|dev| dev.name() == device_hint)
79 }
80 }
81}
82
83pub fn create_cuda_to_ibv_mapping() -> HashMap<String, IbvDevice> {
88 let mut mapping = HashMap::new();
89
90 for gpu_idx in 0..8 {
92 let gpu_idx_str = gpu_idx.to_string();
93 if let Some(cuda_pci_addr) = get_cuda_pci_address(&gpu_idx_str) {
94 let cuda_hint = format!("cuda:{}", gpu_idx);
95 if let Some(rdma_device) = select_optimal_ibv_device(Some(&cuda_hint)) {
96 mapping.insert(cuda_pci_addr, rdma_device);
97 }
98 }
99 }
100
101 mapping
102}
103
104pub fn resolve_ibv_device(device: &IbvDevice) -> Option<IbvDevice> {
109 let device_name = device.name();
110
111 if device_name.starts_with("mlx") {
112 return Some(device.clone());
113 }
114
115 let all_devices = get_all_devices();
116 let is_likely_default = if let Some(first_device) = all_devices.first() {
117 device_name == first_device.name()
118 } else {
119 false
120 };
121
122 if is_likely_default {
123 select_optimal_ibv_device(Some("cpu:0"))
124 } else {
125 Some(device.clone())
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 fn is_gt20_hardware() -> bool {
135 let rdma_devices = get_all_rdma_devices();
136 let device_names: std::collections::HashSet<String> =
137 rdma_devices.iter().map(|(name, _)| name.clone()).collect();
138
139 let expected_gt20_devices = [
141 "mlx5_0", "mlx5_3", "mlx5_4", "mlx5_5", "mlx5_6", "mlx5_9", "mlx5_10", "mlx5_11",
142 ];
143
144 let gpu_count = (0..8)
146 .filter(|&i| get_cuda_pci_address(&i.to_string()).is_some())
147 .count();
148
149 let has_expected_rdma = expected_gt20_devices
151 .iter()
152 .all(|&device| device_names.contains(device));
153
154 has_expected_rdma && gpu_count == 8
155 }
156
157 #[test]
159 fn test_gt20_hardware() {
160 if !is_gt20_hardware() {
162 println!("⚠️ Skipping test_gt20_hardware: Not running on GT20 hardware");
163 return;
164 }
165
166 println!("✓ Detected GT20 hardware - running full validation test");
167 println!("\n1. PCI TOPOLOGY PARSING");
169 let pci_devices = match parse_pci_topology() {
170 Ok(devices) => {
171 println!("✓ Found {} PCI devices", devices.len());
172 devices
173 }
174 Err(e) => {
175 println!("✗ Error: {}", e);
176 return;
177 }
178 };
179
180 println!("\n2. RDMA DEVICE DISCOVERY");
182 let rdma_devices = get_all_rdma_devices();
183 println!("✓ Found {} RDMA devices", rdma_devices.len());
184 for (name, pci_addr) in &rdma_devices {
185 println!(" RDMA {}: {}", name, pci_addr);
186 }
187
188 println!("\n3. DEVICE STRING PARSING");
190 let test_strings = ["cuda:0", "cuda:1", "cpu:0", "cpu:1"];
191 for device_str in &test_strings {
192 if let Some((prefix, postfix)) = parse_device_string(device_str) {
193 println!(
194 " '{}' -> prefix: '{}', postfix: '{}'",
195 device_str, prefix, postfix
196 );
197 } else {
198 println!(" '{}' -> PARSE FAILED", device_str);
199 }
200 }
201
202 println!("\n4. CUDA PCI ADDRESS RESOLUTION");
204 for gpu_idx in 0..8 {
205 let gpu_idx_str = gpu_idx.to_string();
206 match get_cuda_pci_address(&gpu_idx_str) {
207 Some(pci_addr) => {
208 println!(" GPU {} -> PCI: {}", gpu_idx, pci_addr);
209 }
210 None => {
211 println!(" GPU {} -> PCI: NOT FOUND", gpu_idx);
212 }
213 }
214 }
215
216 println!("\n5. CPU/NUMA PCI ADDRESS RESOLUTION");
218 for numa_node in 0..4 {
219 let numa_str = numa_node.to_string();
220 match get_numa_pci_address(&numa_str) {
221 Some(pci_addr) => {
222 println!(" NUMA {} -> PCI: {}", numa_node, pci_addr);
223 }
224 None => {
225 println!(" NUMA {} -> PCI: NOT FOUND", numa_node);
226 }
227 }
228 }
229
230 println!("\n6. DISTANCE CALCULATION TEST (GPU 0)");
232 if let Some(gpu0_pci_addr) = get_cuda_pci_address("0") {
233 if let Some(gpu0_device) = pci_devices.get(&gpu0_pci_addr) {
234 println!("GPU 0 PCI: {}", gpu0_pci_addr);
235 println!("GPU 0 path to root: {:?}", gpu0_device.get_path_to_root());
236
237 let mut all_distances = Vec::new();
238 for (nic_name, nic_pci_addr) in &rdma_devices {
239 if let Some(nic_device) = pci_devices.get(nic_pci_addr) {
240 let distance = gpu0_device.distance_to(nic_device);
241 all_distances.push((distance, nic_name.clone(), nic_pci_addr.clone()));
242 println!(" {} ({}): distance = {}", nic_name, nic_pci_addr, distance);
243 println!(" NIC path to root: {:?}", nic_device.get_path_to_root());
244 }
245 }
246
247 all_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
249 if let Some((min_dist, min_nic, min_addr)) = all_distances.first() {
250 println!(
251 " → CLOSEST: {} ({}) with distance {}",
252 min_nic, min_addr, min_dist
253 );
254 }
255 }
256 }
257
258 println!("\n7. UNIFIED DEVICE SELECTION TEST");
260 let test_cases = [
261 ("cuda:0", "CUDA device 0"),
262 ("cuda:1", "CUDA device 1"),
263 ("cpu:0", "CPU/NUMA node 0"),
264 ("cpu:1", "CPU/NUMA node 1"),
265 ];
266
267 for (device_hint, description) in &test_cases {
268 let selected_device = select_optimal_ibv_device(Some(device_hint));
269 match selected_device {
270 Some(device) => {
271 println!(" {} ({}) -> {}", device_hint, description, device.name());
272 }
273 None => {
274 println!(" {} ({}) -> NOT FOUND", device_hint, description);
275 }
276 }
277 }
278
279 println!("\n8. GPU-TO-RDMA MAPPING VALIDATION (ALL 8 GPUs)");
281
282 let python_expected = [
284 (0, "mlx5_0"),
285 (1, "mlx5_3"),
286 (2, "mlx5_4"),
287 (3, "mlx5_5"),
288 (4, "mlx5_6"),
289 (5, "mlx5_9"),
290 (6, "mlx5_10"),
291 (7, "mlx5_11"),
292 ];
293
294 let mut rust_results = std::collections::HashMap::new();
295 let mut all_match = true;
296
297 for gpu_idx in 0..8 {
299 let cuda_hint = format!("cuda:{}", gpu_idx);
300 let selected_device = select_optimal_ibv_device(Some(&cuda_hint));
301
302 match selected_device {
303 Some(device) => {
304 let device_name = device.name().to_string();
305 rust_results.insert(gpu_idx, device_name.clone());
306 println!(" GPU {} -> {}", gpu_idx, device_name);
307 }
308 None => {
309 println!(" GPU {} -> NOT FOUND", gpu_idx);
310 rust_results.insert(gpu_idx, "NOT_FOUND".to_string());
311 }
312 }
313 }
314
315 println!("\n=== VALIDATION AGAINST EXPECTED RESULTS ===");
317 for (gpu_idx, expected_nic) in python_expected {
318 if let Some(actual_nic) = rust_results.get(&gpu_idx) {
319 let matches = actual_nic == expected_nic;
320 println!(
321 " GPU {} -> {} {} (expected {})",
322 gpu_idx,
323 actual_nic,
324 if matches { "✓" } else { "✗" },
325 expected_nic
326 );
327 all_match = all_match && matches;
328 } else {
329 println!(
330 " GPU {} -> NOT FOUND ✗ (expected {})",
331 gpu_idx, expected_nic
332 );
333 all_match = false;
334 }
335 }
336
337 if all_match {
338 println!("\n🎉 SUCCESS: All GPU-NIC pairings match expected GT20 hardware results!");
339 println!("✓ New unified API produces identical results to proven algorithm");
340 } else {
341 println!("\n⚠️ WARNING: Some GPU-NIC pairings differ from expected results");
342 println!(" This could indicate:");
343 println!(" - Hardware configuration differences");
344 println!(" - Algorithm implementation differences");
345 println!(" - Environment setup differences");
346 }
347
348 println!("\n9. DETAILED CPU DEVICE SELECTION ANALYSIS");
350
351 if let Some(numa0_addr) = get_numa_pci_address("0") {
353 println!(" NUMA 0 representative PCI: {}", numa0_addr);
354 } else {
355 println!(" NUMA 0 representative PCI: NOT FOUND");
356 }
357
358 if let Some(numa1_addr) = get_numa_pci_address("1") {
359 println!(" NUMA 1 representative PCI: {}", numa1_addr);
360 } else {
361 println!(" NUMA 1 representative PCI: NOT FOUND");
362 }
363
364 let cpu0_device = select_optimal_ibv_device(Some("cpu:0"));
366 let cpu1_device = select_optimal_ibv_device(Some("cpu:1"));
367
368 match (
369 cpu0_device.as_ref().map(|d| d.name()),
370 cpu1_device.as_ref().map(|d| d.name()),
371 ) {
372 (Some(cpu0_name), Some(cpu1_name)) => {
373 println!("\n FINAL SELECTIONS:");
374 println!(" CPU:0 -> {}", cpu0_name);
375 println!(" CPU:1 -> {}", cpu1_name);
376 if cpu0_name != cpu1_name {
377 println!(" ✓ Different NUMA nodes select different RDMA devices");
378 } else {
379 println!(" ⚠️ Same RDMA device selected for both NUMA nodes");
380 println!(" This could indicate:");
381 println!(
382 " - {} is genuinely closest to both NUMA nodes",
383 cpu0_name
384 );
385 println!(" - NUMA topology detection issue");
386 println!(" - Cross-NUMA penalty algorithm working correctly");
387 }
388 }
389 _ => {
390 println!(" ○ CPU device selection not available");
391 }
392 }
393
394 println!("\n✓ GT20 hardware test completed");
395
396 }
398}