monarch_rdma/backend/ibverbs/
manager_actor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # Ibverbs Manager
10//!
11//! Contains ibverbs-specific RDMA logic.
12//!
13//! Manages ibverbs resources including:
14//! - Memory registration (CPU and CUDA via dmabuf or segment scanning)
15//! - Queue pair creation and connection establishment
16//! - RDMA domain and protection domain management
17//! - Device selection and PCI-to-RDMA device mapping
18
19use std::collections::HashMap;
20use std::collections::HashSet;
21use std::sync::Arc;
22use std::sync::OnceLock;
23use std::time::Duration;
24use std::time::Instant;
25
26use anyhow::Result;
27use async_trait::async_trait;
28use futures::lock::Mutex;
29use hyperactor::Actor;
30use hyperactor::ActorHandle;
31use hyperactor::Context;
32use hyperactor::HandleClient;
33use hyperactor::Handler;
34use hyperactor::Instance;
35use hyperactor::OncePortHandle;
36use hyperactor::RefClient;
37use hyperactor::reference;
38use serde::Deserialize;
39use serde::Serialize;
40use typeuri::Named;
41
42use super::IbvBuffer;
43use super::IbvOp;
44use super::domain::IbvDomain;
45use super::primitives::IbvConfig;
46use super::primitives::IbvDevice;
47use super::primitives::IbvMemoryRegionView;
48use super::primitives::IbvQpInfo;
49use super::primitives::ibverbs_supported;
50use super::primitives::mlx5dv_supported;
51use super::primitives::resolve_qp_type;
52use super::queue_pair::IbvQueuePair;
53use super::queue_pair::PollCompletionError;
54use super::queue_pair::PollTarget;
55use crate::RdmaOp;
56use crate::RdmaOpType;
57use crate::RdmaTransportLevel;
58use crate::backend::RdmaBackend;
59use crate::rdma_components::get_registered_cuda_segments;
60use crate::rdma_manager_actor::GetIbvActorRefClient;
61use crate::rdma_manager_actor::RdmaManagerActor;
62use crate::rdma_manager_actor::RdmaManagerMessageClient;
63use crate::rdma_manager_actor::get_rdmaxcel_error_message;
64use crate::validate_execution_context;
65
66/// Messages handled by [`IbvManagerActor`].
67#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
68pub enum IbvManagerMessage {
69    /// Register the MR for a buffer identified by `remote_buf_id`. Resolves
70    /// the local memory via the parent [`RdmaManagerActor`]'s
71    /// `RequestLocalMemory`, registers it as an ibverbs MR, and returns
72    /// the resulting [`IbvBuffer`].
73    ///
74    /// Returns `None` if the buffer has already been released or does not
75    /// exist.
76    RequestBuffer {
77        remote_buf_id: usize,
78        #[reply]
79        reply: reference::OncePortRef<Option<IbvBuffer>>,
80    },
81    /// Release a buffer registration by `remote_buf_id`.
82    /// IMPORTANT: This needs to be fire-and-forget (no reply port)
83    /// to avoid a circular deadlock where RdmaManagerActor waits for
84    /// IbvManagerMessage::ReleaseBuffer while IbvManagerActor waits for
85    /// RdmaManagerMessage::RequestLocalMemory.
86    ReleaseBuffer { remote_buf_id: usize },
87    RequestQueuePair {
88        other: reference::ActorRef<IbvManagerActor>,
89        self_device: String,
90        other_device: String,
91        #[reply]
92        reply: reference::OncePortRef<Result<IbvQueuePair, String>>,
93    },
94    Connect {
95        other: reference::ActorRef<IbvManagerActor>,
96        self_device: String,
97        other_device: String,
98        endpoint: IbvQpInfo,
99    },
100    InitializeQP {
101        other: reference::ActorRef<IbvManagerActor>,
102        self_device: String,
103        other_device: String,
104        #[reply]
105        reply: reference::OncePortRef<bool>,
106    },
107    ConnectionInfo {
108        other: reference::ActorRef<IbvManagerActor>,
109        self_device: String,
110        other_device: String,
111        #[reply]
112        reply: reference::OncePortRef<IbvQpInfo>,
113    },
114    ReleaseQueuePair {
115        other: reference::ActorRef<IbvManagerActor>,
116        self_device: String,
117        other_device: String,
118        qp: IbvQueuePair,
119    },
120    GetQpState {
121        other: reference::ActorRef<IbvManagerActor>,
122        self_device: String,
123        other_device: String,
124        #[reply]
125        reply: reference::OncePortRef<u32>,
126    },
127}
128wirevalue::register_type!(IbvManagerMessage);
129
130/// Local-only messages for MR registration/deregistration.
131#[derive(Handler, HandleClient, Debug)]
132pub enum IbvManagerLocalMessage {
133    /// Register a memory region, returning the MR view and device name.
134    RegisterMr {
135        addr: usize,
136        size: usize,
137        #[reply]
138        reply: OncePortHandle<Result<(IbvMemoryRegionView, String), String>>,
139    },
140    /// Deregister a memory region by its MR view id.
141    DeregisterMr {
142        id: usize,
143        #[reply]
144        reply: OncePortHandle<Result<(), String>>,
145    },
146}
147
148/// Manages all ibverbs-specific RDMA resources and operations.
149///
150/// This struct handles memory registration, queue pair management,
151/// and connection establishment using the ibverbs API.
152#[derive(Debug)]
153#[hyperactor::export(
154    handlers = [
155        IbvManagerMessage,
156    ],
157)]
158pub struct IbvManagerActor {
159    owner: OnceLock<ActorHandle<RdmaManagerActor>>,
160
161    // Nested map: local_device -> (ActorId, remote_device) -> IbvQueuePair
162    device_qps: HashMap<String, HashMap<(reference::ActorId, String), IbvQueuePair>>,
163
164    // Track QPs currently being created to prevent duplicate creation
165    // Wrapped in Arc<Mutex> to allow safe concurrent access
166    pending_qp_creation: Arc<Mutex<HashSet<(String, reference::ActorId, String)>>>,
167
168    // Map of RDMA device names to their domains and loopback QPs
169    // Created lazily when memory is registered for a specific device
170    device_domains: HashMap<String, (IbvDomain, Option<IbvQueuePair>)>,
171
172    config: IbvConfig,
173
174    mlx5dv_enabled: bool,
175
176    // Map of unique IbvMemoryRegionView to ibv_mr*.  In case of cuda w/ pytorch its -1
177    // since its managed independently.  Only used for registration/deregistration purposes
178    mr_map: HashMap<usize, usize>,
179
180    // Id for next mrv created
181    mrv_id: usize,
182
183    // Map from buffer_id to registration details.
184    buffer_registrations: HashMap<usize, IbvBuffer>,
185}
186
187#[async_trait]
188impl Actor for IbvManagerActor {
189    async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
190        let owner = if let Some(owner) = this.parent_handle() {
191            owner
192        } else {
193            anyhow::bail!("RdmaManagerActor not found as parent of IbvManagerActor");
194        };
195        self.owner
196            .set(owner)
197            .expect("owner should only be set once during init");
198        Ok(())
199    }
200}
201
202impl Drop for IbvManagerActor {
203    fn drop(&mut self) {
204        // Helper function to destroy QP resources
205        // We can't use Drop on IbvQueuePair because it derives Clone
206        // Note: rdmaxcel_qp_destroy handles destroying both the QP and its CQs internally,
207        // so we must NOT call ibv_destroy_cq separately (would cause double-free/SIGSEGV)
208        fn destroy_queue_pair(qp: &IbvQueuePair, _context: &str) {
209            unsafe {
210                if qp.qp != 0 {
211                    let rdmaxcel_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
212                    rdmaxcel_sys::rdmaxcel_qp_destroy(rdmaxcel_qp);
213                }
214            }
215        }
216
217        // 1. Clean up all queue pairs (both regular and loopback)
218        for (_device_name, device_map) in self.device_qps.drain() {
219            for ((actor_id, _remote_device), qp) in device_map {
220                destroy_queue_pair(&qp, &format!("actor {:?}", actor_id));
221            }
222        }
223
224        // 2. Clean up device domains (which contain PDs and loopback QPs)
225        for (device_name, (domain, qp)) in self.device_domains.drain() {
226            if let Some(qp) = qp {
227                destroy_queue_pair(&qp, &format!("loopback QP on device {}", device_name));
228            }
229            drop(domain);
230        }
231
232        // 3. Clean up memory regions
233        let _mr_count = self.mr_map.len();
234        for (id, mr_ptr) in self.mr_map.drain() {
235            if mr_ptr != 0 {
236                unsafe {
237                    let result = rdmaxcel_sys::ibv_dereg_mr(mr_ptr as *mut rdmaxcel_sys::ibv_mr);
238                    if result != 0 {
239                        tracing::error!(
240                            "Failed to deregister MR with id {}: error code {}",
241                            id,
242                            result
243                        );
244                    }
245                }
246            }
247        }
248
249        // 4. Deregister all CUDA segments (if using mlx5dv)
250        // The segment scanner in Python handles compatibility checks
251        if self.mlx5dv_enabled {
252            unsafe {
253                let result = rdmaxcel_sys::deregister_segments();
254                if result != 0 {
255                    let error_msg = get_rdmaxcel_error_message(result);
256                    tracing::error!(
257                        "Failed to deregister CUDA segments: {} (error code: {})",
258                        error_msg,
259                        result
260                    );
261                }
262            }
263        }
264    }
265}
266
267impl IbvManagerActor {
268    /// Construct an [`ActorHandle`] for the [`IbvManagerActor`] co-located
269    /// with the caller by querying the local [`RdmaManagerActor`].
270    pub async fn local_handle(
271        client: &(impl hyperactor::context::Actor + Send + Sync),
272    ) -> Result<ActorHandle<Self>, anyhow::Error> {
273        let rdma_handle = RdmaManagerActor::local_handle(client);
274        let ibv_ref = rdma_handle
275            .get_ibv_actor_ref(client)
276            .await?
277            .ok_or_else(|| anyhow::anyhow!("local RdmaManagerActor has no ibverbs backend"))?;
278        ibv_ref
279            .downcast_handle(client)
280            .ok_or_else(|| anyhow::anyhow!("IbvManagerActor is not in the local process"))
281    }
282
283    /// Create a new IbvManagerActor with the given configuration.
284    pub async fn new(params: Option<IbvConfig>) -> Result<Self, anyhow::Error> {
285        if !ibverbs_supported() {
286            return Err(anyhow::anyhow!(
287                "Cannot create IbvManagerActor because RDMA is not supported on this machine"
288            ));
289        }
290
291        // Use provided config or default if none provided
292        let mut config = params.unwrap_or_default();
293        tracing::debug!("rdma is enabled, config device hint: {}", config.device);
294
295        let mlx5dv_enabled = resolve_qp_type(config.qp_type) == rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV;
296
297        // check config and hardware support align
298        if config.use_gpu_direct {
299            match validate_execution_context().await {
300                Ok(_) => {
301                    tracing::info!("GPU Direct RDMA execution context validated successfully");
302                }
303                Err(e) => {
304                    tracing::warn!(
305                        "GPU Direct RDMA execution context validation failed: {}. Downgrading to standard ibverbs mode.",
306                        e
307                    );
308                    config.use_gpu_direct = false;
309                }
310            }
311        }
312
313        let actor = Self {
314            owner: OnceLock::new(),
315            device_qps: HashMap::new(),
316            pending_qp_creation: Arc::new(Mutex::new(HashSet::new())),
317            device_domains: HashMap::new(),
318            config,
319            mlx5dv_enabled,
320            mr_map: HashMap::new(),
321            mrv_id: 0,
322            buffer_registrations: HashMap::new(),
323        };
324
325        Ok(actor)
326    }
327
328    /// Get or create a domain and loopback QP for the specified RDMA device
329    fn get_or_create_device_domain(
330        &mut self,
331        device_name: &str,
332        rdma_device: &IbvDevice,
333    ) -> Result<(IbvDomain, Option<IbvQueuePair>), anyhow::Error> {
334        if let Some((domain, qp)) = self.device_domains.get(device_name) {
335            return Ok((domain.clone(), qp.clone()));
336        }
337
338        // Create new domain for this device
339        let domain = IbvDomain::new(rdma_device.clone()).map_err(|e| {
340            anyhow::anyhow!("could not create domain for device {}: {}", device_name, e)
341        })?;
342
343        // Print device info if MONARCH_DEBUG_RDMA=1 is set (before initial QP creation)
344        crate::print_device_info_if_debug_enabled(domain.context);
345
346        // Create loopback QP for this domain if mlx5dv is supported (needed for segment registration)
347        // For EFA, we don't need a loopback QP for segment scanning
348        let qp = if mlx5dv_supported() && !crate::efa::is_efa_device() {
349            let mut qp = IbvQueuePair::new(domain.context, domain.pd, self.config.clone())
350                .map_err(|e| {
351                    anyhow::anyhow!(
352                        "could not create loopback QP for device {}: {}",
353                        device_name,
354                        e
355                    )
356                })?;
357
358            // Get connection info and connect to itself
359            let endpoint = qp.get_qp_info().map_err(|e| {
360                anyhow::anyhow!("could not get QP info for device {}: {}", device_name, e)
361            })?;
362
363            qp.connect(&endpoint).map_err(|e| {
364                anyhow::anyhow!(
365                    "could not connect loopback QP for device {}: {}",
366                    device_name,
367                    e
368                )
369            })?;
370
371            Some(qp)
372        } else {
373            None
374        };
375
376        self.device_domains
377            .insert(device_name.to_string(), (domain.clone(), qp.clone()));
378        Ok((domain, qp))
379    }
380
381    /// Build parallel PD/QP arrays indexed by CUDA device ordinal
382    /// for the C++ register_segments call.
383    fn build_per_device_pd_qp_arrays(
384        &self,
385    ) -> (
386        Vec<*mut rdmaxcel_sys::ibv_pd>,
387        Vec<*mut rdmaxcel_sys::rdmaxcel_qp_t>,
388    ) {
389        let cuda_map = super::device_selection::get_cuda_device_to_ibv_device();
390        let mut pds = Vec::with_capacity(cuda_map.len());
391        let mut qps = Vec::with_capacity(cuda_map.len());
392        for maybe_device in cuda_map {
393            if let Some(device) = maybe_device {
394                if let Some((domain, qp)) = self.device_domains.get(device.name()) {
395                    pds.push(domain.pd);
396                    qps.push(
397                        qp.as_ref()
398                            .map(|q| q.qp as *mut rdmaxcel_sys::rdmaxcel_qp_t)
399                            .unwrap_or(std::ptr::null_mut()),
400                    );
401                } else {
402                    pds.push(std::ptr::null_mut());
403                    qps.push(std::ptr::null_mut());
404                }
405            } else {
406                pds.push(std::ptr::null_mut());
407                qps.push(std::ptr::null_mut());
408            }
409        }
410        (pds, qps)
411    }
412
413    fn find_cuda_segment_for_address(
414        &mut self,
415        addr: usize,
416        size: usize,
417        pd: *mut rdmaxcel_sys::ibv_pd,
418    ) -> Option<IbvMemoryRegionView> {
419        let registered_segments = get_registered_cuda_segments(pd);
420        for segment in registered_segments {
421            let start_addr = segment.phys_address;
422            let end_addr = start_addr + segment.phys_size;
423            if start_addr <= addr && addr + size <= end_addr {
424                let offset = addr - start_addr;
425                let rdma_addr = segment.mr_addr + offset;
426
427                let mrv = IbvMemoryRegionView {
428                    id: self.mrv_id,
429                    virtual_addr: addr,
430                    rdma_addr,
431                    size,
432                    lkey: segment.lkey,
433                    rkey: segment.rkey,
434                };
435                self.mrv_id += 1;
436                return Some(mrv);
437            }
438        }
439        None
440    }
441
442    fn register_mr_impl(
443        &mut self,
444        addr: usize,
445        size: usize,
446    ) -> Result<(IbvMemoryRegionView, String), anyhow::Error> {
447        unsafe {
448            let mut mem_type: i32 = 0;
449            let ptr = addr as rdmaxcel_sys::CUdeviceptr;
450            let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
451                &mut mem_type as *mut _ as *mut std::ffi::c_void,
452                rdmaxcel_sys::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
453                ptr,
454            );
455            let is_cuda = err == rdmaxcel_sys::CUDA_SUCCESS;
456
457            let mut selected_rdma_device = None;
458
459            if is_cuda {
460                // Get device ordinal from the CUDA pointer
461                let mut device_ordinal: i32 = -1;
462                let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
463                    &mut device_ordinal as *mut _ as *mut std::ffi::c_void,
464                    rdmaxcel_sys::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
465                    ptr,
466                );
467                if err == rdmaxcel_sys::CUDA_SUCCESS && device_ordinal >= 0 {
468                    selected_rdma_device = super::device_selection::get_cuda_device_to_ibv_device()
469                        .get(device_ordinal as usize)
470                        .and_then(|d| d.clone());
471                }
472            }
473
474            // Determine the RDMA device to use
475            let rdma_device = if let Some(device) = selected_rdma_device {
476                device
477            } else {
478                // Fallback to default device from config
479                self.config.device.clone()
480            };
481
482            let device_name = rdma_device.name().clone();
483            tracing::debug!(
484                "Using RDMA device: {} for memory at 0x{:x}",
485                device_name,
486                addr
487            );
488
489            // Get or create domain and loopback QP for this device
490            let (domain, _qp) = self.get_or_create_device_domain(&device_name, &rdma_device)?;
491
492            let access = if crate::efa::is_efa_device() {
493                crate::efa::mr_access_flags()
494            } else {
495                rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
496                    | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
497                    | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
498                    | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC
499            };
500
501            let mut mr: *mut rdmaxcel_sys::ibv_mr = std::ptr::null_mut();
502            let mrv;
503
504            if is_cuda {
505                // First, try to use segment scanning if mlx5dv is enabled
506                let mut segment_mrv = None;
507                if self.mlx5dv_enabled {
508                    // Try to find in already registered segments
509                    segment_mrv = self.find_cuda_segment_for_address(addr, size, domain.pd);
510
511                    // If not found, trigger a re-sync with the allocator and retry
512                    if segment_mrv.is_none() {
513                        let (mut pds, mut qps) = self.build_per_device_pd_qp_arrays();
514                        let err = rdmaxcel_sys::register_segments(
515                            pds.as_mut_ptr(),
516                            qps.as_mut_ptr(),
517                            pds.len() as i32,
518                        );
519                        // Only retry if register_segments succeeded
520                        // If it fails (e.g., scanner returns 0 segments), we'll fall back to dmabuf
521                        if err == 0 {
522                            segment_mrv = self.find_cuda_segment_for_address(addr, size, domain.pd);
523                        }
524                    }
525                }
526
527                // Use segment if found, otherwise fall back to direct dmabuf registration
528                if let Some(mrv_from_segment) = segment_mrv {
529                    mrv = mrv_from_segment;
530                } else {
531                    // Dmabuf path: used when mlx5dv is disabled OR scanner returns no segments
532                    let mut fd: i32 = -1;
533                    let cu_err = rdmaxcel_sys::rdmaxcel_cuMemGetHandleForAddressRange(
534                        &mut fd,
535                        addr as rdmaxcel_sys::CUdeviceptr,
536                        size,
537                        rdmaxcel_sys::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD,
538                        0,
539                    );
540                    if cu_err != rdmaxcel_sys::CUDA_SUCCESS || fd < 0 {
541                        return Err(anyhow::anyhow!(
542                            "failed to get dmabuf handle for CUDA memory (addr: 0x{:x}, size: {}, cu_err: {}, fd: {})",
543                            addr,
544                            size,
545                            cu_err,
546                            fd
547                        ));
548                    }
549                    mr =
550                        rdmaxcel_sys::ibv_reg_dmabuf_mr(domain.pd, 0, size, 0, fd, access.0 as i32);
551                    if mr.is_null() {
552                        return Err(anyhow::anyhow!("Failed to register dmabuf MR"));
553                    }
554                    mrv = IbvMemoryRegionView {
555                        id: self.mrv_id,
556                        virtual_addr: addr,
557                        rdma_addr: (*mr).addr as usize,
558                        size,
559                        lkey: (*mr).lkey,
560                        rkey: (*mr).rkey,
561                    };
562                    self.mrv_id += 1;
563                }
564            } else {
565                // CPU memory path
566                mr = rdmaxcel_sys::ibv_reg_mr(
567                    domain.pd,
568                    addr as *mut std::ffi::c_void,
569                    size,
570                    access.0 as i32,
571                );
572
573                if mr.is_null() {
574                    return Err(anyhow::anyhow!("failed to register standard MR"));
575                }
576
577                mrv = IbvMemoryRegionView {
578                    id: self.mrv_id,
579                    virtual_addr: addr,
580                    rdma_addr: (*mr).addr as usize,
581                    size,
582                    lkey: (*mr).lkey,
583                    rkey: (*mr).rkey,
584                };
585                self.mrv_id += 1;
586            }
587            self.mr_map.insert(mrv.id, mr as usize);
588            Ok((mrv, device_name))
589        }
590    }
591
592    fn deregister_mr_impl(&mut self, id: usize) -> Result<(), anyhow::Error> {
593        if let Some(mr_ptr) = self.mr_map.remove(&id) {
594            if mr_ptr != 0 {
595                unsafe {
596                    rdmaxcel_sys::ibv_dereg_mr(mr_ptr as *mut rdmaxcel_sys::ibv_mr);
597                }
598            }
599        }
600        Ok(())
601    }
602
603    async fn request_queue_pair_impl(
604        &mut self,
605        cx: &Context<'_, Self>,
606        other: reference::ActorRef<IbvManagerActor>,
607        self_device: String,
608        other_device: String,
609    ) -> Result<IbvQueuePair, anyhow::Error> {
610        let self_ref: reference::ActorRef<IbvManagerActor> = cx.bind();
611        let other_id = other.actor_id().clone();
612
613        // Use the nested map structure: local_device -> (actor_id, remote_device) -> IbvQueuePair
614        let inner_key = (other_id.clone(), other_device.clone());
615
616        // Check if queue pair exists in map
617        if let Some(device_map) = self.device_qps.get(&self_device) {
618            if let Some(qp) = device_map.get(&inner_key) {
619                return Ok(qp.clone());
620            }
621        }
622
623        // Try to acquire lock and mark as pending (hold lock only once!)
624        let pending_key = (self_device.clone(), other_id.clone(), other_device.clone());
625        let mut pending = self.pending_qp_creation.lock().await;
626
627        if pending.contains(&pending_key) {
628            // Another task is creating this QP, release lock and wait
629            drop(pending);
630
631            // Loop checking device_qps until QP is created (no more locks needed)
632            // Timeout after 1 second
633            let start = Instant::now();
634            let timeout = Duration::from_secs(1);
635
636            loop {
637                tokio::time::sleep(Duration::from_micros(200)).await;
638
639                // Check if QP was created while we waited
640                if let Some(device_map) = self.device_qps.get(&self_device) {
641                    if let Some(qp) = device_map.get(&inner_key) {
642                        return Ok(qp.clone());
643                    }
644                }
645
646                // Check for timeout
647                if start.elapsed() >= timeout {
648                    return Err(anyhow::anyhow!(
649                        "Timeout waiting for QP creation (device {} -> actor {} device {}). \
650                         Another task is creating it but hasn't completed in 1 second",
651                        self_device,
652                        other_id,
653                        other_device
654                    ));
655                }
656            }
657        } else {
658            // Not pending, add to set and proceed with creation
659            pending.insert(pending_key.clone());
660            drop(pending);
661            // Fall through to create QP
662        }
663
664        // Queue pair doesn't exist - need to create connection
665        let result = async {
666            let is_loopback = other_id == *self_ref.actor_id() && self_device == other_device;
667
668            if is_loopback {
669                // Loopback connection setup
670                self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone())
671                    .await?;
672                let endpoint = self
673                    .connection_info(cx, other.clone(), other_device.clone(), self_device.clone())
674                    .await?;
675                self.connect(
676                    cx,
677                    other.clone(),
678                    self_device.clone(),
679                    other_device.clone(),
680                    endpoint,
681                )
682                .await?;
683            } else {
684                // Remote connection setup
685                self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone())
686                    .await?;
687                other
688                    .initialize_qp(
689                        cx,
690                        self_ref.clone(),
691                        other_device.clone(),
692                        self_device.clone(),
693                    )
694                    .await?;
695                let other_endpoint: IbvQpInfo = other
696                    .connection_info(
697                        cx,
698                        self_ref.clone(),
699                        other_device.clone(),
700                        self_device.clone(),
701                    )
702                    .await?;
703                self.connect(
704                    cx,
705                    other.clone(),
706                    self_device.clone(),
707                    other_device.clone(),
708                    other_endpoint,
709                )
710                .await?;
711                let local_endpoint = self
712                    .connection_info(cx, other.clone(), self_device.clone(), other_device.clone())
713                    .await?;
714                other
715                    .connect(
716                        cx,
717                        self_ref.clone(),
718                        other_device.clone(),
719                        self_device.clone(),
720                        local_endpoint,
721                    )
722                    .await?;
723
724                // BARRIER: Ensure remote side has completed its connection and is ready
725                let remote_state = other
726                    .get_qp_state(
727                        cx,
728                        self_ref.clone(),
729                        other_device.clone(),
730                        self_device.clone(),
731                    )
732                    .await?;
733
734                if remote_state != rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS {
735                    return Err(anyhow::anyhow!(
736                        "Remote QP not in RTS state after connection setup. \
737                         Local is ready but remote is in state {}. \
738                         This indicates a synchronization issue in connection setup.",
739                        remote_state
740                    ));
741                }
742            }
743
744            // Now that connection is established, get and clone the queue pair
745            if let Some(device_map) = self.device_qps.get(&self_device) {
746                if let Some(qp) = device_map.get(&inner_key) {
747                    Ok(qp.clone())
748                } else {
749                    Err(anyhow::anyhow!(
750                        "Failed to create connection for actor {} on device {}",
751                        other_id,
752                        other_device
753                    ))
754                }
755            } else {
756                Err(anyhow::anyhow!(
757                    "Failed to create connection for actor {} on device {} - no device map",
758                    other_id,
759                    other_device
760                ))
761            }
762        }
763        .await;
764
765        // Always remove from pending set when done (success or failure)
766        let mut pending = self.pending_qp_creation.lock().await;
767        pending.remove(&pending_key);
768        drop(pending);
769
770        result
771    }
772}
773
774#[async_trait]
775#[hyperactor::handle(IbvManagerMessage)]
776impl IbvManagerMessageHandler for IbvManagerActor {
777    async fn request_buffer(
778        &mut self,
779        cx: &Context<Self>,
780        remote_buf_id: usize,
781    ) -> Result<Option<IbvBuffer>, anyhow::Error> {
782        // If already registered, return it
783        if let Some(buf) = self.buffer_registrations.get(&remote_buf_id) {
784            return Ok(Some(buf.clone()));
785        }
786
787        // Resolve local memory from the parent RdmaManagerActor.
788        // Returns None if the buffer has already been released or does
789        // not exist.
790        let owner = self.owner.get().unwrap();
791        let mem = match owner.request_local_memory(cx, remote_buf_id).await? {
792            Some(mem) => mem,
793            None => return Ok(None),
794        };
795
796        let (mrv, device_name) = self.register_mr_impl(mem.addr(), mem.size())?;
797
798        let buf = IbvBuffer {
799            mr_id: mrv.id,
800            lkey: mrv.lkey,
801            rkey: mrv.rkey,
802            addr: mrv.rdma_addr,
803            size: mrv.size,
804            device_name,
805        };
806
807        self.buffer_registrations.insert(remote_buf_id, buf.clone());
808
809        Ok(Some(buf))
810    }
811
812    async fn release_buffer(
813        &mut self,
814        _cx: &Context<Self>,
815        remote_buf_id: usize,
816    ) -> Result<(), anyhow::Error> {
817        if let Some(buf) = self.buffer_registrations.remove(&remote_buf_id) {
818            self.deregister_mr_impl(buf.mr_id)
819                .map_err(|e| anyhow::anyhow!("could not deregister buffer: {}", e))?;
820        }
821        Ok(())
822    }
823
824    async fn request_queue_pair(
825        &mut self,
826        cx: &Context<Self>,
827        other: reference::ActorRef<IbvManagerActor>,
828        self_device: String,
829        other_device: String,
830    ) -> Result<Result<IbvQueuePair, String>, anyhow::Error> {
831        Ok(self
832            .request_queue_pair_impl(cx, other, self_device, other_device)
833            .await
834            .map_err(|e| e.to_string()))
835    }
836
837    async fn connect(
838        &mut self,
839        _cx: &Context<Self>,
840        other: reference::ActorRef<IbvManagerActor>,
841        self_device: String,
842        other_device: String,
843        endpoint: IbvQpInfo,
844    ) -> Result<(), anyhow::Error> {
845        tracing::debug!("connecting with {:?}", other);
846        let other_id = other.actor_id().clone();
847
848        let inner_key = (other_id.clone(), other_device.clone());
849
850        if let Some(device_map) = self.device_qps.get_mut(&self_device) {
851            match device_map.get_mut(&inner_key) {
852                Some(qp) => {
853                    qp.connect(&endpoint).map_err(|e| {
854                        anyhow::anyhow!("could not connect to RDMA endpoint: {}", e)
855                    })?;
856                    Ok(())
857                }
858                None => Err(anyhow::anyhow!(
859                    "No connection found for actor {}",
860                    other_id
861                )),
862            }
863        } else {
864            Err(anyhow::anyhow!(
865                "No device map found for device {}",
866                self_device
867            ))
868        }
869    }
870
871    async fn initialize_qp(
872        &mut self,
873        _cx: &Context<Self>,
874        other: reference::ActorRef<IbvManagerActor>,
875        self_device: String,
876        other_device: String,
877    ) -> Result<bool, anyhow::Error> {
878        let other_id = other.actor_id().clone();
879        let inner_key = (other_id.clone(), other_device.clone());
880
881        // Check if QP already exists in nested structure
882        if let Some(device_map) = self.device_qps.get(&self_device) {
883            if device_map.contains_key(&inner_key) {
884                return Ok(true);
885            }
886        }
887
888        // The domain is guaranteed to exist here: register_mr is always called before
889        // initialize_qp, either in execute_op (for the local actor) or via resolve_ibv
890        // (for the remote actor), and register_mr always calls get_or_create_device_domain.
891        let (domain, _) = self.device_domains.get(&self_device).ok_or_else(|| {
892            anyhow::anyhow!(
893                "device domain for '{}' not found; register_mr must be called before initialize_qp",
894                self_device
895            )
896        })?;
897        let (domain_context, domain_pd) = (domain.context, domain.pd);
898
899        let qp = IbvQueuePair::new(domain_context, domain_pd, self.config.clone())
900            .map_err(|e| anyhow::anyhow!("could not create IbvQueuePair: {}", e))?;
901
902        // Insert the QP into the nested map structure
903        self.device_qps
904            .entry(self_device.clone())
905            .or_insert_with(HashMap::new)
906            .insert(inner_key, qp);
907
908        tracing::debug!(
909            "successfully created a connection with {:?} for local device {} -> remote device {}",
910            other,
911            self_device,
912            other_device
913        );
914
915        Ok(true)
916    }
917
918    async fn connection_info(
919        &mut self,
920        _cx: &Context<Self>,
921        other: reference::ActorRef<IbvManagerActor>,
922        self_device: String,
923        other_device: String,
924    ) -> Result<IbvQpInfo, anyhow::Error> {
925        tracing::debug!("getting connection info with {:?}", other);
926        let other_id = other.actor_id().clone();
927
928        let inner_key = (other_id.clone(), other_device.clone());
929
930        if let Some(device_map) = self.device_qps.get_mut(&self_device) {
931            match device_map.get_mut(&inner_key) {
932                Some(qp) => {
933                    let connection_info = qp.get_qp_info()?;
934                    Ok(connection_info)
935                }
936                None => Err(anyhow::anyhow!(
937                    "No connection found for actor {}",
938                    other_id
939                )),
940            }
941        } else {
942            Err(anyhow::anyhow!(
943                "No device map found for self device {}",
944                self_device
945            ))
946        }
947    }
948
949    async fn release_queue_pair(
950        &mut self,
951        _cx: &Context<Self>,
952        _other: reference::ActorRef<IbvManagerActor>,
953        _self_device: String,
954        _other_device: String,
955        _qp: IbvQueuePair,
956    ) -> Result<(), anyhow::Error> {
957        Ok(())
958    }
959
960    async fn get_qp_state(
961        &mut self,
962        _cx: &Context<Self>,
963        other: reference::ActorRef<IbvManagerActor>,
964        self_device: String,
965        other_device: String,
966    ) -> Result<u32, anyhow::Error> {
967        let other_id = other.actor_id().clone();
968        let inner_key = (other_id.clone(), other_device.clone());
969
970        if let Some(device_map) = self.device_qps.get_mut(&self_device) {
971            match device_map.get_mut(&inner_key) {
972                Some(qp) => qp.state(),
973                None => Err(anyhow::anyhow!(
974                    "No connection found for actor {} on device {}",
975                    other_id,
976                    other_device
977                )),
978            }
979        } else {
980            Err(anyhow::anyhow!(
981                "No device map found for self device {}",
982                self_device
983            ))
984        }
985    }
986}
987
988#[async_trait]
989#[hyperactor::handle(IbvManagerLocalMessage)]
990impl IbvManagerLocalMessageHandler for IbvManagerActor {
991    async fn register_mr(
992        &mut self,
993        _cx: &Context<Self>,
994        addr: usize,
995        size: usize,
996    ) -> Result<Result<(IbvMemoryRegionView, String), String>, anyhow::Error> {
997        Ok(self.register_mr_impl(addr, size).map_err(|e| e.to_string()))
998    }
999
1000    async fn deregister_mr(
1001        &mut self,
1002        _cx: &Context<Self>,
1003        id: usize,
1004    ) -> Result<Result<(), String>, anyhow::Error> {
1005        Ok(self.deregister_mr_impl(id).map_err(|e| e.to_string()))
1006    }
1007}
1008
1009/// Wrapper around [`ActorHandle<IbvManagerActor>`] that moves the RDMA
1010/// data-plane (post send/recv, poll CQ) off the actor loop while keeping
1011/// state-mutating operations (MR registration/deregistration, QP management)
1012/// serialized through actor messages.
1013#[derive(Debug, Clone)]
1014pub struct IbvBackend(pub ActorHandle<IbvManagerActor>);
1015
1016impl std::ops::Deref for IbvBackend {
1017    type Target = ActorHandle<IbvManagerActor>;
1018    fn deref(&self) -> &Self::Target {
1019        &self.0
1020    }
1021}
1022
1023impl IbvBackend {
1024    /// Waits for the completion of RDMA operations.
1025    ///
1026    /// Polls the completion queue until all specified work requests complete
1027    /// or until the timeout is reached. Pure CQ polling — no actor state needed.
1028    async fn wait_for_completion(
1029        local_buf: &IbvBuffer,
1030        qp: &mut IbvQueuePair,
1031        poll_target: PollTarget,
1032        expected_wr_ids: &[u64],
1033        timeout: Duration,
1034    ) -> Result<(), anyhow::Error> {
1035        let start_time = std::time::Instant::now();
1036
1037        let mut remaining: std::collections::HashSet<u64> =
1038            expected_wr_ids.iter().copied().collect();
1039
1040        while start_time.elapsed() < timeout {
1041            if remaining.is_empty() {
1042                return Ok(());
1043            }
1044
1045            let wr_ids_to_poll: Vec<u64> = remaining.iter().copied().collect();
1046            match qp.poll_completion(poll_target, &wr_ids_to_poll) {
1047                Ok(completions) => {
1048                    for (wr_id, _wc) in completions {
1049                        remaining.remove(&wr_id);
1050                    }
1051                    if remaining.is_empty() {
1052                        return Ok(());
1053                    }
1054                    tokio::time::sleep(Duration::from_millis(1)).await;
1055                }
1056                Err(e) => {
1057                    // When the returned error is WR_FLUSH_ERR, which is generally a
1058                    // secondary error, drain the remaining completions to find the
1059                    // original root cause error. WR_FLUSH_ERR means the QP entered
1060                    // error state due to a DIFFERENT WR's failure, so the actual root
1061                    // cause may be cached or still in the CQ.
1062                    let mut root_cause: Option<PollCompletionError> = None;
1063                    if e.is_wr_flush_err() {
1064                        for &wr_id in &wr_ids_to_poll {
1065                            if let Err(inner_err) = qp.poll_completion(poll_target, &[wr_id]) {
1066                                if !inner_err.is_wr_flush_err() {
1067                                    root_cause = Some(inner_err);
1068                                    break;
1069                                }
1070                            }
1071                        }
1072                    }
1073                    let error_detail = if let Some(cause) = root_cause {
1074                        format!(
1075                            "RDMA polling completion failed: {} (root cause: {})",
1076                            e, cause
1077                        )
1078                    } else {
1079                        format!("RDMA polling completion failed: {}", e)
1080                    };
1081                    return Err(anyhow::anyhow!(
1082                        "{} [lkey={}, rkey={}, addr=0x{:x}, size={}]",
1083                        error_detail,
1084                        local_buf.lkey,
1085                        local_buf.rkey,
1086                        local_buf.addr,
1087                        local_buf.size
1088                    ));
1089                }
1090            }
1091        }
1092        tracing::error!(
1093            "timed out while waiting on request completion for wr_ids={:?}",
1094            remaining
1095        );
1096        Err(anyhow::anyhow!(
1097            "[ibv_buffer({:?})] rdma operation did not complete in time (expected wr_ids={:?})",
1098            local_buf,
1099            expected_wr_ids
1100        ))
1101    }
1102
1103    /// Core submit logic: registers local MR via actor message, resolves remote
1104    /// IbvBuffer lazily, executes the op locally, and deregisters local MR.
1105    async fn execute_op(
1106        &self,
1107        cx: &(impl hyperactor::context::Actor + Send + Sync),
1108        op: IbvOp,
1109        timeout: Duration,
1110    ) -> Result<(), anyhow::Error> {
1111        // Register the local memory via actor message
1112        let (local_mrv, local_device_name) = self
1113            .register_mr(cx, op.local_memory.addr(), op.local_memory.size())
1114            .await?
1115            .map_err(|e| anyhow::anyhow!(e))?;
1116
1117        let local_buffer = IbvBuffer {
1118            mr_id: local_mrv.id,
1119            lkey: local_mrv.lkey,
1120            rkey: local_mrv.rkey,
1121            addr: local_mrv.rdma_addr,
1122            size: local_mrv.size,
1123            device_name: local_device_name,
1124        };
1125
1126        let op_result = async {
1127            let mut qp = self
1128                .request_queue_pair(
1129                    cx,
1130                    op.remote_manager.clone(),
1131                    local_buffer.device_name.clone(),
1132                    op.remote_buffer.device_name.clone(),
1133                )
1134                .await?
1135                .map_err(|e| anyhow::anyhow!(e))?;
1136
1137            let wr_id = match op.op_type {
1138                RdmaOpType::WriteFromLocal => qp.put(local_buffer.clone(), op.remote_buffer)?,
1139                RdmaOpType::ReadIntoLocal => qp.get(local_buffer.clone(), op.remote_buffer)?,
1140            };
1141
1142            Self::wait_for_completion(&local_buffer, &mut qp, PollTarget::Send, &wr_id, timeout)
1143                .await
1144        }
1145        .await;
1146
1147        // Always deregister the locally registered MR via actor message
1148        let dereg_result = self
1149            .deregister_mr(cx, local_buffer.mr_id)
1150            .await?
1151            .map_err(|e| anyhow::anyhow!(e));
1152
1153        match (op_result, dereg_result) {
1154            (Ok(()), Ok(())) => Ok(()),
1155            (Err(e), Ok(())) => Err(e),
1156            (Ok(()), Err(e)) => Err(e),
1157            (Err(op_err), Err(dereg_err)) => Err(anyhow::anyhow!(
1158                "deregister MR error: {}; op error: {}",
1159                dereg_err,
1160                op_err
1161            )),
1162        }
1163    }
1164}
1165
1166#[async_trait]
1167impl RdmaBackend for IbvBackend {
1168    type TransportInfo = ();
1169
1170    /// Submit a batch of RDMA operations.
1171    ///
1172    /// Resolves ibv ops, then executes each directly — registering/deregistering
1173    /// MRs via actor messages, while performing QP put/get and CQ polling locally.
1174    async fn submit(
1175        &mut self,
1176        cx: &(impl hyperactor::context::Actor + Send + Sync),
1177        ops: Vec<RdmaOp>,
1178        timeout: Duration,
1179    ) -> Result<(), anyhow::Error> {
1180        let mut ibv_ops = Vec::with_capacity(ops.len());
1181        for op in ops {
1182            let (remote_ibv_mgr, remote_ibv_buffer) =
1183                op.remote.resolve_ibv(cx).await.ok_or_else(|| {
1184                    anyhow::anyhow!("ibverbs backend not found for buffer: {:?}", op.remote)
1185                })??;
1186
1187            ibv_ops.push(IbvOp {
1188                op_type: op.op_type,
1189                local_memory: op.local.clone(),
1190                remote_buffer: remote_ibv_buffer,
1191                remote_manager: remote_ibv_mgr,
1192            });
1193        }
1194
1195        let deadline = Instant::now() + timeout;
1196        for op in ibv_ops {
1197            let remaining = deadline.saturating_duration_since(Instant::now());
1198            if remaining.is_zero() {
1199                return Err(anyhow::anyhow!("submit timed out"));
1200            }
1201            self.execute_op(cx, op, remaining).await?;
1202        }
1203        Ok(())
1204    }
1205
1206    fn transport_level(&self) -> RdmaTransportLevel {
1207        RdmaTransportLevel::Nic
1208    }
1209
1210    fn transport_info(&self) -> Option<Self::TransportInfo> {
1211        None
1212    }
1213}