1use std::fmt::Debug;
19use std::sync::Arc;
20use std::sync::RwLock;
21
22use serde::Deserialize;
23use serde::Serialize;
24
25pub fn is_device_ptr(addr: usize) -> bool {
30 unsafe {
33 let mut mem_type: u32 = 0;
34 let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
35 &mut mem_type as *mut _ as *mut std::ffi::c_void,
36 rdmaxcel_sys::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
37 addr as rdmaxcel_sys::CUdeviceptr,
38 );
39 err == rdmaxcel_sys::CUDA_SUCCESS && mem_type == rdmaxcel_sys::CU_MEMORYTYPE_DEVICE
40 }
41}
42
43pub(crate) struct CudaCtxGuard {
46 prev: rdmaxcel_sys::CUcontext,
47 retained_device: Option<rdmaxcel_sys::CUdevice>,
49}
50
51impl Drop for CudaCtxGuard {
52 fn drop(&mut self) {
53 unsafe {
54 rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(self.prev);
55 if let Some(device) = self.retained_device {
56 rdmaxcel_sys::rdmaxcel_cuDevicePrimaryCtxRelease(device);
57 }
58 }
59 }
60}
61
62pub(crate) unsafe fn set_ctx_for_ptr(addr: usize) -> Result<CudaCtxGuard, anyhow::Error> {
76 let mut prev: rdmaxcel_sys::CUcontext = std::ptr::null_mut();
77 unsafe {
78 rdmaxcel_sys::rdmaxcel_cuCtxGetCurrent(&mut prev);
79 }
80
81 let mut ctx: rdmaxcel_sys::CUcontext = std::ptr::null_mut();
82 let rc = unsafe {
83 rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
84 &mut ctx as *mut _ as *mut std::ffi::c_void,
85 rdmaxcel_sys::CU_POINTER_ATTRIBUTE_CONTEXT,
86 addr as rdmaxcel_sys::CUdeviceptr,
87 )
88 };
89
90 let mut retained_device = None;
93 if rc != rdmaxcel_sys::CUDA_SUCCESS || ctx.is_null() {
94 let mut ordinal: i32 = -1;
95 let rc = unsafe {
96 rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
97 &mut ordinal as *mut _ as *mut std::ffi::c_void,
98 rdmaxcel_sys::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
99 addr as rdmaxcel_sys::CUdeviceptr,
100 )
101 };
102 anyhow::ensure!(
103 rc == rdmaxcel_sys::CUDA_SUCCESS,
104 "cuPointerGetAttribute(DEVICE_ORDINAL) failed with error code {rc}"
105 );
106
107 let mut device: rdmaxcel_sys::CUdevice = 0;
108 let rc = unsafe { rdmaxcel_sys::rdmaxcel_cuDeviceGet(&mut device, ordinal) };
109 anyhow::ensure!(
110 rc == rdmaxcel_sys::CUDA_SUCCESS,
111 "cuDeviceGet({ordinal}) failed with error code {rc}"
112 );
113
114 let rc = unsafe { rdmaxcel_sys::rdmaxcel_cuDevicePrimaryCtxRetain(&mut ctx, device) };
115 anyhow::ensure!(
116 rc == rdmaxcel_sys::CUDA_SUCCESS,
117 "cuDevicePrimaryCtxRetain failed with error code {rc}"
118 );
119 retained_device = Some(device);
120 }
121
122 let rc = unsafe { rdmaxcel_sys::rdmaxcel_cuCtxSetCurrent(ctx) };
123 anyhow::ensure!(
124 rc == rdmaxcel_sys::CUDA_SUCCESS,
125 "cuCtxSetCurrent failed with error code {rc}"
126 );
127
128 Ok(CudaCtxGuard {
129 prev,
130 retained_device,
131 })
132}
133
134pub trait RdmaLocalMemory: Send + Sync + Debug {
139 fn addr(&self) -> usize;
141
142 fn size(&self) -> usize;
144
145 fn read_at(&self, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error>;
147
148 fn write_at(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error>;
150}
151
152fn check_bounds(offset: usize, len: usize, size: usize) -> Result<(), anyhow::Error> {
154 anyhow::ensure!(
155 offset.checked_add(len).is_some_and(|end| end <= size),
156 "access at offset {offset} with length {len} exceeds region size {size}"
157 );
158 Ok(())
159}
160
161unsafe fn read_cpu(addr: usize, offset: usize, dst: &mut [u8]) {
168 unsafe {
169 std::ptr::copy_nonoverlapping((addr + offset) as *const u8, dst.as_mut_ptr(), dst.len());
170 }
171}
172
173unsafe fn write_cpu(addr: usize, offset: usize, src: &[u8]) {
180 unsafe {
181 std::ptr::copy_nonoverlapping(src.as_ptr(), (addr + offset) as *mut u8, src.len());
182 }
183}
184
185unsafe fn read_gpu(addr: usize, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
192 let _guard = unsafe { set_ctx_for_ptr(addr)? };
193 let rc = unsafe {
194 rdmaxcel_sys::rdmaxcel_cuMemcpyDtoH_v2(
195 dst.as_mut_ptr() as *mut std::ffi::c_void,
196 (addr + offset) as rdmaxcel_sys::CUdeviceptr,
197 dst.len(),
198 )
199 };
200 anyhow::ensure!(
201 rc == rdmaxcel_sys::CUDA_SUCCESS,
202 "cuMemcpyDtoH failed with error code {rc}"
203 );
204 Ok(())
205}
206
207unsafe fn write_gpu(addr: usize, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
214 let _guard = unsafe { set_ctx_for_ptr(addr)? };
215 let rc = unsafe {
216 rdmaxcel_sys::rdmaxcel_cuMemcpyHtoD_v2(
217 (addr + offset) as rdmaxcel_sys::CUdeviceptr,
218 src.as_ptr() as *const std::ffi::c_void,
219 src.len(),
220 )
221 };
222 anyhow::ensure!(
223 rc == rdmaxcel_sys::CUDA_SUCCESS,
224 "cuMemcpyHtoD failed with error code {rc}"
225 );
226 Ok(())
227}
228
229pub trait Keepalive: Send + Sync {}
235
236#[derive(Clone)]
247pub struct KeepaliveLocalMemory {
248 addr: usize,
249 size: usize,
250 direct_access_host_bandwidth: Option<u64>,
253 direct_access_device_bandwidth: Option<u64>,
256 _keepalive: Arc<dyn Keepalive>,
257 guard: Arc<RwLock<()>>,
258}
259
260impl Debug for KeepaliveLocalMemory {
261 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 f.debug_struct("KeepaliveLocalMemory")
263 .field("addr", &self.addr)
264 .field("size", &self.size)
265 .field(
266 "direct_access_host_bandwidth",
267 &self.direct_access_host_bandwidth,
268 )
269 .field(
270 "direct_access_device_bandwidth",
271 &self.direct_access_device_bandwidth,
272 )
273 .finish_non_exhaustive()
274 }
275}
276
277impl KeepaliveLocalMemory {
278 pub fn new(addr: usize, size: usize, keepalive: Arc<dyn Keepalive>) -> Self {
282 let (host_bw, device_bw) = if is_device_ptr(addr) {
284 (None, Some(1))
285 } else {
286 (Some(1), None)
287 };
288 Self {
289 addr,
290 size,
291 direct_access_host_bandwidth: host_bw,
292 direct_access_device_bandwidth: device_bw,
293 _keepalive: keepalive,
294 guard: Arc::new(RwLock::new(())),
295 }
296 }
297}
298
299impl RdmaLocalMemory for KeepaliveLocalMemory {
300 fn addr(&self) -> usize {
301 self.addr
302 }
303
304 fn size(&self) -> usize {
305 self.size
306 }
307
308 fn read_at(&self, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
309 let _lock = self.guard.read().expect("lock poisoned");
310 check_bounds(offset, dst.len(), self.size)?;
311 unsafe {
314 if self.direct_access_host_bandwidth.is_some() {
315 read_cpu(self.addr, offset, dst);
316 Ok(())
317 } else {
318 read_gpu(self.addr, offset, dst)
319 }
320 }
321 }
322
323 fn write_at(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
324 let _lock = self.guard.write().expect("lock poisoned");
325 check_bounds(offset, src.len(), self.size)?;
326 unsafe {
329 if self.direct_access_host_bandwidth.is_some() {
330 write_cpu(self.addr, offset, src);
331 Ok(())
332 } else {
333 write_gpu(self.addr, offset, src)
334 }
335 }
336 }
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
346pub struct UnsafeLocalMemory {
347 pub addr: usize,
348 pub size: usize,
349}
350
351impl UnsafeLocalMemory {
352 pub fn new(addr: usize, size: usize) -> Self {
353 Self { addr, size }
354 }
355}
356
357impl RdmaLocalMemory for UnsafeLocalMemory {
358 fn addr(&self) -> usize {
359 self.addr
360 }
361
362 fn size(&self) -> usize {
363 self.size
364 }
365
366 fn read_at(&self, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
367 check_bounds(offset, dst.len(), self.size)?;
368 unsafe {
371 if is_device_ptr(self.addr) {
372 read_gpu(self.addr, offset, dst)
373 } else {
374 read_cpu(self.addr, offset, dst);
375 Ok(())
376 }
377 }
378 }
379
380 fn write_at(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
381 check_bounds(offset, src.len(), self.size)?;
382 unsafe {
385 if is_device_ptr(self.addr) {
386 write_gpu(self.addr, offset, src)
387 } else {
388 write_cpu(self.addr, offset, src);
389 Ok(())
390 }
391 }
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 impl Keepalive for Vec<u8> {}
402
403 fn host_keepalive_mem(data: Vec<u8>) -> KeepaliveLocalMemory {
404 let addr = data.as_ptr() as usize;
405 let size = data.len();
406 KeepaliveLocalMemory::new(addr, size, Arc::new(data))
407 }
408
409 #[test]
410 fn keepalive_host_read_at() {
411 let mem = host_keepalive_mem(vec![1, 2, 3, 4, 5]);
412 let mut buf = [0u8; 3];
413 mem.read_at(1, &mut buf).unwrap();
414 assert_eq!(buf, [2, 3, 4]);
415 }
416
417 #[test]
418 fn keepalive_host_write_then_read() {
419 let mem = host_keepalive_mem(vec![0; 5]);
420 mem.write_at(1, &[7, 8, 9]).unwrap();
421 let mut buf = [0u8; 5];
422 mem.read_at(0, &mut buf).unwrap();
423 assert_eq!(buf, [0, 7, 8, 9, 0]);
424 }
425
426 #[test]
427 fn keepalive_host_out_of_bounds() {
428 let mem = host_keepalive_mem(vec![0; 3]);
429 let mut buf = [0u8; 3];
430 assert!(mem.read_at(1, &mut buf).is_err());
431 assert!(mem.write_at(1, &[7, 8, 9]).is_err());
432 }
433}