monarch_rdma/local_memory.rs
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Local memory abstractions for RDMA operations.
10//!
11//! [`KeepaliveLocalMemory`] wraps a raw pointer with a [`Keepalive`]
12//! guard and dispatches reads/writes to CPU or CUDA paths.
13
14use std::fmt::Debug;
15use std::sync::Arc;
16use std::sync::Condvar;
17use std::sync::Mutex;
18
19/// Returns `true` when `addr` is a CUDA device pointer.
20///
21/// Probes the CUDA driver via `cuPointerGetAttribute`; returns `false`
22/// when CUDA is unavailable or the pointer is not device memory.
23pub fn is_device_ptr(addr: usize) -> bool {
24 // SAFETY: FFI call that queries pointer metadata without accessing
25 // the pointed-to memory.
26 unsafe {
27 let mut mem_type: u32 = 0;
28 let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
29 &mut mem_type as *mut _ as *mut std::ffi::c_void,
30 rdmaxcel_sys::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
31 addr as rdmaxcel_sys::CUdeviceptr,
32 );
33 err == rdmaxcel_sys::CUDA_SUCCESS && mem_type == rdmaxcel_sys::CU_MEMORYTYPE_DEVICE
34 }
35}
36
37/// RAII guard that restores the previous CUDA context on drop and, if a
38/// primary context was retained, releases it.
39pub(crate) struct CudaCtxGuard {
40 prev: rdmaxcel_sys::CUcontext,
41 /// Set when the fallback path called `cuDevicePrimaryCtxRetain`.
42 retained_device: Option<rdmaxcel_sys::CUdevice>,
43}
44
45impl Drop for CudaCtxGuard {
46 fn drop(&mut self) {
47 unsafe {
48 rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.prev);
49 if let Some(device) = self.retained_device {
50 rdmaxcel_sys::rdmaxcel_cuDevicePrimaryCtxRelease(device);
51 }
52 }
53 }
54}
55
56/// Make the CUDA context that owns `addr` current on the calling
57/// thread, returning a guard that restores the previous context on
58/// drop.
59///
60/// First tries `CU_POINTER_ATTRIBUTE_CONTEXT` to get the exact context
61/// the allocation belongs to. When that returns null (runtime-API or
62/// memory-pool allocations such as PyTorch's caching allocator), falls
63/// back to the device's primary context via
64/// `CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL` + `cuDevicePrimaryCtxRetain`.
65///
66/// # Safety
67///
68/// `addr` must be a valid CUDA device pointer.
69pub(crate) unsafe fn set_ctx_for_ptr(addr: usize) -> Result<CudaCtxGuard, anyhow::Error> {
70 let mut prev: rdmaxcel_sys::CUcontext = std::ptr::null_mut();
71 unsafe {
72 rdmaxcel_sys::rdmaxcel_cuCtxGetCurrent(&mut prev);
73 }
74
75 let mut ctx: rdmaxcel_sys::CUcontext = std::ptr::null_mut();
76 let rc = unsafe {
77 rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
78 &mut ctx as *mut _ as *mut std::ffi::c_void,
79 rdmaxcel_sys::CU_POINTER_ATTRIBUTE_CONTEXT,
80 addr as rdmaxcel_sys::CUdeviceptr,
81 )
82 };
83
84 // Null context: allocation came from the runtime API or a memory
85 // pool. Fall back to the owning device's primary context.
86 let mut retained_device = None;
87 if rc != rdmaxcel_sys::CUDA_SUCCESS || ctx.is_null() {
88 let mut ordinal: i32 = -1;
89 let rc = unsafe {
90 rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
91 &mut ordinal as *mut _ as *mut std::ffi::c_void,
92 rdmaxcel_sys::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
93 addr as rdmaxcel_sys::CUdeviceptr,
94 )
95 };
96 anyhow::ensure!(
97 rc == rdmaxcel_sys::CUDA_SUCCESS,
98 "cuPointerGetAttribute(DEVICE_ORDINAL) failed with error code {rc}"
99 );
100
101 let mut device: rdmaxcel_sys::CUdevice = 0;
102 let rc = unsafe { rdmaxcel_sys::rdmaxcel_cuDeviceGet(&mut device, ordinal) };
103 anyhow::ensure!(
104 rc == rdmaxcel_sys::CUDA_SUCCESS,
105 "cuDeviceGet({ordinal}) failed with error code {rc}"
106 );
107
108 let rc = unsafe { rdmaxcel_sys::rdmaxcel_cuDevicePrimaryCtxRetain(&mut ctx, device) };
109 anyhow::ensure!(
110 rc == rdmaxcel_sys::CUDA_SUCCESS,
111 "cuDevicePrimaryCtxRetain failed with error code {rc}"
112 );
113 retained_device = Some(device);
114 }
115
116 let rc = unsafe { rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(ctx) };
117 anyhow::ensure!(
118 rc == rdmaxcel_sys::CUDA_SUCCESS,
119 "cuCtxSetCurrent failed with error code {rc}"
120 );
121
122 Ok(CudaCtxGuard {
123 prev,
124 retained_device,
125 })
126}
127
128/// Verify that an access at `offset` with `len` bytes fits within `size`.
129fn check_bounds(offset: usize, len: usize, size: usize) -> Result<(), anyhow::Error> {
130 anyhow::ensure!(
131 offset.checked_add(len).is_some_and(|end| end <= size),
132 "access at offset {offset} with length {len} exceeds region size {size}"
133 );
134 Ok(())
135}
136
137/// Copy `dst.len()` bytes from host memory at `addr + offset` into `dst`.
138///
139/// # Safety
140///
141/// The caller must ensure that `addr` points to a valid host allocation of
142/// at least `offset + dst.len()` bytes.
143unsafe fn read_cpu(addr: usize, offset: usize, dst: &mut [u8]) {
144 unsafe {
145 std::ptr::copy_nonoverlapping((addr + offset) as *const u8, dst.as_mut_ptr(), dst.len());
146 }
147}
148
149/// Copy `src.len()` bytes from `src` into host memory at `addr + offset`.
150///
151/// # Safety
152///
153/// The caller must ensure that `addr` points to a valid host allocation of
154/// at least `offset + src.len()` bytes.
155unsafe fn write_cpu(addr: usize, offset: usize, src: &[u8]) {
156 unsafe {
157 std::ptr::copy_nonoverlapping(src.as_ptr(), (addr + offset) as *mut u8, src.len());
158 }
159}
160
161/// Copy `dst.len()` bytes from device memory at `addr + offset` into `dst`.
162///
163/// # Safety
164///
165/// The caller must ensure that `addr` is a valid CUDA device pointer to an
166/// allocation of at least `offset + dst.len()` bytes.
167unsafe fn read_gpu(addr: usize, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
168 let _guard = unsafe { set_ctx_for_ptr(addr)? };
169 let rc = unsafe {
170 rdmaxcel_sys::rdmaxcel_cuMemcpyDtoH_v2(
171 dst.as_mut_ptr() as *mut std::ffi::c_void,
172 (addr + offset) as rdmaxcel_sys::CUdeviceptr,
173 dst.len(),
174 )
175 };
176 anyhow::ensure!(
177 rc == rdmaxcel_sys::CUDA_SUCCESS,
178 "cuMemcpyDtoH failed with error code {rc}"
179 );
180 Ok(())
181}
182
183/// Copy `src.len()` bytes from `src` into device memory at `addr + offset`.
184///
185/// # Safety
186///
187/// The caller must ensure that `addr` is a valid CUDA device pointer to an
188/// allocation of at least `offset + src.len()` bytes.
189unsafe fn write_gpu(addr: usize, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
190 let _guard = unsafe { set_ctx_for_ptr(addr)? };
191 let rc = unsafe {
192 rdmaxcel_sys::rdmaxcel_cuMemcpyHtoD_v2(
193 (addr + offset) as rdmaxcel_sys::CUdeviceptr,
194 src.as_ptr() as *const std::ffi::c_void,
195 src.len(),
196 )
197 };
198 anyhow::ensure!(
199 rc == rdmaxcel_sys::CUDA_SUCCESS,
200 "cuMemcpyHtoD failed with error code {rc}"
201 );
202 Ok(())
203}
204
205/// Three-mode access lock used by [`KeepaliveLocalMemory`] to coordinate
206/// concurrent reads, exclusive writes, and parallel "disjoint" writes
207/// (writers that the caller has promised target disjoint ranges).
208///
209/// - [`AccessLock::read`] returns when no exclusive writer and no
210/// disjoint writer is active. Multiple readers are permitted to hold
211/// the lock at the same time.
212/// - [`AccessLock::disjoint_write`] returns when no reader and no
213/// exclusive writer is active. Multiple disjoint writers are
214/// permitted to hold the lock at the same time.
215/// - [`AccessLock::exclusive`] returns only when no one else holds the
216/// lock.
217///
218/// Read mode and disjoint-write mode are mutually exclusive, which is
219/// what gives readers a torn-free view of memory in the presence of
220/// disjoint parallel writers.
221#[derive(Debug, Default)]
222struct AccessLock {
223 state: Mutex<AccessState>,
224 cond: Condvar,
225}
226
227#[derive(Debug, Default)]
228enum AccessState {
229 #[default]
230 Idle,
231 Reading(usize),
232 DisjointWriting(usize),
233 Exclusive,
234}
235
236impl AccessLock {
237 fn new() -> Self {
238 Self::default()
239 }
240
241 fn read(&self) -> AccessReadGuard<'_> {
242 let mut state = self.state.lock().expect("AccessLock poisoned");
243 loop {
244 match &mut *state {
245 AccessState::Idle => {
246 *state = AccessState::Reading(1);
247 return AccessReadGuard(self);
248 }
249 AccessState::Reading(n) => {
250 *n += 1;
251 return AccessReadGuard(self);
252 }
253 AccessState::DisjointWriting(_) | AccessState::Exclusive => {
254 state = self.cond.wait(state).expect("AccessLock poisoned");
255 }
256 }
257 }
258 }
259
260 fn disjoint_write(&self) -> AccessDisjointWriteGuard<'_> {
261 let mut state = self.state.lock().expect("AccessLock poisoned");
262 loop {
263 match &mut *state {
264 AccessState::Idle => {
265 *state = AccessState::DisjointWriting(1);
266 return AccessDisjointWriteGuard(self);
267 }
268 AccessState::DisjointWriting(n) => {
269 *n += 1;
270 return AccessDisjointWriteGuard(self);
271 }
272 AccessState::Reading(_) | AccessState::Exclusive => {
273 state = self.cond.wait(state).expect("AccessLock poisoned");
274 }
275 }
276 }
277 }
278
279 fn exclusive(&self) -> AccessExclusiveGuard<'_> {
280 let mut state = self.state.lock().expect("AccessLock poisoned");
281 loop {
282 if matches!(*state, AccessState::Idle) {
283 *state = AccessState::Exclusive;
284 return AccessExclusiveGuard(self);
285 }
286 state = self.cond.wait(state).expect("AccessLock poisoned");
287 }
288 }
289}
290
291struct AccessReadGuard<'a>(&'a AccessLock);
292impl Drop for AccessReadGuard<'_> {
293 fn drop(&mut self) {
294 let mut state = self.0.state.lock().expect("AccessLock poisoned");
295 match &mut *state {
296 AccessState::Reading(1) => {
297 *state = AccessState::Idle;
298 self.0.cond.notify_all();
299 }
300 AccessState::Reading(n) => *n -= 1,
301 other => unreachable!("AccessReadGuard dropped in non-Reading state: {other:?}"),
302 }
303 }
304}
305
306struct AccessDisjointWriteGuard<'a>(&'a AccessLock);
307impl Drop for AccessDisjointWriteGuard<'_> {
308 fn drop(&mut self) {
309 let mut state = self.0.state.lock().expect("AccessLock poisoned");
310 match &mut *state {
311 AccessState::DisjointWriting(1) => {
312 *state = AccessState::Idle;
313 self.0.cond.notify_all();
314 }
315 AccessState::DisjointWriting(n) => *n -= 1,
316 other => unreachable!(
317 "AccessDisjointWriteGuard dropped in non-DisjointWriting state: {other:?}"
318 ),
319 }
320 }
321}
322
323struct AccessExclusiveGuard<'a>(&'a AccessLock);
324impl Drop for AccessExclusiveGuard<'_> {
325 fn drop(&mut self) {
326 let mut state = self.0.state.lock().expect("AccessLock poisoned");
327 debug_assert!(matches!(*state, AccessState::Exclusive));
328 *state = AccessState::Idle;
329 self.0.cond.notify_all();
330 }
331}
332
333/// Marker trait: the implementor keeps a backing memory allocation alive.
334///
335/// As long as a value implementing this trait exists, the memory region
336/// described by the containing [`KeepaliveLocalMemory`] is guaranteed to
337/// remain valid.
338pub trait Keepalive: Send + Sync {}
339
340impl Keepalive for Box<[u8]> {}
341
342/// Local memory handle that keeps its backing allocation alive via an
343/// [`Arc<dyn Keepalive>`].
344///
345/// Detects at construction time whether the address is a CUDA device
346/// pointer and dispatches `read_at`/`write_at` accordingly.
347///
348/// All three access methods are `unsafe`: the [`Keepalive`] only
349/// guarantees the allocation stays mapped, not that this handle has
350/// unique ownership. The internal [`AccessLock`] coordinates concurrent
351/// callers that share the same clone of this handle (readers run in
352/// parallel, exclusive writers run alone, disjoint writers run in
353/// parallel with one another but exclude readers and exclusive
354/// writers), but callers must additionally rule out concurrent access
355/// through other views of the same allocation.
356///
357/// The `direct_access_host_bandwidth` and `direct_access_device_bandwidth`
358/// fields indicate the speed of reading the memory via pointer dereference
359/// on a host or device thread, respectively. A value of `None` means the
360/// memory is not directly accessible from that context.
361#[derive(Clone)]
362pub struct KeepaliveLocalMemory {
363 addr: usize,
364 size: usize,
365 /// Bandwidth (bytes/s) for direct host-thread pointer access, or `None`
366 /// if the memory is not host-accessible.
367 direct_access_host_bandwidth: Option<u64>,
368 /// Bandwidth (bytes/s) for direct device-thread pointer access, or
369 /// `None` if the memory is not device-accessible.
370 direct_access_device_bandwidth: Option<u64>,
371 _keepalive: Arc<dyn Keepalive>,
372 /// Coordinates concurrent reads, exclusive writes, and parallel
373 /// disjoint writes against this region.
374 access: Arc<AccessLock>,
375}
376
377impl Debug for KeepaliveLocalMemory {
378 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379 f.debug_struct("KeepaliveLocalMemory")
380 .field("addr", &self.addr)
381 .field("size", &self.size)
382 .field(
383 "direct_access_host_bandwidth",
384 &self.direct_access_host_bandwidth,
385 )
386 .field(
387 "direct_access_device_bandwidth",
388 &self.direct_access_device_bandwidth,
389 )
390 .finish_non_exhaustive()
391 }
392}
393
394impl KeepaliveLocalMemory {
395 /// Create a new handle. Probes the CUDA driver to determine whether
396 /// `addr` is a device pointer and sets the bandwidth fields
397 /// accordingly.
398 pub fn new(addr: usize, size: usize, keepalive: Arc<dyn Keepalive>) -> Self {
399 // TODO(slurye): Using placeholder values for now. Fill in with real values.
400 let (host_bw, device_bw) = if is_device_ptr(addr) {
401 (None, Some(1))
402 } else {
403 (Some(1), None)
404 };
405 Self {
406 addr,
407 size,
408 direct_access_host_bandwidth: host_bw,
409 direct_access_device_bandwidth: device_bw,
410 _keepalive: keepalive,
411 access: Arc::new(AccessLock::new()),
412 }
413 }
414
415 /// Starting virtual address of the memory region.
416 pub fn addr(&self) -> usize {
417 self.addr
418 }
419
420 /// Size of the memory region in bytes.
421 pub fn size(&self) -> usize {
422 self.size
423 }
424
425 /// Copy `dst.len()` bytes from this memory region starting at `offset`
426 /// into `dst`.
427 ///
428 /// Mutually exclusive with both `write_at` and `write_at_disjoint`
429 /// *across clones of this handle*: the [`AccessLock`] guarantees a
430 /// reader and any writer (exclusive or disjoint) that share the
431 /// same lock never observe each other's partial state. Multiple
432 /// concurrent `read_at` calls on shared clones are permitted and
433 /// run in parallel.
434 ///
435 /// # Safety
436 ///
437 /// The [`Keepalive`] guarantees the allocation stays mapped, but it
438 /// does *not* imply unique ownership: another component may hold its
439 /// own view of the same allocation and read or write it concurrently
440 /// outside this handle's [`AccessLock`]. The caller must ensure that
441 /// no such external access produces a torn read of
442 /// `offset..offset + dst.len()` for the duration of this call.
443 pub unsafe fn read_at(&self, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
444 let _guard = self.access.read();
445 check_bounds(offset, dst.len(), self.size)?;
446 // SAFETY: the `_keepalive` field keeps the allocation live, the
447 // read guard above excludes concurrent exclusive and disjoint
448 // writers that share this lock, `check_bounds` verified the access
449 // is in range, and the caller upholds the no-external-writer
450 // obligation documented on this method.
451 unsafe {
452 if self.direct_access_host_bandwidth.is_some() {
453 read_cpu(self.addr, offset, dst);
454 Ok(())
455 } else {
456 read_gpu(self.addr, offset, dst)
457 }
458 }
459 }
460
461 /// Copy `src.len()` bytes from `src` into this memory region starting
462 /// at `offset`.
463 ///
464 /// Mutually exclusive with every other read and write against this
465 /// region *across clones of this handle*: the [`AccessLock`] blocks
466 /// concurrent readers and writers that share the same lock. Use
467 /// [`KeepaliveLocalMemory::write_at_disjoint`] when multiple writers
468 /// can be proven to target disjoint byte ranges.
469 ///
470 /// # Safety
471 ///
472 /// See [`KeepaliveLocalMemory::read_at`]. The [`Keepalive`] guarantee
473 /// covers liveness only; the caller must ensure no concurrent
474 /// external reader or writer observes an overlapping byte range.
475 pub unsafe fn write_at(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
476 let _guard = self.access.exclusive();
477 check_bounds(offset, src.len(), self.size)?;
478 // SAFETY: the `_keepalive` field keeps the allocation live, the
479 // exclusive guard above excludes every other reader and writer
480 // that shares this lock, `check_bounds` verified the access is
481 // in range, and the caller upholds the no-external-access
482 // obligation documented on this method.
483 unsafe {
484 if self.direct_access_host_bandwidth.is_some() {
485 write_cpu(self.addr, offset, src);
486 Ok(())
487 } else {
488 write_gpu(self.addr, offset, src)
489 }
490 }
491 }
492
493 /// Like [`KeepaliveLocalMemory::write_at`], but allows other
494 /// concurrent `write_at_disjoint` calls (across clones of this
495 /// handle) to proceed in parallel. Still mutually exclusive with
496 /// `read_at` and `write_at` through the [`AccessLock`].
497 ///
498 /// # Safety
499 ///
500 /// In addition to the obligations of
501 /// [`KeepaliveLocalMemory::write_at`] (no external concurrent
502 /// reader or writer of the same byte range), the caller must
503 /// ensure that no other concurrent call to this method targets a
504 /// byte range that overlaps `offset..offset + src.len()`. Disjoint
505 /// byte ranges across concurrent disjoint callers are sound.
506 pub unsafe fn write_at_disjoint(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
507 let _guard = self.access.disjoint_write();
508 check_bounds(offset, src.len(), self.size)?;
509 // SAFETY: the `_keepalive` field keeps the allocation live, the
510 // disjoint-write guard above excludes concurrent readers and
511 // exclusive writers that share this lock, `check_bounds`
512 // verified the access is in range, and the caller upholds both
513 // safety obligations documented on this method (no external access,
514 // no overlap with other concurrent disjoint writers).
515 unsafe {
516 if self.direct_access_host_bandwidth.is_some() {
517 write_cpu(self.addr, offset, src);
518 Ok(())
519 } else {
520 write_gpu(self.addr, offset, src)
521 }
522 }
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529
530 // -- KeepaliveLocalMemory (host) --
531
532 fn host_keepalive_mem(data: Box<[u8]>) -> KeepaliveLocalMemory {
533 let addr = data.as_ptr() as usize;
534 let size = data.len();
535 KeepaliveLocalMemory::new(addr, size, Arc::new(data))
536 }
537
538 #[test]
539 fn keepalive_host_read_at() {
540 let mem = host_keepalive_mem(Box::from([1, 2, 3, 4, 5]));
541 let mut buf = [0u8; 3];
542 // SAFETY: `mem` is the sole handle to the allocation, no other
543 // thread or component holds a view of it.
544 unsafe { mem.read_at(1, &mut buf) }.unwrap();
545 assert_eq!(buf, [2, 3, 4]);
546 }
547
548 #[test]
549 fn keepalive_host_write_then_read() {
550 let mem = host_keepalive_mem(vec![0; 5].into_boxed_slice());
551 // SAFETY: `mem` is the sole handle to the allocation, no other
552 // thread or component holds a view of it.
553 unsafe { mem.write_at(1, &[7, 8, 9]) }.unwrap();
554 let mut buf = [0u8; 5];
555 // SAFETY: same as above.
556 unsafe { mem.read_at(0, &mut buf) }.unwrap();
557 assert_eq!(buf, [0, 7, 8, 9, 0]);
558 }
559
560 #[test]
561 fn keepalive_host_out_of_bounds() {
562 let mem = host_keepalive_mem(vec![0; 3].into_boxed_slice());
563 let mut buf = [0u8; 3];
564 // SAFETY: `mem` is the sole handle to the allocation; the
565 // bounds check fires before any pointer dereference.
566 assert!(unsafe { mem.read_at(1, &mut buf) }.is_err());
567 // SAFETY: same as above.
568 assert!(unsafe { mem.write_at(1, &[7, 8, 9]) }.is_err());
569 }
570}