monarch_rdma/
rdma_manager_actor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # RDMA Manager Actor
10//!
11//! Manages RDMA connections and operations using `hyperactor` for asynchronous messaging.
12//!
13//! ## Architecture
14//!
15//! `RdmaManagerActor` is a per-host entity that:
16//! - Manages connections to multiple remote RdmaManagerActors (i.e. across the hosts in a Monarch cluster)
17//! - Handles memory registration, connection setup, and data transfer
18//! - Manages all RdmaBuffers in its associated host
19//!
20//! ## Core Operations
21//!
22//! - Connection establishment with partner actors
23//! - RDMA operations (put/write, get/read)
24//! - Completion polling
25//! - Memory region management
26//!
27//! ## Usage
28//!
29//! See test examples: `test_rdma_write_loopback` and `test_rdma_read_loopback`.
30use std::collections::HashMap;
31use std::collections::HashSet;
32use std::sync::Arc;
33use std::time::Duration;
34use std::time::Instant;
35
36use async_trait::async_trait;
37use futures::lock::Mutex;
38use hyperactor::Actor;
39use hyperactor::ActorId;
40use hyperactor::ActorRef;
41use hyperactor::Context;
42use hyperactor::HandleClient;
43use hyperactor::Handler;
44use hyperactor::Instance;
45use hyperactor::Named;
46use hyperactor::OncePortRef;
47use hyperactor::RefClient;
48use hyperactor::clock::Clock;
49use hyperactor::supervision::ActorSupervisionEvent;
50use serde::Deserialize;
51use serde::Serialize;
52
53use crate::ibverbs_primitives::IbverbsConfig;
54use crate::ibverbs_primitives::RdmaMemoryRegionView;
55use crate::ibverbs_primitives::RdmaQpInfo;
56use crate::ibverbs_primitives::ibverbs_supported;
57use crate::ibverbs_primitives::mlx5dv_supported;
58use crate::ibverbs_primitives::resolve_qp_type;
59use crate::rdma_components::RdmaBuffer;
60use crate::rdma_components::RdmaDomain;
61use crate::rdma_components::RdmaQueuePair;
62use crate::rdma_components::get_registered_cuda_segments;
63use crate::validate_execution_context;
64
65/// Helper function to get detailed error messages from RDMAXCEL error codes
66pub fn get_rdmaxcel_error_message(error_code: i32) -> String {
67    unsafe {
68        let c_str = rdmaxcel_sys::rdmaxcel_error_string(error_code);
69        std::ffi::CStr::from_ptr(c_str)
70            .to_string_lossy()
71            .into_owned()
72    }
73}
74
75/// Represents a reference to a remote RDMA buffer that can be accessed via RDMA operations.
76/// This struct encapsulates all the information needed to identify and access a memory region
77/// on a remote host using RDMA.
78#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
79pub enum RdmaManagerMessage {
80    RequestBuffer {
81        addr: usize,
82        size: usize,
83        #[reply]
84        /// `reply` - Reply channel to return the RDMA buffer handle
85        reply: OncePortRef<RdmaBuffer>,
86    },
87    ReleaseBuffer {
88        buffer: RdmaBuffer,
89    },
90    RequestQueuePair {
91        other: ActorRef<RdmaManagerActor>,
92        self_device: String,
93        other_device: String,
94        #[reply]
95        /// `reply` - Reply channel to return the queue pair for communication
96        reply: OncePortRef<RdmaQueuePair>,
97    },
98    Connect {
99        /// `other` - The ActorId of the actor to connect to
100        other: ActorRef<RdmaManagerActor>,
101        self_device: String,
102        other_device: String,
103        /// `endpoint` - Connection information needed to establish the RDMA connection
104        endpoint: RdmaQpInfo,
105    },
106    InitializeQP {
107        other: ActorRef<RdmaManagerActor>,
108        self_device: String,
109        other_device: String,
110        #[reply]
111        /// `reply` - Reply channel to return the queue pair for communication
112        reply: OncePortRef<bool>,
113    },
114    ConnectionInfo {
115        /// `other` - The ActorId to get connection info for
116        other: ActorRef<RdmaManagerActor>,
117        self_device: String,
118        other_device: String,
119        #[reply]
120        /// `reply` - Reply channel to return the connection info
121        reply: OncePortRef<RdmaQpInfo>,
122    },
123    ReleaseQueuePair {
124        /// `other` - The ActorId to release queue pair for
125        other: ActorRef<RdmaManagerActor>,
126        self_device: String,
127        other_device: String,
128        /// `qp` - The queue pair to return (ownership transferred back)
129        qp: RdmaQueuePair,
130    },
131    GetQpState {
132        other: ActorRef<RdmaManagerActor>,
133        self_device: String,
134        other_device: String,
135        #[reply]
136        /// `reply` - Reply channel to return the QP state
137        reply: OncePortRef<u32>,
138    },
139}
140
141#[derive(Debug)]
142#[hyperactor::export(
143    spawn = true,
144    handlers = [
145        RdmaManagerMessage,
146    ],
147)]
148pub struct RdmaManagerActor {
149    // Nested map: local_device -> (ActorId, remote_device) -> RdmaQueuePair
150    device_qps: HashMap<String, HashMap<(ActorId, String), RdmaQueuePair>>,
151
152    // Track QPs currently being created to prevent duplicate creation
153    // Wrapped in Arc<Mutex> to allow safe concurrent access
154    pending_qp_creation: Arc<Mutex<HashSet<(String, ActorId, String)>>>,
155
156    // Map of RDMA device names to their domains and loopback QPs
157    // Created lazily when memory is registered for a specific device
158    device_domains: HashMap<String, (RdmaDomain, Option<RdmaQueuePair>)>,
159
160    config: IbverbsConfig,
161
162    // Flag indicating PyTorch CUDA allocator compatibility
163    // True if both C10 CUDA allocator is enabled AND expandable segments are enabled
164    pt_cuda_alloc: bool,
165
166    mlx5dv_enabled: bool,
167
168    // Map of unique RdmaMemoryRegionView to ibv_mr*.  In case of cuda w/ pytorch its -1
169    // since its managed independently.  Only used for registration/deregistration purposes
170    mr_map: HashMap<usize, usize>,
171    // Id for next mrv created
172    mrv_id: usize,
173
174    // Map of PCI addresses to their optimal RDMA devices
175    // This is populated during actor initialization using the device selection algorithm
176    pci_to_device: HashMap<String, crate::ibverbs_primitives::RdmaDevice>,
177}
178
179impl Drop for RdmaManagerActor {
180    fn drop(&mut self) {
181        // Helper function to manually destroy QP and CQs
182        // We can't use Drop on RdmaQueuePair because it derives Clone
183        fn destroy_queue_pair(qp: &RdmaQueuePair, context: &str) {
184            unsafe {
185                if qp.qp != 0 {
186                    let rdmaxcel_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
187                    rdmaxcel_sys::rdmaxcel_qp_destroy(rdmaxcel_qp);
188                }
189                if qp.send_cq != 0 {
190                    let result =
191                        rdmaxcel_sys::ibv_destroy_cq(qp.send_cq as *mut rdmaxcel_sys::ibv_cq);
192                    if result != 0 {
193                        tracing::debug!(
194                            "ibv_destroy_cq (send) returned {} for {} (may be busy during shutdown)",
195                            result,
196                            context
197                        );
198                    }
199                }
200                if qp.recv_cq != 0 {
201                    let result =
202                        rdmaxcel_sys::ibv_destroy_cq(qp.recv_cq as *mut rdmaxcel_sys::ibv_cq);
203                    if result != 0 {
204                        tracing::debug!(
205                            "ibv_destroy_cq (recv) returned {} for {} (may be busy during shutdown)",
206                            result,
207                            context
208                        );
209                    }
210                }
211            }
212        }
213
214        // 1. Clean up all queue pairs (both regular and loopback)
215        for (_device_name, device_map) in self.device_qps.drain() {
216            for ((actor_id, _remote_device), qp) in device_map {
217                destroy_queue_pair(&qp, &format!("actor {:?}", actor_id));
218            }
219        }
220
221        // 2. Clean up device domains (which contain PDs and loopback QPs)
222        for (device_name, (domain, qp)) in self.device_domains.drain() {
223            if let Some(qp) = qp {
224                destroy_queue_pair(&qp, &format!("loopback QP on device {}", device_name));
225            }
226            drop(domain);
227        }
228
229        // 3. Clean up memory regions
230        let _mr_count = self.mr_map.len();
231        for (id, mr_ptr) in self.mr_map.drain() {
232            if mr_ptr != 0 {
233                unsafe {
234                    let result = rdmaxcel_sys::ibv_dereg_mr(mr_ptr as *mut rdmaxcel_sys::ibv_mr);
235                    if result != 0 {
236                        tracing::error!(
237                            "Failed to deregister MR with id {}: error code {}",
238                            id,
239                            result
240                        );
241                    }
242                }
243            }
244        }
245
246        // 4. Deregister all CUDA segments (if using PyTorch CUDA allocator)
247        if self.cuda_pt_alloc_enabled() {
248            unsafe {
249                let result = rdmaxcel_sys::deregister_segments();
250                if result != 0 {
251                    let error_msg = get_rdmaxcel_error_message(result);
252                    tracing::error!(
253                        "Failed to deregister CUDA segments: {} (error code: {})",
254                        error_msg,
255                        result
256                    );
257                }
258            }
259        }
260    }
261}
262
263impl RdmaManagerActor {
264    /// Whether to register all memory regions allocated by the PyTorch CUDA allocator
265    /// True if both `pt_cuda_alloc` and `mlx5dv_enabled` are true
266    fn cuda_pt_alloc_enabled(&self) -> bool {
267        self.pt_cuda_alloc && self.mlx5dv_enabled
268    }
269    /// Get or create a domain and loopback QP for the specified RDMA device
270    fn get_or_create_device_domain(
271        &mut self,
272        device_name: &str,
273        rdma_device: &crate::ibverbs_primitives::RdmaDevice,
274    ) -> Result<(RdmaDomain, Option<RdmaQueuePair>), anyhow::Error> {
275        // Check if we already have a domain for this device
276        if let Some((domain, qp)) = self.device_domains.get(device_name) {
277            return Ok((domain.clone(), qp.clone()));
278        }
279
280        // Create new domain for this device
281        let domain = RdmaDomain::new(rdma_device.clone()).map_err(|e| {
282            anyhow::anyhow!("could not create domain for device {}: {}", device_name, e)
283        })?;
284
285        // Print device info if MONARCH_DEBUG_RDMA=1 is set (before initial QP creation)
286        crate::print_device_info_if_debug_enabled(domain.context);
287
288        // Create loopback QP for this domain if mlx5dv is supported
289        let qp = if mlx5dv_supported() {
290            let mut qp = RdmaQueuePair::new(domain.context, domain.pd, self.config.clone())
291                .map_err(|e| {
292                    anyhow::anyhow!(
293                        "could not create loopback QP for device {}: {}",
294                        device_name,
295                        e
296                    )
297                })?;
298
299            // Get connection info and connect to itself
300            let endpoint = qp.get_qp_info().map_err(|e| {
301                anyhow::anyhow!("could not get QP info for device {}: {}", device_name, e)
302            })?;
303
304            qp.connect(&endpoint).map_err(|e| {
305                anyhow::anyhow!(
306                    "could not connect loopback QP for device {}: {}",
307                    device_name,
308                    e
309                )
310            })?;
311
312            Some(qp)
313        } else {
314            None
315        };
316
317        self.device_domains
318            .insert(device_name.to_string(), (domain.clone(), qp.clone()));
319        Ok((domain, qp))
320    }
321
322    fn find_cuda_segment_for_address(
323        &mut self,
324        addr: usize,
325        size: usize,
326    ) -> Option<RdmaMemoryRegionView> {
327        let registered_segments = get_registered_cuda_segments();
328        for segment in registered_segments {
329            let start_addr = segment.phys_address;
330            let end_addr = start_addr + segment.phys_size;
331            if start_addr <= addr && addr + size <= end_addr {
332                let offset = addr - start_addr;
333                let rdma_addr = segment.mr_addr + offset;
334
335                let mrv = RdmaMemoryRegionView {
336                    id: self.mrv_id,
337                    virtual_addr: addr,
338                    rdma_addr,
339                    size,
340                    lkey: segment.lkey,
341                    rkey: segment.rkey,
342                };
343                self.mrv_id += 1;
344                return Some(mrv);
345            }
346        }
347        None
348    }
349
350    fn register_mr(
351        &mut self,
352        addr: usize,
353        size: usize,
354    ) -> Result<(RdmaMemoryRegionView, String), anyhow::Error> {
355        unsafe {
356            let mut mem_type: i32 = 0;
357            let ptr = addr as rdmaxcel_sys::CUdeviceptr;
358            let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
359                &mut mem_type as *mut _ as *mut std::ffi::c_void,
360                rdmaxcel_sys::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
361                ptr,
362            );
363            let is_cuda = err == rdmaxcel_sys::CUDA_SUCCESS;
364
365            let mut selected_rdma_device = None;
366
367            if is_cuda {
368                // Use rdmaxcel utility to get PCI address from CUDA pointer
369                let mut pci_addr_buf: [std::os::raw::c_char; 16] = [0; 16]; // Enough space for "ffff:ff:ff.0\0"
370                let err = rdmaxcel_sys::get_cuda_pci_address_from_ptr(
371                    addr as u64,
372                    pci_addr_buf.as_mut_ptr(),
373                    pci_addr_buf.len(),
374                );
375                if err != 0 {
376                    let error_msg = get_rdmaxcel_error_message(err);
377                    return Err(anyhow::anyhow!(
378                        "RdmaXcel get_cuda_pci_address_from_ptr failed (addr: 0x{:x}, size: {}): {}",
379                        addr,
380                        size,
381                        error_msg
382                    ));
383                }
384
385                // Convert C string to Rust string
386                let pci_addr = std::ffi::CStr::from_ptr(pci_addr_buf.as_ptr())
387                    .to_str()
388                    .unwrap();
389                selected_rdma_device = self.pci_to_device.get(pci_addr).cloned();
390            }
391
392            // Determine the RDMA device to use
393            let rdma_device = if let Some(device) = selected_rdma_device {
394                device
395            } else {
396                // Fallback to default device from config
397                self.config.device.clone()
398            };
399
400            let device_name = rdma_device.name().clone();
401            tracing::debug!(
402                "Using RDMA device: {} for memory at 0x{:x}",
403                device_name,
404                addr
405            );
406
407            // Get or create domain and loopback QP for this device
408            let (domain, qp) = self.get_or_create_device_domain(&device_name, &rdma_device)?;
409
410            let access = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
411                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
412                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
413                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
414
415            let mut mr: *mut rdmaxcel_sys::ibv_mr = std::ptr::null_mut();
416            let mrv;
417
418            if is_cuda && self.cuda_pt_alloc_enabled() {
419                // Get registered segments and check if our memory range is covered
420                let mut maybe_mrv = self.find_cuda_segment_for_address(addr, size);
421                // not found, lets re-sync with caching allocator  and retry
422                if maybe_mrv.is_none() {
423                    let err = rdmaxcel_sys::register_segments(
424                        domain.pd,
425                        qp.unwrap().qp as *mut rdmaxcel_sys::rdmaxcel_qp_t,
426                    );
427                    if err != 0 {
428                        let error_msg = get_rdmaxcel_error_message(err);
429                        return Err(anyhow::anyhow!(
430                            "RdmaXcel register_segments failed (addr: 0x{:x}, size: {}): {}",
431                            addr,
432                            size,
433                            error_msg
434                        ));
435                    }
436
437                    maybe_mrv = self.find_cuda_segment_for_address(addr, size);
438                }
439                // if still not found, throw exception
440                if maybe_mrv.is_none() {
441                    return Err(anyhow::anyhow!(
442                        "MR registration failed for cuda (addr: 0x{:x}, size: {}), unable to find segment in CudaCachingAllocator",
443                        addr,
444                        size
445                    ));
446                }
447                mrv = maybe_mrv.unwrap();
448            } else if is_cuda {
449                let mut fd: i32 = -1;
450                rdmaxcel_sys::rdmaxcel_cuMemGetHandleForAddressRange(
451                    &mut fd,
452                    addr as rdmaxcel_sys::CUdeviceptr,
453                    size,
454                    rdmaxcel_sys::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD,
455                    0,
456                );
457                mr = rdmaxcel_sys::ibv_reg_dmabuf_mr(domain.pd, 0, size, 0, fd, access.0 as i32);
458                if mr.is_null() {
459                    return Err(anyhow::anyhow!("Failed to register dmabuf MR"));
460                }
461                mrv = RdmaMemoryRegionView {
462                    id: self.mrv_id,
463                    virtual_addr: addr,
464                    rdma_addr: (*mr).addr as usize,
465                    size,
466                    lkey: (*mr).lkey,
467                    rkey: (*mr).rkey,
468                };
469                self.mrv_id += 1;
470            } else {
471                // CPU memory path
472                mr = rdmaxcel_sys::ibv_reg_mr(
473                    domain.pd,
474                    addr as *mut std::ffi::c_void,
475                    size,
476                    access.0 as i32,
477                );
478
479                if mr.is_null() {
480                    return Err(anyhow::anyhow!("failed to register standard MR"));
481                }
482
483                mrv = RdmaMemoryRegionView {
484                    id: self.mrv_id,
485                    virtual_addr: addr,
486                    rdma_addr: (*mr).addr as usize,
487                    size,
488                    lkey: (*mr).lkey,
489                    rkey: (*mr).rkey,
490                };
491                self.mrv_id += 1;
492            }
493            self.mr_map.insert(mrv.id, mr as usize);
494            Ok((mrv, device_name))
495        }
496    }
497
498    fn deregister_mr(&mut self, id: usize) -> Result<(), anyhow::Error> {
499        if let Some(mr_ptr) = self.mr_map.remove(&id) {
500            if mr_ptr != 0 {
501                unsafe {
502                    rdmaxcel_sys::ibv_dereg_mr(mr_ptr as *mut rdmaxcel_sys::ibv_mr);
503                }
504            }
505        }
506        Ok(())
507    }
508}
509
510#[async_trait]
511impl Actor for RdmaManagerActor {
512    type Params = Option<IbverbsConfig>;
513
514    async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
515        if !ibverbs_supported() {
516            return Err(anyhow::anyhow!(
517                "Cannot create RdmaManagerActor because RDMA is not supported on this machine"
518            ));
519        }
520
521        // Use provided config or default if none provided
522        let mut config = params.unwrap_or_default();
523        tracing::debug!("rdma is enabled, config device hint: {}", config.device);
524
525        let pt_cuda_alloc = crate::rdma_components::pt_cuda_allocator_compatibility();
526
527        let mlx5dv_enabled = resolve_qp_type(config.qp_type) == rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV;
528
529        // check config and hardware support align
530        if config.use_gpu_direct {
531            match validate_execution_context().await {
532                Ok(_) => {
533                    tracing::info!("GPU Direct RDMA execution context validated successfully");
534                }
535                Err(e) => {
536                    tracing::warn!(
537                        "GPU Direct RDMA execution context validation failed: {}. Downgrading to standard ibverbs mode.",
538                        e
539                    );
540                    config.use_gpu_direct = false;
541                }
542            }
543        }
544
545        // Build the CUDA to RDMA device mapping using device selection algorithm
546        let pci_to_device = crate::device_selection::create_cuda_to_rdma_mapping();
547        tracing::debug!(
548            "Built CUDA to RDMA device mapping with {} entries",
549            pci_to_device.len()
550        );
551
552        Ok(Self {
553            device_qps: HashMap::new(),
554            pending_qp_creation: Arc::new(Mutex::new(HashSet::new())),
555            device_domains: HashMap::new(),
556            config,
557            pt_cuda_alloc,
558            mlx5dv_enabled,
559            mr_map: HashMap::new(),
560            mrv_id: 0,
561            pci_to_device,
562        })
563    }
564
565    async fn init(&mut self, _this: &Instance<Self>) -> Result<(), anyhow::Error> {
566        tracing::debug!("RdmaManagerActor initialized with lazy domain/QP creation");
567        Ok(())
568    }
569
570    async fn handle_supervision_event(
571        &mut self,
572        _cx: &Instance<Self>,
573        _event: &ActorSupervisionEvent,
574    ) -> Result<bool, anyhow::Error> {
575        tracing::error!("rdmaManagerActor supervision event: {:?}", _event);
576        tracing::error!("rdmaManagerActor error occurred, stop the worker process, exit code: 1");
577        std::process::exit(1);
578    }
579}
580
581#[async_trait]
582#[hyperactor::forward(RdmaManagerMessage)]
583impl RdmaManagerMessageHandler for RdmaManagerActor {
584    /// Requests a buffer to be registered with the RDMA domain.
585    ///
586    /// This function registers a memory region with the RDMA domain and returns an `RdmaBuffer`
587    /// that encapsulates the necessary information for RDMA operations.
588    ///
589    /// # Arguments
590    ///
591    /// * `this` - The context of the actor requesting the buffer.
592    /// * `addr` - The starting address of the memory region to be registered.
593    /// * `size` - The size of the memory region to be registered.
594    ///
595    /// # Returns
596    ///
597    /// * `Result<RdmaBuffer, anyhow::Error>` - On success, returns an `RdmaBuffer` containing
598    ///   the registered memory region's details. On failure, returns an error.
599    async fn request_buffer(
600        &mut self,
601        cx: &Context<Self>,
602        addr: usize,
603        size: usize,
604    ) -> Result<RdmaBuffer, anyhow::Error> {
605        let (mrv, device_name) = self.register_mr(addr, size)?;
606
607        Ok(RdmaBuffer {
608            owner: cx.bind().clone(),
609            mr_id: mrv.id,
610            addr: mrv.rdma_addr,
611            size: mrv.size,
612            rkey: mrv.rkey,
613            lkey: mrv.lkey,
614            device_name,
615        })
616    }
617
618    /// Deregisters a buffer from the RDMA domain.
619    ///
620    /// This function removes the specified `RdmaBuffer` from the RDMA domain,
621    /// effectively releasing the resources associated with it.
622    ///
623    /// # Arguments
624    ///
625    /// * `_this` - The context of the actor releasing the buffer.
626    /// * `buffer` - The `RdmaBuffer` to be deregistered.
627    ///
628    /// # Returns
629    ///
630    /// * `Result<(), anyhow::Error>` - On success, returns `Ok(())`. On failure, returns an error.
631    async fn release_buffer(
632        &mut self,
633        _cx: &Context<Self>,
634        buffer: RdmaBuffer,
635    ) -> Result<(), anyhow::Error> {
636        self.deregister_mr(buffer.mr_id)
637            .map_err(|e| anyhow::anyhow!("could not deregister buffer: {}", e))?;
638        Ok(())
639    }
640
641    /// Requests a queue pair for communication with a remote RDMA manager actor.
642    ///
643    /// Basic logic: if queue pair exists in map, return it; if None, create connection first.
644    ///
645    /// # Arguments
646    ///
647    /// * `cx` - The context of the actor requesting the queue pair.
648    /// * `remote` - The ActorRef of the remote RDMA manager actor to communicate with.
649    ///
650    /// # Returns
651    ///
652    /// * `Result<RdmaQueuePair, anyhow::Error>` - On success, returns the queue pair for communication.
653    ///   On failure, returns an error.
654    async fn request_queue_pair(
655        &mut self,
656        cx: &Context<Self>,
657        other: ActorRef<RdmaManagerActor>,
658        self_device: String,
659        other_device: String,
660    ) -> Result<RdmaQueuePair, anyhow::Error> {
661        let other_id = other.actor_id().clone();
662
663        // Use the nested map structure: local_device -> (actor_id, remote_device) -> RdmaQueuePair
664        let inner_key = (other_id.clone(), other_device.clone());
665
666        // Check if queue pair exists in map
667        if let Some(device_map) = self.device_qps.get(&self_device) {
668            if let Some(qp) = device_map.get(&inner_key) {
669                return Ok(qp.clone());
670            }
671        }
672
673        // Try to acquire lock and mark as pending (hold lock only once!)
674        let pending_key = (self_device.clone(), other_id.clone(), other_device.clone());
675        let mut pending = self.pending_qp_creation.lock().await;
676
677        if pending.contains(&pending_key) {
678            // Another task is creating this QP, release lock and wait
679            drop(pending);
680
681            // Loop checking device_qps until QP is created (no more locks needed)
682            // Timeout after 1 second
683            let start = Instant::now();
684            let timeout = Duration::from_secs(1);
685
686            loop {
687                cx.clock().sleep(Duration::from_micros(200)).await;
688
689                // Check if QP was created while we waited
690                if let Some(device_map) = self.device_qps.get(&self_device) {
691                    if let Some(qp) = device_map.get(&inner_key) {
692                        return Ok(qp.clone());
693                    }
694                }
695
696                // Check for timeout
697                if start.elapsed() >= timeout {
698                    return Err(anyhow::anyhow!(
699                        "Timeout waiting for QP creation (device {} -> actor {} device {}). \
700                         Another task is creating it but hasn't completed in 1 second",
701                        self_device,
702                        other_id,
703                        other_device
704                    ));
705                }
706            }
707        } else {
708            // Not pending, add to set and proceed with creation
709            pending.insert(pending_key.clone());
710            drop(pending);
711            // Fall through to create QP
712        }
713
714        // Queue pair doesn't exist - need to create connection
715        let result = async {
716            let is_loopback = other_id == cx.bind::<RdmaManagerActor>().actor_id().clone()
717                && self_device == other_device;
718
719            if is_loopback {
720                // Loopback connection setup
721                self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone())
722                    .await?;
723                let endpoint = self
724                    .connection_info(cx, other.clone(), other_device.clone(), self_device.clone())
725                    .await?;
726                self.connect(
727                    cx,
728                    other.clone(),
729                    self_device.clone(),
730                    other_device.clone(),
731                    endpoint,
732                )
733                .await?;
734            } else {
735                // Remote connection setup
736                self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone())
737                    .await?;
738                other
739                    .initialize_qp(
740                        cx,
741                        cx.bind().clone(),
742                        other_device.clone(),
743                        self_device.clone(),
744                    )
745                    .await?;
746                let other_endpoint: RdmaQpInfo = other
747                    .connection_info(
748                        cx,
749                        cx.bind().clone(),
750                        other_device.clone(),
751                        self_device.clone(),
752                    )
753                    .await?;
754                self.connect(
755                    cx,
756                    other.clone(),
757                    self_device.clone(),
758                    other_device.clone(),
759                    other_endpoint,
760                )
761                .await?;
762                let local_endpoint = self
763                    .connection_info(cx, other.clone(), self_device.clone(), other_device.clone())
764                    .await?;
765                other
766                    .connect(
767                        cx,
768                        cx.bind().clone(),
769                        other_device.clone(),
770                        self_device.clone(),
771                        local_endpoint,
772                    )
773                    .await?;
774
775                // BARRIER: Ensure remote side has completed its connection and is ready
776                let remote_state = other
777                    .get_qp_state(
778                        cx,
779                        cx.bind().clone(),
780                        other_device.clone(),
781                        self_device.clone(),
782                    )
783                    .await?;
784
785                if remote_state != rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS {
786                    return Err(anyhow::anyhow!(
787                        "Remote QP not in RTS state after connection setup. \
788                         Local is ready but remote is in state {}. \
789                         This indicates a synchronization issue in connection setup.",
790                        remote_state
791                    ));
792                }
793            }
794
795            // Now that connection is established, get and clone the queue pair
796            if let Some(device_map) = self.device_qps.get(&self_device) {
797                if let Some(qp) = device_map.get(&inner_key) {
798                    Ok(qp.clone())
799                } else {
800                    Err(anyhow::anyhow!(
801                        "Failed to create connection for actor {} on device {}",
802                        other_id,
803                        other_device
804                    ))
805                }
806            } else {
807                Err(anyhow::anyhow!(
808                    "Failed to create connection for actor {} on device {} - no device map",
809                    other_id,
810                    other_device
811                ))
812            }
813        }
814        .await;
815
816        // Always remove from pending set when done (success or failure)
817        let mut pending = self.pending_qp_creation.lock().await;
818        pending.remove(&pending_key);
819        drop(pending);
820
821        result
822    }
823
824    async fn initialize_qp(
825        &mut self,
826        _cx: &Context<Self>,
827        other: ActorRef<RdmaManagerActor>,
828        self_device: String,
829        other_device: String,
830    ) -> Result<bool, anyhow::Error> {
831        let other_id = other.actor_id().clone();
832        let inner_key = (other_id.clone(), other_device.clone());
833
834        // Check if QP already exists in nested structure
835        if let Some(device_map) = self.device_qps.get(&self_device) {
836            if device_map.contains_key(&inner_key) {
837                return Ok(true);
838            }
839        }
840
841        // Resolve the RDMA device for the local device
842        let rdma_device = self
843            .pci_to_device
844            .iter()
845            .find(|(_, device)| device.name() == &self_device)
846            .map(|(_, device)| device.clone())
847            .unwrap_or_else(|| {
848                // Fallback to default device from config
849                crate::device_selection::resolve_rdma_device(&self.config.device)
850                    .unwrap_or_else(|| self.config.device.clone())
851            });
852
853        // Get or create domain and extract pointers to avoid borrowing issues
854        let (domain_context, domain_pd) = {
855            // Check if we already have a domain for the device
856            let (domain, _) = self.get_or_create_device_domain(&self_device, &rdma_device)?;
857            (domain.context, domain.pd)
858        };
859
860        let qp = RdmaQueuePair::new(domain_context, domain_pd, self.config.clone())
861            .map_err(|e| anyhow::anyhow!("could not create RdmaQueuePair: {}", e))?;
862
863        // Insert the QP into the nested map structure
864        self.device_qps
865            .entry(self_device.clone())
866            .or_insert_with(HashMap::new)
867            .insert(inner_key, qp);
868
869        tracing::debug!(
870            "successfully created a connection with {:?} for local device {} -> remote device {}",
871            other,
872            self_device,
873            other_device
874        );
875
876        Ok(true)
877    }
878
879    /// Establishes a connection with another actor
880    ///
881    /// # Arguments
882    /// * `other` - The ActorRef of the actor to connect to
883    /// * `endpoint` - Connection information needed to establish the RDMA connection
884    async fn connect(
885        &mut self,
886        _cx: &Context<Self>,
887        other: ActorRef<RdmaManagerActor>,
888        self_device: String,
889        other_device: String,
890        endpoint: RdmaQpInfo,
891    ) -> Result<(), anyhow::Error> {
892        tracing::debug!("connecting with {:?}", other);
893        let other_id = other.actor_id().clone();
894
895        let inner_key = (other_id.clone(), other_device.clone());
896
897        if let Some(device_map) = self.device_qps.get_mut(&self_device) {
898            match device_map.get_mut(&inner_key) {
899                Some(qp) => {
900                    qp.connect(&endpoint).map_err(|e| {
901                        anyhow::anyhow!("could not connect to RDMA endpoint: {}", e)
902                    })?;
903                    Ok(())
904                }
905                None => Err(anyhow::anyhow!(
906                    "No connection found for actor {}",
907                    other_id
908                )),
909            }
910        } else {
911            Err(anyhow::anyhow!(
912                "No device map found for device {}",
913                self_device
914            ))
915        }
916    }
917
918    /// Gets connection information for establishing an RDMA connection
919    ///
920    /// # Arguments
921    /// * `other` - The ActorRef to get connection info for
922    ///
923    /// # Returns
924    /// * `RdmaQpInfo` - Connection information needed for the RDMA connection
925    async fn connection_info(
926        &mut self,
927        _cx: &Context<Self>,
928        other: ActorRef<RdmaManagerActor>,
929        self_device: String,
930        other_device: String,
931    ) -> Result<RdmaQpInfo, anyhow::Error> {
932        tracing::debug!("getting connection info with {:?}", other);
933        let other_id = other.actor_id().clone();
934
935        let inner_key = (other_id.clone(), other_device.clone());
936
937        if let Some(device_map) = self.device_qps.get_mut(&self_device) {
938            match device_map.get_mut(&inner_key) {
939                Some(qp) => {
940                    let connection_info = qp.get_qp_info()?;
941                    Ok(connection_info)
942                }
943                None => Err(anyhow::anyhow!(
944                    "No connection found for actor {}",
945                    other_id
946                )),
947            }
948        } else {
949            Err(anyhow::anyhow!(
950                "No device map found for self device {}",
951                self_device
952            ))
953        }
954    }
955
956    /// Releases a queue pair back to the HashMap
957    ///
958    /// This method is now a no-op since RdmaQueuePair is Clone and can be safely shared.
959    /// The queue pair is not actually checked out, so there's nothing to release.
960    /// This method is kept for API compatibility.
961    ///
962    /// # Arguments
963    /// * `remote` - The ActorRef of the remote actor to return the queue pair for
964    /// * `qp` - The queue pair to release (ignored)
965    async fn release_queue_pair(
966        &mut self,
967        _cx: &Context<Self>,
968        _other: ActorRef<RdmaManagerActor>,
969        _self_device: String,
970        _other_device: String,
971        _qp: RdmaQueuePair,
972    ) -> Result<(), anyhow::Error> {
973        // No-op: Queue pairs are now cloned and shared via atomic counters
974        // Nothing needs to be released
975        Ok(())
976    }
977
978    /// Gets the state of a queue pair
979    ///
980    /// # Arguments
981    /// * `other` - The ActorRef to get the QP state for
982    /// * `self_device` - Local device name
983    /// * `other_device` - Remote device name
984    ///
985    /// # Returns
986    /// * `u32` - The QP state (e.g., IBV_QPS_RTS = Ready To Send)
987    async fn get_qp_state(
988        &mut self,
989        _cx: &Context<Self>,
990        other: ActorRef<RdmaManagerActor>,
991        self_device: String,
992        other_device: String,
993    ) -> Result<u32, anyhow::Error> {
994        let other_id = other.actor_id().clone();
995        let inner_key = (other_id.clone(), other_device.clone());
996
997        if let Some(device_map) = self.device_qps.get_mut(&self_device) {
998            match device_map.get_mut(&inner_key) {
999                Some(qp) => qp.state(),
1000                None => Err(anyhow::anyhow!(
1001                    "No connection found for actor {} on device {}",
1002                    other_id,
1003                    other_device
1004                )),
1005            }
1006        } else {
1007            Err(anyhow::anyhow!(
1008                "No device map found for self device {}",
1009                self_device
1010            ))
1011        }
1012    }
1013}