1use std::collections::HashMap;
31
32use async_trait::async_trait;
33use hyperactor::Actor;
34use hyperactor::ActorId;
35use hyperactor::ActorRef;
36use hyperactor::Context;
37use hyperactor::HandleClient;
38use hyperactor::Handler;
39use hyperactor::Instance;
40use hyperactor::Named;
41use hyperactor::OncePortRef;
42use hyperactor::RefClient;
43use hyperactor::supervision::ActorSupervisionEvent;
44use serde::Deserialize;
45use serde::Serialize;
46
47use crate::ibverbs_primitives::IbverbsConfig;
48use crate::ibverbs_primitives::RdmaMemoryRegionView;
49use crate::ibverbs_primitives::RdmaQpInfo;
50use crate::ibverbs_primitives::ibverbs_supported;
51use crate::rdma_components::RdmaBuffer;
52use crate::rdma_components::RdmaDomain;
53use crate::rdma_components::RdmaQueuePair;
54use crate::rdma_components::get_registered_cuda_segments;
55use crate::validate_execution_context;
56
57#[derive(Debug, Clone)]
59pub enum QueuePairState {
60 Available(RdmaQueuePair),
61 CheckedOut,
62}
63
64pub fn get_rdmaxcel_error_message(error_code: i32) -> String {
66 unsafe {
67 let c_str = rdmaxcel_sys::rdmaxcel_error_string(error_code);
68 std::ffi::CStr::from_ptr(c_str)
69 .to_string_lossy()
70 .into_owned()
71 }
72}
73
74#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
78pub enum RdmaManagerMessage {
79 RequestBuffer {
80 addr: usize,
81 size: usize,
82 #[reply]
83 reply: OncePortRef<RdmaBuffer>,
85 },
86 ReleaseBuffer {
87 buffer: RdmaBuffer,
88 },
89 RequestQueuePair {
90 remote: ActorRef<RdmaManagerActor>,
91 #[reply]
92 reply: OncePortRef<RdmaQueuePair>,
94 },
95 Connect {
96 other: ActorRef<RdmaManagerActor>,
98 endpoint: RdmaQpInfo,
100 },
101 InitializeQP {
102 remote: ActorRef<RdmaManagerActor>,
103 #[reply]
104 reply: OncePortRef<bool>,
106 },
107 ConnectionInfo {
108 other: ActorRef<RdmaManagerActor>,
110 #[reply]
111 reply: OncePortRef<RdmaQpInfo>,
113 },
114 ReleaseQueuePair {
115 other: ActorRef<RdmaManagerActor>,
117 qp: RdmaQueuePair,
119 },
120}
121
122#[derive(Debug)]
123#[hyperactor::export(
124 spawn = true,
125 handlers = [
126 RdmaManagerMessage,
127 ],
128)]
129pub struct RdmaManagerActor {
130 qp_map: HashMap<ActorId, QueuePairState>,
132
133 loopback_qp: Option<RdmaQueuePair>,
135
136 domain: RdmaDomain,
147 config: IbverbsConfig,
148
149 pt_cuda_alloc: bool,
152
153 mr_map: HashMap<usize, usize>,
156 mrv_id: usize,
158}
159
160impl RdmaManagerActor {
161 fn find_cuda_segment_for_address(
162 &mut self,
163 addr: usize,
164 size: usize,
165 ) -> Option<RdmaMemoryRegionView> {
166 let registered_segments = get_registered_cuda_segments();
167 for segment in registered_segments {
168 let start_addr = segment.phys_address;
169 let end_addr = start_addr + segment.phys_size;
170 if start_addr <= addr && addr + size <= end_addr {
171 let offset = addr - start_addr;
172 let rdma_addr = segment.mr_addr + offset;
173
174 let mrv = RdmaMemoryRegionView {
175 id: self.mrv_id,
176 virtual_addr: addr,
177 rdma_addr,
178 size,
179 lkey: segment.lkey,
180 rkey: segment.rkey,
181 };
182 self.mrv_id += 1;
183 return Some(mrv);
184 }
185 }
186 None
187 }
188
189 fn register_mr(
190 &mut self,
191 addr: usize,
192 size: usize,
193 ) -> Result<RdmaMemoryRegionView, anyhow::Error> {
194 unsafe {
195 let mut mem_type: i32 = 0;
196 let ptr = addr as cuda_sys::CUdeviceptr;
197 let err = cuda_sys::cuPointerGetAttribute(
198 &mut mem_type as *mut _ as *mut std::ffi::c_void,
199 cuda_sys::CUpointer_attribute_enum::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
200 ptr,
201 );
202 let is_cuda = err == cuda_sys::CUresult::CUDA_SUCCESS;
203
204 let access = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
205 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
206 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
207 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
208
209 let mut mr: *mut rdmaxcel_sys::ibv_mr = std::ptr::null_mut();
210 let mrv;
211
212 if is_cuda && self.pt_cuda_alloc {
213 let mut maybe_mrv = self.find_cuda_segment_for_address(addr, size);
215 if maybe_mrv.is_none() {
217 let qp = self.loopback_qp.as_mut().unwrap();
218 let err = rdmaxcel_sys::register_segments(
219 self.domain.pd,
220 qp.qp as *mut rdmaxcel_sys::ibv_qp,
221 );
222 if err != 0 {
223 let error_msg = get_rdmaxcel_error_message(err);
224 return Err(anyhow::anyhow!(
225 "RdmaXcel register_sements failed (addr: 0x{:x}, size: {}): {}",
226 addr,
227 size,
228 error_msg
229 ));
230 }
231
232 maybe_mrv = self.find_cuda_segment_for_address(addr, size);
233 }
234 if maybe_mrv.is_none() {
236 return Err(anyhow::anyhow!(
237 "MR registration failed for cuda (addr: 0x{:x}, size: {}), unable to find segment in CudaCachingAllocator",
238 addr,
239 size
240 ));
241 }
242 mrv = maybe_mrv.unwrap();
243 } else if is_cuda {
244 let mut fd: i32 = -1;
245 cuda_sys::cuMemGetHandleForAddressRange(
246 &mut fd as *mut i32 as *mut std::ffi::c_void,
247 addr as cuda_sys::CUdeviceptr,
248 size,
249 cuda_sys::CUmemRangeHandleType::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD,
250 0,
251 );
252 mr = rdmaxcel_sys::ibv_reg_dmabuf_mr(
253 self.domain.pd,
254 0,
255 size,
256 0,
257 fd,
258 access.0 as i32,
259 );
260 if mr.is_null() {
261 return Err(anyhow::anyhow!("Failed to register dmabuf MR"));
262 }
263 mrv = RdmaMemoryRegionView {
264 id: self.mrv_id,
265 virtual_addr: addr,
266 rdma_addr: (*mr).addr as usize,
267 size,
268 lkey: (*mr).lkey,
269 rkey: (*mr).rkey,
270 };
271 self.mrv_id += 1;
272 } else {
273 mr = rdmaxcel_sys::ibv_reg_mr(
275 self.domain.pd,
276 addr as *mut std::ffi::c_void,
277 size,
278 access.0 as i32,
279 );
280
281 if mr.is_null() {
282 return Err(anyhow::anyhow!("failed to register standard MR"));
283 }
284
285 mrv = RdmaMemoryRegionView {
286 id: self.mrv_id,
287 virtual_addr: addr,
288 rdma_addr: (*mr).addr as usize,
289 size,
290 lkey: (*mr).lkey,
291 rkey: (*mr).rkey,
292 };
293 self.mrv_id += 1;
294 }
295 self.mr_map.insert(mrv.id, mr as usize);
296 Ok(mrv)
297 }
298 }
299
300 fn deregister_mr(&mut self, id: usize) -> Result<(), anyhow::Error> {
301 if let Some(mr_ptr) = self.mr_map.remove(&id) {
302 if mr_ptr != 0 {
303 unsafe {
304 rdmaxcel_sys::ibv_dereg_mr(mr_ptr as *mut rdmaxcel_sys::ibv_mr);
305 }
306 }
307 }
308 Ok(())
309 }
310}
311
312#[async_trait]
313impl Actor for RdmaManagerActor {
314 type Params = Option<IbverbsConfig>;
315
316 async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
317 if !ibverbs_supported() {
318 return Err(anyhow::anyhow!(
319 "Cannot create RdmaManagerActor because RDMA is not supported on this machine"
320 ));
321 }
322
323 let mut config = params.unwrap_or_default();
325 tracing::debug!("rdma is enabled, using device {}", config.device);
326
327 let pt_cuda_alloc = crate::rdma_components::pt_cuda_allocator_compatibility();
328
329 if config.use_gpu_direct {
331 match validate_execution_context().await {
332 Ok(_) => {
333 tracing::info!("GPU Direct RDMA execution context validated successfully");
334 }
335 Err(e) => {
336 tracing::warn!(
337 "GPU Direct RDMA execution context validation failed: {}. Downgrading to standard ibverbs mode.",
338 e
339 );
340 config.use_gpu_direct = false;
341 }
342 }
343 }
344
345 let domain = RdmaDomain::new(config.device.clone())
346 .map_err(|e| anyhow::anyhow!("rdmaManagerActor could not create domain: {}", e))?;
347
348 Ok(Self {
349 qp_map: HashMap::new(),
350 loopback_qp: None,
351 domain,
352 config,
353 pt_cuda_alloc,
354 mr_map: HashMap::new(),
355 mrv_id: 0,
356 })
357 }
358
359 async fn init(&mut self, _this: &Instance<Self>) -> Result<(), anyhow::Error> {
360 let mut qp = RdmaQueuePair::new(self.domain.context, self.domain.pd, self.config.clone())
362 .map_err(|e| anyhow::anyhow!("could not create RdmaQueuePair: {}", e))?;
363
364 let endpoint = qp
366 .get_qp_info()
367 .map_err(|e| anyhow::anyhow!("could not get QP info: {}", e))?;
368
369 qp.connect(&endpoint)
371 .map_err(|e| anyhow::anyhow!("could not connect to RDMA endpoint: {}", e))?;
372
373 self.loopback_qp = Some(qp);
374 tracing::debug!("successfully created special loopback connection");
375
376 Ok(())
377 }
378
379 async fn handle_supervision_event(
380 &mut self,
381 _cx: &Instance<Self>,
382 _event: &ActorSupervisionEvent,
383 ) -> Result<bool, anyhow::Error> {
384 tracing::error!("rdmaManagerActor supervision event: {:?}", _event);
385 tracing::error!("rdmaManagerActor error occurred, stop the worker process, exit code: 1");
386 std::process::exit(1);
387 }
388}
389
390#[async_trait]
391#[hyperactor::forward(RdmaManagerMessage)]
392impl RdmaManagerMessageHandler for RdmaManagerActor {
393 async fn request_buffer(
409 &mut self,
410 cx: &Context<Self>,
411 addr: usize,
412 size: usize,
413 ) -> Result<RdmaBuffer, anyhow::Error> {
414 let mrv = self.register_mr(addr, size)?;
415
416 Ok(RdmaBuffer {
417 owner: cx.bind().clone(),
418 mr_id: mrv.id,
419 addr: mrv.rdma_addr,
420 size: mrv.size,
421 rkey: mrv.rkey,
422 lkey: mrv.lkey,
423 })
424 }
425
426 async fn release_buffer(
440 &mut self,
441 _cx: &Context<Self>,
442 buffer: RdmaBuffer,
443 ) -> Result<(), anyhow::Error> {
444 self.deregister_mr(buffer.mr_id)
445 .map_err(|e| anyhow::anyhow!("could not deregister buffer: {}", e))?;
446 Ok(())
447 }
448
449 async fn request_queue_pair(
463 &mut self,
464 cx: &Context<Self>,
465 remote: ActorRef<RdmaManagerActor>,
466 ) -> Result<RdmaQueuePair, anyhow::Error> {
467 let remote_id = remote.actor_id().clone();
468
469 match self.qp_map.get(&remote_id).cloned() {
473 Some(QueuePairState::Available(qp)) => {
474 self.qp_map.insert(remote_id, QueuePairState::CheckedOut);
476 Ok(qp)
477 }
478 Some(QueuePairState::CheckedOut) => {
479 Err(anyhow::anyhow!(
481 "Queue pair for actor {} is already checked out",
482 remote_id
483 ))
484 }
485 None => {
486 let is_loopback = remote_id == cx.bind::<RdmaManagerActor>().actor_id().clone();
488
489 if is_loopback {
490 self.initialize_qp(cx, remote.clone()).await?;
492 let endpoint = self.connection_info(cx, remote.clone()).await?;
493 self.connect(cx, remote.clone(), endpoint).await?;
494 } else {
495 self.initialize_qp(cx, remote.clone()).await?;
497 remote.initialize_qp(cx, cx.bind().clone()).await?;
498 let remote_endpoint = remote.connection_info(cx, cx.bind().clone()).await?;
499 self.connect(cx, remote.clone(), remote_endpoint).await?;
500 let local_endpoint = self.connection_info(cx, remote.clone()).await?;
501 remote
502 .connect(cx, cx.bind().clone(), local_endpoint)
503 .await?;
504 }
505
506 match self.qp_map.get(&remote_id).cloned() {
508 Some(QueuePairState::Available(qp)) => {
509 self.qp_map.insert(remote_id, QueuePairState::CheckedOut);
510 Ok(qp)
511 }
512 _ => Err(anyhow::anyhow!(
513 "Failed to create connection for actor {}",
514 remote_id
515 )),
516 }
517 }
518 }
519 }
520
521 async fn initialize_qp(
531 &mut self,
532 _cx: &Context<Self>,
533 other: ActorRef<RdmaManagerActor>,
534 ) -> Result<bool, anyhow::Error> {
535 let key = other.actor_id().clone();
536
537 if let std::collections::hash_map::Entry::Vacant(e) = self.qp_map.entry(key) {
538 let qp = RdmaQueuePair::new(self.domain.context, self.domain.pd, self.config.clone())
539 .map_err(|e| anyhow::anyhow!("could not create RdmaQueuePair: {}", e))?;
540 e.insert(QueuePairState::Available(qp));
541 tracing::debug!("successfully created a connection with {:?}", other);
542 }
543 Ok(true)
544 }
545
546 async fn connect(
552 &mut self,
553 _cx: &Context<Self>,
554 other: ActorRef<RdmaManagerActor>,
555 endpoint: RdmaQpInfo,
556 ) -> Result<(), anyhow::Error> {
557 tracing::debug!("connecting with {:?}", other);
558 let other_id = other.actor_id().clone();
559
560 match self.qp_map.get_mut(&other_id) {
561 Some(QueuePairState::Available(qp)) => {
562 qp.connect(&endpoint)
563 .map_err(|e| anyhow::anyhow!("could not connect to RDMA endpoint: {}", e))?;
564 Ok(())
565 }
566 Some(QueuePairState::CheckedOut) => Err(anyhow::anyhow!(
567 "Cannot connect: queue pair for actor {} is checked out",
568 other_id
569 )),
570 None => Err(anyhow::anyhow!(
571 "On connect, no connection found for actor {}",
572 other_id
573 )),
574 }
575 }
576
577 async fn connection_info(
585 &mut self,
586 _cx: &Context<Self>,
587 other: ActorRef<RdmaManagerActor>,
588 ) -> Result<RdmaQpInfo, anyhow::Error> {
589 tracing::debug!("getting connection info with {:?}", other);
590 let other_id = other.actor_id().clone();
591
592 match self.qp_map.get_mut(&other_id) {
593 Some(QueuePairState::Available(qp)) => {
594 let connection_info = qp.get_qp_info()?;
595 Ok(connection_info)
596 }
597 Some(QueuePairState::CheckedOut) => Err(anyhow::anyhow!(
598 "Cannot get connection info: queue pair for actor {} is checked out",
599 other_id
600 )),
601 None => Err(anyhow::anyhow!(
602 "No connection found for actor {}",
603 other_id
604 )),
605 }
606 }
607
608 async fn release_queue_pair(
617 &mut self,
618 _cx: &Context<Self>,
619 other: ActorRef<RdmaManagerActor>,
620 qp: RdmaQueuePair,
621 ) -> Result<(), anyhow::Error> {
622 let remote_id = other.actor_id().clone();
623
624 match self.qp_map.get(&remote_id) {
626 Some(QueuePairState::CheckedOut) => {
627 self.qp_map
629 .insert(remote_id.clone(), QueuePairState::Available(qp));
630 tracing::debug!("Released queue pair for actor {:?}", remote_id);
631 Ok(())
632 }
633 Some(QueuePairState::Available(_)) => Err(anyhow::anyhow!(
634 "Cannot release queue pair for actor {}: queue pair is not checked out",
635 remote_id
636 )),
637 None => Err(anyhow::anyhow!(
638 "Cannot release queue pair for actor {}: no queue pair found",
639 remote_id
640 )),
641 }
642 }
643}