monarch_rdma/
local_memory.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Local memory abstractions for RDMA operations.
10//!
11//! This module defines the [`RdmaLocalMemory`] trait and its implementations:
12//!
13//! - [`KeepaliveLocalMemory`] – wraps a raw pointer with a keepalive guard
14//!   and dispatches reads/writes to CPU or CUDA paths.
15//! - [`UnsafeLocalMemory`] – raw pointer-based handle where the caller is
16//!   responsible for lifetime management.
17
18use std::fmt::Debug;
19use std::sync::Arc;
20use std::sync::RwLock;
21
22use serde::Deserialize;
23use serde::Serialize;
24
25/// Returns `true` when `addr` is a CUDA device pointer.
26///
27/// Probes the CUDA driver via `cuPointerGetAttribute`; returns `false`
28/// when CUDA is unavailable or the pointer is not device memory.
29pub fn is_device_ptr(addr: usize) -> bool {
30    // SAFETY: FFI call that queries pointer metadata without accessing
31    // the pointed-to memory.
32    unsafe {
33        let mut mem_type: u32 = 0;
34        let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
35            &mut mem_type as *mut _ as *mut std::ffi::c_void,
36            rdmaxcel_sys::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
37            addr as rdmaxcel_sys::CUdeviceptr,
38        );
39        err == rdmaxcel_sys::CUDA_SUCCESS && mem_type == rdmaxcel_sys::CU_MEMORYTYPE_DEVICE
40    }
41}
42
43/// RAII guard that restores the previous CUDA context on drop and, if a
44/// primary context was retained, releases it.
45pub(crate) struct CudaCtxGuard {
46    prev: rdmaxcel_sys::CUcontext,
47    /// Set when the fallback path called `cuDevicePrimaryCtxRetain`.
48    retained_device: Option<rdmaxcel_sys::CUdevice>,
49}
50
51impl Drop for CudaCtxGuard {
52    fn drop(&mut self) {
53        unsafe {
54            rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.prev);
55            if let Some(device) = self.retained_device {
56                rdmaxcel_sys::rdmaxcel_cuDevicePrimaryCtxRelease(device);
57            }
58        }
59    }
60}
61
62/// Make the CUDA context that owns `addr` current on the calling
63/// thread, returning a guard that restores the previous context on
64/// drop.
65///
66/// First tries `CU_POINTER_ATTRIBUTE_CONTEXT` to get the exact context
67/// the allocation belongs to.  When that returns null (runtime-API or
68/// memory-pool allocations such as PyTorch's caching allocator), falls
69/// back to the device's primary context via
70/// `CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL` + `cuDevicePrimaryCtxRetain`.
71///
72/// # Safety
73///
74/// `addr` must be a valid CUDA device pointer.
75pub(crate) unsafe fn set_ctx_for_ptr(addr: usize) -> Result<CudaCtxGuard, anyhow::Error> {
76    let mut prev: rdmaxcel_sys::CUcontext = std::ptr::null_mut();
77    unsafe {
78        rdmaxcel_sys::rdmaxcel_cuCtxGetCurrent(&mut prev);
79    }
80
81    let mut ctx: rdmaxcel_sys::CUcontext = std::ptr::null_mut();
82    let rc = unsafe {
83        rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
84            &mut ctx as *mut _ as *mut std::ffi::c_void,
85            rdmaxcel_sys::CU_POINTER_ATTRIBUTE_CONTEXT,
86            addr as rdmaxcel_sys::CUdeviceptr,
87        )
88    };
89
90    // Null context: allocation came from the runtime API or a memory
91    // pool.  Fall back to the owning device's primary context.
92    let mut retained_device = None;
93    if rc != rdmaxcel_sys::CUDA_SUCCESS || ctx.is_null() {
94        let mut ordinal: i32 = -1;
95        let rc = unsafe {
96            rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
97                &mut ordinal as *mut _ as *mut std::ffi::c_void,
98                rdmaxcel_sys::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
99                addr as rdmaxcel_sys::CUdeviceptr,
100            )
101        };
102        anyhow::ensure!(
103            rc == rdmaxcel_sys::CUDA_SUCCESS,
104            "cuPointerGetAttribute(DEVICE_ORDINAL) failed with error code {rc}"
105        );
106
107        let mut device: rdmaxcel_sys::CUdevice = 0;
108        let rc = unsafe { rdmaxcel_sys::rdmaxcel_cuDeviceGet(&mut device, ordinal) };
109        anyhow::ensure!(
110            rc == rdmaxcel_sys::CUDA_SUCCESS,
111            "cuDeviceGet({ordinal}) failed with error code {rc}"
112        );
113
114        let rc = unsafe { rdmaxcel_sys::rdmaxcel_cuDevicePrimaryCtxRetain(&mut ctx, device) };
115        anyhow::ensure!(
116            rc == rdmaxcel_sys::CUDA_SUCCESS,
117            "cuDevicePrimaryCtxRetain failed with error code {rc}"
118        );
119        retained_device = Some(device);
120    }
121
122    let rc = unsafe { rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(ctx) };
123    anyhow::ensure!(
124        rc == rdmaxcel_sys::CUDA_SUCCESS,
125        "cuCtxSetCurrent failed with error code {rc}"
126    );
127
128    Ok(CudaCtxGuard {
129        prev,
130        retained_device,
131    })
132}
133
134/// Handle to a contiguous region of local memory.
135///
136/// Implementations must guarantee the underlying allocation is valid for the
137/// lifetime of the implementor.
138pub trait RdmaLocalMemory: Send + Sync + Debug {
139    /// Starting virtual address of the memory region.
140    fn addr(&self) -> usize;
141
142    /// Size of the memory region in bytes.
143    fn size(&self) -> usize;
144
145    /// Copy `dst.len()` bytes from this memory region starting at `offset` into `dst`.
146    fn read_at(&self, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error>;
147
148    /// Copy `src.len()` bytes from `src` into this memory region starting at `offset`.
149    fn write_at(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error>;
150}
151
152/// Verify that an access at `offset` with `len` bytes fits within `size`.
153fn check_bounds(offset: usize, len: usize, size: usize) -> Result<(), anyhow::Error> {
154    anyhow::ensure!(
155        offset.checked_add(len).is_some_and(|end| end <= size),
156        "access at offset {offset} with length {len} exceeds region size {size}"
157    );
158    Ok(())
159}
160
161/// Copy `dst.len()` bytes from host memory at `addr + offset` into `dst`.
162///
163/// # Safety
164///
165/// The caller must ensure that `addr` points to a valid host allocation of
166/// at least `offset + dst.len()` bytes.
167unsafe fn read_cpu(addr: usize, offset: usize, dst: &mut [u8]) {
168    unsafe {
169        std::ptr::copy_nonoverlapping((addr + offset) as *const u8, dst.as_mut_ptr(), dst.len());
170    }
171}
172
173/// Copy `src.len()` bytes from `src` into host memory at `addr + offset`.
174///
175/// # Safety
176///
177/// The caller must ensure that `addr` points to a valid host allocation of
178/// at least `offset + src.len()` bytes.
179unsafe fn write_cpu(addr: usize, offset: usize, src: &[u8]) {
180    unsafe {
181        std::ptr::copy_nonoverlapping(src.as_ptr(), (addr + offset) as *mut u8, src.len());
182    }
183}
184
185/// Copy `dst.len()` bytes from device memory at `addr + offset` into `dst`.
186///
187/// # Safety
188///
189/// The caller must ensure that `addr` is a valid CUDA device pointer to an
190/// allocation of at least `offset + dst.len()` bytes.
191unsafe fn read_gpu(addr: usize, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
192    let _guard = unsafe { set_ctx_for_ptr(addr)? };
193    let rc = unsafe {
194        rdmaxcel_sys::rdmaxcel_cuMemcpyDtoH_v2(
195            dst.as_mut_ptr() as *mut std::ffi::c_void,
196            (addr + offset) as rdmaxcel_sys::CUdeviceptr,
197            dst.len(),
198        )
199    };
200    anyhow::ensure!(
201        rc == rdmaxcel_sys::CUDA_SUCCESS,
202        "cuMemcpyDtoH failed with error code {rc}"
203    );
204    Ok(())
205}
206
207/// Copy `src.len()` bytes from `src` into device memory at `addr + offset`.
208///
209/// # Safety
210///
211/// The caller must ensure that `addr` is a valid CUDA device pointer to an
212/// allocation of at least `offset + src.len()` bytes.
213unsafe fn write_gpu(addr: usize, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
214    let _guard = unsafe { set_ctx_for_ptr(addr)? };
215    let rc = unsafe {
216        rdmaxcel_sys::rdmaxcel_cuMemcpyHtoD_v2(
217            (addr + offset) as rdmaxcel_sys::CUdeviceptr,
218            src.as_ptr() as *const std::ffi::c_void,
219            src.len(),
220        )
221    };
222    anyhow::ensure!(
223        rc == rdmaxcel_sys::CUDA_SUCCESS,
224        "cuMemcpyHtoD failed with error code {rc}"
225    );
226    Ok(())
227}
228
229/// Marker trait: the implementor keeps a backing memory allocation alive.
230///
231/// As long as a value implementing this trait exists, the memory region
232/// described by the containing [`KeepaliveLocalMemory`] is guaranteed to
233/// remain valid.
234pub trait Keepalive: Send + Sync {}
235
236/// Local memory handle that keeps its backing allocation alive via an
237/// [`Arc<dyn Keepalive>`].
238///
239/// Detects at construction time whether the address is a CUDA device
240/// pointer and dispatches `read_at`/`write_at` accordingly.
241///
242/// The `direct_access_host_bandwidth` and `direct_access_device_bandwidth`
243/// fields indicate the speed of reading the memory via pointer dereference
244/// on a host or device thread, respectively. A value of `None` means the
245/// memory is not directly accessible from that context.
246#[derive(Clone)]
247pub struct KeepaliveLocalMemory {
248    addr: usize,
249    size: usize,
250    /// Bandwidth (bytes/s) for direct host-thread pointer access, or `None`
251    /// if the memory is not host-accessible.
252    direct_access_host_bandwidth: Option<u64>,
253    /// Bandwidth (bytes/s) for direct device-thread pointer access, or
254    /// `None` if the memory is not device-accessible.
255    direct_access_device_bandwidth: Option<u64>,
256    _keepalive: Arc<dyn Keepalive>,
257    guard: Arc<RwLock<()>>,
258}
259
260impl Debug for KeepaliveLocalMemory {
261    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        f.debug_struct("KeepaliveLocalMemory")
263            .field("addr", &self.addr)
264            .field("size", &self.size)
265            .field(
266                "direct_access_host_bandwidth",
267                &self.direct_access_host_bandwidth,
268            )
269            .field(
270                "direct_access_device_bandwidth",
271                &self.direct_access_device_bandwidth,
272            )
273            .finish_non_exhaustive()
274    }
275}
276
277impl KeepaliveLocalMemory {
278    /// Create a new handle. Probes the CUDA driver to determine whether
279    /// `addr` is a device pointer and sets the bandwidth fields
280    /// accordingly.
281    pub fn new(addr: usize, size: usize, keepalive: Arc<dyn Keepalive>) -> Self {
282        // TODO(slurye): Using placeholder values for now. Fill in with real values.
283        let (host_bw, device_bw) = if is_device_ptr(addr) {
284            (None, Some(1))
285        } else {
286            (Some(1), None)
287        };
288        Self {
289            addr,
290            size,
291            direct_access_host_bandwidth: host_bw,
292            direct_access_device_bandwidth: device_bw,
293            _keepalive: keepalive,
294            guard: Arc::new(RwLock::new(())),
295        }
296    }
297}
298
299impl RdmaLocalMemory for KeepaliveLocalMemory {
300    fn addr(&self) -> usize {
301        self.addr
302    }
303
304    fn size(&self) -> usize {
305        self.size
306    }
307
308    fn read_at(&self, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
309        let _lock = self.guard.read().expect("lock poisoned");
310        check_bounds(offset, dst.len(), self.size)?;
311        // SAFETY: The keepalive guard guarantees the allocation is live, and
312        // check_bounds verified the access is in range.
313        unsafe {
314            if self.direct_access_host_bandwidth.is_some() {
315                read_cpu(self.addr, offset, dst);
316                Ok(())
317            } else {
318                read_gpu(self.addr, offset, dst)
319            }
320        }
321    }
322
323    fn write_at(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
324        let _lock = self.guard.write().expect("lock poisoned");
325        check_bounds(offset, src.len(), self.size)?;
326        // SAFETY: The keepalive guard guarantees the allocation is live, and
327        // check_bounds verified the access is in range.
328        unsafe {
329            if self.direct_access_host_bandwidth.is_some() {
330                write_cpu(self.addr, offset, src);
331                Ok(())
332            } else {
333                write_gpu(self.addr, offset, src)
334            }
335        }
336    }
337}
338
339/// Raw pointer-based local memory handle that supports both CPU and GPU memory.
340///
341/// Wraps a virtual address and size. The caller is responsible for
342/// ensuring the underlying allocation outlives this handle. Uses
343/// `is_device_ptr` to dispatch reads/writes to the appropriate CPU or CUDA
344/// path, just like [`KeepaliveLocalMemory`].
345#[derive(Debug, Clone, Serialize, Deserialize)]
346pub struct UnsafeLocalMemory {
347    pub addr: usize,
348    pub size: usize,
349}
350
351impl UnsafeLocalMemory {
352    pub fn new(addr: usize, size: usize) -> Self {
353        Self { addr, size }
354    }
355}
356
357impl RdmaLocalMemory for UnsafeLocalMemory {
358    fn addr(&self) -> usize {
359        self.addr
360    }
361
362    fn size(&self) -> usize {
363        self.size
364    }
365
366    fn read_at(&self, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
367        check_bounds(offset, dst.len(), self.size)?;
368        // SAFETY: The caller is responsible for ensuring the allocation is
369        // live; check_bounds verified the access is in range.
370        unsafe {
371            if is_device_ptr(self.addr) {
372                read_gpu(self.addr, offset, dst)
373            } else {
374                read_cpu(self.addr, offset, dst);
375                Ok(())
376            }
377        }
378    }
379
380    fn write_at(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
381        check_bounds(offset, src.len(), self.size)?;
382        // SAFETY: The caller is responsible for ensuring the allocation is
383        // live; check_bounds verified the access is in range.
384        unsafe {
385            if is_device_ptr(self.addr) {
386                write_gpu(self.addr, offset, src)
387            } else {
388                write_cpu(self.addr, offset, src);
389                Ok(())
390            }
391        }
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    // -- KeepaliveLocalMemory (host) --
400
401    impl Keepalive for Vec<u8> {}
402
403    fn host_keepalive_mem(data: Vec<u8>) -> KeepaliveLocalMemory {
404        let addr = data.as_ptr() as usize;
405        let size = data.len();
406        KeepaliveLocalMemory::new(addr, size, Arc::new(data))
407    }
408
409    #[test]
410    fn keepalive_host_read_at() {
411        let mem = host_keepalive_mem(vec![1, 2, 3, 4, 5]);
412        let mut buf = [0u8; 3];
413        mem.read_at(1, &mut buf).unwrap();
414        assert_eq!(buf, [2, 3, 4]);
415    }
416
417    #[test]
418    fn keepalive_host_write_then_read() {
419        let mem = host_keepalive_mem(vec![0; 5]);
420        mem.write_at(1, &[7, 8, 9]).unwrap();
421        let mut buf = [0u8; 5];
422        mem.read_at(0, &mut buf).unwrap();
423        assert_eq!(buf, [0, 7, 8, 9, 0]);
424    }
425
426    #[test]
427    fn keepalive_host_out_of_bounds() {
428        let mem = host_keepalive_mem(vec![0; 3]);
429        let mut buf = [0u8; 3];
430        assert!(mem.read_at(1, &mut buf).is_err());
431        assert!(mem.write_at(1, &[7, 8, 9]).is_err());
432    }
433}