1use std::collections::HashMap;
37use std::sync::Arc;
38use std::sync::OnceLock;
39use std::time::Duration;
40use std::time::Instant;
41
42use anyhow::Result;
43use async_trait::async_trait;
44use backoff::ExponentialBackoff;
45use backoff::ExponentialBackoffBuilder;
46use backoff::backoff::Backoff;
47use hyperactor::Actor;
48use hyperactor::ActorHandle;
49use hyperactor::ActorRef;
50use hyperactor::Context;
51use hyperactor::Endpoint as _;
52use hyperactor::HandleClient;
53use hyperactor::Handler;
54use hyperactor::Instance;
55use hyperactor::OncePortHandle;
56use hyperactor::PortRef;
57use hyperactor::RefClient;
58use hyperactor::actor::Referable;
59use serde::Deserialize;
60use serde::Serialize;
61use typeuri::Named;
62
63use super::IbvBuffer;
64use super::IbvOp;
65use super::domain::IbvDomain;
66use super::primitives::IbvConfig;
67use super::primitives::IbvDevice;
68use super::primitives::IbvMemoryRegion;
69use super::primitives::IbvMemoryRegionView;
70use super::primitives::IbvQpInfo;
71use super::primitives::ibverbs_supported;
72use super::primitives::mlx5dv_supported;
73use super::primitives::resolve_qp_type;
74use super::queue_pair::IbvQueuePair;
75use super::queue_pair::PeerInfo;
76use super::queue_pair::PollCompletionError;
77use super::queue_pair::PollTarget;
78use super::queue_pair::QpGuard;
79use super::queue_pair::QpKey;
80use super::queue_pair::QueuePairInitializer;
81use super::queue_pair::destroy_qp;
82use crate::RdmaOp;
83use crate::RdmaOpType;
84use crate::RdmaTransportLevel;
85use crate::backend::RdmaBackend;
86use crate::local_memory::KeepaliveLocalMemory;
87use crate::rdma_components::get_registered_cuda_segments;
88use crate::rdma_manager_actor::GetIbvActorRefClient;
89use crate::rdma_manager_actor::RdmaManagerActor;
90use crate::validate_execution_context;
91
92#[derive(Debug, Serialize, Deserialize, Named)]
96#[serde(bound(serialize = "", deserialize = ""))]
97pub(super) struct EnsureQueuePair<A: Referable> {
98 pub(super) sender: ActorRef<A>,
99 pub(super) sender_device: String,
100 pub(super) receiver_device: String,
101 pub(super) reply: PortRef<PeerInfo>,
102}
103wirevalue::register_type!(EnsureQueuePair<IbvManagerActor>);
104
105#[derive(Debug)]
117enum QpState {
118 Pending {
119 info: IbvQpInfo,
122 initializer: ActorHandle<QueuePairInitializer<IbvManagerActor>>,
125 waiters: Vec<OncePortHandle<Result<IbvQueuePair, String>>>,
128 },
129 Ready(IbvQueuePair),
130 Failed(String),
131}
132
133#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
139pub enum IbvManagerMessage {
140 ReleaseBuffer { remote_buf_id: usize },
143}
144wirevalue::register_type!(IbvManagerMessage);
145
146#[derive(Handler, HandleClient, Debug)]
148pub enum IbvManagerLocalMessage {
149 RegisterMr {
151 addr: usize,
152 size: usize,
153 #[reply]
154 reply: OncePortHandle<Result<(IbvMemoryRegionView, String), String>>,
155 },
156 RegisterRemoteBuffer {
164 remote_buf_id: usize,
165 local: Arc<KeepaliveLocalMemory>,
166 #[reply]
167 reply: OncePortHandle<Result<IbvBuffer, String>>,
168 },
169 RequestQueuePair {
179 other: ActorRef<IbvManagerActor>,
180 self_device: String,
181 other_device: String,
182 reply: OncePortHandle<Result<IbvQueuePair, String>>,
183 },
184}
185
186#[derive(Debug)]
190pub(super) struct QpInitializerDone {
191 pub(super) qp_key: QpKey,
192 pub(super) qp: QpGuard,
193}
194
195#[derive(Debug)]
199pub(super) struct QpInitializerFailed {
200 pub(super) qp_key: QpKey,
201 pub(super) error: String,
202}
203
204struct PollSleepPolicy {
222 yield_window: Option<Duration>,
223 started_at: Option<Instant>,
224 backoff: Option<ExponentialBackoff>,
225}
226
227impl PollSleepPolicy {
228 fn new() -> Self {
229 let yield_window = hyperactor_config::global::get(crate::config::RDMA_CQ_BUSY_POLL_WINDOW);
230 Self {
231 yield_window,
232 started_at: None,
233 backoff: None,
234 }
235 }
236
237 async fn yield_now(&mut self) {
242 let Some(window) = self.yield_window else {
243 tokio::task::yield_now().await;
244 return;
245 };
246 let started = *self.started_at.get_or_insert_with(Instant::now);
247 if started.elapsed() < window {
248 tokio::task::yield_now().await;
249 return;
250 }
251 let backoff = self.backoff.get_or_insert_with(|| {
252 ExponentialBackoffBuilder::new()
253 .with_initial_interval(Duration::from_millis(1))
254 .with_max_interval(Duration::from_millis(10))
255 .with_multiplier(2.0)
256 .with_randomization_factor(0.0)
257 .with_max_elapsed_time(None)
258 .build()
259 });
260 match backoff.next_backoff() {
261 Some(delay) => tokio::time::sleep(delay).await,
262 None => tokio::task::yield_now().await,
263 }
264 }
265}
266
267pub(super) fn lookup_segment_for_address(
280 segments: &[rdmaxcel_sys::rdma_segment_info_t],
281 addr: usize,
282 size: usize,
283) -> Option<SegmentInfo> {
284 for segment in segments {
285 let start_addr = segment.phys_address;
286 let end_addr = start_addr + segment.mr_size;
287 if start_addr <= addr && addr + size <= end_addr {
288 let offset = addr - start_addr;
289 let rdma_addr = segment.mr_addr + offset;
290 return Some(SegmentInfo {
291 rdma_addr,
292 size,
293 lkey: segment.lkey,
294 rkey: segment.rkey,
295 });
296 }
297 }
298 None
299}
300
301#[derive(Debug)]
307pub(super) struct SegmentInfo {
308 pub(super) rdma_addr: usize,
309 pub(super) size: usize,
310 pub(super) lkey: u32,
311 pub(super) rkey: u32,
312}
313
314#[derive(Debug)]
319#[hyperactor::export(
320 handlers = [
321 IbvManagerMessage,
322 EnsureQueuePair<IbvManagerActor>,
323 ],
324)]
325pub struct IbvManagerActor {
326 owner: OnceLock<ActorHandle<RdmaManagerActor>>,
327
328 qps: HashMap<QpKey, QpState>,
330
331 device_domains: HashMap<String, (Arc<IbvDomain>, Option<IbvQueuePair>)>,
337
338 config: IbvConfig,
339
340 mlx5dv_enabled: bool,
341
342 segments_mr: Option<Arc<IbvMemoryRegion>>,
348
349 mrv_id: usize,
351
352 buffer_registrations: HashMap<usize, (IbvBuffer, IbvMemoryRegionView)>,
358}
359
360#[async_trait]
361impl Actor for IbvManagerActor {
362 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
363 let owner = if let Some(owner) = this.parent_handle() {
364 owner
365 } else {
366 anyhow::bail!("RdmaManagerActor not found as parent of IbvManagerActor");
367 };
368 self.owner
369 .set(owner)
370 .expect("owner should only be set once during init");
371 Ok(())
372 }
373}
374
375impl Drop for IbvManagerActor {
376 fn drop(&mut self) {
377 for (_key, state) in self.qps.drain() {
383 match state {
384 QpState::Ready(_) => {
385 }
389 QpState::Pending { initializer, .. } => {
390 let _ = initializer.drain_and_stop("IbvManagerActor dropped");
391 }
392 QpState::Failed(_) => {}
393 }
394 }
395
396 self.buffer_registrations.clear();
402
403 self.segments_mr.take();
407
408 for (_device_name, (domain, qp)) in self.device_domains.drain() {
413 if let Some(qp) = qp {
414 unsafe { destroy_qp(&qp) };
418 }
419 drop(domain);
420 }
421 }
422}
423
424impl IbvManagerActor {
425 pub async fn local_handle(
428 client: &(impl hyperactor::context::Actor + Send + Sync),
429 ) -> Result<ActorHandle<Self>, anyhow::Error> {
430 let rdma_handle = RdmaManagerActor::local_handle(client);
431 let ibv_ref: ActorRef<IbvManagerActor> = rdma_handle
432 .get_ibv_actor_ref(client)
433 .await?
434 .ok_or_else(|| anyhow::anyhow!("local RdmaManagerActor has no ibverbs backend"))?;
435 ibv_ref
436 .downcast_handle(client)
437 .ok_or_else(|| anyhow::anyhow!("IbvManagerActor is not in the local process"))
438 }
439
440 pub async fn new(params: Option<IbvConfig>) -> Result<Self, anyhow::Error> {
442 if !ibverbs_supported() {
443 return Err(anyhow::anyhow!(
444 "Cannot create IbvManagerActor because RDMA is not supported on this machine"
445 ));
446 }
447
448 let mut config = params.unwrap_or_default();
450 tracing::debug!("rdma is enabled, config device hint: {}", config.device);
451
452 let mlx5dv_enabled = resolve_qp_type(config.qp_type) == rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV;
453
454 if config.use_gpu_direct {
456 match validate_execution_context().await {
457 Ok(_) => {
458 tracing::info!("GPU Direct RDMA execution context validated successfully");
459 }
460 Err(e) => {
461 tracing::warn!(
462 "GPU Direct RDMA execution context validation failed: {}. Downgrading to standard ibverbs mode.",
463 e
464 );
465 config.use_gpu_direct = false;
466 }
467 }
468 }
469
470 let actor = Self {
471 owner: OnceLock::new(),
472 qps: HashMap::new(),
473 device_domains: HashMap::new(),
474 config,
475 mlx5dv_enabled,
476 segments_mr: None,
477 mrv_id: 0,
478 buffer_registrations: HashMap::new(),
479 };
480
481 Ok(actor)
482 }
483
484 fn get_or_create_device_domain(
486 &mut self,
487 device_name: &str,
488 rdma_device: &IbvDevice,
489 ) -> Result<(Arc<IbvDomain>, Option<IbvQueuePair>), anyhow::Error> {
490 if let Some((domain, qp)) = self.device_domains.get(device_name) {
491 return Ok((Arc::clone(domain), qp.clone()));
492 }
493
494 let domain = Arc::new(IbvDomain::new(rdma_device.clone()).map_err(|e| {
496 anyhow::anyhow!("could not create domain for device {}: {}", device_name, e)
497 })?);
498
499 crate::print_device_info_if_debug_enabled(domain.context);
501
502 let qp = if mlx5dv_supported() && !crate::efa::is_efa_device() {
505 let mut qp = QpGuard::new(
506 IbvQueuePair::new(domain.context, domain.pd, self.config.clone()).map_err(|e| {
507 anyhow::anyhow!(
508 "could not create loopback QP for device {}: {}",
509 device_name,
510 e
511 )
512 })?,
513 );
514
515 let endpoint = qp.get_qp_info().map_err(|e| {
517 anyhow::anyhow!("could not get QP info for device {}: {}", device_name, e)
518 })?;
519
520 qp.connect(&endpoint).map_err(|e| {
521 anyhow::anyhow!(
522 "could not connect loopback QP for device {}: {}",
523 device_name,
524 e
525 )
526 })?;
527
528 Some(qp)
529 } else {
530 None
531 };
532
533 let qp = qp.map(|qp| qp.into_inner());
534 self.device_domains
535 .insert(device_name.to_string(), (Arc::clone(&domain), qp.clone()));
536 Ok((domain, qp))
537 }
538
539 fn build_per_device_pd_qp_arrays(
542 &self,
543 ) -> (
544 Vec<*mut rdmaxcel_sys::ibv_pd>,
545 Vec<*mut rdmaxcel_sys::rdmaxcel_qp_t>,
546 ) {
547 let cuda_map = super::device_selection::get_cuda_device_to_ibv_device();
548 let mut pds = Vec::with_capacity(cuda_map.len());
549 let mut qps = Vec::with_capacity(cuda_map.len());
550 for maybe_device in cuda_map {
551 if let Some(device) = maybe_device {
552 if let Some((domain, qp)) = self.device_domains.get(device.name()) {
553 pds.push(domain.pd);
554 qps.push(
555 qp.as_ref()
556 .map(|q| q.qp as *mut rdmaxcel_sys::rdmaxcel_qp_t)
557 .unwrap_or(std::ptr::null_mut()),
558 );
559 } else {
560 pds.push(std::ptr::null_mut());
561 qps.push(std::ptr::null_mut());
562 }
563 } else {
564 pds.push(std::ptr::null_mut());
565 qps.push(std::ptr::null_mut());
566 }
567 }
568 (pds, qps)
569 }
570
571 fn find_cuda_segment_for_address(
572 &self,
573 addr: usize,
574 size: usize,
575 pd: *mut rdmaxcel_sys::ibv_pd,
576 ) -> Option<SegmentInfo> {
577 lookup_segment_for_address(&get_registered_cuda_segments(pd), addr, size)
578 }
579
580 fn register_mr_impl(
581 &mut self,
582 addr: usize,
583 size: usize,
584 ) -> Result<(IbvMemoryRegionView, String), anyhow::Error> {
585 unsafe {
586 let mut mem_type: i32 = 0;
587 let ptr = addr as rdmaxcel_sys::CUdeviceptr;
588 let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
589 &mut mem_type as *mut _ as *mut std::ffi::c_void,
590 rdmaxcel_sys::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
591 ptr,
592 );
593 let is_cuda = err == rdmaxcel_sys::CUDA_SUCCESS;
594
595 let mut selected_rdma_device = None;
596
597 if is_cuda {
598 let mut device_ordinal: i32 = -1;
600 let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
601 &mut device_ordinal as *mut _ as *mut std::ffi::c_void,
602 rdmaxcel_sys::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
603 ptr,
604 );
605 if err == rdmaxcel_sys::CUDA_SUCCESS && device_ordinal >= 0 {
606 selected_rdma_device = super::device_selection::get_cuda_device_to_ibv_device()
607 .get(device_ordinal as usize)
608 .and_then(|d| d.clone());
609 }
610 }
611
612 let rdma_device = if let Some(device) = selected_rdma_device {
614 device
615 } else {
616 self.config.device.clone()
618 };
619
620 let device_name = rdma_device.name().clone();
621 tracing::debug!(
622 "Using RDMA device: {} for memory at 0x{:x}",
623 device_name,
624 addr
625 );
626
627 let (domain, _qp) = self.get_or_create_device_domain(&device_name, &rdma_device)?;
629
630 let access = if crate::efa::is_efa_device() {
631 crate::efa::mr_access_flags()
632 } else {
633 rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
634 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
635 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
636 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC
637 };
638
639 let mrv;
640
641 if is_cuda {
642 let mut segment_info = None;
644 if self.mlx5dv_enabled {
645 segment_info = self.find_cuda_segment_for_address(addr, size, domain.pd);
647
648 if segment_info.is_none() {
650 let (mut pds, mut qps) = self.build_per_device_pd_qp_arrays();
651 let err = rdmaxcel_sys::register_segments(
652 pds.as_mut_ptr(),
653 qps.as_mut_ptr(),
654 pds.len() as i32,
655 self.config.max_sge_override,
656 );
657 if err == 0 {
660 self.segments_mr
669 .get_or_insert_with(|| Arc::new(IbvMemoryRegion::Segments));
670 segment_info =
671 self.find_cuda_segment_for_address(addr, size, domain.pd);
672 }
673 }
674 }
675
676 if let Some(info) = segment_info {
678 let segments_mr = Arc::clone(
679 self.segments_mr
680 .get_or_insert_with(|| Arc::new(IbvMemoryRegion::Segments)),
681 );
682 let id = self.mrv_id;
683 self.mrv_id += 1;
684 mrv = IbvMemoryRegionView::new(
685 id,
686 addr,
687 info.rdma_addr,
688 info.size,
689 info.lkey,
690 info.rkey,
691 device_name.clone(),
692 segments_mr,
693 );
694 } else {
695 let mut fd: i32 = -1;
697 let cu_err = rdmaxcel_sys::rdmaxcel_cuMemGetHandleForAddressRange(
698 &mut fd,
699 addr as rdmaxcel_sys::CUdeviceptr,
700 size,
701 rdmaxcel_sys::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD,
702 0,
703 );
704 if cu_err != rdmaxcel_sys::CUDA_SUCCESS || fd < 0 {
705 return Err(anyhow::anyhow!(
706 "failed to get dmabuf handle for CUDA memory (addr: 0x{:x}, size: {}, cu_err: {}, fd: {})",
707 addr,
708 size,
709 cu_err,
710 fd
711 ));
712 }
713 let mr =
714 rdmaxcel_sys::ibv_reg_dmabuf_mr(domain.pd, 0, size, 0, fd, access.0 as i32);
715 if mr.is_null() {
716 return Err(anyhow::anyhow!("Failed to register dmabuf MR"));
717 }
718 let id = self.mrv_id;
719 self.mrv_id += 1;
720 mrv = IbvMemoryRegionView::new(
721 id,
722 addr,
723 (*mr).addr as usize,
724 size,
725 (*mr).lkey,
726 (*mr).rkey,
727 device_name.clone(),
728 Arc::new(IbvMemoryRegion::Direct {
729 mr,
730 _domain: Arc::clone(&domain),
731 }),
732 );
733 }
734 } else {
735 let mr = rdmaxcel_sys::ibv_reg_mr(
737 domain.pd,
738 addr as *mut std::ffi::c_void,
739 size,
740 access.0 as i32,
741 );
742
743 if mr.is_null() {
744 return Err(anyhow::anyhow!("failed to register standard MR"));
745 }
746
747 let id = self.mrv_id;
748 self.mrv_id += 1;
749 mrv = IbvMemoryRegionView::new(
750 id,
751 addr,
752 (*mr).addr as usize,
753 size,
754 (*mr).lkey,
755 (*mr).rkey,
756 device_name.clone(),
757 Arc::new(IbvMemoryRegion::Direct {
758 mr,
759 _domain: Arc::clone(&domain),
760 }),
761 );
762 }
763 Ok((mrv, device_name))
764 }
765 }
766
767 fn ensure_queue_pair_impl(
773 &mut self,
774 cx: &Context<'_, Self>,
775 other: ActorRef<IbvManagerActor>,
776 qp_key: &QpKey,
777 ) -> Result<&mut QpState, anyhow::Error> {
778 if !self.qps.contains_key(qp_key) {
779 let self_device = &qp_key.self_device;
780 let rdma_device = super::primitives::get_all_devices()
781 .into_iter()
782 .find(|d| d.name() == self_device)
783 .ok_or_else(|| anyhow::anyhow!("RDMA device '{}' not found", self_device))?;
784 let (domain, _) = self.get_or_create_device_domain(self_device, &rdma_device)?;
785 let mut qp = QpGuard::new(
790 IbvQueuePair::new(domain.context, domain.pd, self.config.clone())
791 .map_err(|e| anyhow::anyhow!("could not create IbvQueuePair: {}", e))?,
792 );
793 let info = qp
794 .get_qp_info()
795 .map_err(|e| anyhow::anyhow!("could not extract QP info: {}", e))?;
796 let initializer =
797 QueuePairInitializer::new(Instance::handle(cx), other, qp_key.clone(), qp)
798 .spawn(cx)?;
799 self.qps.insert(
800 qp_key.clone(),
801 QpState::Pending {
802 info,
803 initializer,
804 waiters: Vec::new(),
805 },
806 );
807 }
808 Ok(self
809 .qps
810 .get_mut(qp_key)
811 .expect("entry just inserted or pre-existing"))
812 }
813}
814
815#[async_trait]
816#[hyperactor::handle(IbvManagerMessage)]
817impl IbvManagerMessageHandler for IbvManagerActor {
818 async fn release_buffer(
819 &mut self,
820 _cx: &Context<Self>,
821 remote_buf_id: usize,
822 ) -> Result<(), anyhow::Error> {
823 self.buffer_registrations.remove(&remote_buf_id);
827 Ok(())
828 }
829}
830
831#[async_trait]
832impl Handler<EnsureQueuePair<IbvManagerActor>> for IbvManagerActor {
833 async fn handle(
834 &mut self,
835 cx: &Context<Self>,
836 msg: EnsureQueuePair<IbvManagerActor>,
837 ) -> Result<(), anyhow::Error> {
838 let EnsureQueuePair {
839 sender,
840 sender_device,
841 receiver_device,
842 reply,
843 } = msg;
844 let qp_key = QpKey {
845 self_device: receiver_device,
846 other_id: sender.actor_addr().id().clone(),
847 other_device: sender_device,
848 };
849 let state = match self.ensure_queue_pair_impl(cx, sender, &qp_key) {
850 Ok(state) => state,
851 Err(e) => {
852 reply.post(cx, PeerInfo(Err(e.to_string())));
853 return Ok(());
854 }
855 };
856 match state {
857 QpState::Pending {
858 info, initializer, ..
859 } => {
860 let notify_rts = initializer.bind::<QueuePairInitializer<Self>>().port();
861 reply.post(cx, PeerInfo(Ok((info.clone(), notify_rts))));
862 }
863 QpState::Ready(_) => {
864 reply.post(
870 cx,
871 PeerInfo(Err(format!(
872 "EnsureQueuePair on already-Ready entry {qp_key:?}"
873 ))),
874 );
875 }
876 QpState::Failed(error) => {
877 reply.post(cx, PeerInfo(Err(error.clone())));
878 }
879 }
880 Ok(())
881 }
882}
883
884#[async_trait]
885#[hyperactor::handle(IbvManagerLocalMessage)]
886impl IbvManagerLocalMessageHandler for IbvManagerActor {
887 async fn register_mr(
888 &mut self,
889 _cx: &Context<Self>,
890 addr: usize,
891 size: usize,
892 ) -> Result<Result<(IbvMemoryRegionView, String), String>, anyhow::Error> {
893 Ok(self.register_mr_impl(addr, size).map_err(|e| e.to_string()))
894 }
895
896 async fn register_remote_buffer(
897 &mut self,
898 _cx: &Context<Self>,
899 remote_buf_id: usize,
900 local: Arc<KeepaliveLocalMemory>,
901 ) -> Result<Result<IbvBuffer, String>, anyhow::Error> {
902 if let Some((buf, _)) = self.buffer_registrations.get(&remote_buf_id) {
903 return Ok(Ok(buf.clone()));
904 }
905 let (mrv, device_name) = match self.register_mr_impl(local.addr(), local.size()) {
906 Ok(v) => v,
907 Err(e) => return Ok(Err(e.to_string())),
908 };
909 let buf = IbvBuffer {
910 mr_id: mrv.id,
911 lkey: mrv.lkey,
912 rkey: mrv.rkey,
913 addr: mrv.rdma_addr,
914 size: mrv.size,
915 device_name,
916 };
917 self.buffer_registrations
918 .insert(remote_buf_id, (buf.clone(), mrv));
919 Ok(Ok(buf))
920 }
921
922 async fn request_queue_pair(
923 &mut self,
924 cx: &Context<Self>,
925 other: ActorRef<IbvManagerActor>,
926 self_device: String,
927 other_device: String,
928 reply: OncePortHandle<Result<IbvQueuePair, String>>,
929 ) -> Result<(), anyhow::Error> {
930 let qp_key = QpKey {
931 self_device,
932 other_id: other.actor_addr().id().clone(),
933 other_device,
934 };
935 let state = match self.ensure_queue_pair_impl(cx, other, &qp_key) {
936 Ok(state) => state,
937 Err(e) => {
938 reply.post(cx, Err(e.to_string()));
939 return Ok(());
940 }
941 };
942 match state {
943 QpState::Pending { waiters, .. } => waiters.push(reply),
944 QpState::Ready(qp) => reply.post(cx, Ok(qp.clone())),
945 QpState::Failed(error) => reply.post(cx, Err(error.clone())),
946 }
947 Ok(())
948 }
949}
950
951#[async_trait]
952impl Handler<QpInitializerDone> for IbvManagerActor {
953 async fn handle(
954 &mut self,
955 cx: &Context<Self>,
956 msg: QpInitializerDone,
957 ) -> Result<(), anyhow::Error> {
958 let QpInitializerDone { qp_key, qp } = msg;
959 let qp = qp.into_inner();
960 let initializer = match self.qps.remove(&qp_key) {
963 Some(QpState::Pending {
964 waiters,
965 initializer,
966 ..
967 }) => {
968 for w in waiters {
969 w.post(cx, Ok(qp.clone()));
970 }
971 initializer
972 }
973 other => {
974 unreachable!("QpInitializerDone received but state is {other:?}: {qp_key:?}")
975 }
976 };
977 self.qps.insert(qp_key.clone(), QpState::Ready(qp));
978 initializer.drain_and_stop("QpInitializerDone")?;
979 let status = initializer.await;
980 if status.is_failed() {
981 tracing::error!(
986 "QueuePairInitializer for {qp_key:?} terminated with failure after Done: {status:?}"
987 );
988 }
989 Ok(())
990 }
991}
992
993#[async_trait]
994impl Handler<QpInitializerFailed> for IbvManagerActor {
995 async fn handle(
996 &mut self,
997 cx: &Context<Self>,
998 msg: QpInitializerFailed,
999 ) -> Result<(), anyhow::Error> {
1000 let QpInitializerFailed { qp_key, error } = msg;
1001 let initializer = match self.qps.remove(&qp_key) {
1002 Some(QpState::Pending {
1003 waiters,
1004 initializer,
1005 ..
1006 }) => {
1007 for w in waiters {
1008 w.post(cx, Err(error.clone()));
1009 }
1010 initializer
1011 }
1012 other => {
1013 unreachable!("QpInitializerFailed received but state is {other:?}: {qp_key:?}")
1014 }
1015 };
1016 self.qps.insert(qp_key.clone(), QpState::Failed(error));
1020 initializer.drain_and_stop("QpInitializerFailed")?;
1021 let status = initializer.await;
1022 if status.is_failed() {
1023 tracing::error!(
1024 "QueuePairInitializer for {qp_key:?} terminated with failure after Failed: {status:?}"
1025 );
1026 }
1027 Ok(())
1028 }
1029}
1030
1031pub(super) async fn request_queue_pair(
1037 actor: &ActorHandle<IbvManagerActor>,
1038 cx: &(impl hyperactor::context::Actor + Send + Sync),
1039 other: ActorRef<IbvManagerActor>,
1040 self_device: String,
1041 other_device: String,
1042) -> Result<Result<IbvQueuePair, String>, anyhow::Error> {
1043 let (reply, rx) = cx
1044 .mailbox()
1045 .open_once_port::<Result<IbvQueuePair, String>>();
1046 actor
1047 .request_queue_pair(cx, other, self_device, other_device, reply)
1048 .await?;
1049 rx.recv()
1050 .await
1051 .map_err(|e| anyhow::anyhow!("request_queue_pair port closed: {e}"))
1052}
1053
1054#[derive(Debug, Clone)]
1059pub struct IbvBackend(pub ActorHandle<IbvManagerActor>);
1060
1061impl std::ops::Deref for IbvBackend {
1062 type Target = ActorHandle<IbvManagerActor>;
1063 fn deref(&self) -> &Self::Target {
1064 &self.0
1065 }
1066}
1067
1068impl IbvBackend {
1069 async fn wait_for_completion(
1074 local_buf: &IbvBuffer,
1075 qp: &mut IbvQueuePair,
1076 poll_target: PollTarget,
1077 expected_wr_ids: &[u64],
1078 timeout: Duration,
1079 ) -> Result<(), anyhow::Error> {
1080 let start_time = std::time::Instant::now();
1081
1082 let mut remaining: std::collections::HashSet<u64> =
1083 expected_wr_ids.iter().copied().collect();
1084 let mut poll_policy = PollSleepPolicy::new();
1085
1086 while start_time.elapsed() < timeout {
1087 if remaining.is_empty() {
1088 return Ok(());
1089 }
1090
1091 let wr_ids_to_poll: Vec<u64> = remaining.iter().copied().collect();
1092 match qp.poll_completion(poll_target, &wr_ids_to_poll) {
1093 Ok(completions) => {
1094 for (wr_id, _wc) in completions {
1095 remaining.remove(&wr_id);
1096 }
1097 if remaining.is_empty() {
1098 return Ok(());
1099 }
1100 poll_policy.yield_now().await;
1101 }
1102 Err(e) => {
1103 let mut root_cause: Option<PollCompletionError> = None;
1109 if e.is_wr_flush_err() {
1110 for &wr_id in &wr_ids_to_poll {
1111 if let Err(inner_err) = qp.poll_completion(poll_target, &[wr_id]) {
1112 if !inner_err.is_wr_flush_err() {
1113 root_cause = Some(inner_err);
1114 break;
1115 }
1116 }
1117 }
1118 }
1119 let error_detail = if let Some(cause) = root_cause {
1120 format!(
1121 "RDMA polling completion failed: {} (root cause: {})",
1122 e, cause
1123 )
1124 } else {
1125 format!("RDMA polling completion failed: {}", e)
1126 };
1127 return Err(anyhow::anyhow!(
1128 "{} [lkey={}, rkey={}, addr=0x{:x}, size={}]",
1129 error_detail,
1130 local_buf.lkey,
1131 local_buf.rkey,
1132 local_buf.addr,
1133 local_buf.size
1134 ));
1135 }
1136 }
1137 }
1138 tracing::error!(
1139 "timed out while waiting on request completion for wr_ids={:?}",
1140 remaining
1141 );
1142 Err(anyhow::anyhow!(
1143 "[ibv_buffer({:?})] rdma operation did not complete in time (expected wr_ids={:?})",
1144 local_buf,
1145 expected_wr_ids
1146 ))
1147 }
1148
1149 async fn execute_op(
1155 &self,
1156 cx: &(impl hyperactor::context::Actor + Send + Sync),
1157 op: IbvOp,
1158 timeout: Duration,
1159 ) -> Result<(), anyhow::Error> {
1160 let (local_mrv, local_device_name) = self
1161 .register_mr(cx, op.local_memory.addr(), op.local_memory.size())
1162 .await?
1163 .map_err(|e| anyhow::anyhow!(e))?;
1164
1165 let local_buffer = IbvBuffer {
1166 mr_id: local_mrv.id,
1167 lkey: local_mrv.lkey,
1168 rkey: local_mrv.rkey,
1169 addr: local_mrv.rdma_addr,
1170 size: local_mrv.size,
1171 device_name: local_device_name,
1172 };
1173
1174 let result = async {
1175 let mut qp = request_queue_pair(
1176 &self.0,
1177 cx,
1178 op.remote_manager.clone(),
1179 local_buffer.device_name.clone(),
1180 op.remote_buffer.device_name.clone(),
1181 )
1182 .await?
1183 .map_err(|e| anyhow::anyhow!(e))?;
1184
1185 let wr_id = match op.op_type {
1186 RdmaOpType::WriteFromLocal => qp.put(local_buffer.clone(), op.remote_buffer)?,
1187 RdmaOpType::ReadIntoLocal => qp.get(local_buffer.clone(), op.remote_buffer)?,
1188 };
1189
1190 Self::wait_for_completion(&local_buffer, &mut qp, PollTarget::Send, &wr_id, timeout)
1191 .await
1192 }
1193 .await;
1194
1195 drop(local_mrv);
1196 result
1197 }
1198}
1199
1200#[async_trait]
1201impl RdmaBackend for IbvBackend {
1202 type TransportInfo = ();
1203
1204 async fn submit(
1209 &mut self,
1210 cx: &(impl hyperactor::context::Actor + Send + Sync),
1211 ops: Vec<RdmaOp>,
1212 timeout: Duration,
1213 ) -> Result<(), anyhow::Error> {
1214 let mut ibv_ops = Vec::with_capacity(ops.len());
1215 for op in ops {
1216 let (remote_manager, remote_buffer) = op.remote.resolve_ibv().ok_or_else(|| {
1217 anyhow::anyhow!("ibverbs backend not found for buffer: {:?}", op.remote)
1218 })?;
1219 ibv_ops.push(IbvOp {
1220 op_type: op.op_type,
1221 local_memory: op.local.clone(),
1222 remote_buffer,
1223 remote_manager,
1224 });
1225 }
1226
1227 let deadline = Instant::now() + timeout;
1228 for op in ibv_ops {
1229 let remaining = deadline.saturating_duration_since(Instant::now());
1230 if remaining.is_zero() {
1231 return Err(anyhow::anyhow!("submit timed out"));
1232 }
1233 self.execute_op(cx, op, remaining).await?;
1234 }
1235 Ok(())
1236 }
1237
1238 fn transport_level(&self) -> RdmaTransportLevel {
1239 RdmaTransportLevel::Nic
1240 }
1241
1242 fn transport_info(&self) -> Option<Self::TransportInfo> {
1243 None
1244 }
1245}