1use std::fmt::Debug;
19use std::sync::Arc;
20use std::sync::RwLock;
21
22use serde::Deserialize;
23use serde::Serialize;
24
25pub fn is_device_ptr(addr: usize) -> bool {
30 unsafe {
33 let mut mem_type: u32 = 0;
34 let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
35 &mut mem_type as *mut _ as *mut std::ffi::c_void,
36 rdmaxcel_sys::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
37 addr as rdmaxcel_sys::CUdeviceptr,
38 );
39 err == rdmaxcel_sys::CUDA_SUCCESS && mem_type == rdmaxcel_sys::CU_MEMORYTYPE_DEVICE
40 }
41}
42
43pub trait RdmaLocalMemory: Send + Sync + Debug {
48 fn addr(&self) -> usize;
50
51 fn size(&self) -> usize;
53
54 fn read_at(&self, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error>;
56
57 fn write_at(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error>;
59}
60
61fn check_bounds(offset: usize, len: usize, size: usize) -> Result<(), anyhow::Error> {
63 anyhow::ensure!(
64 offset.checked_add(len).is_some_and(|end| end <= size),
65 "access at offset {offset} with length {len} exceeds region size {size}"
66 );
67 Ok(())
68}
69
70unsafe fn read_cpu(addr: usize, offset: usize, dst: &mut [u8]) {
77 unsafe {
78 std::ptr::copy_nonoverlapping((addr + offset) as *const u8, dst.as_mut_ptr(), dst.len());
79 }
80}
81
82unsafe fn write_cpu(addr: usize, offset: usize, src: &[u8]) {
89 unsafe {
90 std::ptr::copy_nonoverlapping(src.as_ptr(), (addr + offset) as *mut u8, src.len());
91 }
92}
93
94unsafe fn read_gpu(addr: usize, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
101 let rc = unsafe {
102 rdmaxcel_sys::rdmaxcel_cuMemcpyDtoH_v2(
103 dst.as_mut_ptr() as *mut std::ffi::c_void,
104 (addr + offset) as rdmaxcel_sys::CUdeviceptr,
105 dst.len(),
106 )
107 };
108 anyhow::ensure!(
109 rc == rdmaxcel_sys::CUDA_SUCCESS,
110 "cuMemcpyDtoH failed with error code {rc}"
111 );
112 Ok(())
113}
114
115unsafe fn write_gpu(addr: usize, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
122 let rc = unsafe {
123 rdmaxcel_sys::rdmaxcel_cuMemcpyHtoD_v2(
124 (addr + offset) as rdmaxcel_sys::CUdeviceptr,
125 src.as_ptr() as *const std::ffi::c_void,
126 src.len(),
127 )
128 };
129 anyhow::ensure!(
130 rc == rdmaxcel_sys::CUDA_SUCCESS,
131 "cuMemcpyHtoD failed with error code {rc}"
132 );
133 Ok(())
134}
135
136pub trait Keepalive: Send + Sync {}
142
143#[derive(Clone)]
154pub struct KeepaliveLocalMemory {
155 addr: usize,
156 size: usize,
157 direct_access_host_bandwidth: Option<u64>,
160 direct_access_device_bandwidth: Option<u64>,
163 _keepalive: Arc<dyn Keepalive>,
164 guard: Arc<RwLock<()>>,
165}
166
167impl Debug for KeepaliveLocalMemory {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 f.debug_struct("KeepaliveLocalMemory")
170 .field("addr", &self.addr)
171 .field("size", &self.size)
172 .field(
173 "direct_access_host_bandwidth",
174 &self.direct_access_host_bandwidth,
175 )
176 .field(
177 "direct_access_device_bandwidth",
178 &self.direct_access_device_bandwidth,
179 )
180 .finish_non_exhaustive()
181 }
182}
183
184impl KeepaliveLocalMemory {
185 pub fn new(addr: usize, size: usize, keepalive: Arc<dyn Keepalive>) -> Self {
189 let (host_bw, device_bw) = if is_device_ptr(addr) {
191 (None, Some(1))
192 } else {
193 (Some(1), None)
194 };
195 Self {
196 addr,
197 size,
198 direct_access_host_bandwidth: host_bw,
199 direct_access_device_bandwidth: device_bw,
200 _keepalive: keepalive,
201 guard: Arc::new(RwLock::new(())),
202 }
203 }
204}
205
206impl RdmaLocalMemory for KeepaliveLocalMemory {
207 fn addr(&self) -> usize {
208 self.addr
209 }
210
211 fn size(&self) -> usize {
212 self.size
213 }
214
215 fn read_at(&self, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
216 let _lock = self.guard.read().expect("lock poisoned");
217 check_bounds(offset, dst.len(), self.size)?;
218 unsafe {
221 if self.direct_access_host_bandwidth.is_some() {
222 read_cpu(self.addr, offset, dst);
223 Ok(())
224 } else {
225 read_gpu(self.addr, offset, dst)
226 }
227 }
228 }
229
230 fn write_at(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
231 let _lock = self.guard.write().expect("lock poisoned");
232 check_bounds(offset, src.len(), self.size)?;
233 unsafe {
236 if self.direct_access_host_bandwidth.is_some() {
237 write_cpu(self.addr, offset, src);
238 Ok(())
239 } else {
240 write_gpu(self.addr, offset, src)
241 }
242 }
243 }
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct UnsafeLocalMemory {
254 pub addr: usize,
255 pub size: usize,
256}
257
258impl UnsafeLocalMemory {
259 pub fn new(addr: usize, size: usize) -> Self {
260 Self { addr, size }
261 }
262}
263
264impl RdmaLocalMemory for UnsafeLocalMemory {
265 fn addr(&self) -> usize {
266 self.addr
267 }
268
269 fn size(&self) -> usize {
270 self.size
271 }
272
273 fn read_at(&self, offset: usize, dst: &mut [u8]) -> Result<(), anyhow::Error> {
274 check_bounds(offset, dst.len(), self.size)?;
275 unsafe {
278 if is_device_ptr(self.addr) {
279 read_gpu(self.addr, offset, dst)
280 } else {
281 read_cpu(self.addr, offset, dst);
282 Ok(())
283 }
284 }
285 }
286
287 fn write_at(&self, offset: usize, src: &[u8]) -> Result<(), anyhow::Error> {
288 check_bounds(offset, src.len(), self.size)?;
289 unsafe {
292 if is_device_ptr(self.addr) {
293 write_gpu(self.addr, offset, src)
294 } else {
295 write_cpu(self.addr, offset, src);
296 Ok(())
297 }
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 impl Keepalive for Vec<u8> {}
309
310 fn host_keepalive_mem(data: Vec<u8>) -> KeepaliveLocalMemory {
311 let addr = data.as_ptr() as usize;
312 let size = data.len();
313 KeepaliveLocalMemory::new(addr, size, Arc::new(data))
314 }
315
316 #[test]
317 fn keepalive_host_read_at() {
318 let mem = host_keepalive_mem(vec![1, 2, 3, 4, 5]);
319 let mut buf = [0u8; 3];
320 mem.read_at(1, &mut buf).unwrap();
321 assert_eq!(buf, [2, 3, 4]);
322 }
323
324 #[test]
325 fn keepalive_host_write_then_read() {
326 let mem = host_keepalive_mem(vec![0; 5]);
327 mem.write_at(1, &[7, 8, 9]).unwrap();
328 let mut buf = [0u8; 5];
329 mem.read_at(0, &mut buf).unwrap();
330 assert_eq!(buf, [0, 7, 8, 9, 0]);
331 }
332
333 #[test]
334 fn keepalive_host_out_of_bounds() {
335 let mem = host_keepalive_mem(vec![0; 3]);
336 let mut buf = [0u8; 3];
337 assert!(mem.read_at(1, &mut buf).is_err());
338 assert!(mem.write_at(1, &[7, 8, 9]).is_err());
339 }
340}