1use std::collections::HashMap;
20use std::collections::HashSet;
21use std::sync::Arc;
22use std::sync::OnceLock;
23use std::time::Duration;
24use std::time::Instant;
25
26use anyhow::Result;
27use async_trait::async_trait;
28use futures::lock::Mutex;
29use hyperactor::Actor;
30use hyperactor::ActorHandle;
31use hyperactor::Context;
32use hyperactor::HandleClient;
33use hyperactor::Handler;
34use hyperactor::Instance;
35use hyperactor::OncePortHandle;
36use hyperactor::RefClient;
37use hyperactor::reference;
38use serde::Deserialize;
39use serde::Serialize;
40use typeuri::Named;
41
42use super::IbvBuffer;
43use super::IbvOp;
44use super::domain::IbvDomain;
45use super::primitives::IbvConfig;
46use super::primitives::IbvDevice;
47use super::primitives::IbvMemoryRegionView;
48use super::primitives::IbvQpInfo;
49use super::primitives::ibverbs_supported;
50use super::primitives::mlx5dv_supported;
51use super::primitives::resolve_qp_type;
52use super::queue_pair::IbvQueuePair;
53use super::queue_pair::PollCompletionError;
54use super::queue_pair::PollTarget;
55use crate::RdmaOp;
56use crate::RdmaOpType;
57use crate::RdmaTransportLevel;
58use crate::backend::RdmaBackend;
59use crate::rdma_components::get_registered_cuda_segments;
60use crate::rdma_manager_actor::GetIbvActorRefClient;
61use crate::rdma_manager_actor::RdmaManagerActor;
62use crate::rdma_manager_actor::RdmaManagerMessageClient;
63use crate::rdma_manager_actor::get_rdmaxcel_error_message;
64use crate::validate_execution_context;
65
66#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
68pub enum IbvManagerMessage {
69 RequestBuffer {
77 remote_buf_id: usize,
78 #[reply]
79 reply: reference::OncePortRef<Option<IbvBuffer>>,
80 },
81 ReleaseBuffer { remote_buf_id: usize },
87 RequestQueuePair {
88 other: reference::ActorRef<IbvManagerActor>,
89 self_device: String,
90 other_device: String,
91 #[reply]
92 reply: reference::OncePortRef<Result<IbvQueuePair, String>>,
93 },
94 Connect {
95 other: reference::ActorRef<IbvManagerActor>,
96 self_device: String,
97 other_device: String,
98 endpoint: IbvQpInfo,
99 },
100 InitializeQP {
101 other: reference::ActorRef<IbvManagerActor>,
102 self_device: String,
103 other_device: String,
104 #[reply]
105 reply: reference::OncePortRef<bool>,
106 },
107 ConnectionInfo {
108 other: reference::ActorRef<IbvManagerActor>,
109 self_device: String,
110 other_device: String,
111 #[reply]
112 reply: reference::OncePortRef<IbvQpInfo>,
113 },
114 ReleaseQueuePair {
115 other: reference::ActorRef<IbvManagerActor>,
116 self_device: String,
117 other_device: String,
118 qp: IbvQueuePair,
119 },
120 GetQpState {
121 other: reference::ActorRef<IbvManagerActor>,
122 self_device: String,
123 other_device: String,
124 #[reply]
125 reply: reference::OncePortRef<u32>,
126 },
127}
128wirevalue::register_type!(IbvManagerMessage);
129
130#[derive(Handler, HandleClient, Debug)]
132pub enum IbvManagerLocalMessage {
133 RegisterMr {
135 addr: usize,
136 size: usize,
137 #[reply]
138 reply: OncePortHandle<Result<(IbvMemoryRegionView, String), String>>,
139 },
140 DeregisterMr {
142 id: usize,
143 #[reply]
144 reply: OncePortHandle<Result<(), String>>,
145 },
146}
147
148#[derive(Debug)]
153#[hyperactor::export(
154 handlers = [
155 IbvManagerMessage,
156 ],
157)]
158pub struct IbvManagerActor {
159 owner: OnceLock<ActorHandle<RdmaManagerActor>>,
160
161 device_qps: HashMap<String, HashMap<(reference::ActorId, String), IbvQueuePair>>,
163
164 pending_qp_creation: Arc<Mutex<HashSet<(String, reference::ActorId, String)>>>,
167
168 device_domains: HashMap<String, (IbvDomain, Option<IbvQueuePair>)>,
171
172 config: IbvConfig,
173
174 mlx5dv_enabled: bool,
175
176 mr_map: HashMap<usize, usize>,
179
180 mrv_id: usize,
182
183 buffer_registrations: HashMap<usize, IbvBuffer>,
185}
186
187#[async_trait]
188impl Actor for IbvManagerActor {
189 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
190 let owner = if let Some(owner) = this.parent_handle() {
191 owner
192 } else {
193 anyhow::bail!("RdmaManagerActor not found as parent of IbvManagerActor");
194 };
195 self.owner
196 .set(owner)
197 .expect("owner should only be set once during init");
198 Ok(())
199 }
200}
201
202impl Drop for IbvManagerActor {
203 fn drop(&mut self) {
204 fn destroy_queue_pair(qp: &IbvQueuePair, _context: &str) {
209 unsafe {
210 if qp.qp != 0 {
211 let rdmaxcel_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
212 rdmaxcel_sys::rdmaxcel_qp_destroy(rdmaxcel_qp);
213 }
214 }
215 }
216
217 for (_device_name, device_map) in self.device_qps.drain() {
219 for ((actor_id, _remote_device), qp) in device_map {
220 destroy_queue_pair(&qp, &format!("actor {:?}", actor_id));
221 }
222 }
223
224 for (device_name, (domain, qp)) in self.device_domains.drain() {
226 if let Some(qp) = qp {
227 destroy_queue_pair(&qp, &format!("loopback QP on device {}", device_name));
228 }
229 drop(domain);
230 }
231
232 let _mr_count = self.mr_map.len();
234 for (id, mr_ptr) in self.mr_map.drain() {
235 if mr_ptr != 0 {
236 unsafe {
237 let result = rdmaxcel_sys::ibv_dereg_mr(mr_ptr as *mut rdmaxcel_sys::ibv_mr);
238 if result != 0 {
239 tracing::error!(
240 "Failed to deregister MR with id {}: error code {}",
241 id,
242 result
243 );
244 }
245 }
246 }
247 }
248
249 if self.mlx5dv_enabled {
252 unsafe {
253 let result = rdmaxcel_sys::deregister_segments();
254 if result != 0 {
255 let error_msg = get_rdmaxcel_error_message(result);
256 tracing::error!(
257 "Failed to deregister CUDA segments: {} (error code: {})",
258 error_msg,
259 result
260 );
261 }
262 }
263 }
264 }
265}
266
267impl IbvManagerActor {
268 pub async fn local_handle(
271 client: &(impl hyperactor::context::Actor + Send + Sync),
272 ) -> Result<ActorHandle<Self>, anyhow::Error> {
273 let rdma_handle = RdmaManagerActor::local_handle(client);
274 let ibv_ref = rdma_handle
275 .get_ibv_actor_ref(client)
276 .await?
277 .ok_or_else(|| anyhow::anyhow!("local RdmaManagerActor has no ibverbs backend"))?;
278 ibv_ref
279 .downcast_handle(client)
280 .ok_or_else(|| anyhow::anyhow!("IbvManagerActor is not in the local process"))
281 }
282
283 pub async fn new(params: Option<IbvConfig>) -> Result<Self, anyhow::Error> {
285 if !ibverbs_supported() {
286 return Err(anyhow::anyhow!(
287 "Cannot create IbvManagerActor because RDMA is not supported on this machine"
288 ));
289 }
290
291 let mut config = params.unwrap_or_default();
293 tracing::debug!("rdma is enabled, config device hint: {}", config.device);
294
295 let mlx5dv_enabled = resolve_qp_type(config.qp_type) == rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV;
296
297 if config.use_gpu_direct {
299 match validate_execution_context().await {
300 Ok(_) => {
301 tracing::info!("GPU Direct RDMA execution context validated successfully");
302 }
303 Err(e) => {
304 tracing::warn!(
305 "GPU Direct RDMA execution context validation failed: {}. Downgrading to standard ibverbs mode.",
306 e
307 );
308 config.use_gpu_direct = false;
309 }
310 }
311 }
312
313 let actor = Self {
314 owner: OnceLock::new(),
315 device_qps: HashMap::new(),
316 pending_qp_creation: Arc::new(Mutex::new(HashSet::new())),
317 device_domains: HashMap::new(),
318 config,
319 mlx5dv_enabled,
320 mr_map: HashMap::new(),
321 mrv_id: 0,
322 buffer_registrations: HashMap::new(),
323 };
324
325 Ok(actor)
326 }
327
328 fn get_or_create_device_domain(
330 &mut self,
331 device_name: &str,
332 rdma_device: &IbvDevice,
333 ) -> Result<(IbvDomain, Option<IbvQueuePair>), anyhow::Error> {
334 if let Some((domain, qp)) = self.device_domains.get(device_name) {
335 return Ok((domain.clone(), qp.clone()));
336 }
337
338 let domain = IbvDomain::new(rdma_device.clone()).map_err(|e| {
340 anyhow::anyhow!("could not create domain for device {}: {}", device_name, e)
341 })?;
342
343 crate::print_device_info_if_debug_enabled(domain.context);
345
346 let qp = if mlx5dv_supported() && !crate::efa::is_efa_device() {
349 let mut qp = IbvQueuePair::new(domain.context, domain.pd, self.config.clone())
350 .map_err(|e| {
351 anyhow::anyhow!(
352 "could not create loopback QP for device {}: {}",
353 device_name,
354 e
355 )
356 })?;
357
358 let endpoint = qp.get_qp_info().map_err(|e| {
360 anyhow::anyhow!("could not get QP info for device {}: {}", device_name, e)
361 })?;
362
363 qp.connect(&endpoint).map_err(|e| {
364 anyhow::anyhow!(
365 "could not connect loopback QP for device {}: {}",
366 device_name,
367 e
368 )
369 })?;
370
371 Some(qp)
372 } else {
373 None
374 };
375
376 self.device_domains
377 .insert(device_name.to_string(), (domain.clone(), qp.clone()));
378 Ok((domain, qp))
379 }
380
381 fn build_per_device_pd_qp_arrays(
384 &self,
385 ) -> (
386 Vec<*mut rdmaxcel_sys::ibv_pd>,
387 Vec<*mut rdmaxcel_sys::rdmaxcel_qp_t>,
388 ) {
389 let cuda_map = super::device_selection::get_cuda_device_to_ibv_device();
390 let mut pds = Vec::with_capacity(cuda_map.len());
391 let mut qps = Vec::with_capacity(cuda_map.len());
392 for maybe_device in cuda_map {
393 if let Some(device) = maybe_device {
394 if let Some((domain, qp)) = self.device_domains.get(device.name()) {
395 pds.push(domain.pd);
396 qps.push(
397 qp.as_ref()
398 .map(|q| q.qp as *mut rdmaxcel_sys::rdmaxcel_qp_t)
399 .unwrap_or(std::ptr::null_mut()),
400 );
401 } else {
402 pds.push(std::ptr::null_mut());
403 qps.push(std::ptr::null_mut());
404 }
405 } else {
406 pds.push(std::ptr::null_mut());
407 qps.push(std::ptr::null_mut());
408 }
409 }
410 (pds, qps)
411 }
412
413 fn find_cuda_segment_for_address(
414 &mut self,
415 addr: usize,
416 size: usize,
417 pd: *mut rdmaxcel_sys::ibv_pd,
418 ) -> Option<IbvMemoryRegionView> {
419 let registered_segments = get_registered_cuda_segments(pd);
420 for segment in registered_segments {
421 let start_addr = segment.phys_address;
422 let end_addr = start_addr + segment.phys_size;
423 if start_addr <= addr && addr + size <= end_addr {
424 let offset = addr - start_addr;
425 let rdma_addr = segment.mr_addr + offset;
426
427 let mrv = IbvMemoryRegionView {
428 id: self.mrv_id,
429 virtual_addr: addr,
430 rdma_addr,
431 size,
432 lkey: segment.lkey,
433 rkey: segment.rkey,
434 };
435 self.mrv_id += 1;
436 return Some(mrv);
437 }
438 }
439 None
440 }
441
442 fn register_mr_impl(
443 &mut self,
444 addr: usize,
445 size: usize,
446 ) -> Result<(IbvMemoryRegionView, String), anyhow::Error> {
447 unsafe {
448 let mut mem_type: i32 = 0;
449 let ptr = addr as rdmaxcel_sys::CUdeviceptr;
450 let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
451 &mut mem_type as *mut _ as *mut std::ffi::c_void,
452 rdmaxcel_sys::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
453 ptr,
454 );
455 let is_cuda = err == rdmaxcel_sys::CUDA_SUCCESS;
456
457 let mut selected_rdma_device = None;
458
459 if is_cuda {
460 let mut device_ordinal: i32 = -1;
462 let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
463 &mut device_ordinal as *mut _ as *mut std::ffi::c_void,
464 rdmaxcel_sys::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
465 ptr,
466 );
467 if err == rdmaxcel_sys::CUDA_SUCCESS && device_ordinal >= 0 {
468 selected_rdma_device = super::device_selection::get_cuda_device_to_ibv_device()
469 .get(device_ordinal as usize)
470 .and_then(|d| d.clone());
471 }
472 }
473
474 let rdma_device = if let Some(device) = selected_rdma_device {
476 device
477 } else {
478 self.config.device.clone()
480 };
481
482 let device_name = rdma_device.name().clone();
483 tracing::debug!(
484 "Using RDMA device: {} for memory at 0x{:x}",
485 device_name,
486 addr
487 );
488
489 let (domain, _qp) = self.get_or_create_device_domain(&device_name, &rdma_device)?;
491
492 let access = if crate::efa::is_efa_device() {
493 crate::efa::mr_access_flags()
494 } else {
495 rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
496 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
497 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
498 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC
499 };
500
501 let mut mr: *mut rdmaxcel_sys::ibv_mr = std::ptr::null_mut();
502 let mrv;
503
504 if is_cuda {
505 let mut segment_mrv = None;
507 if self.mlx5dv_enabled {
508 segment_mrv = self.find_cuda_segment_for_address(addr, size, domain.pd);
510
511 if segment_mrv.is_none() {
513 let (mut pds, mut qps) = self.build_per_device_pd_qp_arrays();
514 let err = rdmaxcel_sys::register_segments(
515 pds.as_mut_ptr(),
516 qps.as_mut_ptr(),
517 pds.len() as i32,
518 );
519 if err == 0 {
522 segment_mrv = self.find_cuda_segment_for_address(addr, size, domain.pd);
523 }
524 }
525 }
526
527 if let Some(mrv_from_segment) = segment_mrv {
529 mrv = mrv_from_segment;
530 } else {
531 let mut fd: i32 = -1;
533 let cu_err = rdmaxcel_sys::rdmaxcel_cuMemGetHandleForAddressRange(
534 &mut fd,
535 addr as rdmaxcel_sys::CUdeviceptr,
536 size,
537 rdmaxcel_sys::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD,
538 0,
539 );
540 if cu_err != rdmaxcel_sys::CUDA_SUCCESS || fd < 0 {
541 return Err(anyhow::anyhow!(
542 "failed to get dmabuf handle for CUDA memory (addr: 0x{:x}, size: {}, cu_err: {}, fd: {})",
543 addr,
544 size,
545 cu_err,
546 fd
547 ));
548 }
549 mr =
550 rdmaxcel_sys::ibv_reg_dmabuf_mr(domain.pd, 0, size, 0, fd, access.0 as i32);
551 if mr.is_null() {
552 return Err(anyhow::anyhow!("Failed to register dmabuf MR"));
553 }
554 mrv = IbvMemoryRegionView {
555 id: self.mrv_id,
556 virtual_addr: addr,
557 rdma_addr: (*mr).addr as usize,
558 size,
559 lkey: (*mr).lkey,
560 rkey: (*mr).rkey,
561 };
562 self.mrv_id += 1;
563 }
564 } else {
565 mr = rdmaxcel_sys::ibv_reg_mr(
567 domain.pd,
568 addr as *mut std::ffi::c_void,
569 size,
570 access.0 as i32,
571 );
572
573 if mr.is_null() {
574 return Err(anyhow::anyhow!("failed to register standard MR"));
575 }
576
577 mrv = IbvMemoryRegionView {
578 id: self.mrv_id,
579 virtual_addr: addr,
580 rdma_addr: (*mr).addr as usize,
581 size,
582 lkey: (*mr).lkey,
583 rkey: (*mr).rkey,
584 };
585 self.mrv_id += 1;
586 }
587 self.mr_map.insert(mrv.id, mr as usize);
588 Ok((mrv, device_name))
589 }
590 }
591
592 fn deregister_mr_impl(&mut self, id: usize) -> Result<(), anyhow::Error> {
593 if let Some(mr_ptr) = self.mr_map.remove(&id) {
594 if mr_ptr != 0 {
595 unsafe {
596 rdmaxcel_sys::ibv_dereg_mr(mr_ptr as *mut rdmaxcel_sys::ibv_mr);
597 }
598 }
599 }
600 Ok(())
601 }
602
603 async fn request_queue_pair_impl(
604 &mut self,
605 cx: &Context<'_, Self>,
606 other: reference::ActorRef<IbvManagerActor>,
607 self_device: String,
608 other_device: String,
609 ) -> Result<IbvQueuePair, anyhow::Error> {
610 let self_ref: reference::ActorRef<IbvManagerActor> = cx.bind();
611 let other_id = other.actor_id().clone();
612
613 let inner_key = (other_id.clone(), other_device.clone());
615
616 if let Some(device_map) = self.device_qps.get(&self_device) {
618 if let Some(qp) = device_map.get(&inner_key) {
619 return Ok(qp.clone());
620 }
621 }
622
623 let pending_key = (self_device.clone(), other_id.clone(), other_device.clone());
625 let mut pending = self.pending_qp_creation.lock().await;
626
627 if pending.contains(&pending_key) {
628 drop(pending);
630
631 let start = Instant::now();
634 let timeout = Duration::from_secs(1);
635
636 loop {
637 tokio::time::sleep(Duration::from_micros(200)).await;
638
639 if let Some(device_map) = self.device_qps.get(&self_device) {
641 if let Some(qp) = device_map.get(&inner_key) {
642 return Ok(qp.clone());
643 }
644 }
645
646 if start.elapsed() >= timeout {
648 return Err(anyhow::anyhow!(
649 "Timeout waiting for QP creation (device {} -> actor {} device {}). \
650 Another task is creating it but hasn't completed in 1 second",
651 self_device,
652 other_id,
653 other_device
654 ));
655 }
656 }
657 } else {
658 pending.insert(pending_key.clone());
660 drop(pending);
661 }
663
664 let result = async {
666 let is_loopback = other_id == *self_ref.actor_id() && self_device == other_device;
667
668 if is_loopback {
669 self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone())
671 .await?;
672 let endpoint = self
673 .connection_info(cx, other.clone(), other_device.clone(), self_device.clone())
674 .await?;
675 self.connect(
676 cx,
677 other.clone(),
678 self_device.clone(),
679 other_device.clone(),
680 endpoint,
681 )
682 .await?;
683 } else {
684 self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone())
686 .await?;
687 other
688 .initialize_qp(
689 cx,
690 self_ref.clone(),
691 other_device.clone(),
692 self_device.clone(),
693 )
694 .await?;
695 let other_endpoint: IbvQpInfo = other
696 .connection_info(
697 cx,
698 self_ref.clone(),
699 other_device.clone(),
700 self_device.clone(),
701 )
702 .await?;
703 self.connect(
704 cx,
705 other.clone(),
706 self_device.clone(),
707 other_device.clone(),
708 other_endpoint,
709 )
710 .await?;
711 let local_endpoint = self
712 .connection_info(cx, other.clone(), self_device.clone(), other_device.clone())
713 .await?;
714 other
715 .connect(
716 cx,
717 self_ref.clone(),
718 other_device.clone(),
719 self_device.clone(),
720 local_endpoint,
721 )
722 .await?;
723
724 let remote_state = other
726 .get_qp_state(
727 cx,
728 self_ref.clone(),
729 other_device.clone(),
730 self_device.clone(),
731 )
732 .await?;
733
734 if remote_state != rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS {
735 return Err(anyhow::anyhow!(
736 "Remote QP not in RTS state after connection setup. \
737 Local is ready but remote is in state {}. \
738 This indicates a synchronization issue in connection setup.",
739 remote_state
740 ));
741 }
742 }
743
744 if let Some(device_map) = self.device_qps.get(&self_device) {
746 if let Some(qp) = device_map.get(&inner_key) {
747 Ok(qp.clone())
748 } else {
749 Err(anyhow::anyhow!(
750 "Failed to create connection for actor {} on device {}",
751 other_id,
752 other_device
753 ))
754 }
755 } else {
756 Err(anyhow::anyhow!(
757 "Failed to create connection for actor {} on device {} - no device map",
758 other_id,
759 other_device
760 ))
761 }
762 }
763 .await;
764
765 let mut pending = self.pending_qp_creation.lock().await;
767 pending.remove(&pending_key);
768 drop(pending);
769
770 result
771 }
772}
773
774#[async_trait]
775#[hyperactor::handle(IbvManagerMessage)]
776impl IbvManagerMessageHandler for IbvManagerActor {
777 async fn request_buffer(
778 &mut self,
779 cx: &Context<Self>,
780 remote_buf_id: usize,
781 ) -> Result<Option<IbvBuffer>, anyhow::Error> {
782 if let Some(buf) = self.buffer_registrations.get(&remote_buf_id) {
784 return Ok(Some(buf.clone()));
785 }
786
787 let owner = self.owner.get().unwrap();
791 let mem = match owner.request_local_memory(cx, remote_buf_id).await? {
792 Some(mem) => mem,
793 None => return Ok(None),
794 };
795
796 let (mrv, device_name) = self.register_mr_impl(mem.addr(), mem.size())?;
797
798 let buf = IbvBuffer {
799 mr_id: mrv.id,
800 lkey: mrv.lkey,
801 rkey: mrv.rkey,
802 addr: mrv.rdma_addr,
803 size: mrv.size,
804 device_name,
805 };
806
807 self.buffer_registrations.insert(remote_buf_id, buf.clone());
808
809 Ok(Some(buf))
810 }
811
812 async fn release_buffer(
813 &mut self,
814 _cx: &Context<Self>,
815 remote_buf_id: usize,
816 ) -> Result<(), anyhow::Error> {
817 if let Some(buf) = self.buffer_registrations.remove(&remote_buf_id) {
818 self.deregister_mr_impl(buf.mr_id)
819 .map_err(|e| anyhow::anyhow!("could not deregister buffer: {}", e))?;
820 }
821 Ok(())
822 }
823
824 async fn request_queue_pair(
825 &mut self,
826 cx: &Context<Self>,
827 other: reference::ActorRef<IbvManagerActor>,
828 self_device: String,
829 other_device: String,
830 ) -> Result<Result<IbvQueuePair, String>, anyhow::Error> {
831 Ok(self
832 .request_queue_pair_impl(cx, other, self_device, other_device)
833 .await
834 .map_err(|e| e.to_string()))
835 }
836
837 async fn connect(
838 &mut self,
839 _cx: &Context<Self>,
840 other: reference::ActorRef<IbvManagerActor>,
841 self_device: String,
842 other_device: String,
843 endpoint: IbvQpInfo,
844 ) -> Result<(), anyhow::Error> {
845 tracing::debug!("connecting with {:?}", other);
846 let other_id = other.actor_id().clone();
847
848 let inner_key = (other_id.clone(), other_device.clone());
849
850 if let Some(device_map) = self.device_qps.get_mut(&self_device) {
851 match device_map.get_mut(&inner_key) {
852 Some(qp) => {
853 qp.connect(&endpoint).map_err(|e| {
854 anyhow::anyhow!("could not connect to RDMA endpoint: {}", e)
855 })?;
856 Ok(())
857 }
858 None => Err(anyhow::anyhow!(
859 "No connection found for actor {}",
860 other_id
861 )),
862 }
863 } else {
864 Err(anyhow::anyhow!(
865 "No device map found for device {}",
866 self_device
867 ))
868 }
869 }
870
871 async fn initialize_qp(
872 &mut self,
873 _cx: &Context<Self>,
874 other: reference::ActorRef<IbvManagerActor>,
875 self_device: String,
876 other_device: String,
877 ) -> Result<bool, anyhow::Error> {
878 let other_id = other.actor_id().clone();
879 let inner_key = (other_id.clone(), other_device.clone());
880
881 if let Some(device_map) = self.device_qps.get(&self_device) {
883 if device_map.contains_key(&inner_key) {
884 return Ok(true);
885 }
886 }
887
888 let (domain, _) = self.device_domains.get(&self_device).ok_or_else(|| {
892 anyhow::anyhow!(
893 "device domain for '{}' not found; register_mr must be called before initialize_qp",
894 self_device
895 )
896 })?;
897 let (domain_context, domain_pd) = (domain.context, domain.pd);
898
899 let qp = IbvQueuePair::new(domain_context, domain_pd, self.config.clone())
900 .map_err(|e| anyhow::anyhow!("could not create IbvQueuePair: {}", e))?;
901
902 self.device_qps
904 .entry(self_device.clone())
905 .or_insert_with(HashMap::new)
906 .insert(inner_key, qp);
907
908 tracing::debug!(
909 "successfully created a connection with {:?} for local device {} -> remote device {}",
910 other,
911 self_device,
912 other_device
913 );
914
915 Ok(true)
916 }
917
918 async fn connection_info(
919 &mut self,
920 _cx: &Context<Self>,
921 other: reference::ActorRef<IbvManagerActor>,
922 self_device: String,
923 other_device: String,
924 ) -> Result<IbvQpInfo, anyhow::Error> {
925 tracing::debug!("getting connection info with {:?}", other);
926 let other_id = other.actor_id().clone();
927
928 let inner_key = (other_id.clone(), other_device.clone());
929
930 if let Some(device_map) = self.device_qps.get_mut(&self_device) {
931 match device_map.get_mut(&inner_key) {
932 Some(qp) => {
933 let connection_info = qp.get_qp_info()?;
934 Ok(connection_info)
935 }
936 None => Err(anyhow::anyhow!(
937 "No connection found for actor {}",
938 other_id
939 )),
940 }
941 } else {
942 Err(anyhow::anyhow!(
943 "No device map found for self device {}",
944 self_device
945 ))
946 }
947 }
948
949 async fn release_queue_pair(
950 &mut self,
951 _cx: &Context<Self>,
952 _other: reference::ActorRef<IbvManagerActor>,
953 _self_device: String,
954 _other_device: String,
955 _qp: IbvQueuePair,
956 ) -> Result<(), anyhow::Error> {
957 Ok(())
958 }
959
960 async fn get_qp_state(
961 &mut self,
962 _cx: &Context<Self>,
963 other: reference::ActorRef<IbvManagerActor>,
964 self_device: String,
965 other_device: String,
966 ) -> Result<u32, anyhow::Error> {
967 let other_id = other.actor_id().clone();
968 let inner_key = (other_id.clone(), other_device.clone());
969
970 if let Some(device_map) = self.device_qps.get_mut(&self_device) {
971 match device_map.get_mut(&inner_key) {
972 Some(qp) => qp.state(),
973 None => Err(anyhow::anyhow!(
974 "No connection found for actor {} on device {}",
975 other_id,
976 other_device
977 )),
978 }
979 } else {
980 Err(anyhow::anyhow!(
981 "No device map found for self device {}",
982 self_device
983 ))
984 }
985 }
986}
987
988#[async_trait]
989#[hyperactor::handle(IbvManagerLocalMessage)]
990impl IbvManagerLocalMessageHandler for IbvManagerActor {
991 async fn register_mr(
992 &mut self,
993 _cx: &Context<Self>,
994 addr: usize,
995 size: usize,
996 ) -> Result<Result<(IbvMemoryRegionView, String), String>, anyhow::Error> {
997 Ok(self.register_mr_impl(addr, size).map_err(|e| e.to_string()))
998 }
999
1000 async fn deregister_mr(
1001 &mut self,
1002 _cx: &Context<Self>,
1003 id: usize,
1004 ) -> Result<Result<(), String>, anyhow::Error> {
1005 Ok(self.deregister_mr_impl(id).map_err(|e| e.to_string()))
1006 }
1007}
1008
1009#[derive(Debug, Clone)]
1014pub struct IbvBackend(pub ActorHandle<IbvManagerActor>);
1015
1016impl std::ops::Deref for IbvBackend {
1017 type Target = ActorHandle<IbvManagerActor>;
1018 fn deref(&self) -> &Self::Target {
1019 &self.0
1020 }
1021}
1022
1023impl IbvBackend {
1024 async fn wait_for_completion(
1029 local_buf: &IbvBuffer,
1030 qp: &mut IbvQueuePair,
1031 poll_target: PollTarget,
1032 expected_wr_ids: &[u64],
1033 timeout: Duration,
1034 ) -> Result<(), anyhow::Error> {
1035 let start_time = std::time::Instant::now();
1036
1037 let mut remaining: std::collections::HashSet<u64> =
1038 expected_wr_ids.iter().copied().collect();
1039
1040 while start_time.elapsed() < timeout {
1041 if remaining.is_empty() {
1042 return Ok(());
1043 }
1044
1045 let wr_ids_to_poll: Vec<u64> = remaining.iter().copied().collect();
1046 match qp.poll_completion(poll_target, &wr_ids_to_poll) {
1047 Ok(completions) => {
1048 for (wr_id, _wc) in completions {
1049 remaining.remove(&wr_id);
1050 }
1051 if remaining.is_empty() {
1052 return Ok(());
1053 }
1054 tokio::time::sleep(Duration::from_millis(1)).await;
1055 }
1056 Err(e) => {
1057 let mut root_cause: Option<PollCompletionError> = None;
1063 if e.is_wr_flush_err() {
1064 for &wr_id in &wr_ids_to_poll {
1065 if let Err(inner_err) = qp.poll_completion(poll_target, &[wr_id]) {
1066 if !inner_err.is_wr_flush_err() {
1067 root_cause = Some(inner_err);
1068 break;
1069 }
1070 }
1071 }
1072 }
1073 let error_detail = if let Some(cause) = root_cause {
1074 format!(
1075 "RDMA polling completion failed: {} (root cause: {})",
1076 e, cause
1077 )
1078 } else {
1079 format!("RDMA polling completion failed: {}", e)
1080 };
1081 return Err(anyhow::anyhow!(
1082 "{} [lkey={}, rkey={}, addr=0x{:x}, size={}]",
1083 error_detail,
1084 local_buf.lkey,
1085 local_buf.rkey,
1086 local_buf.addr,
1087 local_buf.size
1088 ));
1089 }
1090 }
1091 }
1092 tracing::error!(
1093 "timed out while waiting on request completion for wr_ids={:?}",
1094 remaining
1095 );
1096 Err(anyhow::anyhow!(
1097 "[ibv_buffer({:?})] rdma operation did not complete in time (expected wr_ids={:?})",
1098 local_buf,
1099 expected_wr_ids
1100 ))
1101 }
1102
1103 async fn execute_op(
1106 &self,
1107 cx: &(impl hyperactor::context::Actor + Send + Sync),
1108 op: IbvOp,
1109 timeout: Duration,
1110 ) -> Result<(), anyhow::Error> {
1111 let (local_mrv, local_device_name) = self
1113 .register_mr(cx, op.local_memory.addr(), op.local_memory.size())
1114 .await?
1115 .map_err(|e| anyhow::anyhow!(e))?;
1116
1117 let local_buffer = IbvBuffer {
1118 mr_id: local_mrv.id,
1119 lkey: local_mrv.lkey,
1120 rkey: local_mrv.rkey,
1121 addr: local_mrv.rdma_addr,
1122 size: local_mrv.size,
1123 device_name: local_device_name,
1124 };
1125
1126 let op_result = async {
1127 let mut qp = self
1128 .request_queue_pair(
1129 cx,
1130 op.remote_manager.clone(),
1131 local_buffer.device_name.clone(),
1132 op.remote_buffer.device_name.clone(),
1133 )
1134 .await?
1135 .map_err(|e| anyhow::anyhow!(e))?;
1136
1137 let wr_id = match op.op_type {
1138 RdmaOpType::WriteFromLocal => qp.put(local_buffer.clone(), op.remote_buffer)?,
1139 RdmaOpType::ReadIntoLocal => qp.get(local_buffer.clone(), op.remote_buffer)?,
1140 };
1141
1142 Self::wait_for_completion(&local_buffer, &mut qp, PollTarget::Send, &wr_id, timeout)
1143 .await
1144 }
1145 .await;
1146
1147 let dereg_result = self
1149 .deregister_mr(cx, local_buffer.mr_id)
1150 .await?
1151 .map_err(|e| anyhow::anyhow!(e));
1152
1153 match (op_result, dereg_result) {
1154 (Ok(()), Ok(())) => Ok(()),
1155 (Err(e), Ok(())) => Err(e),
1156 (Ok(()), Err(e)) => Err(e),
1157 (Err(op_err), Err(dereg_err)) => Err(anyhow::anyhow!(
1158 "deregister MR error: {}; op error: {}",
1159 dereg_err,
1160 op_err
1161 )),
1162 }
1163 }
1164}
1165
1166#[async_trait]
1167impl RdmaBackend for IbvBackend {
1168 type TransportInfo = ();
1169
1170 async fn submit(
1175 &mut self,
1176 cx: &(impl hyperactor::context::Actor + Send + Sync),
1177 ops: Vec<RdmaOp>,
1178 timeout: Duration,
1179 ) -> Result<(), anyhow::Error> {
1180 let mut ibv_ops = Vec::with_capacity(ops.len());
1181 for op in ops {
1182 let (remote_ibv_mgr, remote_ibv_buffer) =
1183 op.remote.resolve_ibv(cx).await.ok_or_else(|| {
1184 anyhow::anyhow!("ibverbs backend not found for buffer: {:?}", op.remote)
1185 })??;
1186
1187 ibv_ops.push(IbvOp {
1188 op_type: op.op_type,
1189 local_memory: op.local.clone(),
1190 remote_buffer: remote_ibv_buffer,
1191 remote_manager: remote_ibv_mgr,
1192 });
1193 }
1194
1195 let deadline = Instant::now() + timeout;
1196 for op in ibv_ops {
1197 let remaining = deadline.saturating_duration_since(Instant::now());
1198 if remaining.is_zero() {
1199 return Err(anyhow::anyhow!("submit timed out"));
1200 }
1201 self.execute_op(cx, op, remaining).await?;
1202 }
1203 Ok(())
1204 }
1205
1206 fn transport_level(&self) -> RdmaTransportLevel {
1207 RdmaTransportLevel::Nic
1208 }
1209
1210 fn transport_info(&self) -> Option<Self::TransportInfo> {
1211 None
1212 }
1213}