monarch_rdma/
local_memory.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Local memory abstractions for RDMA operations.
10//!
11//! This module defines the [`RdmaLocalMemory`] trait and its implementations:
12//!
13//! - [`KeepaliveLocalMemory`] – wraps a raw pointer with a keepalive guard
14//!   and dispatches reads/writes to CPU or CUDA paths.
15//! - [`UnsafeLocalMemory`] – raw pointer-based handle where the caller is
16//!   responsible for lifetime management.
17
18use std::fmt::Debug;
19use std::sync::Arc;
20use std::sync::RwLock;
21
22use serde::Deserialize;
23use serde::Serialize;
24
25/// Returns `true` when `addr` is a CUDA device pointer.
26///
27/// Probes the CUDA driver via `cuPointerGetAttribute`; returns `false`
28/// when CUDA is unavailable or the pointer is not device memory.
29pub fn is_device_ptr(addr: usize) -> bool {
30    // SAFETY: FFI call that queries pointer metadata without accessing
31    // the pointed-to memory.
32    unsafe {
33        let mut mem_type: u32 = 0;
34        let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
35            &mut mem_type as *mut _ as *mut std::ffi::c_void,
36            rdmaxcel_sys::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
37            addr as rdmaxcel_sys::CUdeviceptr,
38        );
39        err == rdmaxcel_sys::CUDA_SUCCESS && mem_type == rdmaxcel_sys::CU_MEMORYTYPE_DEVICE
40    }
41}
42
43/// Handle to a contiguous region of local memory.
44///
45/// Implementations must guarantee the underlying allocation is valid for the
46/// lifetime of the implementor.
47pub trait RdmaLocalMemory: Send + Sync + Debug {
48    /// Starting virtual address of the memory region.
49    fn addr(&self) -> usize;
50
51    /// Size of the memory region in bytes.
52    fn size(&self) -> usize;
53
54    /// Copy `dst.len()` bytes from this memory region starting at `offset` into `dst`.
55    fn read_at(&self, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error>;
56
57    /// Copy `src.len()` bytes from `src` into this memory region starting at `offset`.
58    fn write_at(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error>;
59}
60
61/// Verify that an access at `offset` with `len` bytes fits within `size`.
62fn check_bounds(offset: usize, len: usize, size: usize) -> Result<(), anyhow::Error> {
63    anyhow::ensure!(
64        offset.checked_add(len).is_some_and(|end| end <= size),
65        "access at offset {offset} with length {len} exceeds region size {size}"
66    );
67    Ok(())
68}
69
70/// Copy `dst.len()` bytes from host memory at `addr + offset` into `dst`.
71///
72/// # Safety
73///
74/// The caller must ensure that `addr` points to a valid host allocation of
75/// at least `offset + dst.len()` bytes.
76unsafe fn read_cpu(addr: usize, offset: usize, dst: &mut [u8]) {
77    unsafe {
78        std::ptr::copy_nonoverlapping((addr + offset) as *const u8, dst.as_mut_ptr(), dst.len());
79    }
80}
81
82/// Copy `src.len()` bytes from `src` into host memory at `addr + offset`.
83///
84/// # Safety
85///
86/// The caller must ensure that `addr` points to a valid host allocation of
87/// at least `offset + src.len()` bytes.
88unsafe fn write_cpu(addr: usize, offset: usize, src: &[u8]) {
89    unsafe {
90        std::ptr::copy_nonoverlapping(src.as_ptr(), (addr + offset) as *mut u8, src.len());
91    }
92}
93
94/// Copy `dst.len()` bytes from device memory at `addr + offset` into `dst`.
95///
96/// # Safety
97///
98/// The caller must ensure that `addr` is a valid CUDA device pointer to an
99/// allocation of at least `offset + dst.len()` bytes.
100unsafe fn read_gpu(addr: usize, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
101    let rc = unsafe {
102        rdmaxcel_sys::rdmaxcel_cuMemcpyDtoH_v2(
103            dst.as_mut_ptr() as *mut std::ffi::c_void,
104            (addr + offset) as rdmaxcel_sys::CUdeviceptr,
105            dst.len(),
106        )
107    };
108    anyhow::ensure!(
109        rc == rdmaxcel_sys::CUDA_SUCCESS,
110        "cuMemcpyDtoH failed with error code {rc}"
111    );
112    Ok(())
113}
114
115/// Copy `src.len()` bytes from `src` into device memory at `addr + offset`.
116///
117/// # Safety
118///
119/// The caller must ensure that `addr` is a valid CUDA device pointer to an
120/// allocation of at least `offset + src.len()` bytes.
121unsafe fn write_gpu(addr: usize, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
122    let rc = unsafe {
123        rdmaxcel_sys::rdmaxcel_cuMemcpyHtoD_v2(
124            (addr + offset) as rdmaxcel_sys::CUdeviceptr,
125            src.as_ptr() as *const std::ffi::c_void,
126            src.len(),
127        )
128    };
129    anyhow::ensure!(
130        rc == rdmaxcel_sys::CUDA_SUCCESS,
131        "cuMemcpyHtoD failed with error code {rc}"
132    );
133    Ok(())
134}
135
136/// Marker trait: the implementor keeps a backing memory allocation alive.
137///
138/// As long as a value implementing this trait exists, the memory region
139/// described by the containing [`KeepaliveLocalMemory`] is guaranteed to
140/// remain valid.
141pub trait Keepalive: Send + Sync {}
142
143/// Local memory handle that keeps its backing allocation alive via an
144/// [`Arc<dyn Keepalive>`].
145///
146/// Detects at construction time whether the address is a CUDA device
147/// pointer and dispatches `read_at`/`write_at` accordingly.
148///
149/// The `direct_access_host_bandwidth` and `direct_access_device_bandwidth`
150/// fields indicate the speed of reading the memory via pointer dereference
151/// on a host or device thread, respectively. A value of `None` means the
152/// memory is not directly accessible from that context.
153#[derive(Clone)]
154pub struct KeepaliveLocalMemory {
155    addr: usize,
156    size: usize,
157    /// Bandwidth (bytes/s) for direct host-thread pointer access, or `None`
158    /// if the memory is not host-accessible.
159    direct_access_host_bandwidth: Option<u64>,
160    /// Bandwidth (bytes/s) for direct device-thread pointer access, or
161    /// `None` if the memory is not device-accessible.
162    direct_access_device_bandwidth: Option<u64>,
163    _keepalive: Arc<dyn Keepalive>,
164    guard: Arc<RwLock<()>>,
165}
166
167impl Debug for KeepaliveLocalMemory {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        f.debug_struct("KeepaliveLocalMemory")
170            .field("addr", &self.addr)
171            .field("size", &self.size)
172            .field(
173                "direct_access_host_bandwidth",
174                &self.direct_access_host_bandwidth,
175            )
176            .field(
177                "direct_access_device_bandwidth",
178                &self.direct_access_device_bandwidth,
179            )
180            .finish_non_exhaustive()
181    }
182}
183
184impl KeepaliveLocalMemory {
185    /// Create a new handle. Probes the CUDA driver to determine whether
186    /// `addr` is a device pointer and sets the bandwidth fields
187    /// accordingly.
188    pub fn new(addr: usize, size: usize, keepalive: Arc<dyn Keepalive>) -> Self {
189        // TODO(slurye): Using placeholder values for now. Fill in with real values.
190        let (host_bw, device_bw) = if is_device_ptr(addr) {
191            (None, Some(1))
192        } else {
193            (Some(1), None)
194        };
195        Self {
196            addr,
197            size,
198            direct_access_host_bandwidth: host_bw,
199            direct_access_device_bandwidth: device_bw,
200            _keepalive: keepalive,
201            guard: Arc::new(RwLock::new(())),
202        }
203    }
204}
205
206impl RdmaLocalMemory for KeepaliveLocalMemory {
207    fn addr(&self) -> usize {
208        self.addr
209    }
210
211    fn size(&self) -> usize {
212        self.size
213    }
214
215    fn read_at(&self, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
216        let _lock = self.guard.read().expect("lock poisoned");
217        check_bounds(offset, dst.len(), self.size)?;
218        // SAFETY: The keepalive guard guarantees the allocation is live, and
219        // check_bounds verified the access is in range.
220        unsafe {
221            if self.direct_access_host_bandwidth.is_some() {
222                read_cpu(self.addr, offset, dst);
223                Ok(())
224            } else {
225                read_gpu(self.addr, offset, dst)
226            }
227        }
228    }
229
230    fn write_at(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
231        let _lock = self.guard.write().expect("lock poisoned");
232        check_bounds(offset, src.len(), self.size)?;
233        // SAFETY: The keepalive guard guarantees the allocation is live, and
234        // check_bounds verified the access is in range.
235        unsafe {
236            if self.direct_access_host_bandwidth.is_some() {
237                write_cpu(self.addr, offset, src);
238                Ok(())
239            } else {
240                write_gpu(self.addr, offset, src)
241            }
242        }
243    }
244}
245
246/// Raw pointer-based local memory handle that supports both CPU and GPU memory.
247///
248/// Wraps a virtual address and size. The caller is responsible for
249/// ensuring the underlying allocation outlives this handle. Uses
250/// `is_device_ptr` to dispatch reads/writes to the appropriate CPU or CUDA
251/// path, just like [`KeepaliveLocalMemory`].
252#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct UnsafeLocalMemory {
254    pub addr: usize,
255    pub size: usize,
256}
257
258impl UnsafeLocalMemory {
259    pub fn new(addr: usize, size: usize) -> Self {
260        Self { addr, size }
261    }
262}
263
264impl RdmaLocalMemory for UnsafeLocalMemory {
265    fn addr(&self) -> usize {
266        self.addr
267    }
268
269    fn size(&self) -> usize {
270        self.size
271    }
272
273    fn read_at(&self, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
274        check_bounds(offset, dst.len(), self.size)?;
275        // SAFETY: The caller is responsible for ensuring the allocation is
276        // live; check_bounds verified the access is in range.
277        unsafe {
278            if is_device_ptr(self.addr) {
279                read_gpu(self.addr, offset, dst)
280            } else {
281                read_cpu(self.addr, offset, dst);
282                Ok(())
283            }
284        }
285    }
286
287    fn write_at(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
288        check_bounds(offset, src.len(), self.size)?;
289        // SAFETY: The caller is responsible for ensuring the allocation is
290        // live; check_bounds verified the access is in range.
291        unsafe {
292            if is_device_ptr(self.addr) {
293                write_gpu(self.addr, offset, src)
294            } else {
295                write_cpu(self.addr, offset, src);
296                Ok(())
297            }
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    // -- KeepaliveLocalMemory (host) --
307
308    impl Keepalive for Vec<u8> {}
309
310    fn host_keepalive_mem(data: Vec<u8>) -> KeepaliveLocalMemory {
311        let addr = data.as_ptr() as usize;
312        let size = data.len();
313        KeepaliveLocalMemory::new(addr, size, Arc::new(data))
314    }
315
316    #[test]
317    fn keepalive_host_read_at() {
318        let mem = host_keepalive_mem(vec![1, 2, 3, 4, 5]);
319        let mut buf = [0u8; 3];
320        mem.read_at(1, &mut buf).unwrap();
321        assert_eq!(buf, [2, 3, 4]);
322    }
323
324    #[test]
325    fn keepalive_host_write_then_read() {
326        let mem = host_keepalive_mem(vec![0; 5]);
327        mem.write_at(1, &[7, 8, 9]).unwrap();
328        let mut buf = [0u8; 5];
329        mem.read_at(0, &mut buf).unwrap();
330        assert_eq!(buf, [0, 7, 8, 9, 0]);
331    }
332
333    #[test]
334    fn keepalive_host_out_of_bounds() {
335        let mem = host_keepalive_mem(vec![0; 3]);
336        let mut buf = [0u8; 3];
337        assert!(mem.read_at(1, &mut buf).is_err());
338        assert!(mem.write_at(1, &[7, 8, 9]).is_err());
339    }
340}