monarch_rdma/backend/ibverbs/
device_selection.rs1use std::sync::OnceLock;
13
14use super::primitives::IbvDevice;
15use super::primitives::get_all_devices;
16use crate::device_selection::PCIDevice;
17use crate::device_selection::get_all_rdma_devices;
18use crate::device_selection::get_cuda_pci_address;
19use crate::device_selection::get_numa_pci_address;
20use crate::device_selection::parse_device_string;
21use crate::device_selection::parse_pci_topology;
22
23pub fn select_optimal_ibv_device(device_hint: Option<&str>) -> Option<IbvDevice> {
28 let device_hint = device_hint?;
29
30 let (prefix, postfix) = parse_device_string(device_hint)?;
31
32 match prefix.as_str() {
33 "nic" => {
34 let all_rdma_devices = get_all_devices();
35 all_rdma_devices
36 .into_iter()
37 .find(|dev| dev.name() == &postfix)
38 }
39 "cuda" | "cpu" => {
40 let source_pci_addr = match prefix.as_str() {
41 "cuda" => get_cuda_pci_address(&postfix)?,
42 "cpu" => get_numa_pci_address(&postfix)?,
43 _ => unreachable!(),
44 };
45 let rdma_devices = get_all_rdma_devices();
46 if rdma_devices.is_empty() {
47 return IbvDevice::first_available();
48 }
49 let pci_devices = parse_pci_topology().ok()?;
50 let source_device = pci_devices.get(&source_pci_addr)?;
51
52 let rdma_names: Vec<String> =
53 rdma_devices.iter().map(|(name, _)| name.clone()).collect();
54 let rdma_pci_devices: Vec<PCIDevice> = rdma_devices
55 .iter()
56 .filter_map(|(_, addr)| pci_devices.get(addr).cloned())
57 .collect();
58
59 if let Some(closest_idx) = source_device.find_closest(&rdma_pci_devices) {
60 if let Some(optimal_name) = rdma_names.get(closest_idx) {
61 let all_rdma_devices = get_all_devices();
62 for device in all_rdma_devices {
63 if *device.name() == *optimal_name {
64 return Some(device);
65 }
66 }
67 }
68 }
69
70 IbvDevice::first_available()
72 }
73 _ => {
74 let rdma_devices = get_all_devices();
76 rdma_devices
77 .into_iter()
78 .find(|dev| dev.name() == device_hint)
79 }
80 }
81}
82
83pub fn get_cuda_device_to_ibv_device() -> &'static Vec<Option<IbvDevice>> {
89 static CUDA_DEVICE_TO_IBV: OnceLock<Vec<Option<IbvDevice>>> = OnceLock::new();
90 CUDA_DEVICE_TO_IBV.get_or_init(|| {
91 let count = unsafe {
92 let mut c: i32 = 0;
93 rdmaxcel_sys::rdmaxcel_cuDeviceGetCount(&mut c);
94 c.max(0) as usize
95 };
96 (0..count)
97 .map(|ordinal| select_optimal_ibv_device(Some(&format!("cuda:{}", ordinal))))
98 .collect()
99 })
100}
101
102pub fn resolve_ibv_device(device: &IbvDevice) -> Option<IbvDevice> {
107 let device_name = device.name();
108
109 if device_name.starts_with("mlx") {
110 return Some(device.clone());
111 }
112
113 let all_devices = get_all_devices();
114 let is_likely_default = if let Some(first_device) = all_devices.first() {
115 device_name == first_device.name()
116 } else {
117 false
118 };
119
120 if is_likely_default {
121 select_optimal_ibv_device(Some("cpu:0"))
122 } else {
123 Some(device.clone())
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 fn is_gt20_hardware() -> bool {
133 let rdma_devices = get_all_rdma_devices();
134 let device_names: std::collections::HashSet<String> =
135 rdma_devices.iter().map(|(name, _)| name.clone()).collect();
136
137 let expected_gt20_devices = [
139 "mlx5_0", "mlx5_3", "mlx5_4", "mlx5_5", "mlx5_6", "mlx5_9", "mlx5_10", "mlx5_11",
140 ];
141
142 let gpu_count = (0..8)
144 .filter(|&i| get_cuda_pci_address(&i.to_string()).is_some())
145 .count();
146
147 let has_expected_rdma = expected_gt20_devices
149 .iter()
150 .all(|&device| device_names.contains(device));
151
152 has_expected_rdma && gpu_count == 8
153 }
154
155 #[test]
157 fn test_gt20_hardware() {
158 if !is_gt20_hardware() {
160 println!("⚠️ Skipping test_gt20_hardware: Not running on GT20 hardware");
161 return;
162 }
163
164 println!("✓ Detected GT20 hardware - running full validation test");
165 println!("\n1. PCI TOPOLOGY PARSING");
167 let pci_devices = match parse_pci_topology() {
168 Ok(devices) => {
169 println!("✓ Found {} PCI devices", devices.len());
170 devices
171 }
172 Err(e) => {
173 println!("✗ Error: {}", e);
174 return;
175 }
176 };
177
178 println!("\n2. RDMA DEVICE DISCOVERY");
180 let rdma_devices = get_all_rdma_devices();
181 println!("✓ Found {} RDMA devices", rdma_devices.len());
182 for (name, pci_addr) in &rdma_devices {
183 println!(" RDMA {}: {}", name, pci_addr);
184 }
185
186 println!("\n3. DEVICE STRING PARSING");
188 let test_strings = ["cuda:0", "cuda:1", "cpu:0", "cpu:1"];
189 for device_str in &test_strings {
190 if let Some((prefix, postfix)) = parse_device_string(device_str) {
191 println!(
192 " '{}' -> prefix: '{}', postfix: '{}'",
193 device_str, prefix, postfix
194 );
195 } else {
196 println!(" '{}' -> PARSE FAILED", device_str);
197 }
198 }
199
200 println!("\n4. CUDA PCI ADDRESS RESOLUTION");
202 for gpu_idx in 0..8 {
203 let gpu_idx_str = gpu_idx.to_string();
204 match get_cuda_pci_address(&gpu_idx_str) {
205 Some(pci_addr) => {
206 println!(" GPU {} -> PCI: {}", gpu_idx, pci_addr);
207 }
208 None => {
209 println!(" GPU {} -> PCI: NOT FOUND", gpu_idx);
210 }
211 }
212 }
213
214 println!("\n5. CPU/NUMA PCI ADDRESS RESOLUTION");
216 for numa_node in 0..4 {
217 let numa_str = numa_node.to_string();
218 match get_numa_pci_address(&numa_str) {
219 Some(pci_addr) => {
220 println!(" NUMA {} -> PCI: {}", numa_node, pci_addr);
221 }
222 None => {
223 println!(" NUMA {} -> PCI: NOT FOUND", numa_node);
224 }
225 }
226 }
227
228 println!("\n6. DISTANCE CALCULATION TEST (GPU 0)");
230 if let Some(gpu0_pci_addr) = get_cuda_pci_address("0") {
231 if let Some(gpu0_device) = pci_devices.get(&gpu0_pci_addr) {
232 println!("GPU 0 PCI: {}", gpu0_pci_addr);
233 println!("GPU 0 path to root: {:?}", gpu0_device.get_path_to_root());
234
235 let mut all_distances = Vec::new();
236 for (nic_name, nic_pci_addr) in &rdma_devices {
237 if let Some(nic_device) = pci_devices.get(nic_pci_addr) {
238 let distance = gpu0_device.distance_to(nic_device);
239 all_distances.push((distance, nic_name.clone(), nic_pci_addr.clone()));
240 println!(" {} ({}): distance = {}", nic_name, nic_pci_addr, distance);
241 println!(" NIC path to root: {:?}", nic_device.get_path_to_root());
242 }
243 }
244
245 all_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
247 if let Some((min_dist, min_nic, min_addr)) = all_distances.first() {
248 println!(
249 " → CLOSEST: {} ({}) with distance {}",
250 min_nic, min_addr, min_dist
251 );
252 }
253 }
254 }
255
256 println!("\n7. UNIFIED DEVICE SELECTION TEST");
258 let test_cases = [
259 ("cuda:0", "CUDA device 0"),
260 ("cuda:1", "CUDA device 1"),
261 ("cpu:0", "CPU/NUMA node 0"),
262 ("cpu:1", "CPU/NUMA node 1"),
263 ];
264
265 for (device_hint, description) in &test_cases {
266 let selected_device = select_optimal_ibv_device(Some(device_hint));
267 match selected_device {
268 Some(device) => {
269 println!(" {} ({}) -> {}", device_hint, description, device.name());
270 }
271 None => {
272 println!(" {} ({}) -> NOT FOUND", device_hint, description);
273 }
274 }
275 }
276
277 println!("\n8. GPU-TO-RDMA MAPPING VALIDATION (ALL 8 GPUs)");
279
280 let python_expected = [
282 (0, "mlx5_0"),
283 (1, "mlx5_3"),
284 (2, "mlx5_4"),
285 (3, "mlx5_5"),
286 (4, "mlx5_6"),
287 (5, "mlx5_9"),
288 (6, "mlx5_10"),
289 (7, "mlx5_11"),
290 ];
291
292 let mut rust_results = std::collections::HashMap::new();
293 let mut all_match = true;
294
295 for gpu_idx in 0..8 {
297 let cuda_hint = format!("cuda:{}", gpu_idx);
298 let selected_device = select_optimal_ibv_device(Some(&cuda_hint));
299
300 match selected_device {
301 Some(device) => {
302 let device_name = device.name().to_string();
303 rust_results.insert(gpu_idx, device_name.clone());
304 println!(" GPU {} -> {}", gpu_idx, device_name);
305 }
306 None => {
307 println!(" GPU {} -> NOT FOUND", gpu_idx);
308 rust_results.insert(gpu_idx, "NOT_FOUND".to_string());
309 }
310 }
311 }
312
313 println!("\n=== VALIDATION AGAINST EXPECTED RESULTS ===");
315 for (gpu_idx, expected_nic) in python_expected {
316 if let Some(actual_nic) = rust_results.get(&gpu_idx) {
317 let matches = actual_nic == expected_nic;
318 println!(
319 " GPU {} -> {} {} (expected {})",
320 gpu_idx,
321 actual_nic,
322 if matches { "✓" } else { "✗" },
323 expected_nic
324 );
325 all_match = all_match && matches;
326 } else {
327 println!(
328 " GPU {} -> NOT FOUND ✗ (expected {})",
329 gpu_idx, expected_nic
330 );
331 all_match = false;
332 }
333 }
334
335 if all_match {
336 println!("\n🎉 SUCCESS: All GPU-NIC pairings match expected GT20 hardware results!");
337 println!("✓ New unified API produces identical results to proven algorithm");
338 } else {
339 println!("\n⚠️ WARNING: Some GPU-NIC pairings differ from expected results");
340 println!(" This could indicate:");
341 println!(" - Hardware configuration differences");
342 println!(" - Algorithm implementation differences");
343 println!(" - Environment setup differences");
344 }
345
346 println!("\n9. DETAILED CPU DEVICE SELECTION ANALYSIS");
348
349 if let Some(numa0_addr) = get_numa_pci_address("0") {
351 println!(" NUMA 0 representative PCI: {}", numa0_addr);
352 } else {
353 println!(" NUMA 0 representative PCI: NOT FOUND");
354 }
355
356 if let Some(numa1_addr) = get_numa_pci_address("1") {
357 println!(" NUMA 1 representative PCI: {}", numa1_addr);
358 } else {
359 println!(" NUMA 1 representative PCI: NOT FOUND");
360 }
361
362 let cpu0_device = select_optimal_ibv_device(Some("cpu:0"));
364 let cpu1_device = select_optimal_ibv_device(Some("cpu:1"));
365
366 match (
367 cpu0_device.as_ref().map(|d| d.name()),
368 cpu1_device.as_ref().map(|d| d.name()),
369 ) {
370 (Some(cpu0_name), Some(cpu1_name)) => {
371 println!("\n FINAL SELECTIONS:");
372 println!(" CPU:0 -> {}", cpu0_name);
373 println!(" CPU:1 -> {}", cpu1_name);
374 if cpu0_name != cpu1_name {
375 println!(" ✓ Different NUMA nodes select different RDMA devices");
376 } else {
377 println!(" ⚠️ Same RDMA device selected for both NUMA nodes");
378 println!(" This could indicate:");
379 println!(
380 " - {} is genuinely closest to both NUMA nodes",
381 cpu0_name
382 );
383 println!(" - NUMA topology detection issue");
384 println!(" - Cross-NUMA penalty algorithm working correctly");
385 }
386 }
387 _ => {
388 println!(" ○ CPU device selection not available");
389 }
390 }
391
392 println!("\n✓ GT20 hardware test completed");
393
394 }
396}