monarch_rdma/rdma_components.rs
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # RDMA Components
10//!
11//! This module provides the core RDMA building blocks for establishing and managing RDMA connections.
12//!
13//! ## Core Components
14//!
15//! * `RdmaDomain` - Manages RDMA resources including context, protection domain, and memory region
16//! * `RdmaQueuePair` - Handles communication between endpoints via queue pairs and completion queues
17//!
18//! ## RDMA Overview
19//!
20//! Remote Direct Memory Access (RDMA) allows direct memory access from the memory of one computer
21//! into the memory of another without involving either computer's operating system. This permits
22//! high-throughput, low-latency networking with minimal CPU overhead.
23//!
24//! ## Connection Architecture
25//!
26//! The module manages the following ibverbs primitives:
27//!
28//! 1. **Queue Pairs (QP)**: Each connection has a send queue and a receive queue
29//! 2. **Completion Queues (CQ)**: Events are reported when operations complete
30//! 3. **Memory Regions (MR)**: Memory must be registered with the RDMA device before use
31//! 4. **Protection Domains (PD)**: Provide isolation between different connections
32//!
33//! ## Connection Lifecycle
34//!
35//! 1. Create an `RdmaDomain` with `new()`
36//! 2. Create an `RdmaQueuePair` from the domain
37//! 3. Exchange connection info with remote peer (application must handle this)
38//! 4. Connect to remote endpoint with `connect()`
39//! 5. Perform RDMA operations (read/write)
40//! 6. Poll for completions
41//! 7. Resources are cleaned up when dropped
42
43/// Maximum size for a single RDMA operation in bytes (1 GiB)
44const MAX_RDMA_MSG_SIZE: usize = 1024 * 1024 * 1024;
45
46use std::ffi::CStr;
47use std::fs;
48use std::io::Error;
49use std::result::Result;
50use std::thread::sleep;
51use std::time::Duration;
52
53use hyperactor::ActorRef;
54use hyperactor::Named;
55use hyperactor::clock::Clock;
56use hyperactor::clock::RealClock;
57use hyperactor::context;
58use serde::Deserialize;
59use serde::Serialize;
60
61use crate::RdmaDevice;
62use crate::RdmaManagerActor;
63use crate::RdmaManagerMessageClient;
64use crate::ibverbs_primitives::Gid;
65use crate::ibverbs_primitives::IbvWc;
66use crate::ibverbs_primitives::IbverbsConfig;
67use crate::ibverbs_primitives::RdmaOperation;
68use crate::ibverbs_primitives::RdmaQpInfo;
69
70#[derive(Debug, Named, Clone, Serialize, Deserialize)]
71pub struct DoorBell {
72 pub src_ptr: usize,
73 pub dst_ptr: usize,
74 pub size: usize,
75}
76
77impl DoorBell {
78 /// Rings the doorbell to trigger the execution of previously enqueued operations.
79 ///
80 /// This method uses unsafe code to directly interact with the RDMA device,
81 /// sending a signal from the source pointer to the destination pointer.
82 ///
83 /// # Returns
84 /// * `Ok(())` if the operation is successful.
85 /// * `Err(anyhow::Error)` if an error occurs during the operation.
86 pub fn ring(&self) -> Result<(), anyhow::Error> {
87 unsafe {
88 let src_ptr = self.src_ptr as *mut std::ffi::c_void;
89 let dst_ptr = self.dst_ptr as *mut std::ffi::c_void;
90 rdmaxcel_sys::db_ring(dst_ptr, src_ptr);
91 Ok(())
92 }
93 }
94}
95
96#[derive(Debug, Serialize, Deserialize, Named, Clone)]
97pub struct RdmaBuffer {
98 pub owner: ActorRef<RdmaManagerActor>,
99 pub mr_id: usize,
100 pub lkey: u32,
101 pub rkey: u32,
102 pub addr: usize,
103 pub size: usize,
104}
105
106impl RdmaBuffer {
107 /// Read from the RdmaBuffer into the provided memory.
108 ///
109 /// This method transfers data from the buffer into the local memory region provided over RDMA.
110 /// This involves calling a `Put` operation on the RdmaBuffer's actor side.
111 ///
112 /// # Arguments
113 /// * `client` - The actor who is reading.
114 /// * `remote` - RdmaBuffer representing the remote memory region
115 /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
116 ///
117 /// # Returns
118 /// `Ok(bool)` indicating if the operation completed successfully.
119 pub async fn read_into(
120 &self,
121 client: &impl context::Actor,
122 remote: RdmaBuffer,
123 timeout: u64,
124 ) -> Result<bool, anyhow::Error> {
125 tracing::debug!(
126 "[buffer] reading from {:?} into remote ({:?}) at {:?}",
127 self,
128 remote.owner.actor_id(),
129 remote,
130 );
131 let remote_owner = remote.owner.clone(); // Clone before the move!
132 let mut qp = self
133 .owner
134 .request_queue_pair_deprecated(client, remote_owner.clone())
135 .await?;
136
137 qp.put(self.clone(), remote)?;
138 let result = self
139 .wait_for_completion(&mut qp, PollTarget::Send, timeout)
140 .await;
141
142 // Release the queue pair back to the actor
143 self.owner
144 .release_queue_pair_deprecated(client, remote_owner, qp)
145 .await?;
146
147 result
148 }
149
150 /// Write from the provided memory into the RdmaBuffer.
151 ///
152 /// This method performs an RDMA write operation, transferring data from the caller's
153 /// memory region to this buffer.
154 /// This involves calling a `Fetch` operation on the RdmaBuffer's actor side.
155 ///
156 /// # Arguments
157 /// * `client` - The actor who is writing.
158 /// * `remote` - RdmaBuffer representing the remote memory region
159 /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
160 ///
161 /// # Returns
162 /// `Ok(bool)` indicating if the operation completed successfully.
163 pub async fn write_from(
164 &self,
165 client: &impl context::Actor,
166 remote: RdmaBuffer,
167 timeout: u64,
168 ) -> Result<bool, anyhow::Error> {
169 tracing::debug!(
170 "[buffer] writing into {:?} from remote ({:?}) at {:?}",
171 self,
172 remote.owner.actor_id(),
173 remote,
174 );
175 let remote_owner = remote.owner.clone(); // Clone before the move!
176 let mut qp = self
177 .owner
178 .request_queue_pair_deprecated(client, remote_owner.clone())
179 .await?;
180 qp.get(self.clone(), remote)?;
181 let result = self
182 .wait_for_completion(&mut qp, PollTarget::Send, timeout)
183 .await;
184
185 // Release the queue pair back to the actor
186 self.owner
187 .release_queue_pair_deprecated(client, remote_owner, qp)
188 .await?;
189
190 result?;
191 Ok(true)
192 }
193 /// Waits for the completion of an RDMA operation.
194 ///
195 /// This method polls the completion queue until the specified work request completes
196 /// or until the timeout is reached.
197 ///
198 /// # Arguments
199 /// * `qp` - The RDMA Queue Pair to poll for completion
200 /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
201 ///
202 /// # Returns
203 /// `Ok(true)` if the operation completes successfully within the timeout,
204 /// or an error if the timeout is reached
205 async fn wait_for_completion(
206 &self,
207 qp: &mut RdmaQueuePair,
208 poll_target: PollTarget,
209 timeout: u64,
210 ) -> Result<bool, anyhow::Error> {
211 let timeout = Duration::from_secs(timeout);
212 let start_time = std::time::Instant::now();
213
214 while start_time.elapsed() < timeout {
215 match qp.poll_completion_target(poll_target) {
216 Ok(Some(_wc)) => {
217 tracing::debug!("work completed");
218 return Ok(true);
219 }
220 Ok(None) => {
221 RealClock.sleep(Duration::from_millis(1)).await;
222 }
223 Err(e) => {
224 return Err(anyhow::anyhow!(
225 "RDMA polling completion failed: {} [lkey={}, rkey={}, addr=0x{:x}, size={}]",
226 e,
227 self.lkey,
228 self.rkey,
229 self.addr,
230 self.size
231 ));
232 }
233 }
234 }
235 tracing::error!("timed out while waiting on request completion");
236 Err(anyhow::anyhow!(
237 "[buffer({:?})] rdma operation did not complete in time",
238 self
239 ))
240 }
241
242 /// Drop the buffer and release remote handles.
243 ///
244 /// This method calls the owning RdmaManagerActor to release the buffer and clean up
245 /// associated memory regions. This is typically called when the buffer is no longer
246 /// needed and resources should be freed.
247 ///
248 /// # Arguments
249 /// * `client` - Mailbox used for communication
250 ///
251 /// # Returns
252 /// `Ok(())` if the operation completed successfully.
253 pub async fn drop_buffer(&self, client: &impl context::Actor) -> Result<(), anyhow::Error> {
254 tracing::debug!("[buffer] dropping buffer {:?}", self);
255 self.owner
256 .release_buffer_deprecated(client, self.clone())
257 .await?;
258 Ok(())
259 }
260}
261
262/// Represents a domain for RDMA operations, encapsulating the necessary resources
263/// for establishing and managing RDMA connections.
264///
265/// An `RdmaDomain` manages the context, protection domain (PD), and memory region (MR)
266/// required for RDMA operations. It provides the foundation for creating queue pairs
267/// and establishing connections between RDMA devices.
268///
269/// # Fields
270///
271/// * `context`: A pointer to the RDMA device context, representing the connection to the RDMA device.
272/// * `pd`: A pointer to the protection domain, which provides isolation between different connections.
273pub struct RdmaDomain {
274 pub context: *mut rdmaxcel_sys::ibv_context,
275 pub pd: *mut rdmaxcel_sys::ibv_pd,
276}
277
278impl std::fmt::Debug for RdmaDomain {
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 f.debug_struct("RdmaDomain")
281 .field("context", &format!("{:p}", self.context))
282 .field("pd", &format!("{:p}", self.pd))
283 .finish()
284 }
285}
286
287// SAFETY:
288// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls.
289// RdmaDomain is `Send` because the raw pointers to ibverbs structs can be
290// accessed from any thread, and it is safe to drop `RdmaDomain` (and run the
291// ibverbs destructors) from any thread.
292unsafe impl Send for RdmaDomain {}
293
294// SAFETY:
295// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls.
296// RdmaDomain is `Sync` because the underlying ibverbs APIs are thread-safe.
297unsafe impl Sync for RdmaDomain {}
298
299impl Drop for RdmaDomain {
300 fn drop(&mut self) {
301 unsafe {
302 rdmaxcel_sys::ibv_dealloc_pd(self.pd);
303 }
304 }
305}
306
307impl RdmaDomain {
308 /// Creates a new RdmaDomain.
309 ///
310 /// This function initializes the RDMA device context, creates a protection domain,
311 /// and registers a memory region with appropriate access permissions.
312 ///
313 /// SAFETY:
314 /// Our memory region (MR) registration uses implicit ODP for RDMA access, which maps large virtual
315 /// address ranges without explicit pinning. This is convenient, but it broadens the memory footprint
316 /// exposed to the NIC and introduces a security liability.
317 ///
318 /// We currently assume a trusted, single-environment and are not enforcing finer-grained memory isolation
319 /// at this layer. We plan to investigate mitigations - such as memory windows or tighter registration
320 /// boundaries in future follow-ups.
321 ///
322 /// # Arguments
323 ///
324 /// * `config` - Configuration settings for the RDMA operations
325 ///
326 /// # Errors
327 ///
328 /// This function may return errors if:
329 /// * No RDMA devices are found
330 /// * The specified device cannot be found
331 /// * Device context creation fails
332 /// * Protection domain allocation fails
333 /// * Memory region registration fails
334 pub fn new(device: RdmaDevice) -> Result<Self, anyhow::Error> {
335 tracing::debug!("creating RdmaDomain for device {}", device.name());
336 // SAFETY:
337 // This code uses unsafe rdmaxcel_sys calls to interact with the RDMA device, but is safe because:
338 // - All pointers are properly initialized and checked for null before use
339 // - Memory registration follows the ibverbs API contract with proper access flags
340 // - Resources are properly cleaned up in error cases to prevent leaks
341 // - The operations follow the documented RDMA protocol for device initialization
342 unsafe {
343 // Get the device based on the provided RdmaDevice
344 let device_name = device.name();
345 let mut num_devices = 0i32;
346 let devices = rdmaxcel_sys::ibv_get_device_list(&mut num_devices as *mut _);
347
348 if devices.is_null() || num_devices == 0 {
349 return Err(anyhow::anyhow!("no RDMA devices found"));
350 }
351
352 // Find the device with the matching name
353 let mut device_ptr = std::ptr::null_mut();
354 for i in 0..num_devices {
355 let dev = *devices.offset(i as isize);
356 let dev_name =
357 CStr::from_ptr(rdmaxcel_sys::ibv_get_device_name(dev)).to_string_lossy();
358
359 if dev_name == *device_name {
360 device_ptr = dev;
361 break;
362 }
363 }
364
365 // If we didn't find the device, return an error
366 if device_ptr.is_null() {
367 rdmaxcel_sys::ibv_free_device_list(devices);
368 return Err(anyhow::anyhow!("device '{}' not found", device_name));
369 }
370 tracing::info!("using RDMA device: {}", device_name);
371
372 // Open device
373 let context = rdmaxcel_sys::ibv_open_device(device_ptr);
374 if context.is_null() {
375 rdmaxcel_sys::ibv_free_device_list(devices);
376 let os_error = Error::last_os_error();
377 return Err(anyhow::anyhow!("failed to create context: {}", os_error));
378 }
379
380 // Create protection domain
381 let pd = rdmaxcel_sys::ibv_alloc_pd(context);
382 if pd.is_null() {
383 rdmaxcel_sys::ibv_close_device(context);
384 rdmaxcel_sys::ibv_free_device_list(devices);
385 let os_error = Error::last_os_error();
386 return Err(anyhow::anyhow!(
387 "failed to create protection domain (PD): {}",
388 os_error
389 ));
390 }
391
392 // Avoids memory leaks
393 rdmaxcel_sys::ibv_free_device_list(devices);
394
395 let domain = RdmaDomain { context, pd };
396
397 Ok(domain)
398 }
399 }
400}
401/// Enum to specify which completion queue to poll
402#[derive(Debug, Clone, Copy, PartialEq)]
403pub enum PollTarget {
404 Send,
405 Recv,
406}
407
408/// Represents an RDMA Queue Pair (QP) that enables communication between two endpoints.
409///
410/// An `RdmaQueuePair` encapsulates the send and receive queues, completion queue,
411/// and other resources needed for RDMA communication. It provides methods for
412/// establishing connections and performing RDMA operations like read and write.
413///
414/// # Fields
415///
416/// * `send_cq` - Send Completion Queue pointer for tracking send operation completions
417/// * `recv_cq` - Receive Completion Queue pointer for tracking receive operation completions
418/// * `qp` - Queue Pair pointer that manages send and receive operations
419/// * `dv_qp` - Pointer to the mlx5 device-specific queue pair structure
420/// * `dv_send_cq` - Pointer to the mlx5 device-specific send completion queue structure
421/// * `dv_recv_cq` - Pointer to the mlx5 device-specific receive completion queue structure
422/// * `context` - RDMA device context pointer
423/// * `config` - Configuration settings for the queue pair
424///
425/// # Connection Lifecycle
426///
427/// 1. Create with `new()` from an `RdmaDomain`
428/// 2. Get connection info with `get_qp_info()`
429/// 3. Exchange connection info with remote peer (application must handle this)
430/// 4. Connect to remote endpoint with `connect()`
431/// 5. Perform RDMA operations with `put()` or `get()`
432/// 6. Poll for completions with `poll_send_completion()` or `poll_recv_completion()`
433#[derive(Debug, Serialize, Deserialize, Named, Clone)]
434pub struct RdmaQueuePair {
435 pub send_cq: usize, // *mut rdmaxcel_sys::ibv_cq,
436 pub recv_cq: usize, // *mut rdmaxcel_sys::ibv_cq,
437 pub qp: usize, // *mut rdmaxcel_sys::ibv_qp,
438 pub dv_qp: usize, // *mut rdmaxcel_sys::mlx5dv_qp,
439 pub dv_send_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
440 pub dv_recv_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
441 context: usize, // *mut rdmaxcel_sys::ibv_context,
442 config: IbverbsConfig,
443 pub send_wqe_idx: u64,
444 pub send_db_idx: u64,
445 pub send_cq_idx: u64,
446 pub recv_wqe_idx: u64,
447 pub recv_db_idx: u64,
448 pub recv_cq_idx: u64,
449 rts_timestamp: u64,
450}
451
452impl RdmaQueuePair {
453 /// Applies hardware initialization delay if this is the first operation since RTS.
454 ///
455 /// This ensures the hardware has sufficient time to settle after reaching
456 /// Ready-to-Send state before the first actual operation.
457 fn apply_first_op_delay(&self, wr_id: u64) {
458 if wr_id == 0 {
459 assert!(
460 self.rts_timestamp != u64::MAX,
461 "First operation attempted before queue pair reached RTS state! Call connect() first."
462 );
463 let current_nanos = RealClock
464 .system_time_now()
465 .duration_since(std::time::UNIX_EPOCH)
466 .unwrap()
467 .as_nanos() as u64;
468 let elapsed_nanos = current_nanos - self.rts_timestamp;
469 let elapsed = Duration::from_nanos(elapsed_nanos);
470 let init_delay = Duration::from_millis(self.config.hw_init_delay_ms);
471 if elapsed < init_delay {
472 let remaining_delay = init_delay - elapsed;
473 sleep(remaining_delay);
474 }
475 }
476 }
477
478 /// Creates a new RdmaQueuePair from a given RdmaDomain.
479 ///
480 /// This function initializes a new Queue Pair (QP) and associated Completion Queue (CQ)
481 /// using the resources from the provided RdmaDomain. The QP is created in the RESET state
482 /// and must be transitioned to other states via the `connect()` method before use.
483 ///
484 /// # Arguments
485 ///
486 /// * `domain` - Reference to an RdmaDomain that provides the context, protection domain,
487 /// and memory region for this queue pair
488 ///
489 /// # Returns
490 ///
491 /// * `Result<Self>` - A new RdmaQueuePair instance or an error if creation fails
492 ///
493 /// # Errors
494 ///
495 /// This function may return errors if:
496 /// * Completion queue (CQ) creation fails
497 /// * Queue pair (QP) creation fails
498 pub fn new(
499 context: *mut rdmaxcel_sys::ibv_context,
500 pd: *mut rdmaxcel_sys::ibv_pd,
501 config: IbverbsConfig,
502 ) -> Result<Self, anyhow::Error> {
503 tracing::debug!("creating an RdmaQueuePair from config {}", config);
504 unsafe {
505 // standard ibverbs QP
506 let qp = rdmaxcel_sys::create_qp(
507 context,
508 pd,
509 config.cq_entries,
510 config.max_send_wr.try_into().unwrap(),
511 config.max_recv_wr.try_into().unwrap(),
512 config.max_send_sge.try_into().unwrap(),
513 config.max_recv_sge.try_into().unwrap(),
514 );
515
516 if qp.is_null() {
517 let os_error = Error::last_os_error();
518 return Err(anyhow::anyhow!(
519 "failed to create queue pair (QP): {}",
520 os_error
521 ));
522 }
523
524 let send_cq = (*qp).send_cq;
525 let recv_cq = (*qp).recv_cq;
526
527 // mlx5dv provider APIs
528 let dv_qp = rdmaxcel_sys::create_mlx5dv_qp(qp);
529 let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq(qp);
530 let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq(qp);
531
532 if dv_qp.is_null() || dv_send_cq.is_null() || dv_recv_cq.is_null() {
533 rdmaxcel_sys::ibv_destroy_cq((*qp).recv_cq);
534 rdmaxcel_sys::ibv_destroy_cq((*qp).send_cq);
535 rdmaxcel_sys::ibv_destroy_qp(qp);
536 return Err(anyhow::anyhow!(
537 "failed to init mlx5dv_qp or completion queues"
538 ));
539 }
540
541 // GPU Direct RDMA specific registrations
542 if config.use_gpu_direct {
543 let ret = rdmaxcel_sys::register_cuda_memory(dv_qp, dv_recv_cq, dv_send_cq);
544 if ret != 0 {
545 rdmaxcel_sys::ibv_destroy_cq((*qp).recv_cq);
546 rdmaxcel_sys::ibv_destroy_cq((*qp).send_cq);
547 rdmaxcel_sys::ibv_destroy_qp(qp);
548 return Err(anyhow::anyhow!(
549 "failed to register GPU Direct RDMA memory: {:?}",
550 ret
551 ));
552 }
553 }
554 Ok(RdmaQueuePair {
555 send_cq: send_cq as usize,
556 recv_cq: recv_cq as usize,
557 qp: qp as usize,
558 dv_qp: dv_qp as usize,
559 dv_send_cq: dv_send_cq as usize,
560 dv_recv_cq: dv_recv_cq as usize,
561 context: context as usize,
562 config,
563 recv_db_idx: 0,
564 recv_wqe_idx: 0,
565 recv_cq_idx: 0,
566 send_db_idx: 0,
567 send_wqe_idx: 0,
568 send_cq_idx: 0,
569 rts_timestamp: u64::MAX,
570 })
571 }
572 }
573
574 /// Returns the information required for a remote peer to connect to this queue pair.
575 ///
576 /// This method retrieves the local queue pair attributes and port information needed by
577 /// a remote peer to establish an RDMA connection. The returned `RdmaQpInfo` contains
578 /// the queue pair number, LID, GID, and other necessary connection parameters.
579 ///
580 /// # Returns
581 ///
582 /// * `Result<RdmaQpInfo>` - Connection information for the remote peer or an error
583 ///
584 /// # Errors
585 ///
586 /// This function may return errors if:
587 /// * Port attribute query fails
588 /// * GID query fails
589 pub fn get_qp_info(&mut self) -> Result<RdmaQpInfo, anyhow::Error> {
590 // SAFETY:
591 // This code uses unsafe rdmaxcel_sys calls to query RDMA device information, but is safe because:
592 // - All pointers are properly initialized before use
593 // - Port and GID queries follow the documented ibverbs API contract
594 // - Error handling properly checks return codes from ibverbs functions
595 // - The memory address provided is only stored, not dereferenced in this function
596 unsafe {
597 let context = self.context as *mut rdmaxcel_sys::ibv_context;
598 let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
599 let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
600 let errno = rdmaxcel_sys::ibv_query_port(
601 context,
602 self.config.port_num,
603 &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
604 );
605 if errno != 0 {
606 let os_error = Error::last_os_error();
607 return Err(anyhow::anyhow!(
608 "Failed to query port attributes: {}",
609 os_error
610 ));
611 }
612
613 let mut gid = Gid::default();
614 let ret = rdmaxcel_sys::ibv_query_gid(
615 context,
616 self.config.port_num,
617 i32::from(self.config.gid_index),
618 gid.as_mut(),
619 );
620 if ret != 0 {
621 return Err(anyhow::anyhow!("Failed to query GID"));
622 }
623
624 Ok(RdmaQpInfo {
625 qp_num: (*qp).qp_num,
626 lid: port_attr.lid,
627 gid: Some(gid),
628 psn: self.config.psn,
629 })
630 }
631 }
632
633 pub fn state(&mut self) -> Result<u32, anyhow::Error> {
634 // SAFETY: This block interacts with the RDMA device through rdmaxcel_sys calls.
635 unsafe {
636 let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
637 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
638 ..Default::default()
639 };
640 let mut qp_init_attr = rdmaxcel_sys::ibv_qp_init_attr {
641 ..Default::default()
642 };
643 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE;
644 let errno =
645 rdmaxcel_sys::ibv_query_qp(qp, &mut qp_attr, mask.0 as i32, &mut qp_init_attr);
646 if errno != 0 {
647 let os_error = Error::last_os_error();
648 return Err(anyhow::anyhow!("failed to query QP state: {}", os_error));
649 }
650 Ok(qp_attr.qp_state)
651 }
652 }
653 /// Connect to a remote Rdma connection point.
654 ///
655 /// This performs the necessary QP state transitions (INIT->RTR->RTS) to establish a connection.
656 ///
657 /// # Arguments
658 ///
659 /// * `connection_info` - The remote connection info to connect to
660 pub fn connect(&mut self, connection_info: &RdmaQpInfo) -> Result<(), anyhow::Error> {
661 // SAFETY:
662 // This unsafe block is necessary because we're interacting with the RDMA device through rdmaxcel_sys calls.
663 // The operations are safe because:
664 // 1. We're following the documented ibverbs API contract
665 // 2. All pointers used are properly initialized and owned by this struct
666 // 3. The QP state transitions (INIT->RTR->RTS) follow the required RDMA connection protocol
667 // 4. Memory access is properly bounded by the registered memory regions
668 unsafe {
669 // Transition to INIT
670 let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
671
672 let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
673 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
674 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
675 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
676
677 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
678 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_INIT,
679 qp_access_flags: qp_access_flags.0,
680 pkey_index: self.config.pkey_index,
681 port_num: self.config.port_num,
682 ..Default::default()
683 };
684
685 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
686 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX
687 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT
688 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS;
689
690 let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32);
691 if errno != 0 {
692 let os_error = Error::last_os_error();
693 return Err(anyhow::anyhow!(
694 "failed to transition QP to INIT: {}",
695 os_error
696 ));
697 }
698
699 // Transition to RTR (Ready to Receive)
700 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
701 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTR,
702 path_mtu: self.config.path_mtu,
703 dest_qp_num: connection_info.qp_num,
704 rq_psn: connection_info.psn,
705 max_dest_rd_atomic: self.config.max_dest_rd_atomic,
706 min_rnr_timer: self.config.min_rnr_timer,
707 ah_attr: rdmaxcel_sys::ibv_ah_attr {
708 dlid: connection_info.lid,
709 sl: 0,
710 src_path_bits: 0,
711 port_num: self.config.port_num,
712 grh: Default::default(),
713 ..Default::default()
714 },
715 ..Default::default()
716 };
717
718 // If the remote connection info contains a Gid, the routing will be global.
719 // Otherwise, it will be local, i.e. using LID.
720 if let Some(gid) = connection_info.gid {
721 qp_attr.ah_attr.is_global = 1;
722 qp_attr.ah_attr.grh.dgid = gid.into();
723 qp_attr.ah_attr.grh.hop_limit = 0xff;
724 qp_attr.ah_attr.grh.sgid_index = self.config.gid_index;
725 } else {
726 // Use LID-based routing, e.g. for Infiniband/RoCEv1
727 qp_attr.ah_attr.is_global = 0;
728 }
729
730 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
731 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_AV
732 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PATH_MTU
733 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_DEST_QPN
734 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RQ_PSN
735 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC
736 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER;
737
738 let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32);
739 if errno != 0 {
740 let os_error = Error::last_os_error();
741 return Err(anyhow::anyhow!(
742 "failed to transition QP to RTR: {}",
743 os_error
744 ));
745 }
746
747 // Transition to RTS (Ready to Send)
748 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
749 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS,
750 sq_psn: self.config.psn,
751 max_rd_atomic: self.config.max_rd_atomic,
752 retry_cnt: self.config.retry_cnt,
753 rnr_retry: self.config.rnr_retry,
754 timeout: self.config.qp_timeout,
755 ..Default::default()
756 };
757
758 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
759 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_TIMEOUT
760 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RETRY_CNT
761 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_SQ_PSN
762 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY
763 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC;
764
765 let errno = rdmaxcel_sys::ibv_modify_qp(qp, &mut qp_attr, mask.0 as i32);
766 if errno != 0 {
767 let os_error = Error::last_os_error();
768 return Err(anyhow::anyhow!(
769 "failed to transition QP to RTS: {}",
770 os_error
771 ));
772 }
773 tracing::debug!(
774 "connection sequence has successfully completed (qp: {:?})",
775 qp
776 );
777
778 // Record RTS time now that the queue pair is ready to send
779 self.rts_timestamp = RealClock
780 .system_time_now()
781 .duration_since(std::time::UNIX_EPOCH)
782 .unwrap()
783 .as_nanos() as u64;
784
785 Ok(())
786 }
787 }
788
789 pub fn recv(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> {
790 let idx = self.recv_wqe_idx;
791 self.recv_wqe_idx += 1;
792 self.send_wqe(
793 0,
794 lhandle.lkey,
795 0,
796 idx,
797 true,
798 RdmaOperation::Recv,
799 0,
800 rhandle.rkey,
801 )
802 .unwrap();
803 Ok(())
804 }
805
806 pub fn put_with_recv(
807 &mut self,
808 lhandle: RdmaBuffer,
809 rhandle: RdmaBuffer,
810 ) -> Result<(), anyhow::Error> {
811 let idx = self.send_wqe_idx;
812 self.send_wqe_idx += 1;
813 self.post_op(
814 lhandle.addr,
815 lhandle.lkey,
816 lhandle.size,
817 idx,
818 true,
819 RdmaOperation::WriteWithImm,
820 rhandle.addr,
821 rhandle.rkey,
822 )
823 .unwrap();
824 self.send_db_idx += 1;
825 Ok(())
826 }
827
828 pub fn put(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> {
829 let total_size = lhandle.size;
830 if rhandle.size < total_size {
831 return Err(anyhow::anyhow!(
832 "Remote buffer size ({}) is smaller than local buffer size ({})",
833 rhandle.size,
834 total_size
835 ));
836 }
837
838 let mut remaining = total_size;
839 let mut offset = 0;
840 while remaining > 0 {
841 let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
842 let idx = self.send_wqe_idx;
843 self.send_wqe_idx += 1;
844 self.post_op(
845 lhandle.addr + offset,
846 lhandle.lkey,
847 chunk_size,
848 idx,
849 true,
850 RdmaOperation::Write,
851 rhandle.addr + offset,
852 rhandle.rkey,
853 )?;
854 self.send_db_idx += 1;
855
856 remaining -= chunk_size;
857 offset += chunk_size;
858 }
859
860 Ok(())
861 }
862
863 /// Get a doorbell for the queue pair.
864 ///
865 /// This method returns a doorbell that can be used to trigger the execution of
866 /// previously enqueued operations.
867 ///
868 /// # Returns
869 ///
870 /// * `Result<DoorBell, anyhow::Error>` - A doorbell for the queue pair
871 pub fn ring_doorbell(&mut self) -> Result<(), anyhow::Error> {
872 unsafe {
873 let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
874 let base_ptr = (*dv_qp).sq.buf as *mut u8;
875 let wqe_cnt = (*dv_qp).sq.wqe_cnt;
876 let stride = (*dv_qp).sq.stride;
877 if (wqe_cnt as u64) < (self.send_wqe_idx - self.send_db_idx) {
878 return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
879 }
880 while self.send_db_idx < self.send_wqe_idx {
881 let offset = (self.send_db_idx % wqe_cnt as u64) * stride as u64;
882 let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
883 rdmaxcel_sys::db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
884 self.send_db_idx += 1;
885 }
886 Ok(())
887 }
888 }
889
890 /// Enqueues a put operation without ringing the doorbell.
891 ///
892 /// This method prepares a put operation but does not execute it.
893 /// Use `get_doorbell().ring()` to execute the operation.
894 ///
895 /// # Arguments
896 ///
897 /// * `lhandle` - Local buffer handle
898 /// * `rhandle` - Remote buffer handle
899 ///
900 /// # Returns
901 ///
902 /// * `Result<(), anyhow::Error>` - Success or error
903 pub fn enqueue_put(
904 &mut self,
905 lhandle: RdmaBuffer,
906 rhandle: RdmaBuffer,
907 ) -> Result<(), anyhow::Error> {
908 let idx = self.send_wqe_idx;
909 self.send_wqe_idx += 1;
910 self.send_wqe(
911 lhandle.addr,
912 lhandle.lkey,
913 lhandle.size,
914 idx,
915 true,
916 RdmaOperation::Write,
917 rhandle.addr,
918 rhandle.rkey,
919 )?;
920 Ok(())
921 }
922
923 /// Enqueues a put with receive operation without ringing the doorbell.
924 ///
925 /// This method prepares a put with receive operation but does not execute it.
926 /// Use `get_doorbell().ring()` to execute the operation.
927 ///
928 /// # Arguments
929 ///
930 /// * `lhandle` - Local buffer handle
931 /// * `rhandle` - Remote buffer handle
932 ///
933 /// # Returns
934 ///
935 /// * `Result<(), anyhow::Error>` - Success or error
936 pub fn enqueue_put_with_recv(
937 &mut self,
938 lhandle: RdmaBuffer,
939 rhandle: RdmaBuffer,
940 ) -> Result<(), anyhow::Error> {
941 let idx = self.send_wqe_idx;
942 self.send_wqe_idx += 1;
943 self.send_wqe(
944 lhandle.addr,
945 lhandle.lkey,
946 lhandle.size,
947 idx,
948 true,
949 RdmaOperation::WriteWithImm,
950 rhandle.addr,
951 rhandle.rkey,
952 )?;
953 Ok(())
954 }
955
956 /// Enqueues a get operation without ringing the doorbell.
957 ///
958 /// This method prepares a get operation but does not execute it.
959 /// Use `get_doorbell().ring()` to execute the operation.
960 ///
961 /// # Arguments
962 ///
963 /// * `lhandle` - Local buffer handle
964 /// * `rhandle` - Remote buffer handle
965 ///
966 /// # Returns
967 ///
968 /// * `Result<(), anyhow::Error>` - Success or error
969 pub fn enqueue_get(
970 &mut self,
971 lhandle: RdmaBuffer,
972 rhandle: RdmaBuffer,
973 ) -> Result<(), anyhow::Error> {
974 let idx = self.send_wqe_idx;
975 self.send_wqe_idx += 1;
976 self.send_wqe(
977 lhandle.addr,
978 lhandle.lkey,
979 lhandle.size,
980 idx,
981 true,
982 RdmaOperation::Read,
983 rhandle.addr,
984 rhandle.rkey,
985 )?;
986 Ok(())
987 }
988
989 pub fn get(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<(), anyhow::Error> {
990 let total_size = lhandle.size;
991 if rhandle.size < total_size {
992 return Err(anyhow::anyhow!(
993 "Remote buffer size ({}) is smaller than local buffer size ({})",
994 rhandle.size,
995 total_size
996 ));
997 }
998
999 let mut remaining = total_size;
1000 let mut offset = 0;
1001
1002 while remaining > 0 {
1003 let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
1004 let idx = self.send_wqe_idx;
1005 self.send_wqe_idx += 1;
1006 self.post_op(
1007 lhandle.addr + offset,
1008 lhandle.lkey,
1009 chunk_size,
1010 idx,
1011 true,
1012 RdmaOperation::Read,
1013 rhandle.addr + offset,
1014 rhandle.rkey,
1015 )?;
1016 self.send_db_idx += 1;
1017
1018 remaining -= chunk_size;
1019 offset += chunk_size;
1020 }
1021
1022 Ok(())
1023 }
1024
1025 /// Posts a request to the queue pair.
1026 ///
1027 /// # Arguments
1028 ///
1029 /// * `local_addr` - The local address containing data to send
1030 /// * `length` - Length of the data to send
1031 /// * `wr_id` - Work request ID for completion identification
1032 /// * `signaled` - Whether to generate a completion event
1033 /// * `op_type` - Optional operation type
1034 /// * `raddr` - the remote address, representing the memory location on the remote peer
1035 /// * `rkey` - the remote key, representing the key required to access the remote memory region
1036 fn post_op(
1037 &mut self,
1038 laddr: usize,
1039 lkey: u32,
1040 length: usize,
1041 wr_id: u64,
1042 signaled: bool,
1043 op_type: RdmaOperation,
1044 raddr: usize,
1045 rkey: u32,
1046 ) -> Result<(), anyhow::Error> {
1047 // SAFETY:
1048 // This code uses unsafe rdmaxcel_sys calls to post work requests to the RDMA device, but is safe because:
1049 // - All pointers (send_sge, send_wr) are properly initialized on the stack before use
1050 // - The memory address in `local_addr` is not dereferenced, only passed to the device
1051 // - The remote connection info is verified to exist before accessing
1052 // - The ibverbs post_send operation follows the documented API contract
1053 // - Error codes from the device are properly checked and propagated
1054 unsafe {
1055 let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
1056 let context = self.context as *mut rdmaxcel_sys::ibv_context;
1057 let ops = &mut (*context).ops;
1058 let errno;
1059 if op_type == RdmaOperation::Recv {
1060 let mut sge = rdmaxcel_sys::ibv_sge {
1061 addr: laddr as u64,
1062 length: length as u32,
1063 lkey,
1064 };
1065 let mut wr = rdmaxcel_sys::ibv_recv_wr {
1066 wr_id: wr_id.into(),
1067 sg_list: &mut sge as *mut _,
1068 num_sge: 1,
1069 ..Default::default()
1070 };
1071 let mut bad_wr: *mut rdmaxcel_sys::ibv_recv_wr = std::ptr::null_mut();
1072 errno = ops.post_recv.as_mut().unwrap()(qp, &mut wr as *mut _, &mut bad_wr);
1073 } else if op_type == RdmaOperation::Write
1074 || op_type == RdmaOperation::Read
1075 || op_type == RdmaOperation::WriteWithImm
1076 {
1077 // Apply hardware initialization delay if this is the first operation
1078 self.apply_first_op_delay(wr_id);
1079 let send_flags = if signaled {
1080 rdmaxcel_sys::ibv_send_flags::IBV_SEND_SIGNALED.0
1081 } else {
1082 0
1083 };
1084 let mut sge = rdmaxcel_sys::ibv_sge {
1085 addr: laddr as u64,
1086 length: length as u32,
1087 lkey,
1088 };
1089 let mut wr = rdmaxcel_sys::ibv_send_wr {
1090 wr_id: wr_id.into(),
1091 next: std::ptr::null_mut(),
1092 sg_list: &mut sge as *mut _,
1093 num_sge: 1,
1094 opcode: op_type.into(),
1095 send_flags,
1096 wr: Default::default(),
1097 qp_type: Default::default(),
1098 __bindgen_anon_1: Default::default(),
1099 __bindgen_anon_2: Default::default(),
1100 };
1101
1102 wr.wr.rdma.remote_addr = raddr as u64;
1103 wr.wr.rdma.rkey = rkey;
1104 let mut bad_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut();
1105
1106 errno = ops.post_send.as_mut().unwrap()(qp, &mut wr as *mut _, &mut bad_wr);
1107 } else {
1108 panic!("Not Implemented");
1109 }
1110
1111 if errno != 0 {
1112 let os_error = Error::last_os_error();
1113 return Err(anyhow::anyhow!("Failed to post send request: {}", os_error));
1114 }
1115 tracing::debug!(
1116 "completed sending {:?} request (lkey: {}, addr: 0x{:x}, length {}) to (raddr 0x{:x}, rkey {})",
1117 op_type,
1118 lkey,
1119 laddr,
1120 length,
1121 raddr,
1122 rkey,
1123 );
1124
1125 Ok(())
1126 }
1127 }
1128
1129 fn send_wqe(
1130 &mut self,
1131 laddr: usize,
1132 lkey: u32,
1133 length: usize,
1134 wr_id: u64,
1135 signaled: bool,
1136 op_type: RdmaOperation,
1137 raddr: usize,
1138 rkey: u32,
1139 ) -> Result<DoorBell, anyhow::Error> {
1140 unsafe {
1141 let op_type_val = match op_type {
1142 RdmaOperation::Write => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE,
1143 RdmaOperation::WriteWithImm => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE_IMM,
1144 RdmaOperation::Read => rdmaxcel_sys::MLX5_OPCODE_RDMA_READ,
1145 RdmaOperation::Recv => 0,
1146 };
1147
1148 let qp = self.qp as *mut rdmaxcel_sys::ibv_qp;
1149 let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
1150 let _dv_cq = if op_type == RdmaOperation::Recv {
1151 self.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq
1152 } else {
1153 self.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq
1154 };
1155
1156 // Create the WQE parameters struct
1157
1158 let buf = if op_type == RdmaOperation::Recv {
1159 (*dv_qp).rq.buf as *mut u8
1160 } else {
1161 (*dv_qp).sq.buf as *mut u8
1162 };
1163
1164 let params = rdmaxcel_sys::wqe_params_t {
1165 laddr,
1166 lkey,
1167 length,
1168 wr_id: wr_id.into(),
1169 signaled,
1170 op_type: op_type_val,
1171 raddr,
1172 rkey,
1173 qp_num: (*qp).qp_num,
1174 buf,
1175 dbrec: (*dv_qp).dbrec,
1176 wqe_cnt: (*dv_qp).sq.wqe_cnt,
1177 };
1178
1179 // Call the C function to post the WQE
1180 if op_type == RdmaOperation::Recv {
1181 rdmaxcel_sys::recv_wqe(params);
1182 std::ptr::write_volatile((*dv_qp).dbrec, 1_u32.to_be());
1183 } else {
1184 rdmaxcel_sys::send_wqe(params);
1185 };
1186
1187 // Create and return a DoorBell struct
1188 Ok(DoorBell {
1189 dst_ptr: (*dv_qp).bf.reg as usize,
1190 src_ptr: (*dv_qp).sq.buf as usize,
1191 size: 8,
1192 })
1193 }
1194 }
1195
1196 /// Poll for completions on the specified completion queue(s)
1197 ///
1198 /// # Arguments
1199 ///
1200 /// * `target` - Which completion queue(s) to poll (Send, Receive)
1201 ///
1202 /// # Returns
1203 ///
1204 /// * `Ok(Some(wc))` - A completion was found
1205 /// * `Ok(None)` - No completion was found
1206 /// * `Err(e)` - An error occurred
1207 pub fn poll_completion_target(
1208 &mut self,
1209 target: PollTarget,
1210 ) -> Result<Option<IbvWc>, anyhow::Error> {
1211 unsafe {
1212 let context = self.context as *mut rdmaxcel_sys::ibv_context;
1213 let _outstanding_wqe =
1214 self.send_db_idx + self.recv_db_idx - self.send_cq_idx - self.recv_cq_idx;
1215
1216 // Check for send completions if requested
1217 if (target == PollTarget::Send) && self.send_db_idx > self.send_cq_idx {
1218 let send_cq = self.send_cq as *mut rdmaxcel_sys::ibv_cq;
1219 let ops = &mut (*context).ops;
1220 let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
1221 let ret = ops.poll_cq.as_mut().unwrap()(send_cq, 1, &mut wc);
1222
1223 if ret < 0 {
1224 return Err(anyhow::anyhow!(
1225 "Failed to poll send CQ: {}",
1226 Error::last_os_error()
1227 ));
1228 }
1229
1230 if ret > 0 {
1231 if !wc.is_valid() {
1232 if let Some((status, vendor_err)) = wc.error() {
1233 return Err(anyhow::anyhow!(
1234 "Send work completion failed with status: {:?}, vendor error: {}, wr_id: {}, send_cq_idx: {}",
1235 status,
1236 vendor_err,
1237 wc.wr_id(),
1238 self.send_cq_idx,
1239 ));
1240 }
1241 }
1242
1243 // This should be a send completion - verify it's the one we're waiting for
1244 if wc.wr_id() == self.send_cq_idx {
1245 self.send_cq_idx += 1;
1246 }
1247 // finished polling, return the last completion
1248 if self.send_cq_idx == self.send_db_idx {
1249 return Ok(Some(IbvWc::from(wc)));
1250 }
1251 }
1252 }
1253
1254 // Check for receive completions if requested
1255 if (target == PollTarget::Recv) && self.recv_db_idx > self.recv_cq_idx {
1256 let recv_cq = self.recv_cq as *mut rdmaxcel_sys::ibv_cq;
1257 let ops = &mut (*context).ops;
1258 let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
1259 let ret = ops.poll_cq.as_mut().unwrap()(recv_cq, 1, &mut wc);
1260
1261 if ret < 0 {
1262 return Err(anyhow::anyhow!(
1263 "Failed to poll receive CQ: {}",
1264 Error::last_os_error()
1265 ));
1266 }
1267
1268 if ret > 0 {
1269 if !wc.is_valid() {
1270 if let Some((status, vendor_err)) = wc.error() {
1271 return Err(anyhow::anyhow!(
1272 "Recv work completion failed with status: {:?}, vendor error: {}, wr_id: {}, send_cq_idx: {}",
1273 status,
1274 vendor_err,
1275 wc.wr_id(),
1276 self.recv_cq_idx,
1277 ));
1278 }
1279 }
1280
1281 // This should be a send completion - verify it's the one we're waiting for
1282 if wc.wr_id() == self.recv_cq_idx {
1283 self.recv_cq_idx += 1;
1284 }
1285 // finished polling, return the last completion
1286 if self.recv_cq_idx == self.recv_db_idx {
1287 return Ok(Some(IbvWc::from(wc)));
1288 }
1289 }
1290 }
1291
1292 // No completion found
1293 Ok(None)
1294 }
1295 }
1296
1297 pub fn poll_send_completion(&mut self) -> Result<Option<IbvWc>, anyhow::Error> {
1298 self.poll_completion_target(PollTarget::Send)
1299 }
1300
1301 pub fn poll_recv_completion(&mut self) -> Result<Option<IbvWc>, anyhow::Error> {
1302 self.poll_completion_target(PollTarget::Recv)
1303 }
1304}
1305
1306/// Utility to validate execution context.
1307///
1308/// Remote Execution environments do not always have access to the nvidia_peermem module
1309/// and/or set the PeerMappingOverride parameter due to security. This function can be
1310/// used to validate that the execution context when running operations that need this
1311/// functionality (ie. cudaHostRegisterIoMemory).
1312///
1313/// # Returns
1314///
1315/// * `Ok(())` if the execution context is valid
1316/// * `Err(anyhow::Error)` if the execution context is invalid
1317pub async fn validate_execution_context() -> Result<(), anyhow::Error> {
1318 // Check for nvidia peermem
1319 match fs::read_to_string("/proc/modules") {
1320 Ok(contents) => {
1321 if !contents.contains("nvidia_peermem") {
1322 return Err(anyhow::anyhow!(
1323 "nvidia_peermem module not found in /proc/modules"
1324 ));
1325 }
1326 }
1327 Err(e) => {
1328 return Err(anyhow::anyhow!(e));
1329 }
1330 }
1331
1332 // Test file access to nvidia params
1333 match fs::read_to_string("/proc/driver/nvidia/params") {
1334 Ok(contents) => {
1335 if !contents.contains("PeerMappingOverride=1") {
1336 return Err(anyhow::anyhow!(
1337 "PeerMappingOverride=1 not found in /proc/driver/nvidia/params"
1338 ));
1339 }
1340 }
1341 Err(e) => {
1342 return Err(anyhow::anyhow!(e));
1343 }
1344 }
1345 Ok(())
1346}
1347
1348/// Get all segments that have been registered with MRs
1349///
1350/// # Returns
1351/// * `Vec<SegmentInfo>` - Vector containing all registered segment information
1352pub fn get_registered_cuda_segments() -> Vec<rdmaxcel_sys::rdma_segment_info_t> {
1353 unsafe {
1354 let segment_count = rdmaxcel_sys::rdma_get_active_segment_count();
1355 if segment_count <= 0 {
1356 return Vec::new();
1357 }
1358
1359 let mut segments = vec![
1360 std::mem::MaybeUninit::<rdmaxcel_sys::rdma_segment_info_t>::zeroed()
1361 .assume_init();
1362 segment_count as usize
1363 ];
1364 let actual_count =
1365 rdmaxcel_sys::rdma_get_all_segment_info(segments.as_mut_ptr(), segment_count);
1366
1367 if actual_count > 0 {
1368 segments.truncate(actual_count as usize);
1369 segments
1370 } else {
1371 Vec::new()
1372 }
1373 }
1374}
1375
1376/// Check if PyTorch CUDA caching allocator has expandable segments enabled.
1377///
1378/// This function calls the C++ implementation that directly accesses the
1379/// PyTorch C10 CUDA allocator configuration to check if expandable segments
1380/// are enabled, which is required for RDMA operations with CUDA tensors.
1381///
1382/// # Returns
1383///
1384/// `true` if both CUDA caching allocator is enabled AND expandable segments are enabled,
1385/// `false` otherwise.
1386pub fn pt_cuda_allocator_compatibility() -> bool {
1387 // SAFETY: We are calling a C++ function from rdmaxcel that accesses PyTorch C10 APIs.
1388 unsafe { rdmaxcel_sys::pt_cuda_allocator_compatibility() }
1389}
1390
1391#[cfg(test)]
1392mod tests {
1393 use super::*;
1394
1395 #[test]
1396 fn test_create_connection() {
1397 // Skip test if RDMA devices are not available
1398 if crate::ibverbs_primitives::get_all_devices().is_empty() {
1399 println!("Skipping test: RDMA devices not available");
1400 return;
1401 }
1402
1403 let config = IbverbsConfig {
1404 use_gpu_direct: false,
1405 ..Default::default()
1406 };
1407 let domain = RdmaDomain::new(config.device.clone());
1408 assert!(domain.is_ok());
1409
1410 let domain = domain.unwrap();
1411 let queue_pair = RdmaQueuePair::new(domain.context, domain.pd, config.clone());
1412 assert!(queue_pair.is_ok());
1413 }
1414
1415 #[test]
1416 fn test_loopback_connection() {
1417 // Skip test if RDMA devices are not available
1418 if crate::ibverbs_primitives::get_all_devices().is_empty() {
1419 println!("Skipping test: RDMA devices not available");
1420 return;
1421 }
1422
1423 let server_config = IbverbsConfig {
1424 use_gpu_direct: false,
1425 ..Default::default()
1426 };
1427 let client_config = IbverbsConfig {
1428 use_gpu_direct: false,
1429 ..Default::default()
1430 };
1431
1432 let server_domain = RdmaDomain::new(server_config.device.clone()).unwrap();
1433 let client_domain = RdmaDomain::new(client_config.device.clone()).unwrap();
1434
1435 let mut server_qp = RdmaQueuePair::new(
1436 server_domain.context,
1437 server_domain.pd,
1438 server_config.clone(),
1439 )
1440 .unwrap();
1441 let mut client_qp = RdmaQueuePair::new(
1442 client_domain.context,
1443 client_domain.pd,
1444 client_config.clone(),
1445 )
1446 .unwrap();
1447
1448 let server_connection_info = server_qp.get_qp_info().unwrap();
1449 let client_connection_info = client_qp.get_qp_info().unwrap();
1450
1451 assert!(server_qp.connect(&client_connection_info).is_ok());
1452 assert!(client_qp.connect(&server_connection_info).is_ok());
1453 }
1454}