1use std::collections::HashMap;
20use std::collections::HashSet;
21use std::sync::Arc;
22use std::sync::OnceLock;
23use std::time::Duration;
24use std::time::Instant;
25
26use anyhow::Result;
27use async_trait::async_trait;
28use futures::lock::Mutex;
29use hyperactor::Actor;
30use hyperactor::ActorHandle;
31use hyperactor::Context;
32use hyperactor::HandleClient;
33use hyperactor::Handler;
34use hyperactor::Instance;
35use hyperactor::OncePortHandle;
36use hyperactor::RefClient;
37use hyperactor::reference;
38use serde::Deserialize;
39use serde::Serialize;
40use typeuri::Named;
41
42use super::IbvBuffer;
43use super::IbvOp;
44use super::domain::IbvDomain;
45use super::primitives::IbvConfig;
46use super::primitives::IbvDevice;
47use super::primitives::IbvMemoryRegionView;
48use super::primitives::IbvQpInfo;
49use super::primitives::ibverbs_supported;
50use super::primitives::mlx5dv_supported;
51use super::primitives::resolve_qp_type;
52use super::queue_pair::IbvQueuePair;
53use super::queue_pair::PollTarget;
54use crate::RdmaOp;
55use crate::RdmaOpType;
56use crate::RdmaTransportLevel;
57use crate::backend::RdmaBackend;
58use crate::rdma_components::get_registered_cuda_segments;
59use crate::rdma_manager_actor::GetIbvActorRefClient;
60use crate::rdma_manager_actor::RdmaManagerActor;
61use crate::rdma_manager_actor::RdmaManagerMessageClient;
62use crate::rdma_manager_actor::get_rdmaxcel_error_message;
63use crate::validate_execution_context;
64
65#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
67pub enum IbvManagerMessage {
68 RequestBuffer {
76 remote_buf_id: usize,
77 #[reply]
78 reply: reference::OncePortRef<Option<IbvBuffer>>,
79 },
80 ReleaseBuffer {
82 remote_buf_id: usize,
83 #[reply]
84 reply: reference::OncePortRef<()>,
85 },
86 RequestQueuePair {
87 other: reference::ActorRef<IbvManagerActor>,
88 self_device: String,
89 other_device: String,
90 #[reply]
91 reply: reference::OncePortRef<IbvQueuePair>,
92 },
93 Connect {
94 other: reference::ActorRef<IbvManagerActor>,
95 self_device: String,
96 other_device: String,
97 endpoint: IbvQpInfo,
98 },
99 InitializeQP {
100 other: reference::ActorRef<IbvManagerActor>,
101 self_device: String,
102 other_device: String,
103 #[reply]
104 reply: reference::OncePortRef<bool>,
105 },
106 ConnectionInfo {
107 other: reference::ActorRef<IbvManagerActor>,
108 self_device: String,
109 other_device: String,
110 #[reply]
111 reply: reference::OncePortRef<IbvQpInfo>,
112 },
113 ReleaseQueuePair {
114 other: reference::ActorRef<IbvManagerActor>,
115 self_device: String,
116 other_device: String,
117 qp: IbvQueuePair,
118 },
119 GetQpState {
120 other: reference::ActorRef<IbvManagerActor>,
121 self_device: String,
122 other_device: String,
123 #[reply]
124 reply: reference::OncePortRef<u32>,
125 },
126}
127wirevalue::register_type!(IbvManagerMessage);
128
129#[derive(Handler, HandleClient, Debug)]
134pub struct IbvSubmit {
135 pub ops: Vec<IbvOp>,
136 pub timeout: Duration,
137 #[reply]
138 pub reply: OncePortHandle<Result<(), String>>,
139}
140
141#[derive(Debug)]
146#[hyperactor::export(
147 handlers = [
148 IbvManagerMessage,
149 ],
150)]
151pub struct IbvManagerActor {
152 owner: OnceLock<ActorHandle<RdmaManagerActor>>,
153
154 device_qps: HashMap<String, HashMap<(reference::ActorId, String), IbvQueuePair>>,
156
157 pending_qp_creation: Arc<Mutex<HashSet<(String, reference::ActorId, String)>>>,
160
161 device_domains: HashMap<String, (IbvDomain, Option<IbvQueuePair>)>,
164
165 config: IbvConfig,
166
167 mlx5dv_enabled: bool,
168
169 mr_map: HashMap<usize, usize>,
172
173 mrv_id: usize,
175
176 pci_to_device: HashMap<String, IbvDevice>,
179
180 buffer_registrations: HashMap<usize, IbvBuffer>,
182}
183
184#[async_trait]
185impl Actor for IbvManagerActor {
186 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
187 let owner = if let Some(owner) = this.parent_handle() {
188 owner
189 } else {
190 anyhow::bail!("RdmaManagerActor not found as parent of IbvManagerActor");
191 };
192 self.owner
193 .set(owner)
194 .expect("owner should only be set once during init");
195 Ok(())
196 }
197}
198
199impl Drop for IbvManagerActor {
200 fn drop(&mut self) {
201 fn destroy_queue_pair(qp: &IbvQueuePair, _context: &str) {
206 unsafe {
207 if qp.qp != 0 {
208 let rdmaxcel_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
209 rdmaxcel_sys::rdmaxcel_qp_destroy(rdmaxcel_qp);
210 }
211 }
212 }
213
214 for (_device_name, device_map) in self.device_qps.drain() {
216 for ((actor_id, _remote_device), qp) in device_map {
217 destroy_queue_pair(&qp, &format!("actor {:?}", actor_id));
218 }
219 }
220
221 for (device_name, (domain, qp)) in self.device_domains.drain() {
223 if let Some(qp) = qp {
224 destroy_queue_pair(&qp, &format!("loopback QP on device {}", device_name));
225 }
226 drop(domain);
227 }
228
229 let _mr_count = self.mr_map.len();
231 for (id, mr_ptr) in self.mr_map.drain() {
232 if mr_ptr != 0 {
233 unsafe {
234 let result = rdmaxcel_sys::ibv_dereg_mr(mr_ptr as *mut rdmaxcel_sys::ibv_mr);
235 if result != 0 {
236 tracing::error!(
237 "Failed to deregister MR with id {}: error code {}",
238 id,
239 result
240 );
241 }
242 }
243 }
244 }
245
246 if self.mlx5dv_enabled {
249 unsafe {
250 let result = rdmaxcel_sys::deregister_segments();
251 if result != 0 {
252 let error_msg = get_rdmaxcel_error_message(result);
253 tracing::error!(
254 "Failed to deregister CUDA segments: {} (error code: {})",
255 error_msg,
256 result
257 );
258 }
259 }
260 }
261 }
262}
263
264impl IbvManagerActor {
265 pub async fn local_handle(
268 client: &(impl hyperactor::context::Actor + Send + Sync),
269 ) -> Result<ActorHandle<Self>, anyhow::Error> {
270 let rdma_handle = RdmaManagerActor::local_handle(client);
271 let ibv_ref = rdma_handle
272 .get_ibv_actor_ref(client)
273 .await?
274 .ok_or_else(|| anyhow::anyhow!("local RdmaManagerActor has no ibverbs backend"))?;
275 ibv_ref
276 .downcast_handle(client)
277 .ok_or_else(|| anyhow::anyhow!("IbvManagerActor is not in the local process"))
278 }
279
280 pub async fn new(params: Option<IbvConfig>) -> Result<Self, anyhow::Error> {
282 if !ibverbs_supported() {
283 return Err(anyhow::anyhow!(
284 "Cannot create IbvManagerActor because RDMA is not supported on this machine"
285 ));
286 }
287
288 let mut config = params.unwrap_or_default();
290 tracing::debug!("rdma is enabled, config device hint: {}", config.device);
291
292 let mlx5dv_enabled = resolve_qp_type(config.qp_type) == rdmaxcel_sys::RDMA_QP_TYPE_MLX5DV;
293
294 if config.use_gpu_direct {
296 match validate_execution_context().await {
297 Ok(_) => {
298 tracing::info!("GPU Direct RDMA execution context validated successfully");
299 }
300 Err(e) => {
301 tracing::warn!(
302 "GPU Direct RDMA execution context validation failed: {}. Downgrading to standard ibverbs mode.",
303 e
304 );
305 config.use_gpu_direct = false;
306 }
307 }
308 }
309
310 let pci_to_device = super::device_selection::create_cuda_to_ibv_mapping();
312 tracing::debug!(
313 "Built CUDA to RDMA device mapping with {} entries",
314 pci_to_device.len()
315 );
316
317 Ok(Self {
318 owner: OnceLock::new(),
319 device_qps: HashMap::new(),
320 pending_qp_creation: Arc::new(Mutex::new(HashSet::new())),
321 device_domains: HashMap::new(),
322 config,
323 mlx5dv_enabled,
324 mr_map: HashMap::new(),
325 mrv_id: 0,
326 pci_to_device,
327 buffer_registrations: HashMap::new(),
328 })
329 }
330
331 fn get_or_create_device_domain(
333 &mut self,
334 device_name: &str,
335 rdma_device: &IbvDevice,
336 ) -> Result<(IbvDomain, Option<IbvQueuePair>), anyhow::Error> {
337 if let Some((domain, qp)) = self.device_domains.get(device_name) {
338 return Ok((domain.clone(), qp.clone()));
339 }
340
341 let domain = IbvDomain::new(rdma_device.clone()).map_err(|e| {
343 anyhow::anyhow!("could not create domain for device {}: {}", device_name, e)
344 })?;
345
346 crate::print_device_info_if_debug_enabled(domain.context);
348
349 let qp = if mlx5dv_supported() && !crate::efa::is_efa_device() {
352 let mut qp = IbvQueuePair::new(domain.context, domain.pd, self.config.clone())
353 .map_err(|e| {
354 anyhow::anyhow!(
355 "could not create loopback QP for device {}: {}",
356 device_name,
357 e
358 )
359 })?;
360
361 let endpoint = qp.get_qp_info().map_err(|e| {
363 anyhow::anyhow!("could not get QP info for device {}: {}", device_name, e)
364 })?;
365
366 qp.connect(&endpoint).map_err(|e| {
367 anyhow::anyhow!(
368 "could not connect loopback QP for device {}: {}",
369 device_name,
370 e
371 )
372 })?;
373
374 Some(qp)
375 } else {
376 None
377 };
378
379 self.device_domains
380 .insert(device_name.to_string(), (domain.clone(), qp.clone()));
381 Ok((domain, qp))
382 }
383
384 fn find_cuda_segment_for_address(
385 &mut self,
386 addr: usize,
387 size: usize,
388 ) -> Option<IbvMemoryRegionView> {
389 let registered_segments = get_registered_cuda_segments();
390 for segment in registered_segments {
391 let start_addr = segment.phys_address;
392 let end_addr = start_addr + segment.phys_size;
393 if start_addr <= addr && addr + size <= end_addr {
394 let offset = addr - start_addr;
395 let rdma_addr = segment.mr_addr + offset;
396
397 let mrv = IbvMemoryRegionView {
398 id: self.mrv_id,
399 virtual_addr: addr,
400 rdma_addr,
401 size,
402 lkey: segment.lkey,
403 rkey: segment.rkey,
404 };
405 self.mrv_id += 1;
406 return Some(mrv);
407 }
408 }
409 None
410 }
411
412 fn register_mr(
413 &mut self,
414 addr: usize,
415 size: usize,
416 ) -> Result<(IbvMemoryRegionView, String), anyhow::Error> {
417 unsafe {
418 let mut mem_type: i32 = 0;
419 let ptr = addr as rdmaxcel_sys::CUdeviceptr;
420 let err = rdmaxcel_sys::rdmaxcel_cuPointerGetAttribute(
421 &mut mem_type as *mut _ as *mut std::ffi::c_void,
422 rdmaxcel_sys::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
423 ptr,
424 );
425 let is_cuda = err == rdmaxcel_sys::CUDA_SUCCESS;
426
427 let mut selected_rdma_device = None;
428
429 if is_cuda {
430 let mut pci_addr_buf: [std::os::raw::c_char; 16] = [0; 16]; let err = rdmaxcel_sys::get_cuda_pci_address_from_ptr(
433 addr as rdmaxcel_sys::CUdeviceptr,
434 pci_addr_buf.as_mut_ptr(),
435 pci_addr_buf.len(),
436 );
437 if err != 0 {
438 let error_msg = get_rdmaxcel_error_message(err);
439 return Err(anyhow::anyhow!(
440 "RdmaXcel get_cuda_pci_address_from_ptr failed (addr: 0x{:x}, size: {}): {}",
441 addr,
442 size,
443 error_msg
444 ));
445 }
446
447 let pci_addr = std::ffi::CStr::from_ptr(pci_addr_buf.as_ptr())
449 .to_str()
450 .unwrap();
451 selected_rdma_device = self.pci_to_device.get(pci_addr).cloned();
452 }
453
454 let rdma_device = if let Some(device) = selected_rdma_device {
456 device
457 } else {
458 self.config.device.clone()
460 };
461
462 let device_name = rdma_device.name().clone();
463 tracing::debug!(
464 "Using RDMA device: {} for memory at 0x{:x}",
465 device_name,
466 addr
467 );
468
469 let (domain, qp) = self.get_or_create_device_domain(&device_name, &rdma_device)?;
471
472 let access = if crate::efa::is_efa_device() {
473 crate::efa::mr_access_flags()
474 } else {
475 rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
476 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
477 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
478 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC
479 };
480
481 let mut mr: *mut rdmaxcel_sys::ibv_mr = std::ptr::null_mut();
482 let mrv;
483
484 if is_cuda {
485 let mut segment_mrv = None;
487 if self.mlx5dv_enabled {
488 segment_mrv = self.find_cuda_segment_for_address(addr, size);
490
491 if segment_mrv.is_none() {
493 let err = rdmaxcel_sys::register_segments(
494 domain.pd,
495 qp.unwrap().qp as *mut rdmaxcel_sys::rdmaxcel_qp_t,
496 );
497 if err == 0 {
500 segment_mrv = self.find_cuda_segment_for_address(addr, size);
501 }
502 }
503 }
504
505 if let Some(mrv_from_segment) = segment_mrv {
507 mrv = mrv_from_segment;
508 } else {
509 let mut fd: i32 = -1;
511 rdmaxcel_sys::rdmaxcel_cuMemGetHandleForAddressRange(
512 &mut fd,
513 addr as rdmaxcel_sys::CUdeviceptr,
514 size,
515 rdmaxcel_sys::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD,
516 0,
517 );
518 mr =
519 rdmaxcel_sys::ibv_reg_dmabuf_mr(domain.pd, 0, size, 0, fd, access.0 as i32);
520 if mr.is_null() {
521 return Err(anyhow::anyhow!("Failed to register dmabuf MR"));
522 }
523 mrv = IbvMemoryRegionView {
524 id: self.mrv_id,
525 virtual_addr: addr,
526 rdma_addr: (*mr).addr as usize,
527 size,
528 lkey: (*mr).lkey,
529 rkey: (*mr).rkey,
530 };
531 self.mrv_id += 1;
532 }
533 } else {
534 mr = rdmaxcel_sys::ibv_reg_mr(
536 domain.pd,
537 addr as *mut std::ffi::c_void,
538 size,
539 access.0 as i32,
540 );
541
542 if mr.is_null() {
543 return Err(anyhow::anyhow!("failed to register standard MR"));
544 }
545
546 mrv = IbvMemoryRegionView {
547 id: self.mrv_id,
548 virtual_addr: addr,
549 rdma_addr: (*mr).addr as usize,
550 size,
551 lkey: (*mr).lkey,
552 rkey: (*mr).rkey,
553 };
554 self.mrv_id += 1;
555 }
556 self.mr_map.insert(mrv.id, mr as usize);
557 Ok((mrv, device_name))
558 }
559 }
560
561 fn deregister_mr(&mut self, id: usize) -> Result<(), anyhow::Error> {
562 if let Some(mr_ptr) = self.mr_map.remove(&id) {
563 if mr_ptr != 0 {
564 unsafe {
565 rdmaxcel_sys::ibv_dereg_mr(mr_ptr as *mut rdmaxcel_sys::ibv_mr);
566 }
567 }
568 }
569 Ok(())
570 }
571
572 async fn wait_for_completion(
584 &self,
585 local_buf: &IbvBuffer,
586 qp: &mut IbvQueuePair,
587 poll_target: PollTarget,
588 expected_wr_ids: &[u64],
589 timeout: Duration,
590 ) -> Result<(), anyhow::Error> {
591 let start_time = std::time::Instant::now();
592
593 let mut remaining: std::collections::HashSet<u64> =
594 expected_wr_ids.iter().copied().collect();
595
596 while start_time.elapsed() < timeout {
597 if remaining.is_empty() {
598 return Ok(());
599 }
600
601 let wr_ids_to_poll: Vec<u64> = remaining.iter().copied().collect();
602 match qp.poll_completion(poll_target, &wr_ids_to_poll) {
603 Ok(completions) => {
604 for (wr_id, _wc) in completions {
605 remaining.remove(&wr_id);
606 }
607 if remaining.is_empty() {
608 return Ok(());
609 }
610 tokio::time::sleep(Duration::from_millis(1)).await;
611 }
612 Err(e) => {
613 return Err(anyhow::anyhow!(
614 "RDMA polling completion failed: {} [lkey={}, rkey={}, addr=0x{:x}, size={}]",
615 e,
616 local_buf.lkey,
617 local_buf.rkey,
618 local_buf.addr,
619 local_buf.size
620 ));
621 }
622 }
623 }
624 tracing::error!(
625 "timed out while waiting on request completion for wr_ids={:?}",
626 remaining
627 );
628 Err(anyhow::anyhow!(
629 "[ibv_buffer({:?})] rdma operation did not complete in time (expected wr_ids={:?})",
630 local_buf,
631 expected_wr_ids
632 ))
633 }
634
635 async fn execute_op(
638 &mut self,
639 cx: &Context<'_, Self>,
640 op: IbvOp,
641 timeout: Duration,
642 ) -> Result<(), anyhow::Error> {
643 let (local_mrv, local_device_name) =
645 self.register_mr(op.local_memory.addr(), op.local_memory.size())?;
646 let local_buffer = IbvBuffer {
647 mr_id: local_mrv.id,
648 lkey: local_mrv.lkey,
649 rkey: local_mrv.rkey,
650 addr: local_mrv.rdma_addr,
651 size: local_mrv.size,
652 device_name: local_device_name,
653 };
654
655 let op_result = async {
656 let mut qp = self
657 .request_queue_pair(
658 cx,
659 op.remote_manager.clone(),
660 local_buffer.device_name.clone(),
661 op.remote_buffer.device_name.clone(),
662 )
663 .await?;
664
665 let wr_id = match op.op_type {
666 RdmaOpType::WriteFromLocal => qp.put(local_buffer.clone(), op.remote_buffer)?,
667 RdmaOpType::ReadIntoLocal => qp.get(local_buffer.clone(), op.remote_buffer)?,
668 };
669
670 self.wait_for_completion(&local_buffer, &mut qp, PollTarget::Send, &wr_id, timeout)
671 .await
672 }
673 .await;
674
675 let dereg_result = self.deregister_mr(local_buffer.mr_id);
677
678 match (op_result, dereg_result) {
679 (Ok(()), Ok(())) => Ok(()),
680 (Err(e), Ok(())) => Err(e),
681 (Ok(()), Err(e)) => Err(e),
682 (Err(op_err), Err(dereg_err)) => Err(anyhow::anyhow!(
683 "deregister MR error: {}; op error: {}",
684 dereg_err,
685 op_err
686 )),
687 }
688 }
689}
690
691#[async_trait]
692#[hyperactor::handle(IbvManagerMessage)]
693impl IbvManagerMessageHandler for IbvManagerActor {
694 async fn request_buffer(
695 &mut self,
696 cx: &Context<Self>,
697 remote_buf_id: usize,
698 ) -> Result<Option<IbvBuffer>, anyhow::Error> {
699 if let Some(buf) = self.buffer_registrations.get(&remote_buf_id) {
701 return Ok(Some(buf.clone()));
702 }
703
704 let owner = self.owner.get().unwrap();
708 let mem = match owner.request_local_memory(cx, remote_buf_id).await? {
709 Some(mem) => mem,
710 None => return Ok(None),
711 };
712
713 let (mrv, device_name) = self.register_mr(mem.addr(), mem.size())?;
714
715 let buf = IbvBuffer {
716 mr_id: mrv.id,
717 lkey: mrv.lkey,
718 rkey: mrv.rkey,
719 addr: mrv.rdma_addr,
720 size: mrv.size,
721 device_name,
722 };
723
724 self.buffer_registrations.insert(remote_buf_id, buf.clone());
725
726 Ok(Some(buf))
727 }
728
729 async fn release_buffer(
730 &mut self,
731 _cx: &Context<Self>,
732 remote_buf_id: usize,
733 ) -> Result<(), anyhow::Error> {
734 if let Some(buf) = self.buffer_registrations.remove(&remote_buf_id) {
735 self.deregister_mr(buf.mr_id)
736 .map_err(|e| anyhow::anyhow!("could not deregister buffer: {}", e))?;
737 }
738 Ok(())
739 }
740
741 async fn request_queue_pair(
742 &mut self,
743 cx: &Context<Self>,
744 other: reference::ActorRef<IbvManagerActor>,
745 self_device: String,
746 other_device: String,
747 ) -> Result<IbvQueuePair, anyhow::Error> {
748 let self_ref: reference::ActorRef<IbvManagerActor> = cx.bind();
749 let other_id = other.actor_id().clone();
750
751 let inner_key = (other_id.clone(), other_device.clone());
753
754 if let Some(device_map) = self.device_qps.get(&self_device) {
756 if let Some(qp) = device_map.get(&inner_key) {
757 return Ok(qp.clone());
758 }
759 }
760
761 let pending_key = (self_device.clone(), other_id.clone(), other_device.clone());
763 let mut pending = self.pending_qp_creation.lock().await;
764
765 if pending.contains(&pending_key) {
766 drop(pending);
768
769 let start = Instant::now();
772 let timeout = Duration::from_secs(1);
773
774 loop {
775 tokio::time::sleep(Duration::from_micros(200)).await;
776
777 if let Some(device_map) = self.device_qps.get(&self_device) {
779 if let Some(qp) = device_map.get(&inner_key) {
780 return Ok(qp.clone());
781 }
782 }
783
784 if start.elapsed() >= timeout {
786 return Err(anyhow::anyhow!(
787 "Timeout waiting for QP creation (device {} -> actor {} device {}). \
788 Another task is creating it but hasn't completed in 1 second",
789 self_device,
790 other_id,
791 other_device
792 ));
793 }
794 }
795 } else {
796 pending.insert(pending_key.clone());
798 drop(pending);
799 }
801
802 let result = async {
804 let is_loopback = other_id == *self_ref.actor_id() && self_device == other_device;
805
806 if is_loopback {
807 self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone())
809 .await?;
810 let endpoint = self
811 .connection_info(cx, other.clone(), other_device.clone(), self_device.clone())
812 .await?;
813 self.connect(
814 cx,
815 other.clone(),
816 self_device.clone(),
817 other_device.clone(),
818 endpoint,
819 )
820 .await?;
821 } else {
822 self.initialize_qp(cx, other.clone(), self_device.clone(), other_device.clone())
824 .await?;
825 other
826 .initialize_qp(
827 cx,
828 self_ref.clone(),
829 other_device.clone(),
830 self_device.clone(),
831 )
832 .await?;
833 let other_endpoint: IbvQpInfo = other
834 .connection_info(
835 cx,
836 self_ref.clone(),
837 other_device.clone(),
838 self_device.clone(),
839 )
840 .await?;
841 self.connect(
842 cx,
843 other.clone(),
844 self_device.clone(),
845 other_device.clone(),
846 other_endpoint,
847 )
848 .await?;
849 let local_endpoint = self
850 .connection_info(cx, other.clone(), self_device.clone(), other_device.clone())
851 .await?;
852 other
853 .connect(
854 cx,
855 self_ref.clone(),
856 other_device.clone(),
857 self_device.clone(),
858 local_endpoint,
859 )
860 .await?;
861
862 let remote_state = other
864 .get_qp_state(
865 cx,
866 self_ref.clone(),
867 other_device.clone(),
868 self_device.clone(),
869 )
870 .await?;
871
872 if remote_state != rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS {
873 return Err(anyhow::anyhow!(
874 "Remote QP not in RTS state after connection setup. \
875 Local is ready but remote is in state {}. \
876 This indicates a synchronization issue in connection setup.",
877 remote_state
878 ));
879 }
880 }
881
882 if let Some(device_map) = self.device_qps.get(&self_device) {
884 if let Some(qp) = device_map.get(&inner_key) {
885 Ok(qp.clone())
886 } else {
887 Err(anyhow::anyhow!(
888 "Failed to create connection for actor {} on device {}",
889 other_id,
890 other_device
891 ))
892 }
893 } else {
894 Err(anyhow::anyhow!(
895 "Failed to create connection for actor {} on device {} - no device map",
896 other_id,
897 other_device
898 ))
899 }
900 }
901 .await;
902
903 let mut pending = self.pending_qp_creation.lock().await;
905 pending.remove(&pending_key);
906 drop(pending);
907
908 result
909 }
910
911 async fn connect(
912 &mut self,
913 _cx: &Context<Self>,
914 other: reference::ActorRef<IbvManagerActor>,
915 self_device: String,
916 other_device: String,
917 endpoint: IbvQpInfo,
918 ) -> Result<(), anyhow::Error> {
919 tracing::debug!("connecting with {:?}", other);
920 let other_id = other.actor_id().clone();
921
922 let inner_key = (other_id.clone(), other_device.clone());
923
924 if let Some(device_map) = self.device_qps.get_mut(&self_device) {
925 match device_map.get_mut(&inner_key) {
926 Some(qp) => {
927 qp.connect(&endpoint).map_err(|e| {
928 anyhow::anyhow!("could not connect to RDMA endpoint: {}", e)
929 })?;
930 Ok(())
931 }
932 None => Err(anyhow::anyhow!(
933 "No connection found for actor {}",
934 other_id
935 )),
936 }
937 } else {
938 Err(anyhow::anyhow!(
939 "No device map found for device {}",
940 self_device
941 ))
942 }
943 }
944
945 async fn initialize_qp(
946 &mut self,
947 _cx: &Context<Self>,
948 other: reference::ActorRef<IbvManagerActor>,
949 self_device: String,
950 other_device: String,
951 ) -> Result<bool, anyhow::Error> {
952 let other_id = other.actor_id().clone();
953 let inner_key = (other_id.clone(), other_device.clone());
954
955 if let Some(device_map) = self.device_qps.get(&self_device) {
957 if device_map.contains_key(&inner_key) {
958 return Ok(true);
959 }
960 }
961
962 let rdma_device = self
964 .pci_to_device
965 .iter()
966 .find(|(_, device)| device.name() == &self_device)
967 .map(|(_, device)| device.clone())
968 .unwrap_or_else(|| {
969 super::device_selection::resolve_ibv_device(&self.config.device)
971 .unwrap_or_else(|| self.config.device.clone())
972 });
973
974 let (domain_context, domain_pd) = {
976 let (domain, _) = self.get_or_create_device_domain(&self_device, &rdma_device)?;
978 (domain.context, domain.pd)
979 };
980
981 let qp = IbvQueuePair::new(domain_context, domain_pd, self.config.clone())
982 .map_err(|e| anyhow::anyhow!("could not create IbvQueuePair: {}", e))?;
983
984 self.device_qps
986 .entry(self_device.clone())
987 .or_insert_with(HashMap::new)
988 .insert(inner_key, qp);
989
990 tracing::debug!(
991 "successfully created a connection with {:?} for local device {} -> remote device {}",
992 other,
993 self_device,
994 other_device
995 );
996
997 Ok(true)
998 }
999
1000 async fn connection_info(
1001 &mut self,
1002 _cx: &Context<Self>,
1003 other: reference::ActorRef<IbvManagerActor>,
1004 self_device: String,
1005 other_device: String,
1006 ) -> Result<IbvQpInfo, anyhow::Error> {
1007 tracing::debug!("getting connection info with {:?}", other);
1008 let other_id = other.actor_id().clone();
1009
1010 let inner_key = (other_id.clone(), other_device.clone());
1011
1012 if let Some(device_map) = self.device_qps.get_mut(&self_device) {
1013 match device_map.get_mut(&inner_key) {
1014 Some(qp) => {
1015 let connection_info = qp.get_qp_info()?;
1016 Ok(connection_info)
1017 }
1018 None => Err(anyhow::anyhow!(
1019 "No connection found for actor {}",
1020 other_id
1021 )),
1022 }
1023 } else {
1024 Err(anyhow::anyhow!(
1025 "No device map found for self device {}",
1026 self_device
1027 ))
1028 }
1029 }
1030
1031 async fn release_queue_pair(
1032 &mut self,
1033 _cx: &Context<Self>,
1034 _other: reference::ActorRef<IbvManagerActor>,
1035 _self_device: String,
1036 _other_device: String,
1037 _qp: IbvQueuePair,
1038 ) -> Result<(), anyhow::Error> {
1039 Ok(())
1040 }
1041
1042 async fn get_qp_state(
1043 &mut self,
1044 _cx: &Context<Self>,
1045 other: reference::ActorRef<IbvManagerActor>,
1046 self_device: String,
1047 other_device: String,
1048 ) -> Result<u32, anyhow::Error> {
1049 let other_id = other.actor_id().clone();
1050 let inner_key = (other_id.clone(), other_device.clone());
1051
1052 if let Some(device_map) = self.device_qps.get_mut(&self_device) {
1053 match device_map.get_mut(&inner_key) {
1054 Some(qp) => qp.state(),
1055 None => Err(anyhow::anyhow!(
1056 "No connection found for actor {} on device {}",
1057 other_id,
1058 other_device
1059 )),
1060 }
1061 } else {
1062 Err(anyhow::anyhow!(
1063 "No device map found for self device {}",
1064 self_device
1065 ))
1066 }
1067 }
1068}
1069
1070#[async_trait]
1071#[hyperactor::handle(IbvSubmit)]
1072impl IbvSubmitHandler for IbvManagerActor {
1073 async fn ibv_submit(
1074 &mut self,
1075 cx: &Context<Self>,
1076 ops: Vec<IbvOp>,
1077 timeout: Duration,
1078 ) -> Result<Result<(), String>, anyhow::Error> {
1079 let deadline = Instant::now() + timeout;
1080 let mut result = Ok(());
1081 for op in ops {
1082 let remaining = deadline.saturating_duration_since(Instant::now());
1083 if remaining.is_zero() {
1084 result = Err("submit timed out".to_string());
1085 break;
1086 }
1087 if let Err(e) = self.execute_op(cx, op, remaining).await {
1088 result = Err(e.to_string());
1089 break;
1090 }
1091 }
1092 Ok(result)
1093 }
1094}
1095
1096#[async_trait]
1097impl RdmaBackend for ActorHandle<IbvManagerActor> {
1098 type TransportInfo = ();
1099
1100 async fn submit(
1107 &mut self,
1108 cx: &(impl hyperactor::context::Actor + Send + Sync),
1109 ops: Vec<RdmaOp>,
1110 timeout: Duration,
1111 ) -> Result<(), anyhow::Error> {
1112 let mut ibv_ops = Vec::with_capacity(ops.len());
1113 for op in ops {
1114 let (remote_ibv_mgr, remote_ibv_buffer) =
1115 op.remote.resolve_ibv(cx).await.ok_or_else(|| {
1116 anyhow::anyhow!("ibverbs backend not found for buffer: {:?}", op.remote)
1117 })??;
1118
1119 ibv_ops.push(IbvOp {
1120 op_type: op.op_type,
1121 local_memory: op.local.clone(),
1122 remote_buffer: remote_ibv_buffer,
1123 remote_manager: remote_ibv_mgr,
1124 });
1125 }
1126
1127 <Self as IbvSubmitClient>::ibv_submit(self, cx, ibv_ops, timeout)
1128 .await?
1129 .map_err(|e: String| anyhow::anyhow!(e))
1130 }
1131
1132 fn transport_level(&self) -> RdmaTransportLevel {
1133 RdmaTransportLevel::Nic
1134 }
1135
1136 fn transport_info(&self) -> Option<Self::TransportInfo> {
1137 None
1138 }
1139}