Skip to main content

monarch_rdma/backend/ibverbs/
manager_actor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # Ibverbs Manager
10//!
11//! Contains ibverbs-specific RDMA logic.
12//!
13//! Manages ibverbs resources including:
14//! - Memory registration (CPU and CUDA via dmabuf or segment scanning)
15//! - Queue pair creation and connection establishment
16//! - RDMA domain and protection domain management
17//! - Device selection and PCI-to-RDMA device mapping
18//!
19//! ## Queue-pair lifecycle
20//!
21//! Bringing up a queue pair to a peer is a two-sided handshake (each
22//! side has its own QP and must learn the other side's endpoint
23//! before transitioning `INIT → RTR → RTS`). Doing all of that in
24//! response to a single message would block our actor loop while
25//! awaiting peer RPCs, and the peer's symmetric request would block
26//! waiting for us — a deadlock.
27//!
28//! Instead, [`IbvManagerActor`] does only sync bookkeeping in the
29//! handler and offloads the handshake to a per-QP child actor,
30//! [`QueuePairInitializer`]. The store of QPs ([`Self::qps`]) is
31//! keyed by [`QpKey`] and holds a [`QpState`]: `Pending { info,
32//! initializer, waiters }` while the handshake runs, `Ready(qp)`
33//! once this side is RTS and has observed the peer's RTS, or
34//! `Failed(error)` as a tombstone after a fatal error.
35
36use std::collections::HashMap;
37use std::sync::Arc;
38use std::sync::OnceLock;
39use std::time::Duration;
40use std::time::Instant;
41
42use anyhow::Result;
43use async_trait::async_trait;
44use backoff::ExponentialBackoff;
45use backoff::ExponentialBackoffBuilder;
46use backoff::backoff::Backoff;
47use hyperactor::Actor;
48use hyperactor::ActorHandle;
49use hyperactor::ActorRef;
50use hyperactor::Context;
51use hyperactor::Endpoint as _;
52use hyperactor::HandleClient;
53use hyperactor::Handler;
54use hyperactor::Instance;
55use hyperactor::OncePortHandle;
56use hyperactor::PortRef;
57use hyperactor::RefClient;
58use hyperactor::actor::Referable;
59use serde::Deserialize;
60use serde::Serialize;
61use typeuri::Named;
62
63use super::IbvBuffer;
64use super::IbvOp;
65use super::domain::IbvDomain;
66use super::primitives::IbvConfig;
67use super::primitives::IbvDevice;
68use super::primitives::IbvMemoryRegion;
69use super::primitives::IbvMemoryRegionView;
70use super::primitives::IbvQpInfo;
71use super::primitives::ibverbs_supported;
72use super::primitives::mlx5dv_supported;
73use super::primitives::resolve_qp_type;
74use super::queue_pair::IbvQueuePair;
75use super::queue_pair::PeerInfo;
76use super::queue_pair::PollCompletionError;
77use super::queue_pair::PollTarget;
78use super::queue_pair::QpGuard;
79use super::queue_pair::QpKey;
80use super::queue_pair::QueuePairInitializer;
81use super::queue_pair::destroy_qp;
82use crate::RdmaOp;
83use crate::RdmaOpType;
84use crate::RdmaTransportLevel;
85use crate::backend::RdmaBackend;
86use crate::local_memory::KeepaliveLocalMemory;
87use crate::rdma_components::get_registered_cuda_segments;
88use crate::rdma_manager_actor::GetIbvActorRefClient;
89use crate::rdma_manager_actor::RdmaManagerActor;
90use crate::validate_execution_context;
91
92/// Cross-proc message: peer asks for our endpoint, lazily creating
93/// the entry on our side if absent. Generic over the manager actor
94/// type so tests can swap in a mock.
95#[derive(Debug, Serialize, Deserialize, Named)]
96#[serde(bound(serialize = "", deserialize = ""))]
97pub(super) struct EnsureQueuePair<A: Referable> {
98    pub(super) sender: ActorRef<A>,
99    pub(super) sender_device: String,
100    pub(super) receiver_device: String,
101    pub(super) reply: PortRef<PeerInfo>,
102}
103wirevalue::register_type!(EnsureQueuePair<IbvManagerActor>);
104
105/// Per-QpKey state in [`IbvManagerActor::qps`].
106///
107/// `Pending` covers the entire handshake (an initializer is running);
108/// `Ready` is the terminal usable state; `Failed` is a tombstone that
109/// records the error so subsequent `RequestQueuePair` / `EnsureQueuePair`
110/// calls for the same key surface the same error rather than retrying
111/// or hanging.
112///
113/// TODO: add recovery — allow retries via an explicit message or after
114/// a backoff. For now the entry stays `Failed` for the life of the
115/// manager.
116#[derive(Debug)]
117enum QpState {
118    Pending {
119        /// Local endpoint, captured when the QP was first created so
120        /// repeated `EnsureQueuePair` calls don't have to re-extract it.
121        info: IbvQpInfo,
122        /// Child actor driving the handshake. Stopped on
123        /// `QpInitializerDone`/`QpInitializerFailed`.
124        initializer: ActorHandle<QueuePairInitializer<IbvManagerActor>>,
125        /// Local `RequestQueuePair` callers waiting for the QP. Drained
126        /// to `Ok(qp.clone())` on `Ready`, or `Err(_)` on failure.
127        waiters: Vec<OncePortHandle<Result<IbvQueuePair, String>>>,
128    },
129    Ready(IbvQueuePair),
130    Failed(String),
131}
132
133/// Cross-proc messages handled by [`IbvManagerActor`].
134///
135/// `EnsureQueuePair` is defined as a separate top-level message
136/// because it's generic over the manager actor type to allow
137/// mocking in tests.
138#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
139pub enum IbvManagerMessage {
140    /// Release a buffer registration by `remote_buf_id`. Fire-and-forget
141    /// (no reply port) to avoid blocking the caller during teardown.
142    ReleaseBuffer { remote_buf_id: usize },
143}
144wirevalue::register_type!(IbvManagerMessage);
145
146/// Local-only messages for [`IbvManagerActor`].
147#[derive(Handler, HandleClient, Debug)]
148pub enum IbvManagerLocalMessage {
149    /// Register a memory region, returning the MR view and device name.
150    RegisterMr {
151        addr: usize,
152        size: usize,
153        #[reply]
154        reply: OncePortHandle<Result<(IbvMemoryRegionView, String), String>>,
155    },
156    /// Register a remote-facing buffer's MR and return its
157    /// [`IbvBuffer`]. Called by
158    /// [`crate::rdma_manager_actor::RdmaManagerActor::request_buffer`]
159    /// at buffer-creation time.
160    ///
161    /// The MR lives in [`IbvManagerActor::buffer_registrations`] and
162    /// is deregistered on [`IbvManagerMessage::ReleaseBuffer`].
163    RegisterRemoteBuffer {
164        remote_buf_id: usize,
165        local: Arc<KeepaliveLocalMemory>,
166        #[reply]
167        reply: OncePortHandle<Result<IbvBuffer, String>>,
168    },
169    /// User-facing entry point: get a connected `IbvQueuePair` for
170    /// `(self_device, other actor's id, other_device)`. Lazily creates
171    /// the QP + initializer if absent; if a handshake is in flight,
172    /// the reply port is queued and answered when the QP becomes
173    /// `Ready` (or fails).
174    ///
175    /// No `#[reply]` because the handler may park `reply` on the
176    /// `Pending` entry and answer it later from [`QpInitializerDone`]/
177    /// [`QpInitializerFailed`].
178    RequestQueuePair {
179        other: ActorRef<IbvManagerActor>,
180        self_device: String,
181        other_device: String,
182        reply: OncePortHandle<Result<IbvQueuePair, String>>,
183    },
184}
185
186/// Local-only handshake-success report. The initializer sends this
187/// to its owning manager once both sides have reached RTS, handing
188/// over the freshly-connected [`QpGuard`].
189#[derive(Debug)]
190pub(super) struct QpInitializerDone {
191    pub(super) qp_key: QpKey,
192    pub(super) qp: QpGuard,
193}
194
195/// Local-only handshake-failure report. The initializer sends this
196/// to its owning manager when the handshake aborted; the underlying
197/// QP has already been dropped on the initializer side.
198#[derive(Debug)]
199pub(super) struct QpInitializerFailed {
200    pub(super) qp_key: QpKey,
201    pub(super) error: String,
202}
203
204/// Adaptive wait between completion polls.
205///
206/// While the elapsed time since [`Self::yield_now`] was first called
207/// is below `yield_window`, the policy yields cooperatively
208/// (`tokio::task::yield_now`) — keeping latency tight when the WR
209/// completes shortly after being posted. `tokio::time::sleep` has a
210/// minimum resolution of ~1ms (the timer wheel tick), so even a
211/// `sleep(Duration::from_micros(100))` would block that long; `yield_now` is
212/// sub-millisecond and lets the next poll fire as soon as the runtime
213/// schedules us. Past `yield_window` the policy switches to an
214/// exponential backoff (1ms initial, doubling, capped at 10ms) so
215/// long-running operations don't keep the runtime spinning.
216///
217/// `yield_window` is read from
218/// [`crate::config::RDMA_CQ_BUSY_POLL_WINDOW`]. When it's `None`
219/// (the default) the policy disables the cutoff and only ever
220/// yields, never sleeps.
221struct PollSleepPolicy {
222    yield_window: Option<Duration>,
223    started_at: Option<Instant>,
224    backoff: Option<ExponentialBackoff>,
225}
226
227impl PollSleepPolicy {
228    fn new() -> Self {
229        let yield_window = hyperactor_config::global::get(crate::config::RDMA_CQ_BUSY_POLL_WINDOW);
230        Self {
231            yield_window,
232            started_at: None,
233            backoff: None,
234        }
235    }
236
237    /// Suspend the current task before the next poll. If no yield
238    /// window is configured (the default), always yields. Otherwise,
239    /// yields while within the window and then walks an exponential
240    /// backoff up to 10ms past it.
241    async fn yield_now(&mut self) {
242        let Some(window) = self.yield_window else {
243            tokio::task::yield_now().await;
244            return;
245        };
246        let started = *self.started_at.get_or_insert_with(Instant::now);
247        if started.elapsed() < window {
248            tokio::task::yield_now().await;
249            return;
250        }
251        let backoff = self.backoff.get_or_insert_with(|| {
252            ExponentialBackoffBuilder::new()
253                .with_initial_interval(Duration::from_millis(1))
254                .with_max_interval(Duration::from_millis(10))
255                .with_multiplier(2.0)
256                .with_randomization_factor(0.0)
257                .with_max_elapsed_time(None)
258                .build()
259        });
260        match backoff.next_backoff() {
261            Some(delay) => tokio::time::sleep(delay).await,
262            None => tokio::task::yield_now().await,
263        }
264    }
265}
266
267/// Look up `(addr, size)` in a slice of registered CUDA segments
268/// and return a view into the matching mkey.
269///
270/// Bounded by `mr_size` (what the mkey actually covers), NOT by
271/// `phys_size` (the scanner-reported extent). They diverge when
272/// `register_segments` hits `max_sge` and stops growing the binding.
273/// Returning a view based on `phys_size` would hand out an
274/// `(lkey, offset)` past the bound and the WR would fail with
275/// `IBV_WC_LOC_PROT_ERR`; bounding by `mr_size` makes the gap a
276/// miss so the caller falls back to per-buffer dmabuf.
277///
278/// Free function so the boundary can be unit-tested without an actor.
279pub(super) fn lookup_segment_for_address(
280    segments: &[rdmaxcel_sys::rdma_segment_info_t],
281    addr: usize,
282    size: usize,
283) -> Option<SegmentInfo> {
284    for segment in segments {
285        let start_addr = segment.phys_address;
286        let end_addr = start_addr + segment.mr_size;
287        if start_addr <= addr && addr + size <= end_addr {
288            let offset = addr - start_addr;
289            let rdma_addr = segment.mr_addr + offset;
290            return Some(SegmentInfo {
291                rdma_addr,
292                size,
293                lkey: segment.lkey,
294                rkey: segment.rkey,
295            });
296        }
297    }
298    None
299}
300
301/// Result of a successful [`lookup_segment_for_address`] hit. Just
302/// the device-derived facts about the matched mkey; the caller
303/// composes these with whatever provenance it needs (mrv id,
304/// device name, refcounted owners) when materializing an
305/// [`IbvMemoryRegionView`].
306#[derive(Debug)]
307pub(super) struct SegmentInfo {
308    pub(super) rdma_addr: usize,
309    pub(super) size: usize,
310    pub(super) lkey: u32,
311    pub(super) rkey: u32,
312}
313
314/// Manages all ibverbs-specific RDMA resources and operations.
315///
316/// This struct handles memory registration, queue pair management,
317/// and connection establishment using the ibverbs API.
318#[derive(Debug)]
319#[hyperactor::export(
320    handlers = [
321        IbvManagerMessage,
322        EnsureQueuePair<IbvManagerActor>,
323    ],
324)]
325pub struct IbvManagerActor {
326    owner: OnceLock<ActorHandle<RdmaManagerActor>>,
327
328    /// Per-QP state, keyed from this manager's perspective. See [`QpKey`].
329    qps: HashMap<QpKey, QpState>,
330
331    /// Map of RDMA device names to their domains and loopback QPs.
332    /// Created lazily when memory is registered for a specific
333    /// device. `Arc<IbvDomain>` so every `IbvMemoryRegionView`
334    /// registered against the domain can hold a clone and keep the
335    /// PD alive until the last MR is dereg'd.
336    device_domains: HashMap<String, (Arc<IbvDomain>, Option<IbvQueuePair>)>,
337
338    config: IbvConfig,
339
340    mlx5dv_enabled: bool,
341
342    /// Singleton Arc owning the CUDA segment scanner state. `Some`
343    /// once `register_segments` succeeds; cloned into every
344    /// segment-backed view. `deregister_segments` runs from the
345    /// `IbvMemoryRegion::Segments` Drop when the last reference goes
346    /// away.
347    segments_mr: Option<Arc<IbvMemoryRegion>>,
348
349    /// Id for next mrv created.
350    mrv_id: usize,
351
352    /// Map from buffer_id to the buffer's `(IbvBuffer, view)`. The
353    /// view keeps the MR (and its PD) alive for the lifetime of the
354    /// registration; `ReleaseBuffer` drops the entry, and the FFI
355    /// resources are released by the `Arc`s' `Drop`s once no other
356    /// holder of those views remains.
357    buffer_registrations: HashMap<usize, (IbvBuffer, IbvMemoryRegionView)>,
358}
359
360#[async_trait]
361impl Actor for IbvManagerActor {
362    async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
363        let owner = if let Some(owner) = this.parent_handle() {
364            owner
365        } else {
366            anyhow::bail!("RdmaManagerActor not found as parent of IbvManagerActor");
367        };
368        self.owner
369            .set(owner)
370            .expect("owner should only be set once during init");
371        Ok(())
372    }
373}
374
375impl Drop for IbvManagerActor {
376    fn drop(&mut self) {
377        // 1. Clean up QPs. `Pending` entries hold the qp via the
378        // initializer; signal the initializer to stop and let the
379        // runtime tear it down. `Failed` is a tombstone with no
380        // resources. Pending waiters won't be answered and their
381        // callers will observe the dropped reply ports as `Err(_)`.
382        for (_key, state) in self.qps.drain() {
383            match state {
384                QpState::Ready(_) => {
385                    // TODO(slurye): Proper cleanup of QPs. Currently there's no safe way to do this
386                    // because `IbvQueuePair` can have arbitrary clones and there's no way to guarantee
387                    // that none of them are still in use.
388                }
389                QpState::Pending { initializer, .. } => {
390                    let _ = initializer.drain_and_stop("IbvManagerActor dropped");
391                }
392                QpState::Failed(_) => {}
393            }
394        }
395
396        // 2. Drop buffer registrations. Each entry's
397        // `IbvMemoryRegionView` carries the only manager-side
398        // reference to that MR's `Arc<IbvMemoryRegion>` and its
399        // `Arc<IbvDomain>`. They run their FFI cleanup from their
400        // `Drop`s once no surviving view holds a clone.
401        self.buffer_registrations.clear();
402
403        // 3. Drop the segments-owner Arc. `deregister_segments`
404        // runs from `IbvMemoryRegion::Segments::Drop` when the last
405        // segment-backed view also goes away.
406        self.segments_mr.take();
407
408        // 4. Drop device domains (PDs + loopback QPs). Loopback QPs
409        // are tied to this map only; destroy them explicitly. PD
410        // teardown waits on the last `Arc<IbvDomain>` clone (held
411        // by any outstanding view) to drop.
412        for (_device_name, (domain, qp)) in self.device_domains.drain() {
413            if let Some(qp) = qp {
414                // SAFETY: `device_domains` is the only holder of
415                // these loopback QPs; we just drained it and the
416                // manager is being dropped, so no clones survive.
417                unsafe { destroy_qp(&qp) };
418            }
419            drop(domain);
420        }
421    }
422}
423
424impl IbvManagerActor {
425    /// Construct an [`ActorHandle`] for the [`IbvManagerActor`] co-located
426    /// with the caller by querying the local [`RdmaManagerActor`].
427    pub async fn local_handle(
428        client: &(impl hyperactor::context::Actor + Send + Sync),
429    ) -> Result<ActorHandle<Self>, anyhow::Error> {
430        let rdma_handle = RdmaManagerActor::local_handle(client);
431        let ibv_ref: ActorRef<IbvManagerActor> = rdma_handle
432            .get_ibv_actor_ref(client)
433            .await?
434            .ok_or_else(|| anyhow::anyhow!("local RdmaManagerActor has no ibverbs backend"))?;
435        ibv_ref
436            .downcast_handle(client)
437            .ok_or_else(|| anyhow::anyhow!("IbvManagerActor is not in the local process"))
438    }
439
440    /// Create a new IbvManagerActor with the given configuration.
441    pub async fn new(params: Option<IbvConfig>) -> Result<Self, anyhow::Error> {
442        if !ibverbs_supported() {
443            return Err(anyhow::anyhow!(
444                "Cannot create IbvManagerActor because RDMA is not supported on this machine"
445            ));
446        }
447
448        // Use provided config or default if none provided
449        let mut config = params.unwrap_or_default();
450        tracing::debug!("rdma is enabled, config device hint: {}", config.device);
451
452        let mlx5dv_enabled = resolve_qp_type(config.qp_type) == rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV;
453
454        // check config and hardware support align
455        if config.use_gpu_direct {
456            match validate_execution_context().await {
457                Ok(_) => {
458                    tracing::info!("GPU Direct RDMA execution context validated successfully");
459                }
460                Err(e) => {
461                    tracing::warn!(
462                        "GPU Direct RDMA execution context validation failed: {}. Downgrading to standard ibverbs mode.",
463                        e
464                    );
465                    config.use_gpu_direct = false;
466                }
467            }
468        }
469
470        let actor = Self {
471            owner: OnceLock::new(),
472            qps: HashMap::new(),
473            device_domains: HashMap::new(),
474            config,
475            mlx5dv_enabled,
476            segments_mr: None,
477            mrv_id: 0,
478            buffer_registrations: HashMap::new(),
479        };
480
481        Ok(actor)
482    }
483
484    /// Get or create a domain and loopback QP for the specified RDMA device
485    fn get_or_create_device_domain(
486        &mut self,
487        device_name: &str,
488        rdma_device: &IbvDevice,
489    ) -> Result<(Arc<IbvDomain>, Option<IbvQueuePair>), anyhow::Error> {
490        if let Some((domain, qp)) = self.device_domains.get(device_name) {
491            return Ok((Arc::clone(domain), qp.clone()));
492        }
493
494        // Create new domain for this device
495        let domain = Arc::new(IbvDomain::new(rdma_device.clone()).map_err(|e| {
496            anyhow::anyhow!("could not create domain for device {}: {}", device_name, e)
497        })?);
498
499        // Print device info if MONARCH_DEBUG_RDMA=1 is set (before initial QP creation)
500        crate::print_device_info_if_debug_enabled(domain.context);
501
502        // Create loopback QP for this domain if mlx5dv is supported (needed for segment registration)
503        // For EFA, we don't need a loopback QP for segment scanning
504        let qp = if mlx5dv_supported() && !crate::efa::is_efa_device() {
505            let mut qp = QpGuard::new(
506                IbvQueuePair::new(domain.context, domain.pd, self.config.clone()).map_err(|e| {
507                    anyhow::anyhow!(
508                        "could not create loopback QP for device {}: {}",
509                        device_name,
510                        e
511                    )
512                })?,
513            );
514
515            // Get connection info and connect to itself
516            let endpoint = qp.get_qp_info().map_err(|e| {
517                anyhow::anyhow!("could not get QP info for device {}: {}", device_name, e)
518            })?;
519
520            qp.connect(&endpoint).map_err(|e| {
521                anyhow::anyhow!(
522                    "could not connect loopback QP for device {}: {}",
523                    device_name,
524                    e
525                )
526            })?;
527
528            Some(qp)
529        } else {
530            None
531        };
532
533        let qp = qp.map(|qp| qp.into_inner());
534        self.device_domains
535            .insert(device_name.to_string(), (Arc::clone(&domain), qp.clone()));
536        Ok((domain, qp))
537    }
538
539    /// Build parallel PD/QP arrays indexed by CUDA device ordinal
540    /// for the C++ register_segments call.
541    fn build_per_device_pd_qp_arrays(
542        &self,
543    ) -> (
544        Vec<*mut rdmaxcel_sys::ibv_pd>,
545        Vec<*mut rdmaxcel_sys::rdmaxcel_qp_t>,
546    ) {
547        let cuda_map = super::device_selection::get_cuda_device_to_ibv_device();
548        let mut pds = Vec::with_capacity(cuda_map.len());
549        let mut qps = Vec::with_capacity(cuda_map.len());
550        for maybe_device in cuda_map {
551            if let Some(device) = maybe_device {
552                if let Some((domain, qp)) = self.device_domains.get(device.name()) {
553                    pds.push(domain.pd);
554                    qps.push(
555                        qp.as_ref()
556                            .map(|q| q.qp as *mut rdmaxcel_sys::rdmaxcel_qp_t)
557                            .unwrap_or(std::ptr::null_mut()),
558                    );
559                } else {
560                    pds.push(std::ptr::null_mut());
561                    qps.push(std::ptr::null_mut());
562                }
563            } else {
564                pds.push(std::ptr::null_mut());
565                qps.push(std::ptr::null_mut());
566            }
567        }
568        (pds, qps)
569    }
570
571    fn find_cuda_segment_for_address(
572        &self,
573        addr: usize,
574        size: usize,
575        pd: *mut rdmaxcel_sys::ibv_pd,
576    ) -> Option<SegmentInfo> {
577        lookup_segment_for_address(&get_registered_cuda_segments(pd), addr, size)
578    }
579
580    fn register_mr_impl(
581        &mut self,
582        addr: usize,
583        size: usize,
584    ) -> Result<(IbvMemoryRegionView, String), anyhow::Error> {
585        unsafe {
586            let mut mem_type: i32 = 0;
587            let ptr = addr as rdmaxcel_sys::CUdeviceptr;
588            let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
589                &mut mem_type as *mut _ as *mut std::ffi::c_void,
590                rdmaxcel_sys::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
591                ptr,
592            );
593            let is_cuda = err == rdmaxcel_sys::CUDA_SUCCESS;
594
595            let mut selected_rdma_device = None;
596
597            if is_cuda {
598                // Get device ordinal from the CUDA pointer
599                let mut device_ordinal: i32 = -1;
600                let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
601                    &mut device_ordinal as *mut _ as *mut std::ffi::c_void,
602                    rdmaxcel_sys::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
603                    ptr,
604                );
605                if err == rdmaxcel_sys::CUDA_SUCCESS && device_ordinal >= 0 {
606                    selected_rdma_device = super::device_selection::get_cuda_device_to_ibv_device()
607                        .get(device_ordinal as usize)
608                        .and_then(|d| d.clone());
609                }
610            }
611
612            // Determine the RDMA device to use
613            let rdma_device = if let Some(device) = selected_rdma_device {
614                device
615            } else {
616                // Fallback to default device from config
617                self.config.device.clone()
618            };
619
620            let device_name = rdma_device.name().clone();
621            tracing::debug!(
622                "Using RDMA device: {} for memory at 0x{:x}",
623                device_name,
624                addr
625            );
626
627            // Get or create domain and loopback QP for this device
628            let (domain, _qp) = self.get_or_create_device_domain(&device_name, &rdma_device)?;
629
630            let access = if crate::efa::is_efa_device() {
631                crate::efa::mr_access_flags()
632            } else {
633                rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
634                    | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
635                    | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
636                    | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC
637            };
638
639            let mrv;
640
641            if is_cuda {
642                // First, try to use segment scanning if mlx5dv is enabled
643                let mut segment_info = None;
644                if self.mlx5dv_enabled {
645                    // Try to find in already registered segments
646                    segment_info = self.find_cuda_segment_for_address(addr, size, domain.pd);
647
648                    // If not found, trigger a re-sync with the allocator and retry
649                    if segment_info.is_none() {
650                        let (mut pds, mut qps) = self.build_per_device_pd_qp_arrays();
651                        let err = rdmaxcel_sys::register_segments(
652                            pds.as_mut_ptr(),
653                            qps.as_mut_ptr(),
654                            pds.len() as i32,
655                            self.config.max_sge_override,
656                        );
657                        // Only retry if register_segments succeeded
658                        // If it fails (e.g., scanner returns 0 segments), we'll fall back to dmabuf
659                        if err == 0 {
660                            // The scanner just registered (or
661                            // re-synced) global segment state. Lazily
662                            // install the singleton `segments_mr` now,
663                            // independent of whether *this* address
664                            // matches a segment. Without this, a
665                            // subsequent retry that doesn't find a
666                            // segment would leak the newly-registered
667                            // global state on manager teardown.
668                            self.segments_mr
669                                .get_or_insert_with(|| Arc::new(IbvMemoryRegion::Segments));
670                            segment_info =
671                                self.find_cuda_segment_for_address(addr, size, domain.pd);
672                        }
673                    }
674                }
675
676                // Use segment if found, otherwise fall back to direct dmabuf registration
677                if let Some(info) = segment_info {
678                    let segments_mr = Arc::clone(
679                        self.segments_mr
680                            .get_or_insert_with(|| Arc::new(IbvMemoryRegion::Segments)),
681                    );
682                    let id = self.mrv_id;
683                    self.mrv_id += 1;
684                    mrv = IbvMemoryRegionView::new(
685                        id,
686                        addr,
687                        info.rdma_addr,
688                        info.size,
689                        info.lkey,
690                        info.rkey,
691                        device_name.clone(),
692                        segments_mr,
693                    );
694                } else {
695                    // Dmabuf path: used when mlx5dv is disabled OR scanner returns no segments
696                    let mut fd: i32 = -1;
697                    let cu_err = rdmaxcel_sys::rdmaxcel_cuMemGetHandleForAddressRange(
698                        &mut fd,
699                        addr as rdmaxcel_sys::CUdeviceptr,
700                        size,
701                        rdmaxcel_sys::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD,
702                        0,
703                    );
704                    if cu_err != rdmaxcel_sys::CUDA_SUCCESS || fd < 0 {
705                        return Err(anyhow::anyhow!(
706                            "failed to get dmabuf handle for CUDA memory (addr: 0x{:x}, size: {}, cu_err: {}, fd: {})",
707                            addr,
708                            size,
709                            cu_err,
710                            fd
711                        ));
712                    }
713                    let mr =
714                        rdmaxcel_sys::ibv_reg_dmabuf_mr(domain.pd, 0, size, 0, fd, access.0 as i32);
715                    if mr.is_null() {
716                        return Err(anyhow::anyhow!("Failed to register dmabuf MR"));
717                    }
718                    let id = self.mrv_id;
719                    self.mrv_id += 1;
720                    mrv = IbvMemoryRegionView::new(
721                        id,
722                        addr,
723                        (*mr).addr as usize,
724                        size,
725                        (*mr).lkey,
726                        (*mr).rkey,
727                        device_name.clone(),
728                        Arc::new(IbvMemoryRegion::Direct {
729                            mr,
730                            _domain: Arc::clone(&domain),
731                        }),
732                    );
733                }
734            } else {
735                // CPU memory path
736                let mr = rdmaxcel_sys::ibv_reg_mr(
737                    domain.pd,
738                    addr as *mut std::ffi::c_void,
739                    size,
740                    access.0 as i32,
741                );
742
743                if mr.is_null() {
744                    return Err(anyhow::anyhow!("failed to register standard MR"));
745                }
746
747                let id = self.mrv_id;
748                self.mrv_id += 1;
749                mrv = IbvMemoryRegionView::new(
750                    id,
751                    addr,
752                    (*mr).addr as usize,
753                    size,
754                    (*mr).lkey,
755                    (*mr).rkey,
756                    device_name.clone(),
757                    Arc::new(IbvMemoryRegion::Direct {
758                        mr,
759                        _domain: Arc::clone(&domain),
760                    }),
761                );
762            }
763            Ok((mrv, device_name))
764        }
765    }
766
767    /// Lazy QP creation: if `qp_key` is absent, create the local
768    /// `IbvQueuePair`, capture its `IbvQpInfo`, and spawn a
769    /// `QueuePairInitializer` to drive the handshake. Returns the
770    /// `QpState` entry — either the freshly-inserted `Pending` one,
771    /// or the existing `Pending`/`Ready`/`Failed`.
772    fn ensure_queue_pair_impl(
773        &mut self,
774        cx: &Context<'_, Self>,
775        other: ActorRef<IbvManagerActor>,
776        qp_key: &QpKey,
777    ) -> Result<&mut QpState, anyhow::Error> {
778        if !self.qps.contains_key(qp_key) {
779            let self_device = &qp_key.self_device;
780            let rdma_device = super::primitives::get_all_devices()
781                .into_iter()
782                .find(|d| d.name() == self_device)
783                .ok_or_else(|| anyhow::anyhow!("RDMA device '{}' not found", self_device))?;
784            let (domain, _) = self.get_or_create_device_domain(self_device, &rdma_device)?;
785            // Wrap the freshly-created QP in a `QpGuard` immediately
786            // so that any early-return path below (e.g. `get_qp_info`
787            // failing) destroys the underlying `rdmaxcel_qp_t` via
788            // the guard's `Drop` rather than leaking it.
789            let mut qp = QpGuard::new(
790                IbvQueuePair::new(domain.context, domain.pd, self.config.clone())
791                    .map_err(|e| anyhow::anyhow!("could not create IbvQueuePair: {}", e))?,
792            );
793            let info = qp
794                .get_qp_info()
795                .map_err(|e| anyhow::anyhow!("could not extract QP info: {}", e))?;
796            let initializer =
797                QueuePairInitializer::new(Instance::handle(cx), other, qp_key.clone(), qp)
798                    .spawn(cx)?;
799            self.qps.insert(
800                qp_key.clone(),
801                QpState::Pending {
802                    info,
803                    initializer,
804                    waiters: Vec::new(),
805                },
806            );
807        }
808        Ok(self
809            .qps
810            .get_mut(qp_key)
811            .expect("entry just inserted or pre-existing"))
812    }
813}
814
815#[async_trait]
816#[hyperactor::handle(IbvManagerMessage)]
817impl IbvManagerMessageHandler for IbvManagerActor {
818    async fn release_buffer(
819        &mut self,
820        _cx: &Context<Self>,
821        remote_buf_id: usize,
822    ) -> Result<(), anyhow::Error> {
823        // Dropping the entry releases the manager's `Arc` clones on
824        // the view's MR and PD; FFI cleanup happens via their `Drop`s
825        // once the last referencing view is gone.
826        self.buffer_registrations.remove(&remote_buf_id);
827        Ok(())
828    }
829}
830
831#[async_trait]
832impl Handler<EnsureQueuePair<IbvManagerActor>> for IbvManagerActor {
833    async fn handle(
834        &mut self,
835        cx: &Context<Self>,
836        msg: EnsureQueuePair<IbvManagerActor>,
837    ) -> Result<(), anyhow::Error> {
838        let EnsureQueuePair {
839            sender,
840            sender_device,
841            receiver_device,
842            reply,
843        } = msg;
844        let qp_key = QpKey {
845            self_device: receiver_device,
846            other_id: sender.actor_addr().id().clone(),
847            other_device: sender_device,
848        };
849        let state = match self.ensure_queue_pair_impl(cx, sender, &qp_key) {
850            Ok(state) => state,
851            Err(e) => {
852                reply.post(cx, PeerInfo(Err(e.to_string())));
853                return Ok(());
854            }
855        };
856        match state {
857            QpState::Pending {
858                info, initializer, ..
859            } => {
860                let notify_rts = initializer.bind::<QueuePairInitializer<Self>>().port();
861                reply.post(cx, PeerInfo(Ok((info.clone(), notify_rts))));
862            }
863            QpState::Ready(_) => {
864                // `Ready` means a prior handshake completed and the
865                // initializer was stopped — we can't hand back an
866                // initializer ref. Reaching here represents a logic
867                // error (peer is asking us to redo a handshake we've
868                // already finished); surface it as `Err`.
869                reply.post(
870                    cx,
871                    PeerInfo(Err(format!(
872                        "EnsureQueuePair on already-Ready entry {qp_key:?}"
873                    ))),
874                );
875            }
876            QpState::Failed(error) => {
877                reply.post(cx, PeerInfo(Err(error.clone())));
878            }
879        }
880        Ok(())
881    }
882}
883
884#[async_trait]
885#[hyperactor::handle(IbvManagerLocalMessage)]
886impl IbvManagerLocalMessageHandler for IbvManagerActor {
887    async fn register_mr(
888        &mut self,
889        _cx: &Context<Self>,
890        addr: usize,
891        size: usize,
892    ) -> Result<Result<(IbvMemoryRegionView, String), String>, anyhow::Error> {
893        Ok(self.register_mr_impl(addr, size).map_err(|e| e.to_string()))
894    }
895
896    async fn register_remote_buffer(
897        &mut self,
898        _cx: &Context<Self>,
899        remote_buf_id: usize,
900        local: Arc<KeepaliveLocalMemory>,
901    ) -> Result<Result<IbvBuffer, String>, anyhow::Error> {
902        if let Some((buf, _)) = self.buffer_registrations.get(&remote_buf_id) {
903            return Ok(Ok(buf.clone()));
904        }
905        let (mrv, device_name) = match self.register_mr_impl(local.addr(), local.size()) {
906            Ok(v) => v,
907            Err(e) => return Ok(Err(e.to_string())),
908        };
909        let buf = IbvBuffer {
910            mr_id: mrv.id,
911            lkey: mrv.lkey,
912            rkey: mrv.rkey,
913            addr: mrv.rdma_addr,
914            size: mrv.size,
915            device_name,
916        };
917        self.buffer_registrations
918            .insert(remote_buf_id, (buf.clone(), mrv));
919        Ok(Ok(buf))
920    }
921
922    async fn request_queue_pair(
923        &mut self,
924        cx: &Context<Self>,
925        other: ActorRef<IbvManagerActor>,
926        self_device: String,
927        other_device: String,
928        reply: OncePortHandle<Result<IbvQueuePair, String>>,
929    ) -> Result<(), anyhow::Error> {
930        let qp_key = QpKey {
931            self_device,
932            other_id: other.actor_addr().id().clone(),
933            other_device,
934        };
935        let state = match self.ensure_queue_pair_impl(cx, other, &qp_key) {
936            Ok(state) => state,
937            Err(e) => {
938                reply.post(cx, Err(e.to_string()));
939                return Ok(());
940            }
941        };
942        match state {
943            QpState::Pending { waiters, .. } => waiters.push(reply),
944            QpState::Ready(qp) => reply.post(cx, Ok(qp.clone())),
945            QpState::Failed(error) => reply.post(cx, Err(error.clone())),
946        }
947        Ok(())
948    }
949}
950
951#[async_trait]
952impl Handler<QpInitializerDone> for IbvManagerActor {
953    async fn handle(
954        &mut self,
955        cx: &Context<Self>,
956        msg: QpInitializerDone,
957    ) -> Result<(), anyhow::Error> {
958        let QpInitializerDone { qp_key, qp } = msg;
959        let qp = qp.into_inner();
960        // Take the entry out, transition to Ready, drain waiters,
961        // then stop the initializer.
962        let initializer = match self.qps.remove(&qp_key) {
963            Some(QpState::Pending {
964                waiters,
965                initializer,
966                ..
967            }) => {
968                for w in waiters {
969                    w.post(cx, Ok(qp.clone()));
970                }
971                initializer
972            }
973            other => {
974                unreachable!("QpInitializerDone received but state is {other:?}: {qp_key:?}")
975            }
976        };
977        self.qps.insert(qp_key.clone(), QpState::Ready(qp));
978        initializer.drain_and_stop("QpInitializerDone")?;
979        let status = initializer.await;
980        if status.is_failed() {
981            // The QP itself is already `Ready` and waiters have been
982            // drained, so a non-clean initializer shutdown is not
983            // user-visible — log and move on rather than crashing
984            // the manager.
985            tracing::error!(
986                "QueuePairInitializer for {qp_key:?} terminated with failure after Done: {status:?}"
987            );
988        }
989        Ok(())
990    }
991}
992
993#[async_trait]
994impl Handler<QpInitializerFailed> for IbvManagerActor {
995    async fn handle(
996        &mut self,
997        cx: &Context<Self>,
998        msg: QpInitializerFailed,
999    ) -> Result<(), anyhow::Error> {
1000        let QpInitializerFailed { qp_key, error } = msg;
1001        let initializer = match self.qps.remove(&qp_key) {
1002            Some(QpState::Pending {
1003                waiters,
1004                initializer,
1005                ..
1006            }) => {
1007                for w in waiters {
1008                    w.post(cx, Err(error.clone()));
1009                }
1010                initializer
1011            }
1012            other => {
1013                unreachable!("QpInitializerFailed received but state is {other:?}: {qp_key:?}")
1014            }
1015        };
1016        // Tombstone the entry: subsequent `RequestQueuePair` calls
1017        // for the same key surface the same error rather than
1018        // retrying or hanging. TODO: add recovery.
1019        self.qps.insert(qp_key.clone(), QpState::Failed(error));
1020        initializer.drain_and_stop("QpInitializerFailed")?;
1021        let status = initializer.await;
1022        if status.is_failed() {
1023            tracing::error!(
1024                "QueuePairInitializer for {qp_key:?} terminated with failure after Failed: {status:?}"
1025            );
1026        }
1027        Ok(())
1028    }
1029}
1030
1031/// Free helper around [`IbvManagerLocalMessage::RequestQueuePair`] — opens
1032/// a `OncePortHandle` for the reply, sends the message, and awaits the
1033/// answer. Exists because `RequestQueuePair` doesn't use `#[reply]`
1034/// (the handler may park the port until the QP becomes `Ready`), so
1035/// the auto-derived client method only does fire-and-forget.
1036pub(super) async fn request_queue_pair(
1037    actor: &ActorHandle<IbvManagerActor>,
1038    cx: &(impl hyperactor::context::Actor + Send + Sync),
1039    other: ActorRef<IbvManagerActor>,
1040    self_device: String,
1041    other_device: String,
1042) -> Result<Result<IbvQueuePair, String>, anyhow::Error> {
1043    let (reply, rx) = cx
1044        .mailbox()
1045        .open_once_port::<Result<IbvQueuePair, String>>();
1046    actor
1047        .request_queue_pair(cx, other, self_device, other_device, reply)
1048        .await?;
1049    rx.recv()
1050        .await
1051        .map_err(|e| anyhow::anyhow!("request_queue_pair port closed: {e}"))
1052}
1053
1054/// Wrapper around [`ActorHandle<IbvManagerActor>`] that moves the RDMA
1055/// data-plane (post send/recv, poll CQ) off the actor loop while keeping
1056/// state-mutating operations (MR registration/deregistration, QP management)
1057/// serialized through actor messages.
1058#[derive(Debug, Clone)]
1059pub struct IbvBackend(pub ActorHandle<IbvManagerActor>);
1060
1061impl std::ops::Deref for IbvBackend {
1062    type Target = ActorHandle<IbvManagerActor>;
1063    fn deref(&self) -> &Self::Target {
1064        &self.0
1065    }
1066}
1067
1068impl IbvBackend {
1069    /// Waits for the completion of RDMA operations.
1070    ///
1071    /// Polls the completion queue until all specified work requests complete
1072    /// or until the timeout is reached. Pure CQ polling — no actor state needed.
1073    async fn wait_for_completion(
1074        local_buf: &IbvBuffer,
1075        qp: &mut IbvQueuePair,
1076        poll_target: PollTarget,
1077        expected_wr_ids: &[u64],
1078        timeout: Duration,
1079    ) -> Result<(), anyhow::Error> {
1080        let start_time = std::time::Instant::now();
1081
1082        let mut remaining: std::collections::HashSet<u64> =
1083            expected_wr_ids.iter().copied().collect();
1084        let mut poll_policy = PollSleepPolicy::new();
1085
1086        while start_time.elapsed() < timeout {
1087            if remaining.is_empty() {
1088                return Ok(());
1089            }
1090
1091            let wr_ids_to_poll: Vec<u64> = remaining.iter().copied().collect();
1092            match qp.poll_completion(poll_target, &wr_ids_to_poll) {
1093                Ok(completions) => {
1094                    for (wr_id, _wc) in completions {
1095                        remaining.remove(&wr_id);
1096                    }
1097                    if remaining.is_empty() {
1098                        return Ok(());
1099                    }
1100                    poll_policy.yield_now().await;
1101                }
1102                Err(e) => {
1103                    // When the returned error is WR_FLUSH_ERR, which is generally a
1104                    // secondary error, drain the remaining completions to find the
1105                    // original root cause error. WR_FLUSH_ERR means the QP entered
1106                    // error state due to a DIFFERENT WR's failure, so the actual root
1107                    // cause may be cached or still in the CQ.
1108                    let mut root_cause: Option<PollCompletionError> = None;
1109                    if e.is_wr_flush_err() {
1110                        for &wr_id in &wr_ids_to_poll {
1111                            if let Err(inner_err) = qp.poll_completion(poll_target, &[wr_id]) {
1112                                if !inner_err.is_wr_flush_err() {
1113                                    root_cause = Some(inner_err);
1114                                    break;
1115                                }
1116                            }
1117                        }
1118                    }
1119                    let error_detail = if let Some(cause) = root_cause {
1120                        format!(
1121                            "RDMA polling completion failed: {} (root cause: {})",
1122                            e, cause
1123                        )
1124                    } else {
1125                        format!("RDMA polling completion failed: {}", e)
1126                    };
1127                    return Err(anyhow::anyhow!(
1128                        "{} [lkey={}, rkey={}, addr=0x{:x}, size={}]",
1129                        error_detail,
1130                        local_buf.lkey,
1131                        local_buf.rkey,
1132                        local_buf.addr,
1133                        local_buf.size
1134                    ));
1135                }
1136            }
1137        }
1138        tracing::error!(
1139            "timed out while waiting on request completion for wr_ids={:?}",
1140            remaining
1141        );
1142        Err(anyhow::anyhow!(
1143            "[ibv_buffer({:?})] rdma operation did not complete in time (expected wr_ids={:?})",
1144            local_buf,
1145            expected_wr_ids
1146        ))
1147    }
1148
1149    /// Core submit logic: registers a local MR via actor message,
1150    /// resolves the remote `IbvBuffer` lazily, and executes the op.
1151    /// `local_mrv` is kept in scope for the duration of the op so
1152    /// its `Arc<IbvMemoryRegion>` (and the PD it lives on) survive
1153    /// until completion; on drop, the FFI MR is deregistered.
1154    async fn execute_op(
1155        &self,
1156        cx: &(impl hyperactor::context::Actor + Send + Sync),
1157        op: IbvOp,
1158        timeout: Duration,
1159    ) -> Result<(), anyhow::Error> {
1160        let (local_mrv, local_device_name) = self
1161            .register_mr(cx, op.local_memory.addr(), op.local_memory.size())
1162            .await?
1163            .map_err(|e| anyhow::anyhow!(e))?;
1164
1165        let local_buffer = IbvBuffer {
1166            mr_id: local_mrv.id,
1167            lkey: local_mrv.lkey,
1168            rkey: local_mrv.rkey,
1169            addr: local_mrv.rdma_addr,
1170            size: local_mrv.size,
1171            device_name: local_device_name,
1172        };
1173
1174        let result = async {
1175            let mut qp = request_queue_pair(
1176                &self.0,
1177                cx,
1178                op.remote_manager.clone(),
1179                local_buffer.device_name.clone(),
1180                op.remote_buffer.device_name.clone(),
1181            )
1182            .await?
1183            .map_err(|e| anyhow::anyhow!(e))?;
1184
1185            let wr_id = match op.op_type {
1186                RdmaOpType::WriteFromLocal => qp.put(local_buffer.clone(), op.remote_buffer)?,
1187                RdmaOpType::ReadIntoLocal => qp.get(local_buffer.clone(), op.remote_buffer)?,
1188            };
1189
1190            Self::wait_for_completion(&local_buffer, &mut qp, PollTarget::Send, &wr_id, timeout)
1191                .await
1192        }
1193        .await;
1194
1195        drop(local_mrv);
1196        result
1197    }
1198}
1199
1200#[async_trait]
1201impl RdmaBackend for IbvBackend {
1202    type TransportInfo = ();
1203
1204    /// Submit a batch of RDMA operations.
1205    ///
1206    /// Resolves ibv ops, then executes each directly — registering/deregistering
1207    /// MRs via actor messages, while performing QP put/get and CQ polling locally.
1208    async fn submit(
1209        &mut self,
1210        cx: &(impl hyperactor::context::Actor + Send + Sync),
1211        ops: Vec<RdmaOp>,
1212        timeout: Duration,
1213    ) -> Result<(), anyhow::Error> {
1214        let mut ibv_ops = Vec::with_capacity(ops.len());
1215        for op in ops {
1216            let (remote_manager, remote_buffer) = op.remote.resolve_ibv().ok_or_else(|| {
1217                anyhow::anyhow!("ibverbs backend not found for buffer: {:?}", op.remote)
1218            })?;
1219            ibv_ops.push(IbvOp {
1220                op_type: op.op_type,
1221                local_memory: op.local.clone(),
1222                remote_buffer,
1223                remote_manager,
1224            });
1225        }
1226
1227        let deadline = Instant::now() + timeout;
1228        for op in ibv_ops {
1229            let remaining = deadline.saturating_duration_since(Instant::now());
1230            if remaining.is_zero() {
1231                return Err(anyhow::anyhow!("submit timed out"));
1232            }
1233            self.execute_op(cx, op, remaining).await?;
1234        }
1235        Ok(())
1236    }
1237
1238    fn transport_level(&self) -> RdmaTransportLevel {
1239        RdmaTransportLevel::Nic
1240    }
1241
1242    fn transport_info(&self) -> Option<Self::TransportInfo> {
1243        None
1244    }
1245}