Skip to main content

monarch_rdma/
rdma_components.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # RDMA Components
10//!
11//! This module provides the core RDMA building blocks for establishing and managing RDMA connections.
12//!
13//! ## Core Components
14//!
15//! * `IbvDomain` - Manages RDMA resources including context, protection domain, and memory region
16//! * `IbvQueuePair` - Handles communication between endpoints via queue pairs and completion queues
17//!
18//! ## RDMA Overview
19//!
20//! Remote Direct Memory Access (RDMA) allows direct memory access from the memory of one computer
21//! into the memory of another without involving either computer's operating system. This permits
22//! high-throughput, low-latency networking with minimal CPU overhead.
23//!
24//! ## Connection Architecture
25//!
26//! The module manages the following ibverbs primitives:
27//!
28//! 1. **Queue Pairs (QP)**: Each connection has a send queue and a receive queue
29//! 2. **Completion Queues (CQ)**: Events are reported when operations complete
30//! 3. **Memory Regions (MR)**: Memory must be registered with the RDMA device before use
31//! 4. **Protection Domains (PD)**: Provide isolation between different connections
32//!
33//! ## Connection Lifecycle
34//!
35//! 1. Create an `IbvDomain` with `new()`
36//! 2. Create an `IbvQueuePair` from the domain
37//! 3. Exchange connection info with remote peer (application must handle this)
38//! 4. Connect to remote endpoint with `connect()`
39//! 5. Perform RDMA operations (read/write)
40//! 6. Poll for completions
41//! 7. Resources are cleaned up when dropped
42
43/// Maximum size for a single RDMA operation in bytes (1 GiB)
44use std::fs;
45use std::result::Result;
46use std::sync::Arc;
47use std::time::Duration;
48
49use hyperactor::ActorRef;
50use hyperactor::context;
51use serde::Deserialize;
52use serde::Serialize;
53use typeuri::Named;
54
55use crate::RdmaManagerActor;
56use crate::RdmaOp;
57use crate::RdmaOpType;
58use crate::ReleaseBufferClient;
59use crate::backend::RdmaBackend;
60use crate::backend::RdmaRemoteBackendContext;
61use crate::backend::ibverbs::IbvBuffer;
62use crate::backend::ibverbs::manager_actor::IbvBackend;
63use crate::backend::ibverbs::manager_actor::IbvManagerActor;
64use crate::backend::tcp::manager_actor::TcpBackend;
65use crate::backend::tcp::manager_actor::TcpManagerActor;
66use crate::local_memory::KeepaliveLocalMemory;
67
68/// Lightweight handle representing a registered RDMA buffer.
69///
70/// Contains an id for the buffer registration, the buffer size, a reference
71/// to the owning [`RdmaManagerActor`], and backend-specific contexts for
72/// performing RDMA operations.
73#[derive(Debug, Named, Clone, Serialize, Deserialize)]
74pub struct RdmaRemoteBuffer {
75    pub id: usize,
76    pub size: usize,
77    pub owner: ActorRef<RdmaManagerActor>,
78    pub backends: Vec<RdmaRemoteBackendContext>,
79}
80wirevalue::register_type!(RdmaRemoteBuffer);
81
82/// Backend handle returned by [`RdmaRemoteBuffer::choose_backend`].
83///
84/// `RdmaBackend` is not object-safe (associated type + generic parameter
85/// on `submit`), so we use an enum that delegates to the concrete handle.
86#[derive(Debug)]
87pub enum RdmaLocalBackend {
88    Ibv(IbvBackend),
89    Tcp(TcpBackend),
90}
91
92impl RdmaLocalBackend {
93    async fn submit(
94        &mut self,
95        cx: &(impl context::Actor + Send + Sync),
96        ops: Vec<RdmaOp>,
97        timeout: Duration,
98    ) -> Result<(), anyhow::Error> {
99        match self {
100            RdmaLocalBackend::Ibv(h) => h.submit(cx, ops, timeout).await,
101            RdmaLocalBackend::Tcp(h) => h.submit(cx, ops, timeout).await,
102        }
103    }
104}
105
106impl RdmaRemoteBuffer {
107    /// Choose the best available backend for this buffer.
108    ///
109    /// Prefers ibverbs when both the local and remote sides support it.
110    /// Falls back to TCP when ibverbs is unavailable and
111    /// [`RDMA_ALLOW_TCP_FALLBACK`](crate::config::RDMA_ALLOW_TCP_FALLBACK)
112    /// is enabled.
113    pub async fn choose_backend(
114        &self,
115        client: &(impl context::Actor + Send + Sync),
116    ) -> Result<RdmaLocalBackend, anyhow::Error> {
117        if self.has_ibverbs_backend() {
118            if let Ok(ibv_handle) = IbvManagerActor::local_handle(client).await {
119                return Ok(RdmaLocalBackend::Ibv(IbvBackend(ibv_handle)));
120            }
121
122            return self
123                .tcp_fallback_or_bail("no ibverbs backend on the local side", client)
124                .await;
125        }
126
127        self.tcp_fallback_or_bail(
128            &format!(
129                "no ibverbs backend on the remote side (owner={})",
130                self.owner.actor_addr()
131            ),
132            client,
133        )
134        .await
135    }
136
137    /// Push data from local memory into this remote buffer (local->remote).
138    pub async fn write_from_local(
139        &self,
140        client: &(impl context::Actor + Send + Sync),
141        local: Arc<KeepaliveLocalMemory>,
142        timeout: u64,
143    ) -> Result<bool, anyhow::Error> {
144        let mut backend = self.choose_backend(client).await?;
145        backend
146            .submit(
147                client,
148                vec![RdmaOp {
149                    op_type: RdmaOpType::WriteFromLocal,
150                    local,
151                    remote: self.clone(),
152                }],
153                Duration::from_secs(timeout),
154            )
155            .await?;
156        Ok(true)
157    }
158
159    /// Pull data from this remote buffer into local memory (remote->local).
160    pub async fn read_into_local(
161        &self,
162        client: &(impl context::Actor + Send + Sync),
163        local: Arc<KeepaliveLocalMemory>,
164        timeout: u64,
165    ) -> Result<bool, anyhow::Error> {
166        let mut backend = self.choose_backend(client).await?;
167        backend
168            .submit(
169                client,
170                vec![RdmaOp {
171                    op_type: RdmaOpType::ReadIntoLocal,
172                    local,
173                    remote: self.clone(),
174                }],
175                Duration::from_secs(timeout),
176            )
177            .await?;
178        Ok(true)
179    }
180
181    /// Get a TCP backend handle, or bail if TCP fallback is disabled.
182    async fn tcp_fallback_or_bail(
183        &self,
184        reason: &str,
185        client: &(impl context::Actor + Send + Sync),
186    ) -> Result<RdmaLocalBackend, anyhow::Error> {
187        if !hyperactor_config::global::get(crate::config::RDMA_ALLOW_TCP_FALLBACK) {
188            anyhow::bail!(
189                "{reason}, and TCP fallback is disabled; \
190                 enable it with monarch.configure(rdma_allow_tcp_fallback=True)"
191            );
192        }
193
194        tracing::warn!("falling back to TCP transport ({reason})");
195
196        let tcp_handle = TcpManagerActor::local_handle(client).await?;
197        Ok(RdmaLocalBackend::Tcp(TcpBackend(tcp_handle)))
198    }
199
200    /// Drop the buffer and release remote handles.
201    pub async fn drop_buffer(&self, client: &impl context::Actor) -> Result<(), anyhow::Error> {
202        tracing::debug!("[buffer] dropping buffer id={}", self.id);
203        self.owner.release_buffer(client, self.id).await?;
204        Ok(())
205    }
206
207    /// Whether this buffer has an ibverbs backend context.
208    fn has_ibverbs_backend(&self) -> bool {
209        self.backends
210            .iter()
211            .any(|b| matches!(b, RdmaRemoteBackendContext::Ibverbs(..)))
212    }
213
214    /// Extract the ibverbs backend context for this buffer.
215    ///
216    /// Returns `None` if the buffer has no ibverbs backend context
217    /// (i.e., the remote side was created without ibverbs).
218    pub fn resolve_ibv(&self) -> Option<(ActorRef<IbvManagerActor>, IbvBuffer)> {
219        self.backends.iter().find_map(|b| match b {
220            RdmaRemoteBackendContext::Ibverbs(mgr, buf) => Some((mgr.clone(), buf.clone())),
221            _ => None,
222        })
223    }
224
225    /// Extract the TCP backend context from this buffer.
226    pub fn resolve_tcp(&self) -> Result<(ActorRef<TcpManagerActor>, usize), anyhow::Error> {
227        self.backends
228            .iter()
229            .find_map(|b| match b {
230                RdmaRemoteBackendContext::Tcp(tcp_ref) => Some((tcp_ref.clone(), self.id)),
231                _ => None,
232            })
233            .ok_or_else(|| anyhow::anyhow!("tcp backend not found for buffer: {:?}", self))
234    }
235}
236
237/// Utility to validate execution context.
238///
239/// Remote Execution environments do not always have access to the nvidia_peermem module
240/// and/or set the PeerMappingOverride parameter due to security. This function can be
241/// used to validate that the execution context when running operations that need this
242/// functionality (ie. cudaHostRegisterIoMemory).
243///
244/// # Returns
245///
246/// * `Ok(())` if the execution context is valid
247/// * `Err(anyhow::Error)` if the execution context is invalid
248pub async fn validate_execution_context() -> Result<(), anyhow::Error> {
249    // Check for nvidia peermem
250    match fs::read_to_string("/proc/modules") {
251        Ok(contents) => {
252            if !contents.contains("nvidia_peermem") {
253                return Err(anyhow::anyhow!(
254                    "nvidia_peermem module not found in /proc/modules"
255                ));
256            }
257        }
258        Err(e) => {
259            return Err(anyhow::anyhow!(e));
260        }
261    }
262
263    // Test file access to nvidia params
264    match fs::read_to_string("/proc/driver/nvidia/params") {
265        Ok(contents) => {
266            if !contents.contains("PeerMappingOverride=1") {
267                return Err(anyhow::anyhow!(
268                    "PeerMappingOverride=1 not found in /proc/driver/nvidia/params"
269                ));
270            }
271        }
272        Err(e) => {
273            return Err(anyhow::anyhow!(e));
274        }
275    }
276    Ok(())
277}
278
279/// Get all segments that have been registered with MRs for the given PD.
280///
281/// Each protection domain maintains independent segment registrations, so
282/// callers must pass the PD whose lkeys they intend to use.
283pub fn get_registered_cuda_segments(
284    pd: *mut rdmaxcel_sys::ibv_pd,
285) -> Vec<rdmaxcel_sys::rdma_segment_info_t> {
286    unsafe {
287        let segment_count = rdmaxcel_sys::rdma_get_active_segment_count(pd);
288        if segment_count <= 0 {
289            return Vec::new();
290        }
291
292        let mut segments = vec![
293            std::mem::MaybeUninit::<rdmaxcel_sys::rdma_segment_info_t>::zeroed()
294                .assume_init();
295            segment_count as usize
296        ];
297        let actual_count = rdmaxcel_sys::rdma_get_all_registered_segment_info(
298            pd,
299            segments.as_mut_ptr(),
300            segment_count,
301        );
302
303        if actual_count > 0 {
304            segments.truncate(actual_count as usize);
305            segments
306        } else {
307            Vec::new()
308        }
309    }
310}
311
312/// Segment scanner callback type alias for convenience.
313pub type SegmentScannerFn = rdmaxcel_sys::RdmaxcelSegmentScannerFn;
314
315/// Register a segment scanner callback.
316///
317/// The scanner callback is called during RDMA segment registration to discover
318/// CUDA memory segments. The callback should fill the provided buffer with
319/// segment information and return the total count of segments found.
320///
321/// If the returned count exceeds the buffer size, the caller will allocate
322/// a larger buffer and retry.
323///
324/// Pass `None` to unregister the scanner.
325///
326/// # Safety
327///
328/// The provided callback function must be safe to call from C code and must
329/// properly handle the segment buffer.
330pub fn register_segment_scanner(scanner: SegmentScannerFn) {
331    // SAFETY: We are registering a callback function pointer with rdmaxcel.
332    unsafe { rdmaxcel_sys::rdmaxcel_register_segment_scanner(scanner) }
333}