Skip to main content

monarch_rdma/
local_memory.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Local memory abstractions for RDMA operations.
10//!
11//! [`KeepaliveLocalMemory`] wraps a raw pointer with a [`Keepalive`]
12//! guard and dispatches reads/writes to CPU or CUDA paths.
13
14use std::fmt::Debug;
15use std::sync::Arc;
16use std::sync::Condvar;
17use std::sync::Mutex;
18
19/// Returns `true` when `addr` is a CUDA device pointer.
20///
21/// Probes the CUDA driver via `cuPointerGetAttribute`; returns `false`
22/// when CUDA is unavailable or the pointer is not device memory.
23pub fn is_device_ptr(addr: usize) -> bool {
24    // SAFETY: FFI call that queries pointer metadata without accessing
25    // the pointed-to memory.
26    unsafe {
27        let mut mem_type: u32 = 0;
28        let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
29            &mut mem_type as *mut _ as *mut std::ffi::c_void,
30            rdmaxcel_sys::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
31            addr as rdmaxcel_sys::CUdeviceptr,
32        );
33        err == rdmaxcel_sys::CUDA_SUCCESS && mem_type == rdmaxcel_sys::CU_MEMORYTYPE_DEVICE
34    }
35}
36
37/// RAII guard that restores the previous CUDA context on drop and, if a
38/// primary context was retained, releases it.
39pub(crate) struct CudaCtxGuard {
40    prev: rdmaxcel_sys::CUcontext,
41    /// Set when the fallback path called `cuDevicePrimaryCtxRetain`.
42    retained_device: Option<rdmaxcel_sys::CUdevice>,
43}
44
45impl Drop for CudaCtxGuard {
46    fn drop(&mut self) {
47        unsafe {
48            rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.prev);
49            if let Some(device) = self.retained_device {
50                rdmaxcel_sys::rdmaxcel_cuDevicePrimaryCtxRelease(device);
51            }
52        }
53    }
54}
55
56/// Make the CUDA context that owns `addr` current on the calling
57/// thread, returning a guard that restores the previous context on
58/// drop.
59///
60/// First tries `CU_POINTER_ATTRIBUTE_CONTEXT` to get the exact context
61/// the allocation belongs to.  When that returns null (runtime-API or
62/// memory-pool allocations such as PyTorch's caching allocator), falls
63/// back to the device's primary context via
64/// `CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL` + `cuDevicePrimaryCtxRetain`.
65///
66/// # Safety
67///
68/// `addr` must be a valid CUDA device pointer.
69pub(crate) unsafe fn set_ctx_for_ptr(addr: usize) -> Result<CudaCtxGuard, anyhow::Error> {
70    let mut prev: rdmaxcel_sys::CUcontext = std::ptr::null_mut();
71    unsafe {
72        rdmaxcel_sys::rdmaxcel_cuCtxGetCurrent(&mut prev);
73    }
74
75    let mut ctx: rdmaxcel_sys::CUcontext = std::ptr::null_mut();
76    let rc = unsafe {
77        rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
78            &mut ctx as *mut _ as *mut std::ffi::c_void,
79            rdmaxcel_sys::CU_POINTER_ATTRIBUTE_CONTEXT,
80            addr as rdmaxcel_sys::CUdeviceptr,
81        )
82    };
83
84    // Null context: allocation came from the runtime API or a memory
85    // pool.  Fall back to the owning device's primary context.
86    let mut retained_device = None;
87    if rc != rdmaxcel_sys::CUDA_SUCCESS || ctx.is_null() {
88        let mut ordinal: i32 = -1;
89        let rc = unsafe {
90            rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
91                &mut ordinal as *mut _ as *mut std::ffi::c_void,
92                rdmaxcel_sys::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
93                addr as rdmaxcel_sys::CUdeviceptr,
94            )
95        };
96        anyhow::ensure!(
97            rc == rdmaxcel_sys::CUDA_SUCCESS,
98            "cuPointerGetAttribute(DEVICE_ORDINAL) failed with error code {rc}"
99        );
100
101        let mut device: rdmaxcel_sys::CUdevice = 0;
102        let rc = unsafe { rdmaxcel_sys::rdmaxcel_cuDeviceGet(&mut device, ordinal) };
103        anyhow::ensure!(
104            rc == rdmaxcel_sys::CUDA_SUCCESS,
105            "cuDeviceGet({ordinal}) failed with error code {rc}"
106        );
107
108        let rc = unsafe { rdmaxcel_sys::rdmaxcel_cuDevicePrimaryCtxRetain(&mut ctx, device) };
109        anyhow::ensure!(
110            rc == rdmaxcel_sys::CUDA_SUCCESS,
111            "cuDevicePrimaryCtxRetain failed with error code {rc}"
112        );
113        retained_device = Some(device);
114    }
115
116    let rc = unsafe { rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(ctx) };
117    anyhow::ensure!(
118        rc == rdmaxcel_sys::CUDA_SUCCESS,
119        "cuCtxSetCurrent failed with error code {rc}"
120    );
121
122    Ok(CudaCtxGuard {
123        prev,
124        retained_device,
125    })
126}
127
128/// Verify that an access at `offset` with `len` bytes fits within `size`.
129fn check_bounds(offset: usize, len: usize, size: usize) -> Result<(), anyhow::Error> {
130    anyhow::ensure!(
131        offset.checked_add(len).is_some_and(|end| end <= size),
132        "access at offset {offset} with length {len} exceeds region size {size}"
133    );
134    Ok(())
135}
136
137/// Copy `dst.len()` bytes from host memory at `addr + offset` into `dst`.
138///
139/// # Safety
140///
141/// The caller must ensure that `addr` points to a valid host allocation of
142/// at least `offset + dst.len()` bytes.
143unsafe fn read_cpu(addr: usize, offset: usize, dst: &mut [u8]) {
144    unsafe {
145        std::ptr::copy_nonoverlapping((addr + offset) as *const u8, dst.as_mut_ptr(), dst.len());
146    }
147}
148
149/// Copy `src.len()` bytes from `src` into host memory at `addr + offset`.
150///
151/// # Safety
152///
153/// The caller must ensure that `addr` points to a valid host allocation of
154/// at least `offset + src.len()` bytes.
155unsafe fn write_cpu(addr: usize, offset: usize, src: &[u8]) {
156    unsafe {
157        std::ptr::copy_nonoverlapping(src.as_ptr(), (addr + offset) as *mut u8, src.len());
158    }
159}
160
161/// Copy `dst.len()` bytes from device memory at `addr + offset` into `dst`.
162///
163/// # Safety
164///
165/// The caller must ensure that `addr` is a valid CUDA device pointer to an
166/// allocation of at least `offset + dst.len()` bytes.
167unsafe fn read_gpu(addr: usize, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
168    let _guard = unsafe { set_ctx_for_ptr(addr)? };
169    let rc = unsafe {
170        rdmaxcel_sys::rdmaxcel_cuMemcpyDtoH_v2(
171            dst.as_mut_ptr() as *mut std::ffi::c_void,
172            (addr + offset) as rdmaxcel_sys::CUdeviceptr,
173            dst.len(),
174        )
175    };
176    anyhow::ensure!(
177        rc == rdmaxcel_sys::CUDA_SUCCESS,
178        "cuMemcpyDtoH failed with error code {rc}"
179    );
180    Ok(())
181}
182
183/// Copy `src.len()` bytes from `src` into device memory at `addr + offset`.
184///
185/// # Safety
186///
187/// The caller must ensure that `addr` is a valid CUDA device pointer to an
188/// allocation of at least `offset + src.len()` bytes.
189unsafe fn write_gpu(addr: usize, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
190    let _guard = unsafe { set_ctx_for_ptr(addr)? };
191    let rc = unsafe {
192        rdmaxcel_sys::rdmaxcel_cuMemcpyHtoD_v2(
193            (addr + offset) as rdmaxcel_sys::CUdeviceptr,
194            src.as_ptr() as *const std::ffi::c_void,
195            src.len(),
196        )
197    };
198    anyhow::ensure!(
199        rc == rdmaxcel_sys::CUDA_SUCCESS,
200        "cuMemcpyHtoD failed with error code {rc}"
201    );
202    Ok(())
203}
204
205/// Three-mode access lock used by [`KeepaliveLocalMemory`] to coordinate
206/// concurrent reads, exclusive writes, and parallel "disjoint" writes
207/// (writers that the caller has promised target disjoint ranges).
208///
209/// - [`AccessLock::read`] returns when no exclusive writer and no
210///   disjoint writer is active. Multiple readers are permitted to hold
211///   the lock at the same time.
212/// - [`AccessLock::disjoint_write`] returns when no reader and no
213///   exclusive writer is active. Multiple disjoint writers are
214///   permitted to hold the lock at the same time.
215/// - [`AccessLock::exclusive`] returns only when no one else holds the
216///   lock.
217///
218/// Read mode and disjoint-write mode are mutually exclusive, which is
219/// what gives readers a torn-free view of memory in the presence of
220/// disjoint parallel writers.
221#[derive(Debug, Default)]
222struct AccessLock {
223    state: Mutex<AccessState>,
224    cond: Condvar,
225}
226
227#[derive(Debug, Default)]
228enum AccessState {
229    #[default]
230    Idle,
231    Reading(usize),
232    DisjointWriting(usize),
233    Exclusive,
234}
235
236impl AccessLock {
237    fn new() -> Self {
238        Self::default()
239    }
240
241    fn read(&self) -> AccessReadGuard<'_> {
242        let mut state = self.state.lock().expect("AccessLock poisoned");
243        loop {
244            match &mut *state {
245                AccessState::Idle => {
246                    *state = AccessState::Reading(1);
247                    return AccessReadGuard(self);
248                }
249                AccessState::Reading(n) => {
250                    *n += 1;
251                    return AccessReadGuard(self);
252                }
253                AccessState::DisjointWriting(_) | AccessState::Exclusive => {
254                    state = self.cond.wait(state).expect("AccessLock poisoned");
255                }
256            }
257        }
258    }
259
260    fn disjoint_write(&self) -> AccessDisjointWriteGuard<'_> {
261        let mut state = self.state.lock().expect("AccessLock poisoned");
262        loop {
263            match &mut *state {
264                AccessState::Idle => {
265                    *state = AccessState::DisjointWriting(1);
266                    return AccessDisjointWriteGuard(self);
267                }
268                AccessState::DisjointWriting(n) => {
269                    *n += 1;
270                    return AccessDisjointWriteGuard(self);
271                }
272                AccessState::Reading(_) | AccessState::Exclusive => {
273                    state = self.cond.wait(state).expect("AccessLock poisoned");
274                }
275            }
276        }
277    }
278
279    fn exclusive(&self) -> AccessExclusiveGuard<'_> {
280        let mut state = self.state.lock().expect("AccessLock poisoned");
281        loop {
282            if matches!(*state, AccessState::Idle) {
283                *state = AccessState::Exclusive;
284                return AccessExclusiveGuard(self);
285            }
286            state = self.cond.wait(state).expect("AccessLock poisoned");
287        }
288    }
289}
290
291struct AccessReadGuard<'a>(&'a AccessLock);
292impl Drop for AccessReadGuard<'_> {
293    fn drop(&mut self) {
294        let mut state = self.0.state.lock().expect("AccessLock poisoned");
295        match &mut *state {
296            AccessState::Reading(1) => {
297                *state = AccessState::Idle;
298                self.0.cond.notify_all();
299            }
300            AccessState::Reading(n) => *n -= 1,
301            other => unreachable!("AccessReadGuard dropped in non-Reading state: {other:?}"),
302        }
303    }
304}
305
306struct AccessDisjointWriteGuard<'a>(&'a AccessLock);
307impl Drop for AccessDisjointWriteGuard<'_> {
308    fn drop(&mut self) {
309        let mut state = self.0.state.lock().expect("AccessLock poisoned");
310        match &mut *state {
311            AccessState::DisjointWriting(1) => {
312                *state = AccessState::Idle;
313                self.0.cond.notify_all();
314            }
315            AccessState::DisjointWriting(n) => *n -= 1,
316            other => unreachable!(
317                "AccessDisjointWriteGuard dropped in non-DisjointWriting state: {other:?}"
318            ),
319        }
320    }
321}
322
323struct AccessExclusiveGuard<'a>(&'a AccessLock);
324impl Drop for AccessExclusiveGuard<'_> {
325    fn drop(&mut self) {
326        let mut state = self.0.state.lock().expect("AccessLock poisoned");
327        debug_assert!(matches!(*state, AccessState::Exclusive));
328        *state = AccessState::Idle;
329        self.0.cond.notify_all();
330    }
331}
332
333/// Marker trait: the implementor keeps a backing memory allocation alive.
334///
335/// As long as a value implementing this trait exists, the memory region
336/// described by the containing [`KeepaliveLocalMemory`] is guaranteed to
337/// remain valid.
338pub trait Keepalive: Send + Sync {}
339
340impl Keepalive for Box<[u8]> {}
341
342/// Local memory handle that keeps its backing allocation alive via an
343/// [`Arc<dyn Keepalive>`].
344///
345/// Detects at construction time whether the address is a CUDA device
346/// pointer and dispatches `read_at`/`write_at` accordingly.
347///
348/// All three access methods are `unsafe`: the [`Keepalive`] only
349/// guarantees the allocation stays mapped, not that this handle has
350/// unique ownership. The internal [`AccessLock`] coordinates concurrent
351/// callers that share the same clone of this handle (readers run in
352/// parallel, exclusive writers run alone, disjoint writers run in
353/// parallel with one another but exclude readers and exclusive
354/// writers), but callers must additionally rule out concurrent access
355/// through other views of the same allocation.
356///
357/// The `direct_access_host_bandwidth` and `direct_access_device_bandwidth`
358/// fields indicate the speed of reading the memory via pointer dereference
359/// on a host or device thread, respectively. A value of `None` means the
360/// memory is not directly accessible from that context.
361#[derive(Clone)]
362pub struct KeepaliveLocalMemory {
363    addr: usize,
364    size: usize,
365    /// Bandwidth (bytes/s) for direct host-thread pointer access, or `None`
366    /// if the memory is not host-accessible.
367    direct_access_host_bandwidth: Option<u64>,
368    /// Bandwidth (bytes/s) for direct device-thread pointer access, or
369    /// `None` if the memory is not device-accessible.
370    direct_access_device_bandwidth: Option<u64>,
371    _keepalive: Arc<dyn Keepalive>,
372    /// Coordinates concurrent reads, exclusive writes, and parallel
373    /// disjoint writes against this region.
374    access: Arc<AccessLock>,
375}
376
377impl Debug for KeepaliveLocalMemory {
378    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379        f.debug_struct("KeepaliveLocalMemory")
380            .field("addr", &self.addr)
381            .field("size", &self.size)
382            .field(
383                "direct_access_host_bandwidth",
384                &self.direct_access_host_bandwidth,
385            )
386            .field(
387                "direct_access_device_bandwidth",
388                &self.direct_access_device_bandwidth,
389            )
390            .finish_non_exhaustive()
391    }
392}
393
394impl KeepaliveLocalMemory {
395    /// Create a new handle. Probes the CUDA driver to determine whether
396    /// `addr` is a device pointer and sets the bandwidth fields
397    /// accordingly.
398    pub fn new(addr: usize, size: usize, keepalive: Arc<dyn Keepalive>) -> Self {
399        // TODO(slurye): Using placeholder values for now. Fill in with real values.
400        let (host_bw, device_bw) = if is_device_ptr(addr) {
401            (None, Some(1))
402        } else {
403            (Some(1), None)
404        };
405        Self {
406            addr,
407            size,
408            direct_access_host_bandwidth: host_bw,
409            direct_access_device_bandwidth: device_bw,
410            _keepalive: keepalive,
411            access: Arc::new(AccessLock::new()),
412        }
413    }
414
415    /// Starting virtual address of the memory region.
416    pub fn addr(&self) -> usize {
417        self.addr
418    }
419
420    /// Size of the memory region in bytes.
421    pub fn size(&self) -> usize {
422        self.size
423    }
424
425    /// Copy `dst.len()` bytes from this memory region starting at `offset`
426    /// into `dst`.
427    ///
428    /// Mutually exclusive with both `write_at` and `write_at_disjoint`
429    /// *across clones of this handle*: the [`AccessLock`] guarantees a
430    /// reader and any writer (exclusive or disjoint) that share the
431    /// same lock never observe each other's partial state. Multiple
432    /// concurrent `read_at` calls on shared clones are permitted and
433    /// run in parallel.
434    ///
435    /// # Safety
436    ///
437    /// The [`Keepalive`] guarantees the allocation stays mapped, but it
438    /// does *not* imply unique ownership: another component may hold its
439    /// own view of the same allocation and read or write it concurrently
440    /// outside this handle's [`AccessLock`]. The caller must ensure that
441    /// no such external access produces a torn read of
442    /// `offset..offset + dst.len()` for the duration of this call.
443    pub unsafe fn read_at(&self, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
444        let _guard = self.access.read();
445        check_bounds(offset, dst.len(), self.size)?;
446        // SAFETY: the `_keepalive` field keeps the allocation live, the
447        // read guard above excludes concurrent exclusive and disjoint
448        // writers that share this lock, `check_bounds` verified the access
449        // is in range, and the caller upholds the no-external-writer
450        // obligation documented on this method.
451        unsafe {
452            if self.direct_access_host_bandwidth.is_some() {
453                read_cpu(self.addr, offset, dst);
454                Ok(())
455            } else {
456                read_gpu(self.addr, offset, dst)
457            }
458        }
459    }
460
461    /// Copy `src.len()` bytes from `src` into this memory region starting
462    /// at `offset`.
463    ///
464    /// Mutually exclusive with every other read and write against this
465    /// region *across clones of this handle*: the [`AccessLock`] blocks
466    /// concurrent readers and writers that share the same lock. Use
467    /// [`KeepaliveLocalMemory::write_at_disjoint`] when multiple writers
468    /// can be proven to target disjoint byte ranges.
469    ///
470    /// # Safety
471    ///
472    /// See [`KeepaliveLocalMemory::read_at`]. The [`Keepalive`] guarantee
473    /// covers liveness only; the caller must ensure no concurrent
474    /// external reader or writer observes an overlapping byte range.
475    pub unsafe fn write_at(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
476        let _guard = self.access.exclusive();
477        check_bounds(offset, src.len(), self.size)?;
478        // SAFETY: the `_keepalive` field keeps the allocation live, the
479        // exclusive guard above excludes every other reader and writer
480        // that shares this lock, `check_bounds` verified the access is
481        // in range, and the caller upholds the no-external-access
482        // obligation documented on this method.
483        unsafe {
484            if self.direct_access_host_bandwidth.is_some() {
485                write_cpu(self.addr, offset, src);
486                Ok(())
487            } else {
488                write_gpu(self.addr, offset, src)
489            }
490        }
491    }
492
493    /// Like [`KeepaliveLocalMemory::write_at`], but allows other
494    /// concurrent `write_at_disjoint` calls (across clones of this
495    /// handle) to proceed in parallel. Still mutually exclusive with
496    /// `read_at` and `write_at` through the [`AccessLock`].
497    ///
498    /// # Safety
499    ///
500    /// In addition to the obligations of
501    /// [`KeepaliveLocalMemory::write_at`] (no external concurrent
502    /// reader or writer of the same byte range), the caller must
503    /// ensure that no other concurrent call to this method targets a
504    /// byte range that overlaps `offset..offset + src.len()`. Disjoint
505    /// byte ranges across concurrent disjoint callers are sound.
506    pub unsafe fn write_at_disjoint(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
507        let _guard = self.access.disjoint_write();
508        check_bounds(offset, src.len(), self.size)?;
509        // SAFETY: the `_keepalive` field keeps the allocation live, the
510        // disjoint-write guard above excludes concurrent readers and
511        // exclusive writers that share this lock, `check_bounds`
512        // verified the access is in range, and the caller upholds both
513        // safety obligations documented on this method (no external access,
514        // no overlap with other concurrent disjoint writers).
515        unsafe {
516            if self.direct_access_host_bandwidth.is_some() {
517                write_cpu(self.addr, offset, src);
518                Ok(())
519            } else {
520                write_gpu(self.addr, offset, src)
521            }
522        }
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529
530    // -- KeepaliveLocalMemory (host) --
531
532    fn host_keepalive_mem(data: Box<[u8]>) -> KeepaliveLocalMemory {
533        let addr = data.as_ptr() as usize;
534        let size = data.len();
535        KeepaliveLocalMemory::new(addr, size, Arc::new(data))
536    }
537
538    #[test]
539    fn keepalive_host_read_at() {
540        let mem = host_keepalive_mem(Box::from([1, 2, 3, 4, 5]));
541        let mut buf = [0u8; 3];
542        // SAFETY: `mem` is the sole handle to the allocation, no other
543        // thread or component holds a view of it.
544        unsafe { mem.read_at(1, &mut buf) }.unwrap();
545        assert_eq!(buf, [2, 3, 4]);
546    }
547
548    #[test]
549    fn keepalive_host_write_then_read() {
550        let mem = host_keepalive_mem(vec![0; 5].into_boxed_slice());
551        // SAFETY: `mem` is the sole handle to the allocation, no other
552        // thread or component holds a view of it.
553        unsafe { mem.write_at(1, &[7, 8, 9]) }.unwrap();
554        let mut buf = [0u8; 5];
555        // SAFETY: same as above.
556        unsafe { mem.read_at(0, &mut buf) }.unwrap();
557        assert_eq!(buf, [0, 7, 8, 9, 0]);
558    }
559
560    #[test]
561    fn keepalive_host_out_of_bounds() {
562        let mem = host_keepalive_mem(vec![0; 3].into_boxed_slice());
563        let mut buf = [0u8; 3];
564        // SAFETY: `mem` is the sole handle to the allocation; the
565        // bounds check fires before any pointer dereference.
566        assert!(unsafe { mem.read_at(1, &mut buf) }.is_err());
567        // SAFETY: same as above.
568        assert!(unsafe { mem.write_at(1, &[7, 8, 9]) }.is_err());
569    }
570}