monarch_rdma/rdma_components.rs
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # RDMA Components
10//!
11//! This module provides the core RDMA building blocks for establishing and managing RDMA connections.
12//!
13//! ## Core Components
14//!
15//! * `RdmaDomain` - Manages RDMA resources including context, protection domain, and memory region
16//! * `RdmaQueuePair` - Handles communication between endpoints via queue pairs and completion queues
17//!
18//! ## RDMA Overview
19//!
20//! Remote Direct Memory Access (RDMA) allows direct memory access from the memory of one computer
21//! into the memory of another without involving either computer's operating system. This permits
22//! high-throughput, low-latency networking with minimal CPU overhead.
23//!
24//! ## Connection Architecture
25//!
26//! The module manages the following ibverbs primitives:
27//!
28//! 1. **Queue Pairs (QP)**: Each connection has a send queue and a receive queue
29//! 2. **Completion Queues (CQ)**: Events are reported when operations complete
30//! 3. **Memory Regions (MR)**: Memory must be registered with the RDMA device before use
31//! 4. **Protection Domains (PD)**: Provide isolation between different connections
32//!
33//! ## Connection Lifecycle
34//!
35//! 1. Create an `RdmaDomain` with `new()`
36//! 2. Create an `RdmaQueuePair` from the domain
37//! 3. Exchange connection info with remote peer (application must handle this)
38//! 4. Connect to remote endpoint with `connect()`
39//! 5. Perform RDMA operations (read/write)
40//! 6. Poll for completions
41//! 7. Resources are cleaned up when dropped
42
43/// Maximum size for a single RDMA operation in bytes (1 GiB)
44const MAX_RDMA_MSG_SIZE: usize = 1024 * 1024 * 1024;
45
46use std::ffi::CStr;
47use std::fs;
48use std::io::Error;
49use std::result::Result;
50use std::thread::sleep;
51use std::time::Duration;
52
53use hyperactor::ActorRef;
54use hyperactor::clock::Clock;
55use hyperactor::clock::RealClock;
56use hyperactor::context;
57use serde::Deserialize;
58use serde::Serialize;
59use typeuri::Named;
60
61use crate::RdmaDevice;
62use crate::RdmaManagerActor;
63use crate::RdmaManagerMessageClient;
64use crate::ibverbs_primitives::Gid;
65use crate::ibverbs_primitives::IbvWc;
66use crate::ibverbs_primitives::IbverbsConfig;
67use crate::ibverbs_primitives::RdmaOperation;
68use crate::ibverbs_primitives::RdmaQpInfo;
69use crate::ibverbs_primitives::resolve_qp_type;
70
71#[derive(Debug, Named, Clone, Serialize, Deserialize)]
72pub struct DoorBell {
73 pub src_ptr: usize,
74 pub dst_ptr: usize,
75 pub size: usize,
76}
77wirevalue::register_type!(DoorBell);
78
79impl DoorBell {
80 /// Rings the doorbell to trigger the execution of previously enqueued operations.
81 ///
82 /// This method uses unsafe code to directly interact with the RDMA device,
83 /// sending a signal from the source pointer to the destination pointer.
84 ///
85 /// # Returns
86 /// * `Ok(())` if the operation is successful.
87 /// * `Err(anyhow::Error)` if an error occurs during the operation.
88 pub fn ring(&self) -> Result<(), anyhow::Error> {
89 unsafe {
90 let src_ptr = self.src_ptr as *mut std::ffi::c_void;
91 let dst_ptr = self.dst_ptr as *mut std::ffi::c_void;
92 rdmaxcel_sys::db_ring(dst_ptr, src_ptr);
93 Ok(())
94 }
95 }
96}
97
98#[derive(Debug, Serialize, Deserialize, Named, Clone)]
99pub struct RdmaBuffer {
100 pub owner: ActorRef<RdmaManagerActor>,
101 pub mr_id: usize,
102 pub lkey: u32,
103 pub rkey: u32,
104 pub addr: usize,
105 pub size: usize,
106 pub device_name: String,
107}
108wirevalue::register_type!(RdmaBuffer);
109
110impl RdmaBuffer {
111 /// Read from the RdmaBuffer into the provided memory.
112 ///
113 /// This method transfers data from the buffer into the local memory region provided over RDMA.
114 /// This involves calling a `Put` operation on the RdmaBuffer's actor side.
115 ///
116 /// # Arguments
117 /// * `client` - The actor who is reading.
118 /// * `remote` - RdmaBuffer representing the remote memory region
119 /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
120 ///
121 /// # Returns
122 /// `Ok(bool)` indicating if the operation completed successfully.
123 pub async fn read_into(
124 &self,
125 client: &impl context::Actor,
126 remote: RdmaBuffer,
127 timeout: u64,
128 ) -> Result<bool, anyhow::Error> {
129 tracing::debug!(
130 "[buffer] reading from {:?} into remote ({:?}) at {:?}",
131 self,
132 remote.owner.actor_id(),
133 remote,
134 );
135 let remote_owner = remote.owner.clone();
136
137 let local_device = self.device_name.clone();
138 let remote_device = remote.device_name.clone();
139 let mut qp = self
140 .owner
141 .request_queue_pair(
142 client,
143 remote_owner.clone(),
144 local_device.clone(),
145 remote_device.clone(),
146 )
147 .await?;
148
149 let wr_id = qp.put(self.clone(), remote)?;
150 let result = self
151 .wait_for_completion(&mut qp, PollTarget::Send, &wr_id, timeout)
152 .await;
153
154 // Release the queue pair back to the actor
155 self.owner
156 .release_queue_pair(client, remote_owner, local_device, remote_device, qp)
157 .await?;
158
159 result
160 }
161
162 /// Write from the provided memory into the RdmaBuffer.
163 ///
164 /// This method performs an RDMA write operation, transferring data from the caller's
165 /// memory region to this buffer.
166 /// This involves calling a `Fetch` operation on the RdmaBuffer's actor side.
167 ///
168 /// # Arguments
169 /// * `client` - The actor who is writing.
170 /// * `remote` - RdmaBuffer representing the remote memory region
171 /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
172 ///
173 /// # Returns
174 /// `Ok(bool)` indicating if the operation completed successfully.
175 pub async fn write_from(
176 &self,
177 client: &impl context::Actor,
178 remote: RdmaBuffer,
179 timeout: u64,
180 ) -> Result<bool, anyhow::Error> {
181 tracing::debug!(
182 "[buffer] writing into {:?} from remote ({:?}) at {:?}",
183 self,
184 remote.owner.actor_id(),
185 remote,
186 );
187 let remote_owner = remote.owner.clone(); // Clone before the move!
188
189 // Extract device name from buffer, fallback to a default if not present
190 let local_device = self.device_name.clone();
191 let remote_device = remote.device_name.clone();
192
193 let mut qp = self
194 .owner
195 .request_queue_pair(
196 client,
197 remote_owner.clone(),
198 local_device.clone(),
199 remote_device.clone(),
200 )
201 .await?;
202 let wr_id = qp.get(self.clone(), remote)?;
203 let result = self
204 .wait_for_completion(&mut qp, PollTarget::Send, &wr_id, timeout)
205 .await;
206
207 // Release the queue pair back to the actor
208 self.owner
209 .release_queue_pair(client, remote_owner, local_device, remote_device, qp)
210 .await?;
211
212 result
213 }
214 /// Waits for the completion of RDMA operations.
215 ///
216 /// This method polls the completion queue until all specified work requests complete
217 /// or until the timeout is reached.
218 ///
219 /// # Arguments
220 /// * `qp` - The RDMA Queue Pair to poll for completion
221 /// * `poll_target` - Which CQ to poll (Send or Recv)
222 /// * `expected_wr_ids` - The work request IDs to wait for
223 /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
224 ///
225 /// # Returns
226 /// `Ok(true)` if all operations complete successfully within the timeout,
227 /// or an error if the timeout is reached
228 async fn wait_for_completion(
229 &self,
230 qp: &mut RdmaQueuePair,
231 poll_target: PollTarget,
232 expected_wr_ids: &[u64],
233 timeout: u64,
234 ) -> Result<bool, anyhow::Error> {
235 let timeout = Duration::from_secs(timeout);
236 let start_time = std::time::Instant::now();
237
238 let mut remaining: std::collections::HashSet<u64> =
239 expected_wr_ids.iter().copied().collect();
240
241 while start_time.elapsed() < timeout {
242 if remaining.is_empty() {
243 return Ok(true);
244 }
245
246 let wr_ids_to_poll: Vec<u64> = remaining.iter().copied().collect();
247 match qp.poll_completion(poll_target, &wr_ids_to_poll) {
248 Ok(completions) => {
249 for (wr_id, _wc) in completions {
250 remaining.remove(&wr_id);
251 }
252 if remaining.is_empty() {
253 return Ok(true);
254 }
255 RealClock.sleep(Duration::from_millis(1)).await;
256 }
257 Err(e) => {
258 return Err(anyhow::anyhow!(
259 "RDMA polling completion failed: {} [lkey={}, rkey={}, addr=0x{:x}, size={}]",
260 e,
261 self.lkey,
262 self.rkey,
263 self.addr,
264 self.size
265 ));
266 }
267 }
268 }
269 tracing::error!(
270 "timed out while waiting on request completion for wr_ids={:?}",
271 remaining
272 );
273 Err(anyhow::anyhow!(
274 "[buffer({:?})] rdma operation did not complete in time (expected wr_ids={:?})",
275 self,
276 expected_wr_ids
277 ))
278 }
279
280 /// Drop the buffer and release remote handles.
281 ///
282 /// This method calls the owning RdmaManagerActor to release the buffer and clean up
283 /// associated memory regions. This is typically called when the buffer is no longer
284 /// needed and resources should be freed.
285 ///
286 /// # Arguments
287 /// * `client` - Mailbox used for communication
288 ///
289 /// # Returns
290 /// `Ok(())` if the operation completed successfully.
291 pub async fn drop_buffer(&self, client: &impl context::Actor) -> Result<(), anyhow::Error> {
292 tracing::debug!("[buffer] dropping buffer {:?}", self);
293 self.owner.release_buffer(client, self.clone()).await?;
294 Ok(())
295 }
296}
297
298/// Represents a domain for RDMA operations, encapsulating the necessary resources
299/// for establishing and managing RDMA connections.
300///
301/// An `RdmaDomain` manages the context, protection domain (PD), and memory region (MR)
302/// required for RDMA operations. It provides the foundation for creating queue pairs
303/// and establishing connections between RDMA devices.
304///
305/// # Fields
306///
307/// * `context`: A pointer to the RDMA device context, representing the connection to the RDMA device.
308/// * `pd`: A pointer to the protection domain, which provides isolation between different connections.
309#[derive(Clone)]
310pub struct RdmaDomain {
311 pub context: *mut rdmaxcel_sys::ibv_context,
312 pub pd: *mut rdmaxcel_sys::ibv_pd,
313}
314
315impl std::fmt::Debug for RdmaDomain {
316 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
317 f.debug_struct("RdmaDomain")
318 .field("context", &format!("{:p}", self.context))
319 .field("pd", &format!("{:p}", self.pd))
320 .finish()
321 }
322}
323
324// SAFETY:
325// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls.
326// RdmaDomain is `Send` because the raw pointers to ibverbs structs can be
327// accessed from any thread, and it is safe to drop `RdmaDomain` (and run the
328// ibverbs destructors) from any thread.
329unsafe impl Send for RdmaDomain {}
330
331// SAFETY:
332// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls.
333// RdmaDomain is `Sync` because the underlying ibverbs APIs are thread-safe.
334unsafe impl Sync for RdmaDomain {}
335
336impl Drop for RdmaDomain {
337 fn drop(&mut self) {
338 unsafe {
339 rdmaxcel_sys::ibv_dealloc_pd(self.pd);
340 }
341 }
342}
343
344impl RdmaDomain {
345 /// Creates a new RdmaDomain.
346 ///
347 /// This function initializes the RDMA device context, creates a protection domain,
348 /// and registers a memory region with appropriate access permissions.
349 ///
350 /// SAFETY:
351 /// Our memory region (MR) registration uses implicit ODP for RDMA access, which maps large virtual
352 /// address ranges without explicit pinning. This is convenient, but it broadens the memory footprint
353 /// exposed to the NIC and introduces a security liability.
354 ///
355 /// We currently assume a trusted, single-environment and are not enforcing finer-grained memory isolation
356 /// at this layer. We plan to investigate mitigations - such as memory windows or tighter registration
357 /// boundaries in future follow-ups.
358 ///
359 /// # Arguments
360 ///
361 /// * `config` - Configuration settings for the RDMA operations
362 ///
363 /// # Errors
364 ///
365 /// This function may return errors if:
366 /// * No RDMA devices are found
367 /// * The specified device cannot be found
368 /// * Device context creation fails
369 /// * Protection domain allocation fails
370 /// * Memory region registration fails
371 pub fn new(device: RdmaDevice) -> Result<Self, anyhow::Error> {
372 tracing::debug!("creating RdmaDomain for device {}", device.name());
373 // SAFETY:
374 // This code uses unsafe rdmaxcel_sys calls to interact with the RDMA device, but is safe because:
375 // - All pointers are properly initialized and checked for null before use
376 // - Memory registration follows the ibverbs API contract with proper access flags
377 // - Resources are properly cleaned up in error cases to prevent leaks
378 // - The operations follow the documented RDMA protocol for device initialization
379 unsafe {
380 // Get the device based on the provided RdmaDevice
381 let device_name = device.name();
382 let mut num_devices = 0i32;
383 let devices = rdmaxcel_sys::ibv_get_device_list(&mut num_devices as *mut _);
384
385 if devices.is_null() || num_devices == 0 {
386 return Err(anyhow::anyhow!("no RDMA devices found"));
387 }
388
389 // Find the device with the matching name
390 let mut device_ptr = std::ptr::null_mut();
391 for i in 0..num_devices {
392 let dev = *devices.offset(i as isize);
393 let dev_name =
394 CStr::from_ptr(rdmaxcel_sys::ibv_get_device_name(dev)).to_string_lossy();
395
396 if dev_name == *device_name {
397 device_ptr = dev;
398 break;
399 }
400 }
401
402 // If we didn't find the device, return an error
403 if device_ptr.is_null() {
404 rdmaxcel_sys::ibv_free_device_list(devices);
405 return Err(anyhow::anyhow!("device '{}' not found", device_name));
406 }
407 tracing::info!("using RDMA device: {}", device_name);
408
409 // Open device
410 let context = rdmaxcel_sys::ibv_open_device(device_ptr);
411 if context.is_null() {
412 rdmaxcel_sys::ibv_free_device_list(devices);
413 let os_error = Error::last_os_error();
414 return Err(anyhow::anyhow!("failed to create context: {}", os_error));
415 }
416
417 // Create protection domain
418 let pd = rdmaxcel_sys::ibv_alloc_pd(context);
419 if pd.is_null() {
420 rdmaxcel_sys::ibv_close_device(context);
421 rdmaxcel_sys::ibv_free_device_list(devices);
422 let os_error = Error::last_os_error();
423 return Err(anyhow::anyhow!(
424 "failed to create protection domain (PD): {}",
425 os_error
426 ));
427 }
428
429 // Avoids memory leaks
430 rdmaxcel_sys::ibv_free_device_list(devices);
431
432 let domain = RdmaDomain { context, pd };
433
434 Ok(domain)
435 }
436 }
437}
438/// Enum to specify which completion queue to poll
439#[derive(Debug, Clone, Copy, PartialEq)]
440pub enum PollTarget {
441 Send,
442 Recv,
443}
444
445/// Represents an RDMA Queue Pair (QP) that enables communication between two endpoints.
446///
447/// An `RdmaQueuePair` encapsulates the send and receive queues, completion queue,
448/// and other resources needed for RDMA communication. It provides methods for
449/// establishing connections and performing RDMA operations like read and write.
450///
451/// # Fields
452///
453/// * `send_cq` - Send Completion Queue pointer for tracking send operation completions
454/// * `recv_cq` - Receive Completion Queue pointer for tracking receive operation completions
455/// * `qp` - Queue Pair pointer that manages send and receive operations
456/// * `dv_qp` - Pointer to the mlx5 device-specific queue pair structure
457/// * `dv_send_cq` - Pointer to the mlx5 device-specific send completion queue structure
458/// * `dv_recv_cq` - Pointer to the mlx5 device-specific receive completion queue structure
459/// * `context` - RDMA device context pointer
460/// * `config` - Configuration settings for the queue pair
461///
462/// # Connection Lifecycle
463///
464/// 1. Create with `new()` from an `RdmaDomain`
465/// 2. Get connection info with `get_qp_info()`
466/// 3. Exchange connection info with remote peer (application must handle this)
467/// 4. Connect to remote endpoint with `connect()`
468/// 5. Perform RDMA operations with `put()` or `get()`
469/// 6. Poll for completions with `poll_send_completion(wr_id)` or `poll_recv_completion(wr_id)`
470///
471/// # Notes
472/// - The `qp` field stores a pointer to `rdmaxcel_qp_t` (not `ibv_qp`)
473/// - `rdmaxcel_qp_t` contains atomic counters and completion caches internally
474/// - This makes RdmaQueuePair trivially Clone and Serialize
475/// - Multiple clones share the same underlying rdmaxcel_qp_t via the pointer
476#[derive(Debug, Serialize, Deserialize, Named, Clone)]
477pub struct RdmaQueuePair {
478 pub send_cq: usize, // *mut rdmaxcel_sys::ibv_cq,
479 pub recv_cq: usize, // *mut rdmaxcel_sys::ibv_cq,
480 pub qp: usize, // *mut rdmaxcel_sys::rdmaxcel_qp_t
481 pub dv_qp: usize, // *mut rdmaxcel_sys::mlx5dv_qp,
482 pub dv_send_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
483 pub dv_recv_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
484 context: usize, // *mut rdmaxcel_sys::ibv_context,
485 config: IbverbsConfig,
486}
487wirevalue::register_type!(RdmaQueuePair);
488
489impl RdmaQueuePair {
490 /// Applies hardware initialization delay if this is the first operation since RTS.
491 ///
492 /// This ensures the hardware has sufficient time to settle after reaching
493 /// Ready-to-Send state before the first actual operation.
494 fn apply_first_op_delay(&self, wr_id: u64) {
495 unsafe {
496 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
497 if wr_id == 0 {
498 let rts_timestamp = rdmaxcel_sys::rdmaxcel_qp_load_rts_timestamp(qp);
499 assert!(
500 rts_timestamp != u64::MAX,
501 "First operation attempted before queue pair reached RTS state! Call connect() first."
502 );
503 let current_nanos = RealClock
504 .system_time_now()
505 .duration_since(std::time::UNIX_EPOCH)
506 .unwrap()
507 .as_nanos() as u64;
508 let elapsed_nanos = current_nanos - rts_timestamp;
509 let elapsed = Duration::from_nanos(elapsed_nanos);
510 let init_delay = Duration::from_millis(self.config.hw_init_delay_ms);
511 if elapsed < init_delay {
512 let remaining_delay = init_delay - elapsed;
513 sleep(remaining_delay);
514 }
515 }
516 }
517 }
518
519 /// Creates a new RdmaQueuePair from a given RdmaDomain.
520 ///
521 /// This function initializes a new Queue Pair (QP) and associated Completion Queue (CQ)
522 /// using the resources from the provided RdmaDomain. The QP is created in the RESET state
523 /// and must be transitioned to other states via the `connect()` method before use.
524 ///
525 /// # Arguments
526 ///
527 /// * `domain` - Reference to an RdmaDomain that provides the context, protection domain,
528 /// and memory region for this queue pair
529 ///
530 /// # Returns
531 ///
532 /// * `Result<Self>` - A new RdmaQueuePair instance or an error if creation fails
533 ///
534 /// # Errors
535 ///
536 /// This function may return errors if:
537 /// * Completion queue (CQ) creation fails
538 /// * Queue pair (QP) creation fails
539 pub fn new(
540 context: *mut rdmaxcel_sys::ibv_context,
541 pd: *mut rdmaxcel_sys::ibv_pd,
542 config: IbverbsConfig,
543 ) -> Result<Self, anyhow::Error> {
544 tracing::debug!("creating an RdmaQueuePair from config {}", config);
545 unsafe {
546 // Resolve Auto to a concrete QP type based on device capabilities
547 let resolved_qp_type = resolve_qp_type(config.qp_type);
548 let qp = rdmaxcel_sys::rdmaxcel_qp_create(
549 context,
550 pd,
551 config.cq_entries,
552 config.max_send_wr.try_into().unwrap(),
553 config.max_recv_wr.try_into().unwrap(),
554 config.max_send_sge.try_into().unwrap(),
555 config.max_recv_sge.try_into().unwrap(),
556 resolved_qp_type,
557 );
558
559 if qp.is_null() {
560 let os_error = Error::last_os_error();
561 return Err(anyhow::anyhow!(
562 "failed to create queue pair (QP): {}",
563 os_error
564 ));
565 }
566
567 let send_cq = (*(*qp).ibv_qp).send_cq;
568 let recv_cq = (*(*qp).ibv_qp).recv_cq;
569
570 // mlx5dv provider APIs
571 let dv_qp = rdmaxcel_sys::create_mlx5dv_qp((*qp).ibv_qp);
572 let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq((*qp).ibv_qp);
573 let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq((*qp).ibv_qp);
574
575 if dv_qp.is_null() || dv_send_cq.is_null() || dv_recv_cq.is_null() {
576 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
577 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
578 rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
579 return Err(anyhow::anyhow!(
580 "failed to init mlx5dv_qp or completion queues"
581 ));
582 }
583
584 // GPU Direct RDMA specific registrations
585 if config.use_gpu_direct {
586 let ret = rdmaxcel_sys::register_cuda_memory(dv_qp, dv_recv_cq, dv_send_cq);
587 if ret != 0 {
588 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
589 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
590 rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
591 return Err(anyhow::anyhow!(
592 "failed to register GPU Direct RDMA memory: {:?}",
593 ret
594 ));
595 }
596 }
597 Ok(RdmaQueuePair {
598 send_cq: send_cq as usize,
599 recv_cq: recv_cq as usize,
600 qp: qp as usize,
601 dv_qp: qp as usize,
602 dv_send_cq: dv_send_cq as usize,
603 dv_recv_cq: dv_recv_cq as usize,
604 context: context as usize,
605 config,
606 })
607 }
608 }
609
610 /// Returns the information required for a remote peer to connect to this queue pair.
611 ///
612 /// This method retrieves the local queue pair attributes and port information needed by
613 /// a remote peer to establish an RDMA connection. The returned `RdmaQpInfo` contains
614 /// the queue pair number, LID, GID, and other necessary connection parameters.
615 ///
616 /// # Returns
617 ///
618 /// * `Result<RdmaQpInfo>` - Connection information for the remote peer or an error
619 ///
620 /// # Errors
621 ///
622 /// This function may return errors if:
623 /// * Port attribute query fails
624 /// * GID query fails
625 pub fn get_qp_info(&mut self) -> Result<RdmaQpInfo, anyhow::Error> {
626 // SAFETY:
627 // This code uses unsafe rdmaxcel_sys calls to query RDMA device information, but is safe because:
628 // - All pointers are properly initialized before use
629 // - Port and GID queries follow the documented ibverbs API contract
630 // - Error handling properly checks return codes from ibverbs functions
631 // - The memory address provided is only stored, not dereferenced in this function
632 unsafe {
633 let context = self.context as *mut rdmaxcel_sys::ibv_context;
634 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
635 let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
636 let errno = rdmaxcel_sys::ibv_query_port(
637 context,
638 self.config.port_num,
639 &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
640 );
641 if errno != 0 {
642 let os_error = Error::last_os_error();
643 return Err(anyhow::anyhow!(
644 "Failed to query port attributes: {}",
645 os_error
646 ));
647 }
648
649 let mut gid = Gid::default();
650 let ret = rdmaxcel_sys::ibv_query_gid(
651 context,
652 self.config.port_num,
653 i32::from(self.config.gid_index),
654 gid.as_mut(),
655 );
656 if ret != 0 {
657 return Err(anyhow::anyhow!("Failed to query GID"));
658 }
659
660 Ok(RdmaQpInfo {
661 qp_num: (*(*qp).ibv_qp).qp_num,
662 lid: port_attr.lid,
663 gid: Some(gid),
664 psn: self.config.psn,
665 })
666 }
667 }
668
669 pub fn state(&mut self) -> Result<u32, anyhow::Error> {
670 // SAFETY: This block interacts with the RDMA device through rdmaxcel_sys calls.
671 unsafe {
672 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
673 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
674 ..Default::default()
675 };
676 let mut qp_init_attr = rdmaxcel_sys::ibv_qp_init_attr {
677 ..Default::default()
678 };
679 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE;
680 let errno = rdmaxcel_sys::ibv_query_qp(
681 (*qp).ibv_qp,
682 &mut qp_attr,
683 mask.0 as i32,
684 &mut qp_init_attr,
685 );
686 if errno != 0 {
687 let os_error = Error::last_os_error();
688 return Err(anyhow::anyhow!("failed to query QP state: {}", os_error));
689 }
690 Ok(qp_attr.qp_state)
691 }
692 }
693 /// Connect to a remote Rdma connection point.
694 ///
695 /// This performs the necessary QP state transitions (INIT->RTR->RTS) to establish a connection.
696 ///
697 /// # Arguments
698 ///
699 /// * `connection_info` - The remote connection info to connect to
700 pub fn connect(&mut self, connection_info: &RdmaQpInfo) -> Result<(), anyhow::Error> {
701 // SAFETY:
702 // This unsafe block is necessary because we're interacting with the RDMA device through rdmaxcel_sys calls.
703 // The operations are safe because:
704 // 1. We're following the documented ibverbs API contract
705 // 2. All pointers used are properly initialized and owned by this struct
706 // 3. The QP state transitions (INIT->RTR->RTS) follow the required RDMA connection protocol
707 // 4. Memory access is properly bounded by the registered memory regions
708 unsafe {
709 // Transition to INIT
710 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
711
712 let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
713 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
714 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
715 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
716
717 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
718 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_INIT,
719 qp_access_flags: qp_access_flags.0,
720 pkey_index: self.config.pkey_index,
721 port_num: self.config.port_num,
722 ..Default::default()
723 };
724
725 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
726 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX
727 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT
728 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS;
729
730 let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
731 if errno != 0 {
732 let os_error = Error::last_os_error();
733 return Err(anyhow::anyhow!(
734 "failed to transition QP to INIT: {}",
735 os_error
736 ));
737 }
738
739 // Transition to RTR (Ready to Receive)
740 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
741 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTR,
742 path_mtu: self.config.path_mtu,
743 dest_qp_num: connection_info.qp_num,
744 rq_psn: connection_info.psn,
745 max_dest_rd_atomic: self.config.max_dest_rd_atomic,
746 min_rnr_timer: self.config.min_rnr_timer,
747 ah_attr: rdmaxcel_sys::ibv_ah_attr {
748 dlid: connection_info.lid,
749 sl: 0,
750 src_path_bits: 0,
751 port_num: self.config.port_num,
752 grh: Default::default(),
753 ..Default::default()
754 },
755 ..Default::default()
756 };
757
758 // If the remote connection info contains a Gid, the routing will be global.
759 // Otherwise, it will be local, i.e. using LID.
760 if let Some(gid) = connection_info.gid {
761 qp_attr.ah_attr.is_global = 1;
762 qp_attr.ah_attr.grh.dgid = gid.into();
763 qp_attr.ah_attr.grh.hop_limit = 0xff;
764 qp_attr.ah_attr.grh.sgid_index = self.config.gid_index;
765 } else {
766 // Use LID-based routing, e.g. for Infiniband/RoCEv1
767 qp_attr.ah_attr.is_global = 0;
768 }
769
770 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
771 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_AV
772 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PATH_MTU
773 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_DEST_QPN
774 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RQ_PSN
775 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC
776 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER;
777
778 let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
779 if errno != 0 {
780 let os_error = Error::last_os_error();
781 return Err(anyhow::anyhow!(
782 "failed to transition QP to RTR: {}",
783 os_error
784 ));
785 }
786
787 // Transition to RTS (Ready to Send)
788 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
789 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS,
790 sq_psn: self.config.psn,
791 max_rd_atomic: self.config.max_rd_atomic,
792 retry_cnt: self.config.retry_cnt,
793 rnr_retry: self.config.rnr_retry,
794 timeout: self.config.qp_timeout,
795 ..Default::default()
796 };
797
798 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
799 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_TIMEOUT
800 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RETRY_CNT
801 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_SQ_PSN
802 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY
803 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC;
804
805 let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
806 if errno != 0 {
807 let os_error = Error::last_os_error();
808 return Err(anyhow::anyhow!(
809 "failed to transition QP to RTS: {}",
810 os_error
811 ));
812 }
813 tracing::debug!(
814 "connection sequence has successfully completed (qp: {:?})",
815 qp
816 );
817
818 // Record RTS time now that the queue pair is ready to send
819 let rts_timestamp_nanos = RealClock
820 .system_time_now()
821 .duration_since(std::time::UNIX_EPOCH)
822 .unwrap()
823 .as_nanos() as u64;
824 rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos);
825
826 Ok(())
827 }
828 }
829
830 pub fn recv(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<u64, anyhow::Error> {
831 unsafe {
832 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
833 let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(qp);
834 self.post_op(
835 0,
836 lhandle.lkey,
837 0,
838 idx,
839 true,
840 RdmaOperation::Recv,
841 0,
842 rhandle.rkey,
843 )
844 .unwrap();
845 rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(qp);
846 Ok(idx)
847 }
848 }
849
850 pub fn put_with_recv(
851 &mut self,
852 lhandle: RdmaBuffer,
853 rhandle: RdmaBuffer,
854 ) -> Result<Vec<u64>, anyhow::Error> {
855 unsafe {
856 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
857 let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(qp);
858 self.post_op(
859 lhandle.addr,
860 lhandle.lkey,
861 lhandle.size,
862 idx,
863 true,
864 RdmaOperation::WriteWithImm,
865 rhandle.addr,
866 rhandle.rkey,
867 )
868 .unwrap();
869 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(qp);
870 Ok(vec![idx])
871 }
872 }
873
874 pub fn put(
875 &mut self,
876 lhandle: RdmaBuffer,
877 rhandle: RdmaBuffer,
878 ) -> Result<Vec<u64>, anyhow::Error> {
879 let total_size = lhandle.size;
880 if rhandle.size < total_size {
881 return Err(anyhow::anyhow!(
882 "Remote buffer size ({}) is smaller than local buffer size ({})",
883 rhandle.size,
884 total_size
885 ));
886 }
887
888 let mut remaining = total_size;
889 let mut offset = 0;
890 let mut wr_ids = Vec::new();
891 while remaining > 0 {
892 let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
893 let idx = unsafe {
894 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
895 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
896 )
897 };
898 wr_ids.push(idx);
899 self.post_op(
900 lhandle.addr + offset,
901 lhandle.lkey,
902 chunk_size,
903 idx,
904 true,
905 RdmaOperation::Write,
906 rhandle.addr + offset,
907 rhandle.rkey,
908 )?;
909 unsafe {
910 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
911 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
912 );
913 }
914
915 remaining -= chunk_size;
916 offset += chunk_size;
917 }
918
919 Ok(wr_ids)
920 }
921
922 /// Get a doorbell for the queue pair.
923 ///
924 /// This method returns a doorbell that can be used to trigger the execution of
925 /// previously enqueued operations.
926 ///
927 /// # Returns
928 ///
929 /// * `Result<DoorBell, anyhow::Error>` - A doorbell for the queue pair
930 pub fn ring_doorbell(&mut self) -> Result<(), anyhow::Error> {
931 unsafe {
932 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
933 let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
934 let base_ptr = (*dv_qp).sq.buf as *mut u8;
935 let wqe_cnt = (*dv_qp).sq.wqe_cnt;
936 let stride = (*dv_qp).sq.stride;
937 let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(qp);
938 let mut send_db_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp);
939 if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) {
940 return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
941 }
942 self.apply_first_op_delay(send_db_idx);
943 while send_db_idx < send_wqe_idx {
944 let offset = (send_db_idx % wqe_cnt as u64) * stride as u64;
945 let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
946 rdmaxcel_sys::db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
947 send_db_idx += 1;
948 rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx(qp, send_db_idx);
949 }
950 Ok(())
951 }
952 }
953
954 /// Enqueues a put operation without ringing the doorbell.
955 ///
956 /// This method prepares a put operation but does not execute it.
957 /// Use `get_doorbell().ring()` to execute the operation.
958 ///
959 /// # Arguments
960 ///
961 /// * `lhandle` - Local buffer handle
962 /// * `rhandle` - Remote buffer handle
963 ///
964 /// # Returns
965 ///
966 /// * `Result<Vec<u64>, anyhow::Error>` - The work request IDs or error
967 pub fn enqueue_put(
968 &mut self,
969 lhandle: RdmaBuffer,
970 rhandle: RdmaBuffer,
971 ) -> Result<Vec<u64>, anyhow::Error> {
972 let idx = unsafe {
973 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
974 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
975 )
976 };
977
978 self.send_wqe(
979 lhandle.addr,
980 lhandle.lkey,
981 lhandle.size,
982 idx,
983 true,
984 RdmaOperation::Write,
985 rhandle.addr,
986 rhandle.rkey,
987 )?;
988 Ok(vec![idx])
989 }
990
991 /// Enqueues a put with receive operation without ringing the doorbell.
992 ///
993 /// This method prepares a put with receive operation but does not execute it.
994 /// Use `get_doorbell().ring()` to execute the operation.
995 ///
996 /// # Arguments
997 ///
998 /// * `lhandle` - Local buffer handle
999 /// * `rhandle` - Remote buffer handle
1000 ///
1001 /// # Returns
1002 ///
1003 /// * `Result<Vec<u64>, anyhow::Error>` - The work request IDs or error
1004 pub fn enqueue_put_with_recv(
1005 &mut self,
1006 lhandle: RdmaBuffer,
1007 rhandle: RdmaBuffer,
1008 ) -> Result<Vec<u64>, anyhow::Error> {
1009 let idx = unsafe {
1010 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
1011 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1012 )
1013 };
1014
1015 self.send_wqe(
1016 lhandle.addr,
1017 lhandle.lkey,
1018 lhandle.size,
1019 idx,
1020 true,
1021 RdmaOperation::WriteWithImm,
1022 rhandle.addr,
1023 rhandle.rkey,
1024 )?;
1025 Ok(vec![idx])
1026 }
1027
1028 /// Enqueues a get operation without ringing the doorbell.
1029 ///
1030 /// This method prepares a get operation but does not execute it.
1031 /// Use `get_doorbell().ring()` to execute the operation.
1032 ///
1033 /// # Arguments
1034 ///
1035 /// * `lhandle` - Local buffer handle
1036 /// * `rhandle` - Remote buffer handle
1037 ///
1038 /// # Returns
1039 ///
1040 /// * `Result<Vec<u64>, anyhow::Error>` - The work request IDs or error
1041 pub fn enqueue_get(
1042 &mut self,
1043 lhandle: RdmaBuffer,
1044 rhandle: RdmaBuffer,
1045 ) -> Result<Vec<u64>, anyhow::Error> {
1046 let idx = unsafe {
1047 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
1048 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1049 )
1050 };
1051
1052 self.send_wqe(
1053 lhandle.addr,
1054 lhandle.lkey,
1055 lhandle.size,
1056 idx,
1057 true,
1058 RdmaOperation::Read,
1059 rhandle.addr,
1060 rhandle.rkey,
1061 )?;
1062 Ok(vec![idx])
1063 }
1064
1065 pub fn get(
1066 &mut self,
1067 lhandle: RdmaBuffer,
1068 rhandle: RdmaBuffer,
1069 ) -> Result<Vec<u64>, anyhow::Error> {
1070 let total_size = lhandle.size;
1071 if rhandle.size < total_size {
1072 return Err(anyhow::anyhow!(
1073 "Remote buffer size ({}) is smaller than local buffer size ({})",
1074 rhandle.size,
1075 total_size
1076 ));
1077 }
1078
1079 let mut remaining = total_size;
1080 let mut offset = 0;
1081 let mut wr_ids = Vec::new();
1082
1083 while remaining > 0 {
1084 let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
1085 let idx = unsafe {
1086 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
1087 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1088 )
1089 };
1090 wr_ids.push(idx);
1091
1092 self.post_op(
1093 lhandle.addr + offset,
1094 lhandle.lkey,
1095 chunk_size,
1096 idx,
1097 true,
1098 RdmaOperation::Read,
1099 rhandle.addr + offset,
1100 rhandle.rkey,
1101 )?;
1102 unsafe {
1103 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
1104 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1105 );
1106 }
1107
1108 remaining -= chunk_size;
1109 offset += chunk_size;
1110 }
1111
1112 Ok(wr_ids)
1113 }
1114
1115 /// Posts a request to the queue pair.
1116 ///
1117 /// # Arguments
1118 ///
1119 /// * `local_addr` - The local address containing data to send
1120 /// * `length` - Length of the data to send
1121 /// * `wr_id` - Work request ID for completion identification
1122 /// * `signaled` - Whether to generate a completion event
1123 /// * `op_type` - Optional operation type
1124 /// * `raddr` - the remote address, representing the memory location on the remote peer
1125 /// * `rkey` - the remote key, representing the key required to access the remote memory region
1126 fn post_op(
1127 &mut self,
1128 laddr: usize,
1129 lkey: u32,
1130 length: usize,
1131 wr_id: u64,
1132 signaled: bool,
1133 op_type: RdmaOperation,
1134 raddr: usize,
1135 rkey: u32,
1136 ) -> Result<(), anyhow::Error> {
1137 // SAFETY:
1138 // This code uses unsafe rdmaxcel_sys calls to post work requests to the RDMA device, but is safe because:
1139 // - All pointers (send_sge, send_wr) are properly initialized on the stack before use
1140 // - The memory address in `local_addr` is not dereferenced, only passed to the device
1141 // - The remote connection info is verified to exist before accessing
1142 // - The ibverbs post_send operation follows the documented API contract
1143 // - Error codes from the device are properly checked and propagated
1144 unsafe {
1145 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
1146 let context = self.context as *mut rdmaxcel_sys::ibv_context;
1147 let ops = &mut (*context).ops;
1148 let errno;
1149 if op_type == RdmaOperation::Recv {
1150 let mut sge = rdmaxcel_sys::ibv_sge {
1151 addr: laddr as u64,
1152 length: length as u32,
1153 lkey,
1154 };
1155 let mut wr = rdmaxcel_sys::ibv_recv_wr {
1156 wr_id,
1157 sg_list: &mut sge as *mut _,
1158 num_sge: 1,
1159 ..Default::default()
1160 };
1161 let mut bad_wr: *mut rdmaxcel_sys::ibv_recv_wr = std::ptr::null_mut();
1162 errno =
1163 ops.post_recv.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
1164 } else if op_type == RdmaOperation::Write
1165 || op_type == RdmaOperation::Read
1166 || op_type == RdmaOperation::WriteWithImm
1167 {
1168 // Apply hardware initialization delay if this is the first operation
1169 self.apply_first_op_delay(wr_id);
1170 let send_flags = if signaled {
1171 rdmaxcel_sys::ibv_send_flags::IBV_SEND_SIGNALED.0
1172 } else {
1173 0
1174 };
1175 let mut sge = rdmaxcel_sys::ibv_sge {
1176 addr: laddr as u64,
1177 length: length as u32,
1178 lkey,
1179 };
1180 let mut wr = rdmaxcel_sys::ibv_send_wr {
1181 wr_id,
1182 next: std::ptr::null_mut(),
1183 sg_list: &mut sge as *mut _,
1184 num_sge: 1,
1185 opcode: op_type.into(),
1186 send_flags,
1187 wr: Default::default(),
1188 qp_type: Default::default(),
1189 __bindgen_anon_1: Default::default(),
1190 __bindgen_anon_2: Default::default(),
1191 };
1192
1193 wr.wr.rdma.remote_addr = raddr as u64;
1194 wr.wr.rdma.rkey = rkey;
1195 let mut bad_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut();
1196
1197 errno =
1198 ops.post_send.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
1199 } else {
1200 panic!("Not Implemented");
1201 }
1202
1203 if errno != 0 {
1204 let os_error = Error::last_os_error();
1205 return Err(anyhow::anyhow!("Failed to post send request: {}", os_error));
1206 }
1207 tracing::debug!(
1208 "completed sending {:?} request (lkey: {}, addr: 0x{:x}, length {}) to (raddr 0x{:x}, rkey {})",
1209 op_type,
1210 lkey,
1211 laddr,
1212 length,
1213 raddr,
1214 rkey,
1215 );
1216
1217 Ok(())
1218 }
1219 }
1220
1221 fn send_wqe(
1222 &mut self,
1223 laddr: usize,
1224 lkey: u32,
1225 length: usize,
1226 wr_id: u64,
1227 signaled: bool,
1228 op_type: RdmaOperation,
1229 raddr: usize,
1230 rkey: u32,
1231 ) -> Result<DoorBell, anyhow::Error> {
1232 unsafe {
1233 let op_type_val = match op_type {
1234 RdmaOperation::Write => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE,
1235 RdmaOperation::WriteWithImm => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE_IMM,
1236 RdmaOperation::Read => rdmaxcel_sys::MLX5_OPCODE_RDMA_READ,
1237 RdmaOperation::Recv => 0,
1238 };
1239
1240 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
1241 let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
1242 let _dv_cq = if op_type == RdmaOperation::Recv {
1243 self.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq
1244 } else {
1245 self.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq
1246 };
1247
1248 // Create the WQE parameters struct
1249
1250 let buf = if op_type == RdmaOperation::Recv {
1251 (*dv_qp).rq.buf as *mut u8
1252 } else {
1253 (*dv_qp).sq.buf as *mut u8
1254 };
1255
1256 let params = rdmaxcel_sys::wqe_params_t {
1257 laddr,
1258 lkey,
1259 length,
1260 wr_id,
1261 signaled,
1262 op_type: op_type_val,
1263 raddr,
1264 rkey,
1265 qp_num: (*(*qp).ibv_qp).qp_num,
1266 buf,
1267 dbrec: (*dv_qp).dbrec,
1268 wqe_cnt: (*dv_qp).sq.wqe_cnt,
1269 };
1270
1271 // Call the C function to post the WQE
1272 if op_type == RdmaOperation::Recv {
1273 rdmaxcel_sys::recv_wqe(params);
1274 std::ptr::write_volatile((*dv_qp).dbrec, 1_u32.to_be());
1275 } else {
1276 rdmaxcel_sys::send_wqe(params);
1277 };
1278
1279 // Create and return a DoorBell struct
1280 Ok(DoorBell {
1281 dst_ptr: (*dv_qp).bf.reg as usize,
1282 src_ptr: (*dv_qp).sq.buf as usize,
1283 size: 8,
1284 })
1285 }
1286 }
1287
1288 /// Poll for work completions by wr_ids.
1289 ///
1290 /// # Arguments
1291 ///
1292 /// * `target` - Which completion queue to poll (Send, Receive)
1293 /// * `expected_wr_ids` - Slice of work request IDs to wait for
1294 ///
1295 /// # Returns
1296 ///
1297 /// * `Ok(Vec<(u64, IbvWc)>)` - Vector of (wr_id, completion) pairs that were found
1298 /// * `Err(e)` - An error occurred
1299 pub fn poll_completion(
1300 &mut self,
1301 target: PollTarget,
1302 expected_wr_ids: &[u64],
1303 ) -> Result<Vec<(u64, IbvWc)>, anyhow::Error> {
1304 if expected_wr_ids.is_empty() {
1305 return Ok(Vec::new());
1306 }
1307
1308 unsafe {
1309 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
1310 let qp_num = (*(*qp).ibv_qp).qp_num;
1311
1312 let (cq, cache, cq_type) = match target {
1313 PollTarget::Send => (
1314 self.send_cq as *mut rdmaxcel_sys::ibv_cq,
1315 rdmaxcel_sys::rdmaxcel_qp_get_send_cache(qp),
1316 "send",
1317 ),
1318 PollTarget::Recv => (
1319 self.recv_cq as *mut rdmaxcel_sys::ibv_cq,
1320 rdmaxcel_sys::rdmaxcel_qp_get_recv_cache(qp),
1321 "recv",
1322 ),
1323 };
1324
1325 let mut results = Vec::new();
1326
1327 // Single-shot poll: check each wr_id once and return what we find
1328 for &expected_wr_id in expected_wr_ids {
1329 let mut poll_ctx = rdmaxcel_sys::poll_context_t {
1330 expected_wr_id,
1331 expected_qp_num: qp_num,
1332 cache,
1333 cq,
1334 };
1335
1336 let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
1337 let ret = rdmaxcel_sys::poll_cq_with_cache(&mut poll_ctx, &mut wc);
1338
1339 match ret {
1340 1 => {
1341 // Found completion
1342 if !wc.is_valid() {
1343 if let Some((status, vendor_err)) = wc.error() {
1344 return Err(anyhow::anyhow!(
1345 "{} completion failed for wr_id={}: status={:?}, vendor_err={}",
1346 cq_type,
1347 expected_wr_id,
1348 status,
1349 vendor_err,
1350 ));
1351 }
1352 }
1353 results.push((expected_wr_id, IbvWc::from(wc)));
1354 }
1355 0 => {
1356 // Not found yet - this is fine for single-shot poll
1357 }
1358 -17 => {
1359 // RDMAXCEL_COMPLETION_FAILED: Completion found but failed - wc contains the error details
1360 let error_msg =
1361 std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1362 .to_str()
1363 .unwrap_or("Unknown error");
1364 if let Some((status, vendor_err)) = wc.error() {
1365 return Err(anyhow::anyhow!(
1366 "Failed to poll {} CQ for wr_id={}: {} [status={:?}, vendor_err={}, qp_num={}, byte_len={}]",
1367 cq_type,
1368 expected_wr_id,
1369 error_msg,
1370 status,
1371 vendor_err,
1372 wc.qp_num,
1373 wc.len(),
1374 ));
1375 } else {
1376 return Err(anyhow::anyhow!(
1377 "Failed to poll {} CQ for wr_id={}: {} [qp_num={}, byte_len={}]",
1378 cq_type,
1379 expected_wr_id,
1380 error_msg,
1381 wc.qp_num,
1382 wc.len(),
1383 ));
1384 }
1385 }
1386 _ => {
1387 // Other errors
1388 let error_msg =
1389 std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1390 .to_str()
1391 .unwrap_or("Unknown error");
1392 return Err(anyhow::anyhow!(
1393 "Failed to poll {} CQ for wr_id={}: {}",
1394 cq_type,
1395 expected_wr_id,
1396 error_msg
1397 ));
1398 }
1399 }
1400 }
1401
1402 Ok(results)
1403 }
1404 }
1405}
1406
1407/// Utility to validate execution context.
1408///
1409/// Remote Execution environments do not always have access to the nvidia_peermem module
1410/// and/or set the PeerMappingOverride parameter due to security. This function can be
1411/// used to validate that the execution context when running operations that need this
1412/// functionality (ie. cudaHostRegisterIoMemory).
1413///
1414/// # Returns
1415///
1416/// * `Ok(())` if the execution context is valid
1417/// * `Err(anyhow::Error)` if the execution context is invalid
1418pub async fn validate_execution_context() -> Result<(), anyhow::Error> {
1419 // Check for nvidia peermem
1420 match fs::read_to_string("/proc/modules") {
1421 Ok(contents) => {
1422 if !contents.contains("nvidia_peermem") {
1423 return Err(anyhow::anyhow!(
1424 "nvidia_peermem module not found in /proc/modules"
1425 ));
1426 }
1427 }
1428 Err(e) => {
1429 return Err(anyhow::anyhow!(e));
1430 }
1431 }
1432
1433 // Test file access to nvidia params
1434 match fs::read_to_string("/proc/driver/nvidia/params") {
1435 Ok(contents) => {
1436 if !contents.contains("PeerMappingOverride=1") {
1437 return Err(anyhow::anyhow!(
1438 "PeerMappingOverride=1 not found in /proc/driver/nvidia/params"
1439 ));
1440 }
1441 }
1442 Err(e) => {
1443 return Err(anyhow::anyhow!(e));
1444 }
1445 }
1446 Ok(())
1447}
1448
1449/// Get all segments that have been registered with MRs
1450///
1451/// # Returns
1452/// * `Vec<SegmentInfo>` - Vector containing all registered segment information
1453pub fn get_registered_cuda_segments() -> Vec<rdmaxcel_sys::rdma_segment_info_t> {
1454 unsafe {
1455 let segment_count = rdmaxcel_sys::rdma_get_active_segment_count();
1456 if segment_count <= 0 {
1457 return Vec::new();
1458 }
1459
1460 let mut segments = vec![
1461 std::mem::MaybeUninit::<rdmaxcel_sys::rdma_segment_info_t>::zeroed()
1462 .assume_init();
1463 segment_count as usize
1464 ];
1465 let actual_count =
1466 rdmaxcel_sys::rdma_get_all_segment_info(segments.as_mut_ptr(), segment_count);
1467
1468 if actual_count > 0 {
1469 segments.truncate(actual_count as usize);
1470 segments
1471 } else {
1472 Vec::new()
1473 }
1474 }
1475}
1476
1477/// Segment scanner callback type alias for convenience.
1478pub type SegmentScannerFn = rdmaxcel_sys::RdmaxcelSegmentScannerFn;
1479
1480/// Register a segment scanner callback.
1481///
1482/// The scanner callback is called during RDMA segment registration to discover
1483/// CUDA memory segments. The callback should fill the provided buffer with
1484/// segment information and return the total count of segments found.
1485///
1486/// If the returned count exceeds the buffer size, the caller will allocate
1487/// a larger buffer and retry.
1488///
1489/// Pass `None` to unregister the scanner.
1490///
1491/// # Safety
1492///
1493/// The provided callback function must be safe to call from C code and must
1494/// properly handle the segment buffer.
1495pub fn register_segment_scanner(scanner: SegmentScannerFn) {
1496 // SAFETY: We are registering a callback function pointer with rdmaxcel.
1497 unsafe { rdmaxcel_sys::rdmaxcel_register_segment_scanner(scanner) }
1498}
1499
1500#[cfg(test)]
1501mod tests {
1502 use super::*;
1503
1504 #[test]
1505 fn test_create_connection() {
1506 // Skip test if RDMA devices are not available
1507 if crate::ibverbs_primitives::get_all_devices().is_empty() {
1508 println!("Skipping test: RDMA devices not available");
1509 return;
1510 }
1511
1512 let config = IbverbsConfig {
1513 use_gpu_direct: false,
1514 ..Default::default()
1515 };
1516 let domain = RdmaDomain::new(config.device.clone());
1517 assert!(domain.is_ok());
1518
1519 let domain = domain.unwrap();
1520 let queue_pair = RdmaQueuePair::new(domain.context, domain.pd, config.clone());
1521 assert!(queue_pair.is_ok());
1522 }
1523
1524 #[test]
1525 fn test_loopback_connection() {
1526 // Skip test if RDMA devices are not available
1527 if crate::ibverbs_primitives::get_all_devices().is_empty() {
1528 println!("Skipping test: RDMA devices not available");
1529 return;
1530 }
1531
1532 let server_config = IbverbsConfig {
1533 use_gpu_direct: false,
1534 ..Default::default()
1535 };
1536 let client_config = IbverbsConfig {
1537 use_gpu_direct: false,
1538 ..Default::default()
1539 };
1540
1541 let server_domain = RdmaDomain::new(server_config.device.clone()).unwrap();
1542 let client_domain = RdmaDomain::new(client_config.device.clone()).unwrap();
1543
1544 let mut server_qp = RdmaQueuePair::new(
1545 server_domain.context,
1546 server_domain.pd,
1547 server_config.clone(),
1548 )
1549 .unwrap();
1550 let mut client_qp = RdmaQueuePair::new(
1551 client_domain.context,
1552 client_domain.pd,
1553 client_config.clone(),
1554 )
1555 .unwrap();
1556
1557 let server_connection_info = server_qp.get_qp_info().unwrap();
1558 let client_connection_info = client_qp.get_qp_info().unwrap();
1559
1560 assert!(server_qp.connect(&client_connection_info).is_ok());
1561 assert!(client_qp.connect(&server_connection_info).is_ok());
1562 }
1563}