monarch_rdma/backend/ibverbs/
manager_actor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # Ibverbs Manager
10//!
11//! Contains ibverbs-specific RDMA logic.
12//!
13//! Manages ibverbs resources including:
14//! - Memory registration (CPU and CUDA via dmabuf or segment scanning)
15//! - Queue pair creation and connection establishment
16//! - RDMA domain and protection domain management
17//! - Device selection and PCI-to-RDMA device mapping
18
19use std::collections::HashMap;
20use std::collections::HashSet;
21use std::sync::Arc;
22use std::sync::OnceLock;
23use std::time::Duration;
24use std::time::Instant;
25
26use anyhow::Result;
27use async_trait::async_trait;
28use futures::lock::Mutex;
29use hyperactor::Actor;
30use hyperactor::ActorHandle;
31use hyperactor::Context;
32use hyperactor::HandleClient;
33use hyperactor::Handler;
34use hyperactor::Instance;
35use hyperactor::OncePortHandle;
36use hyperactor::RefClient;
37use hyperactor::reference;
38use serde::Deserialize;
39use serde::Serialize;
40use typeuri::Named;
41
42use super::IbvBuffer;
43use super::IbvOp;
44use super::domain::IbvDomain;
45use super::primitives::IbvConfig;
46use super::primitives::IbvDevice;
47use super::primitives::IbvMemoryRegionView;
48use super::primitives::IbvQpInfo;
49use super::primitives::ibverbs_supported;
50use super::primitives::mlx5dv_supported;
51use super::primitives::resolve_qp_type;
52use super::queue_pair::IbvQueuePair;
53use super::queue_pair::PollTarget;
54use crate::RdmaOp;
55use crate::RdmaOpType;
56use crate::RdmaTransportLevel;
57use crate::backend::RdmaBackend;
58use crate::rdma_components::get_registered_cuda_segments;
59use crate::rdma_manager_actor::GetIbvActorRefClient;
60use crate::rdma_manager_actor::RdmaManagerActor;
61use crate::rdma_manager_actor::RdmaManagerMessageClient;
62use crate::rdma_manager_actor::get_rdmaxcel_error_message;
63use crate::validate_execution_context;
64
65/// Messages handled by [`IbvManagerActor`].
66#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
67pub enum IbvManagerMessage {
68    /// Register the MR for a buffer identified by `remote_buf_id`. Resolves
69    /// the local memory via the parent [`RdmaManagerActor`]'s
70    /// `RequestLocalMemory`, registers it as an ibverbs MR, and returns
71    /// the resulting [`IbvBuffer`].
72    ///
73    /// Returns `None` if the buffer has already been released or does not
74    /// exist.
75    RequestBuffer {
76        remote_buf_id: usize,
77        #[reply]
78        reply: reference::OncePortRef<Option<IbvBuffer>>,
79    },
80    /// Release a buffer registration by `remote_buf_id`.
81    ReleaseBuffer {
82        remote_buf_id: usize,
83        #[reply]
84        reply: reference::OncePortRef<()>,
85    },
86    RequestQueuePair {
87        other: reference::ActorRef<IbvManagerActor>,
88        self_device: String,
89        other_device: String,
90        #[reply]
91        reply: reference::OncePortRef<IbvQueuePair>,
92    },
93    Connect {
94        other: reference::ActorRef<IbvManagerActor>,
95        self_device: String,
96        other_device: String,
97        endpoint: IbvQpInfo,
98    },
99    InitializeQP {
100        other: reference::ActorRef<IbvManagerActor>,
101        self_device: String,
102        other_device: String,
103        #[reply]
104        reply: reference::OncePortRef<bool>,
105    },
106    ConnectionInfo {
107        other: reference::ActorRef<IbvManagerActor>,
108        self_device: String,
109        other_device: String,
110        #[reply]
111        reply: reference::OncePortRef<IbvQpInfo>,
112    },
113    ReleaseQueuePair {
114        other: reference::ActorRef<IbvManagerActor>,
115        self_device: String,
116        other_device: String,
117        qp: IbvQueuePair,
118    },
119    GetQpState {
120        other: reference::ActorRef<IbvManagerActor>,
121        self_device: String,
122        other_device: String,
123        #[reply]
124        reply: reference::OncePortRef<u32>,
125    },
126}
127wirevalue::register_type!(IbvManagerMessage);
128
129/// Local-only submit message for the [`IbvManagerActor`].
130///
131/// Not serializable because [`IbvOp`] contains `Arc<dyn RdmaLocalMemory>`.
132/// Can only be sent within the same process.
133#[derive(Handler, HandleClient, Debug)]
134pub struct IbvSubmit {
135    pub ops: Vec<IbvOp>,
136    pub timeout: Duration,
137    #[reply]
138    pub reply: OncePortHandle<Result<(), String>>,
139}
140
141/// Manages all ibverbs-specific RDMA resources and operations.
142///
143/// This struct handles memory registration, queue pair management,
144/// and connection establishment using the ibverbs API.
145#[derive(Debug)]
146#[hyperactor::export(
147    handlers = [
148        IbvManagerMessage,
149    ],
150)]
151pub struct IbvManagerActor {
152    owner: OnceLock<ActorHandle<RdmaManagerActor>>,
153
154    // Nested map: local_device -> (ActorId, remote_device) -> IbvQueuePair
155    device_qps: HashMap<String, HashMap<(reference::ActorId, String), IbvQueuePair>>,
156
157    // Track QPs currently being created to prevent duplicate creation
158    // Wrapped in Arc<Mutex> to allow safe concurrent access
159    pending_qp_creation: Arc<Mutex<HashSet<(String, reference::ActorId, String)>>>,
160
161    // Map of RDMA device names to their domains and loopback QPs
162    // Created lazily when memory is registered for a specific device
163    device_domains: HashMap<String, (IbvDomain, Option<IbvQueuePair>)>,
164
165    config: IbvConfig,
166
167    mlx5dv_enabled: bool,
168
169    // Map of unique IbvMemoryRegionView to ibv_mr*.  In case of cuda w/ pytorch its -1
170    // since its managed independently.  Only used for registration/deregistration purposes
171    mr_map: HashMap<usize, usize>,
172
173    // Id for next mrv created
174    mrv_id: usize,
175
176    // Map of PCI addresses to their optimal RDMA devices
177    // This is populated during actor initialization using the device selection algorithm
178    pci_to_device: HashMap<String, IbvDevice>,
179
180    // Map from buffer_id to registration details.
181    buffer_registrations: HashMap<usize, IbvBuffer>,
182}
183
184#[async_trait]
185impl Actor for IbvManagerActor {
186    async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
187        let owner = if let Some(owner) = this.parent_handle() {
188            owner
189        } else {
190            anyhow::bail!("RdmaManagerActor not found as parent of IbvManagerActor");
191        };
192        self.owner
193            .set(owner)
194            .expect("owner should only be set once during init");
195        Ok(())
196    }
197}
198
199impl Drop for IbvManagerActor {
200    fn drop(&mut self) {
201        // Helper function to destroy QP resources
202        // We can't use Drop on IbvQueuePair because it derives Clone
203        // Note: rdmaxcel_qp_destroy handles destroying both the QP and its CQs internally,
204        // so we must NOT call ibv_destroy_cq separately (would cause double-free/SIGSEGV)
205        fn destroy_queue_pair(qp: &IbvQueuePair, _context: &str) {
206            unsafe {
207                if qp.qp != 0 {
208                    let rdmaxcel_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
209                    rdmaxcel_sys::rdmaxcel_qp_destroy(rdmaxcel_qp);
210                }
211            }
212        }
213
214        // 1. Clean up all queue pairs (both regular and loopback)
215        for (_device_name, device_map) in self.device_qps.drain() {
216            for ((actor_id, _remote_device), qp) in device_map {
217                destroy_queue_pair(&qp, &format!("actor {:?}", actor_id));
218            }
219        }
220
221        // 2. Clean up device domains (which contain PDs and loopback QPs)
222        for (device_name, (domain, qp)) in self.device_domains.drain() {
223            if let Some(qp) = qp {
224                destroy_queue_pair(&qp, &format!("loopback QP on device {}", device_name));
225            }
226            drop(domain);
227        }
228
229        // 3. Clean up memory regions
230        let _mr_count = self.mr_map.len();
231        for (id, mr_ptr) in self.mr_map.drain() {
232            if mr_ptr != 0 {
233                unsafe {
234                    let result = rdmaxcel_sys::ibv_dereg_mr(mr_ptr as *mut rdmaxcel_sys::ibv_mr);
235                    if result != 0 {
236                        tracing::error!(
237                            "Failed to deregister MR with id {}: error code {}",
238                            id,
239                            result
240                        );
241                    }
242                }
243            }
244        }
245
246        // 4. Deregister all CUDA segments (if using mlx5dv)
247        // The segment scanner in Python handles compatibility checks
248        if self.mlx5dv_enabled {
249            unsafe {
250                let result = rdmaxcel_sys::deregister_segments();
251                if result != 0 {
252                    let error_msg = get_rdmaxcel_error_message(result);
253                    tracing::error!(
254                        "Failed to deregister CUDA segments: {} (error code: {})",
255                        error_msg,
256                        result
257                    );
258                }
259            }
260        }
261    }
262}
263
264impl IbvManagerActor {
265    /// Construct an [`ActorHandle`] for the [`IbvManagerActor`] co-located
266    /// with the caller by querying the local [`RdmaManagerActor`].
267    pub async fn local_handle(
268        client: &(impl hyperactor::context::Actor + Send + Sync),
269    ) -> Result<ActorHandle<Self>, anyhow::Error> {
270        let rdma_handle = RdmaManagerActor::local_handle(client);
271        let ibv_ref = rdma_handle
272            .get_ibv_actor_ref(client)
273            .await?
274            .ok_or_else(|| anyhow::anyhow!("local RdmaManagerActor has no ibverbs backend"))?;
275        ibv_ref
276            .downcast_handle(client)
277            .ok_or_else(|| anyhow::anyhow!("IbvManagerActor is not in the local process"))
278    }
279
280    /// Create a new IbvManagerActor with the given configuration.
281    pub async fn new(params: Option<IbvConfig>) -> Result<Self, anyhow::Error> {
282        if !ibverbs_supported() {
283            return Err(anyhow::anyhow!(
284                "Cannot create IbvManagerActor because RDMA is not supported on this machine"
285            ));
286        }
287
288        // Use provided config or default if none provided
289        let mut config = params.unwrap_or_default();
290        tracing::debug!("rdma is enabled, config device hint: {}", config.device);
291
292        let mlx5dv_enabled = resolve_qp_type(config.qp_type) == rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV;
293
294        // check config and hardware support align
295        if config.use_gpu_direct {
296            match validate_execution_context().await {
297                Ok(_) => {
298                    tracing::info!("GPU Direct RDMA execution context validated successfully");
299                }
300                Err(e) => {
301                    tracing::warn!(
302                        "GPU Direct RDMA execution context validation failed: {}. Downgrading to standard ibverbs mode.",
303                        e
304                    );
305                    config.use_gpu_direct = false;
306                }
307            }
308        }
309
310        // Build the CUDA to RDMA device mapping using device selection algorithm
311        let pci_to_device = super::device_selection::create_cuda_to_ibv_mapping();
312        tracing::debug!(
313            "Built CUDA to RDMA device mapping with {} entries",
314            pci_to_device.len()
315        );
316
317        Ok(Self {
318            owner: OnceLock::new(),
319            device_qps: HashMap::new(),
320            pending_qp_creation: Arc::new(Mutex::new(HashSet::new())),
321            device_domains: HashMap::new(),
322            config,
323            mlx5dv_enabled,
324            mr_map: HashMap::new(),
325            mrv_id: 0,
326            pci_to_device,
327            buffer_registrations: HashMap::new(),
328        })
329    }
330
331    /// Get or create a domain and loopback QP for the specified RDMA device
332    fn get_or_create_device_domain(
333        &mut self,
334        device_name: &str,
335        rdma_device: &IbvDevice,
336    ) -> Result<(IbvDomain, Option<IbvQueuePair>), anyhow::Error> {
337        if let Some((domain, qp)) = self.device_domains.get(device_name) {
338            return Ok((domain.clone(), qp.clone()));
339        }
340
341        // Create new domain for this device
342        let domain = IbvDomain::new(rdma_device.clone()).map_err(|e| {
343            anyhow::anyhow!("could not create domain for device {}: {}", device_name, e)
344        })?;
345
346        // Print device info if MONARCH_DEBUG_RDMA=1 is set (before initial QP creation)
347        crate::print_device_info_if_debug_enabled(domain.context);
348
349        // Create loopback QP for this domain if mlx5dv is supported (needed for segment registration)
350        // For EFA, we don't need a loopback QP for segment scanning
351        let qp = if mlx5dv_supported() && !crate::efa::is_efa_device() {
352            let mut qp = IbvQueuePair::new(domain.context, domain.pd, self.config.clone())
353                .map_err(|e| {
354                    anyhow::anyhow!(
355                        "could not create loopback QP for device {}: {}",
356                        device_name,
357                        e
358                    )
359                })?;
360
361            // Get connection info and connect to itself
362            let endpoint = qp.get_qp_info().map_err(|e| {
363                anyhow::anyhow!("could not get QP info for device {}: {}", device_name, e)
364            })?;
365
366            qp.connect(&endpoint).map_err(|e| {
367                anyhow::anyhow!(
368                    "could not connect loopback QP for device {}: {}",
369                    device_name,
370                    e
371                )
372            })?;
373
374            Some(qp)
375        } else {
376            None
377        };
378
379        self.device_domains
380            .insert(device_name.to_string(), (domain.clone(), qp.clone()));
381        Ok((domain, qp))
382    }
383
384    fn find_cuda_segment_for_address(
385        &mut self,
386        addr: usize,
387        size: usize,
388    ) -> Option<IbvMemoryRegionView> {
389        let registered_segments = get_registered_cuda_segments();
390        for segment in registered_segments {
391            let start_addr = segment.phys_address;
392            let end_addr = start_addr + segment.phys_size;
393            if start_addr <= addr && addr + size <= end_addr {
394                let offset = addr - start_addr;
395                let rdma_addr = segment.mr_addr + offset;
396
397                let mrv = IbvMemoryRegionView {
398                    id: self.mrv_id,
399                    virtual_addr: addr,
400                    rdma_addr,
401                    size,
402                    lkey: segment.lkey,
403                    rkey: segment.rkey,
404                };
405                self.mrv_id += 1;
406                return Some(mrv);
407            }
408        }
409        None
410    }
411
412    fn register_mr(
413        &mut self,
414        addr: usize,
415        size: usize,
416    ) -> Result<(IbvMemoryRegionView, String), anyhow::Error> {
417        unsafe {
418            let mut mem_type: i32 = 0;
419            let ptr = addr as rdmaxcel_sys::CUdeviceptr;
420            let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
421                &mut mem_type as *mut _ as *mut std::ffi::c_void,
422                rdmaxcel_sys::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
423                ptr,
424            );
425            let is_cuda = err == rdmaxcel_sys::CUDA_SUCCESS;
426
427            let mut selected_rdma_device = None;
428
429            if is_cuda {
430                // Use rdmaxcel utility to get PCI address from CUDA pointer
431                let mut pci_addr_buf: [std::os::raw::c_char; 16] = [0; 16]; // Enough space for "ffff:ff:ff.0\0"
432                let err = rdmaxcel_sys::get_cuda_pci_address_from_ptr(
433                    addr as rdmaxcel_sys::CUdeviceptr,
434                    pci_addr_buf.as_mut_ptr(),
435                    pci_addr_buf.len(),
436                );
437                if err != 0 {
438                    let error_msg = get_rdmaxcel_error_message(err);
439                    return Err(anyhow::anyhow!(
440                        "RdmaXcel get_cuda_pci_address_from_ptr failed (addr: 0x{:x}, size: {}): {}",
441                        addr,
442                        size,
443                        error_msg
444                    ));
445                }
446
447                // Convert C string to Rust string
448                let pci_addr = std::ffi::CStr::from_ptr(pci_addr_buf.as_ptr())
449                    .to_str()
450                    .unwrap();
451                selected_rdma_device = self.pci_to_device.get(pci_addr).cloned();
452            }
453
454            // Determine the RDMA device to use
455            let rdma_device = if let Some(device) = selected_rdma_device {
456                device
457            } else {
458                // Fallback to default device from config
459                self.config.device.clone()
460            };
461
462            let device_name = rdma_device.name().clone();
463            tracing::debug!(
464                "Using RDMA device: {} for memory at 0x{:x}",
465                device_name,
466                addr
467            );
468
469            // Get or create domain and loopback QP for this device
470            let (domain, qp) = self.get_or_create_device_domain(&device_name, &rdma_device)?;
471
472            let access = if crate::efa::is_efa_device() {
473                crate::efa::mr_access_flags()
474            } else {
475                rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
476                    | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
477                    | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
478                    | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC
479            };
480
481            let mut mr: *mut rdmaxcel_sys::ibv_mr = std::ptr::null_mut();
482            let mrv;
483
484            if is_cuda {
485                // First, try to use segment scanning if mlx5dv is enabled
486                let mut segment_mrv = None;
487                if self.mlx5dv_enabled {
488                    // Try to find in already registered segments
489                    segment_mrv = self.find_cuda_segment_for_address(addr, size);
490
491                    // If not found, trigger a re-sync with the allocator and retry
492                    if segment_mrv.is_none() {
493                        let err = rdmaxcel_sys::register_segments(
494                            domain.pd,
495                            qp.unwrap().qp as *mut rdmaxcel_sys::rdmaxcel_qp_t,
496                        );
497                        // Only retry if register_segments succeeded
498                        // If it fails (e.g., scanner returns 0 segments), we'll fall back to dmabuf
499                        if err == 0 {
500                            segment_mrv = self.find_cuda_segment_for_address(addr, size);
501                        }
502                    }
503                }
504
505                // Use segment if found, otherwise fall back to direct dmabuf registration
506                if let Some(mrv_from_segment) = segment_mrv {
507                    mrv = mrv_from_segment;
508                } else {
509                    // Dmabuf path: used when mlx5dv is disabled OR scanner returns no segments
510                    let mut fd: i32 = -1;
511                    rdmaxcel_sys::rdmaxcel_cuMemGetHandleForAddressRange(
512                        &mut fd,
513                        addr as rdmaxcel_sys::CUdeviceptr,
514                        size,
515                        rdmaxcel_sys::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD,
516                        0,
517                    );
518                    mr =
519                        rdmaxcel_sys::ibv_reg_dmabuf_mr(domain.pd, 0, size, 0, fd, access.0 as i32);
520                    if mr.is_null() {
521                        return Err(anyhow::anyhow!("Failed to register dmabuf MR"));
522                    }
523                    mrv = IbvMemoryRegionView {
524                        id: self.mrv_id,
525                        virtual_addr: addr,
526                        rdma_addr: (*mr).addr as usize,
527                        size,
528                        lkey: (*mr).lkey,
529                        rkey: (*mr).rkey,
530                    };
531                    self.mrv_id += 1;
532                }
533            } else {
534                // CPU memory path
535                mr = rdmaxcel_sys::ibv_reg_mr(
536                    domain.pd,
537                    addr as *mut std::ffi::c_void,
538                    size,
539                    access.0 as i32,
540                );
541
542                if mr.is_null() {
543                    return Err(anyhow::anyhow!("failed to register standard MR"));
544                }
545
546                mrv = IbvMemoryRegionView {
547                    id: self.mrv_id,
548                    virtual_addr: addr,
549                    rdma_addr: (*mr).addr as usize,
550                    size,
551                    lkey: (*mr).lkey,
552                    rkey: (*mr).rkey,
553                };
554                self.mrv_id += 1;
555            }
556            self.mr_map.insert(mrv.id, mr as usize);
557            Ok((mrv, device_name))
558        }
559    }
560
561    fn deregister_mr(&mut self, id: usize) -> Result<(), anyhow::Error> {
562        if let Some(mr_ptr) = self.mr_map.remove(&id) {
563            if mr_ptr != 0 {
564                unsafe {
565                    rdmaxcel_sys::ibv_dereg_mr(mr_ptr as *mut rdmaxcel_sys::ibv_mr);
566                }
567            }
568        }
569        Ok(())
570    }
571
572    /// Waits for the completion of RDMA operations.
573    ///
574    /// Polls the completion queue until all specified work requests complete
575    /// or until the timeout is reached.
576    ///
577    /// # Arguments
578    /// * `local_buf` - The ibv buffer describing the local side of the RDMA ops
579    /// * `qp` - The RDMA Queue Pair to poll for completion
580    /// * `poll_target` - Which CQ to poll (Send or Recv)
581    /// * `expected_wr_ids` - The work request IDs to wait for
582    /// * `timeout` - Timeout for the RDMA operation to complete.
583    async fn wait_for_completion(
584        &self,
585        local_buf: &IbvBuffer,
586        qp: &mut IbvQueuePair,
587        poll_target: PollTarget,
588        expected_wr_ids: &[u64],
589        timeout: Duration,
590    ) -> Result<(), anyhow::Error> {
591        let start_time = std::time::Instant::now();
592
593        let mut remaining: std::collections::HashSet<u64> =
594            expected_wr_ids.iter().copied().collect();
595
596        while start_time.elapsed() < timeout {
597            if remaining.is_empty() {
598                return Ok(());
599            }
600
601            let wr_ids_to_poll: Vec<u64> = remaining.iter().copied().collect();
602            match qp.poll_completion(poll_target, &wr_ids_to_poll) {
603                Ok(completions) => {
604                    for (wr_id, _wc) in completions {
605                        remaining.remove(&wr_id);
606                    }
607                    if remaining.is_empty() {
608                        return Ok(());
609                    }
610                    tokio::time::sleep(Duration::from_millis(1)).await;
611                }
612                Err(e) => {
613                    return Err(anyhow::anyhow!(
614                        "RDMA polling completion failed: {} [lkey={}, rkey={}, addr=0x{:x}, size={}]",
615                        e,
616                        local_buf.lkey,
617                        local_buf.rkey,
618                        local_buf.addr,
619                        local_buf.size
620                    ));
621                }
622            }
623        }
624        tracing::error!(
625            "timed out while waiting on request completion for wr_ids={:?}",
626            remaining
627        );
628        Err(anyhow::anyhow!(
629            "[ibv_buffer({:?})] rdma operation did not complete in time (expected wr_ids={:?})",
630            local_buf,
631            expected_wr_ids
632        ))
633    }
634
635    /// Core submit logic: registers local MR, resolves remote
636    /// IbvBuffer lazily, executes the op, and deregisters local MR.
637    async fn execute_op(
638        &mut self,
639        cx: &Context<'_, Self>,
640        op: IbvOp,
641        timeout: Duration,
642    ) -> Result<(), anyhow::Error> {
643        // Register the local memory to get an IbvBuffer
644        let (local_mrv, local_device_name) =
645            self.register_mr(op.local_memory.addr(), op.local_memory.size())?;
646        let local_buffer = IbvBuffer {
647            mr_id: local_mrv.id,
648            lkey: local_mrv.lkey,
649            rkey: local_mrv.rkey,
650            addr: local_mrv.rdma_addr,
651            size: local_mrv.size,
652            device_name: local_device_name,
653        };
654
655        let op_result = async {
656            let mut qp = self
657                .request_queue_pair(
658                    cx,
659                    op.remote_manager.clone(),
660                    local_buffer.device_name.clone(),
661                    op.remote_buffer.device_name.clone(),
662                )
663                .await?;
664
665            let wr_id = match op.op_type {
666                RdmaOpType::WriteFromLocal => qp.put(local_buffer.clone(), op.remote_buffer)?,
667                RdmaOpType::ReadIntoLocal => qp.get(local_buffer.clone(), op.remote_buffer)?,
668            };
669
670            self.wait_for_completion(&local_buffer, &mut qp, PollTarget::Send, &wr_id, timeout)
671                .await
672        }
673        .await;
674
675        // Always deregister the locally registered MR
676        let dereg_result = self.deregister_mr(local_buffer.mr_id);
677
678        match (op_result, dereg_result) {
679            (Ok(()), Ok(())) => Ok(()),
680            (Err(e), Ok(())) => Err(e),
681            (Ok(()), Err(e)) => Err(e),
682            (Err(op_err), Err(dereg_err)) => Err(anyhow::anyhow!(
683                "deregister MR error: {}; op error: {}",
684                dereg_err,
685                op_err
686            )),
687        }
688    }
689}
690
691#[async_trait]
692#[hyperactor::handle(IbvManagerMessage)]
693impl IbvManagerMessageHandler for IbvManagerActor {
694    async fn request_buffer(
695        &mut self,
696        cx: &Context<Self>,
697        remote_buf_id: usize,
698    ) -> Result<Option<IbvBuffer>, anyhow::Error> {
699        // If already registered, return it
700        if let Some(buf) = self.buffer_registrations.get(&remote_buf_id) {
701            return Ok(Some(buf.clone()));
702        }
703
704        // Resolve local memory from the parent RdmaManagerActor.
705        // Returns None if the buffer has already been released or does
706        // not exist.
707        let owner = self.owner.get().unwrap();
708        let mem = match owner.request_local_memory(cx, remote_buf_id).await? {
709            Some(mem) => mem,
710            None => return Ok(None),
711        };
712
713        let (mrv, device_name) = self.register_mr(mem.addr(), mem.size())?;
714
715        let buf = IbvBuffer {
716            mr_id: mrv.id,
717            lkey: mrv.lkey,
718            rkey: mrv.rkey,
719            addr: mrv.rdma_addr,
720            size: mrv.size,
721            device_name,
722        };
723
724        self.buffer_registrations.insert(remote_buf_id, buf.clone());
725
726        Ok(Some(buf))
727    }
728
729    async fn release_buffer(
730        &mut self,
731        _cx: &Context<Self>,
732        remote_buf_id: usize,
733    ) -> Result<(), anyhow::Error> {
734        if let Some(buf) = self.buffer_registrations.remove(&remote_buf_id) {
735            self.deregister_mr(buf.mr_id)
736                .map_err(|e| anyhow::anyhow!("could not deregister buffer: {}", e))?;
737        }
738        Ok(())
739    }
740
741    async fn request_queue_pair(
742        &mut self,
743        cx: &Context<Self>,
744        other: reference::ActorRef<IbvManagerActor>,
745        self_device: String,
746        other_device: String,
747    ) -> Result<IbvQueuePair, anyhow::Error> {
748        let self_ref: reference::ActorRef<IbvManagerActor> = cx.bind();
749        let other_id = other.actor_id().clone();
750
751        // Use the nested map structure: local_device -> (actor_id, remote_device) -> IbvQueuePair
752        let inner_key = (other_id.clone(), other_device.clone());
753
754        // Check if queue pair exists in map
755        if let Some(device_map) = self.device_qps.get(&self_device) {
756            if let Some(qp) = device_map.get(&inner_key) {
757                return Ok(qp.clone());
758            }
759        }
760
761        // Try to acquire lock and mark as pending (hold lock only once!)
762        let pending_key = (self_device.clone(), other_id.clone(), other_device.clone());
763        let mut pending = self.pending_qp_creation.lock().await;
764
765        if pending.contains(&pending_key) {
766            // Another task is creating this QP, release lock and wait
767            drop(pending);
768
769            // Loop checking device_qps until QP is created (no more locks needed)
770            // Timeout after 1 second
771            let start = Instant::now();
772            let timeout = Duration::from_secs(1);
773
774            loop {
775                tokio::time::sleep(Duration::from_micros(200)).await;
776
777                // Check if QP was created while we waited
778                if let Some(device_map) = self.device_qps.get(&self_device) {
779                    if let Some(qp) = device_map.get(&inner_key) {
780                        return Ok(qp.clone());
781                    }
782                }
783
784                // Check for timeout
785                if start.elapsed() >= timeout {
786                    return Err(anyhow::anyhow!(
787                        "Timeout waiting for QP creation (device {} -> actor {} device {}). \
788                         Another task is creating it but hasn't completed in 1 second",
789                        self_device,
790                        other_id,
791                        other_device
792                    ));
793                }
794            }
795        } else {
796            // Not pending, add to set and proceed with creation
797            pending.insert(pending_key.clone());
798            drop(pending);
799            // Fall through to create QP
800        }
801
802        // Queue pair doesn't exist - need to create connection
803        let result = async {
804            let is_loopback = other_id == *self_ref.actor_id() && self_device == other_device;
805
806            if is_loopback {
807                // Loopback connection setup
808                self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone())
809                    .await?;
810                let endpoint = self
811                    .connection_info(cx, other.clone(), other_device.clone(), self_device.clone())
812                    .await?;
813                self.connect(
814                    cx,
815                    other.clone(),
816                    self_device.clone(),
817                    other_device.clone(),
818                    endpoint,
819                )
820                .await?;
821            } else {
822                // Remote connection setup
823                self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone())
824                    .await?;
825                other
826                    .initialize_qp(
827                        cx,
828                        self_ref.clone(),
829                        other_device.clone(),
830                        self_device.clone(),
831                    )
832                    .await?;
833                let other_endpoint: IbvQpInfo = other
834                    .connection_info(
835                        cx,
836                        self_ref.clone(),
837                        other_device.clone(),
838                        self_device.clone(),
839                    )
840                    .await?;
841                self.connect(
842                    cx,
843                    other.clone(),
844                    self_device.clone(),
845                    other_device.clone(),
846                    other_endpoint,
847                )
848                .await?;
849                let local_endpoint = self
850                    .connection_info(cx, other.clone(), self_device.clone(), other_device.clone())
851                    .await?;
852                other
853                    .connect(
854                        cx,
855                        self_ref.clone(),
856                        other_device.clone(),
857                        self_device.clone(),
858                        local_endpoint,
859                    )
860                    .await?;
861
862                // BARRIER: Ensure remote side has completed its connection and is ready
863                let remote_state = other
864                    .get_qp_state(
865                        cx,
866                        self_ref.clone(),
867                        other_device.clone(),
868                        self_device.clone(),
869                    )
870                    .await?;
871
872                if remote_state != rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS {
873                    return Err(anyhow::anyhow!(
874                        "Remote QP not in RTS state after connection setup. \
875                         Local is ready but remote is in state {}. \
876                         This indicates a synchronization issue in connection setup.",
877                        remote_state
878                    ));
879                }
880            }
881
882            // Now that connection is established, get and clone the queue pair
883            if let Some(device_map) = self.device_qps.get(&self_device) {
884                if let Some(qp) = device_map.get(&inner_key) {
885                    Ok(qp.clone())
886                } else {
887                    Err(anyhow::anyhow!(
888                        "Failed to create connection for actor {} on device {}",
889                        other_id,
890                        other_device
891                    ))
892                }
893            } else {
894                Err(anyhow::anyhow!(
895                    "Failed to create connection for actor {} on device {} - no device map",
896                    other_id,
897                    other_device
898                ))
899            }
900        }
901        .await;
902
903        // Always remove from pending set when done (success or failure)
904        let mut pending = self.pending_qp_creation.lock().await;
905        pending.remove(&pending_key);
906        drop(pending);
907
908        result
909    }
910
911    async fn connect(
912        &mut self,
913        _cx: &Context<Self>,
914        other: reference::ActorRef<IbvManagerActor>,
915        self_device: String,
916        other_device: String,
917        endpoint: IbvQpInfo,
918    ) -> Result<(), anyhow::Error> {
919        tracing::debug!("connecting with {:?}", other);
920        let other_id = other.actor_id().clone();
921
922        let inner_key = (other_id.clone(), other_device.clone());
923
924        if let Some(device_map) = self.device_qps.get_mut(&self_device) {
925            match device_map.get_mut(&inner_key) {
926                Some(qp) => {
927                    qp.connect(&endpoint).map_err(|e| {
928                        anyhow::anyhow!("could not connect to RDMA endpoint: {}", e)
929                    })?;
930                    Ok(())
931                }
932                None => Err(anyhow::anyhow!(
933                    "No connection found for actor {}",
934                    other_id
935                )),
936            }
937        } else {
938            Err(anyhow::anyhow!(
939                "No device map found for device {}",
940                self_device
941            ))
942        }
943    }
944
945    async fn initialize_qp(
946        &mut self,
947        _cx: &Context<Self>,
948        other: reference::ActorRef<IbvManagerActor>,
949        self_device: String,
950        other_device: String,
951    ) -> Result<bool, anyhow::Error> {
952        let other_id = other.actor_id().clone();
953        let inner_key = (other_id.clone(), other_device.clone());
954
955        // Check if QP already exists in nested structure
956        if let Some(device_map) = self.device_qps.get(&self_device) {
957            if device_map.contains_key(&inner_key) {
958                return Ok(true);
959            }
960        }
961
962        // Resolve the RDMA device for the local device
963        let rdma_device = self
964            .pci_to_device
965            .iter()
966            .find(|(_, device)| device.name() == &self_device)
967            .map(|(_, device)| device.clone())
968            .unwrap_or_else(|| {
969                // Fallback to default device from config
970                super::device_selection::resolve_ibv_device(&self.config.device)
971                    .unwrap_or_else(|| self.config.device.clone())
972            });
973
974        // Get or create domain and extract pointers to avoid borrowing issues
975        let (domain_context, domain_pd) = {
976            // Check if we already have a domain for the device
977            let (domain, _) = self.get_or_create_device_domain(&self_device, &rdma_device)?;
978            (domain.context, domain.pd)
979        };
980
981        let qp = IbvQueuePair::new(domain_context, domain_pd, self.config.clone())
982            .map_err(|e| anyhow::anyhow!("could not create IbvQueuePair: {}", e))?;
983
984        // Insert the QP into the nested map structure
985        self.device_qps
986            .entry(self_device.clone())
987            .or_insert_with(HashMap::new)
988            .insert(inner_key, qp);
989
990        tracing::debug!(
991            "successfully created a connection with {:?} for local device {} -> remote device {}",
992            other,
993            self_device,
994            other_device
995        );
996
997        Ok(true)
998    }
999
1000    async fn connection_info(
1001        &mut self,
1002        _cx: &Context<Self>,
1003        other: reference::ActorRef<IbvManagerActor>,
1004        self_device: String,
1005        other_device: String,
1006    ) -> Result<IbvQpInfo, anyhow::Error> {
1007        tracing::debug!("getting connection info with {:?}", other);
1008        let other_id = other.actor_id().clone();
1009
1010        let inner_key = (other_id.clone(), other_device.clone());
1011
1012        if let Some(device_map) = self.device_qps.get_mut(&self_device) {
1013            match device_map.get_mut(&inner_key) {
1014                Some(qp) => {
1015                    let connection_info = qp.get_qp_info()?;
1016                    Ok(connection_info)
1017                }
1018                None => Err(anyhow::anyhow!(
1019                    "No connection found for actor {}",
1020                    other_id
1021                )),
1022            }
1023        } else {
1024            Err(anyhow::anyhow!(
1025                "No device map found for self device {}",
1026                self_device
1027            ))
1028        }
1029    }
1030
1031    async fn release_queue_pair(
1032        &mut self,
1033        _cx: &Context<Self>,
1034        _other: reference::ActorRef<IbvManagerActor>,
1035        _self_device: String,
1036        _other_device: String,
1037        _qp: IbvQueuePair,
1038    ) -> Result<(), anyhow::Error> {
1039        Ok(())
1040    }
1041
1042    async fn get_qp_state(
1043        &mut self,
1044        _cx: &Context<Self>,
1045        other: reference::ActorRef<IbvManagerActor>,
1046        self_device: String,
1047        other_device: String,
1048    ) -> Result<u32, anyhow::Error> {
1049        let other_id = other.actor_id().clone();
1050        let inner_key = (other_id.clone(), other_device.clone());
1051
1052        if let Some(device_map) = self.device_qps.get_mut(&self_device) {
1053            match device_map.get_mut(&inner_key) {
1054                Some(qp) => qp.state(),
1055                None => Err(anyhow::anyhow!(
1056                    "No connection found for actor {} on device {}",
1057                    other_id,
1058                    other_device
1059                )),
1060            }
1061        } else {
1062            Err(anyhow::anyhow!(
1063                "No device map found for self device {}",
1064                self_device
1065            ))
1066        }
1067    }
1068}
1069
1070#[async_trait]
1071#[hyperactor::handle(IbvSubmit)]
1072impl IbvSubmitHandler for IbvManagerActor {
1073    async fn ibv_submit(
1074        &mut self,
1075        cx: &Context<Self>,
1076        ops: Vec<IbvOp>,
1077        timeout: Duration,
1078    ) -> Result<Result<(), String>, anyhow::Error> {
1079        let deadline = Instant::now() + timeout;
1080        let mut result = Ok(());
1081        for op in ops {
1082            let remaining = deadline.saturating_duration_since(Instant::now());
1083            if remaining.is_zero() {
1084                result = Err("submit timed out".to_string());
1085                break;
1086            }
1087            if let Err(e) = self.execute_op(cx, op, remaining).await {
1088                result = Err(e.to_string());
1089                break;
1090            }
1091        }
1092        Ok(result)
1093    }
1094}
1095
1096#[async_trait]
1097impl RdmaBackend for ActorHandle<IbvManagerActor> {
1098    type TransportInfo = ();
1099
1100    /// Submit a batch of RDMA operations.
1101    ///
1102    /// Delegates to the [`IbvManagerActor`], which manages QPs and connections
1103    /// internally. Each op is routed based on its type (read/write).
1104    /// Registers local memory on the fly; remote buffers are lazily
1105    /// resolved via [`IbvManagerMessageClient::request_buffer`].
1106    async fn submit(
1107        &mut self,
1108        cx: &(impl hyperactor::context::Actor + Send + Sync),
1109        ops: Vec<RdmaOp>,
1110        timeout: Duration,
1111    ) -> Result<(), anyhow::Error> {
1112        let mut ibv_ops = Vec::with_capacity(ops.len());
1113        for op in ops {
1114            let (remote_ibv_mgr, remote_ibv_buffer) =
1115                op.remote.resolve_ibv(cx).await.ok_or_else(|| {
1116                    anyhow::anyhow!("ibverbs backend not found for buffer: {:?}", op.remote)
1117                })??;
1118
1119            ibv_ops.push(IbvOp {
1120                op_type: op.op_type,
1121                local_memory: op.local.clone(),
1122                remote_buffer: remote_ibv_buffer,
1123                remote_manager: remote_ibv_mgr,
1124            });
1125        }
1126
1127        <Self as IbvSubmitClient>::ibv_submit(self, cx, ibv_ops, timeout)
1128            .await?
1129            .map_err(|e: String| anyhow::anyhow!(e))
1130    }
1131
1132    fn transport_level(&self) -> RdmaTransportLevel {
1133        RdmaTransportLevel::Nic
1134    }
1135
1136    fn transport_info(&self) -> Option<Self::TransportInfo> {
1137        None
1138    }
1139}