monarch_rdma/rdma_components.rs
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # RDMA Components
10//!
11//! This module provides the core RDMA building blocks for establishing and managing RDMA connections.
12//!
13//! ## Core Components
14//!
15//! * `RdmaDomain` - Manages RDMA resources including context, protection domain, and memory region
16//! * `RdmaQueuePair` - Handles communication between endpoints via queue pairs and completion queues
17//!
18//! ## RDMA Overview
19//!
20//! Remote Direct Memory Access (RDMA) allows direct memory access from the memory of one computer
21//! into the memory of another without involving either computer's operating system. This permits
22//! high-throughput, low-latency networking with minimal CPU overhead.
23//!
24//! ## Connection Architecture
25//!
26//! The module manages the following ibverbs primitives:
27//!
28//! 1. **Queue Pairs (QP)**: Each connection has a send queue and a receive queue
29//! 2. **Completion Queues (CQ)**: Events are reported when operations complete
30//! 3. **Memory Regions (MR)**: Memory must be registered with the RDMA device before use
31//! 4. **Protection Domains (PD)**: Provide isolation between different connections
32//!
33//! ## Connection Lifecycle
34//!
35//! 1. Create an `RdmaDomain` with `new()`
36//! 2. Create an `RdmaQueuePair` from the domain
37//! 3. Exchange connection info with remote peer (application must handle this)
38//! 4. Connect to remote endpoint with `connect()`
39//! 5. Perform RDMA operations (read/write)
40//! 6. Poll for completions
41//! 7. Resources are cleaned up when dropped
42
43/// Maximum size for a single RDMA operation in bytes (1 GiB)
44const MAX_RDMA_MSG_SIZE: usize = 1024 * 1024 * 1024;
45
46use std::ffi::CStr;
47use std::fs;
48use std::io::Error;
49use std::result::Result;
50use std::thread::sleep;
51use std::time::Duration;
52
53use hyperactor::ActorRef;
54use hyperactor::Named;
55use hyperactor::clock::Clock;
56use hyperactor::clock::RealClock;
57use hyperactor::context;
58use serde::Deserialize;
59use serde::Serialize;
60
61use crate::RdmaDevice;
62use crate::RdmaManagerActor;
63use crate::RdmaManagerMessageClient;
64use crate::ibverbs_primitives::Gid;
65use crate::ibverbs_primitives::IbvWc;
66use crate::ibverbs_primitives::IbverbsConfig;
67use crate::ibverbs_primitives::RdmaOperation;
68use crate::ibverbs_primitives::RdmaQpInfo;
69use crate::ibverbs_primitives::resolve_qp_type;
70
71#[derive(Debug, Named, Clone, Serialize, Deserialize)]
72pub struct DoorBell {
73 pub src_ptr: usize,
74 pub dst_ptr: usize,
75 pub size: usize,
76}
77
78impl DoorBell {
79 /// Rings the doorbell to trigger the execution of previously enqueued operations.
80 ///
81 /// This method uses unsafe code to directly interact with the RDMA device,
82 /// sending a signal from the source pointer to the destination pointer.
83 ///
84 /// # Returns
85 /// * `Ok(())` if the operation is successful.
86 /// * `Err(anyhow::Error)` if an error occurs during the operation.
87 pub fn ring(&self) -> Result<(), anyhow::Error> {
88 unsafe {
89 let src_ptr = self.src_ptr as *mut std::ffi::c_void;
90 let dst_ptr = self.dst_ptr as *mut std::ffi::c_void;
91 rdmaxcel_sys::db_ring(dst_ptr, src_ptr);
92 Ok(())
93 }
94 }
95}
96
97#[derive(Debug, Serialize, Deserialize, Named, Clone)]
98pub struct RdmaBuffer {
99 pub owner: ActorRef<RdmaManagerActor>,
100 pub mr_id: usize,
101 pub lkey: u32,
102 pub rkey: u32,
103 pub addr: usize,
104 pub size: usize,
105 pub device_name: String,
106}
107
108impl RdmaBuffer {
109 /// Read from the RdmaBuffer into the provided memory.
110 ///
111 /// This method transfers data from the buffer into the local memory region provided over RDMA.
112 /// This involves calling a `Put` operation on the RdmaBuffer's actor side.
113 ///
114 /// # Arguments
115 /// * `client` - The actor who is reading.
116 /// * `remote` - RdmaBuffer representing the remote memory region
117 /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
118 ///
119 /// # Returns
120 /// `Ok(bool)` indicating if the operation completed successfully.
121 pub async fn read_into(
122 &self,
123 client: &impl context::Actor,
124 remote: RdmaBuffer,
125 timeout: u64,
126 ) -> Result<bool, anyhow::Error> {
127 tracing::debug!(
128 "[buffer] reading from {:?} into remote ({:?}) at {:?}",
129 self,
130 remote.owner.actor_id(),
131 remote,
132 );
133 let remote_owner = remote.owner.clone();
134
135 let local_device = self.device_name.clone();
136 let remote_device = remote.device_name.clone();
137 let mut qp = self
138 .owner
139 .request_queue_pair(
140 client,
141 remote_owner.clone(),
142 local_device.clone(),
143 remote_device.clone(),
144 )
145 .await?;
146
147 let wr_id = qp.put(self.clone(), remote)?;
148 let result = self
149 .wait_for_completion(&mut qp, PollTarget::Send, &wr_id, timeout)
150 .await;
151
152 // Release the queue pair back to the actor
153 self.owner
154 .release_queue_pair(client, remote_owner, local_device, remote_device, qp)
155 .await?;
156
157 result
158 }
159
160 /// Write from the provided memory into the RdmaBuffer.
161 ///
162 /// This method performs an RDMA write operation, transferring data from the caller's
163 /// memory region to this buffer.
164 /// This involves calling a `Fetch` operation on the RdmaBuffer's actor side.
165 ///
166 /// # Arguments
167 /// * `client` - The actor who is writing.
168 /// * `remote` - RdmaBuffer representing the remote memory region
169 /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
170 ///
171 /// # Returns
172 /// `Ok(bool)` indicating if the operation completed successfully.
173 pub async fn write_from(
174 &self,
175 client: &impl context::Actor,
176 remote: RdmaBuffer,
177 timeout: u64,
178 ) -> Result<bool, anyhow::Error> {
179 tracing::debug!(
180 "[buffer] writing into {:?} from remote ({:?}) at {:?}",
181 self,
182 remote.owner.actor_id(),
183 remote,
184 );
185 let remote_owner = remote.owner.clone(); // Clone before the move!
186
187 // Extract device name from buffer, fallback to a default if not present
188 let local_device = self.device_name.clone();
189 let remote_device = remote.device_name.clone();
190
191 let mut qp = self
192 .owner
193 .request_queue_pair(
194 client,
195 remote_owner.clone(),
196 local_device.clone(),
197 remote_device.clone(),
198 )
199 .await?;
200 let wr_id = qp.get(self.clone(), remote)?;
201 let result = self
202 .wait_for_completion(&mut qp, PollTarget::Send, &wr_id, timeout)
203 .await;
204
205 // Release the queue pair back to the actor
206 self.owner
207 .release_queue_pair(client, remote_owner, local_device, remote_device, qp)
208 .await?;
209
210 result
211 }
212 /// Waits for the completion of RDMA operations.
213 ///
214 /// This method polls the completion queue until all specified work requests complete
215 /// or until the timeout is reached.
216 ///
217 /// # Arguments
218 /// * `qp` - The RDMA Queue Pair to poll for completion
219 /// * `poll_target` - Which CQ to poll (Send or Recv)
220 /// * `expected_wr_ids` - The work request IDs to wait for
221 /// * `timeout` - Timeout in seconds for the RDMA operation to complete.
222 ///
223 /// # Returns
224 /// `Ok(true)` if all operations complete successfully within the timeout,
225 /// or an error if the timeout is reached
226 async fn wait_for_completion(
227 &self,
228 qp: &mut RdmaQueuePair,
229 poll_target: PollTarget,
230 expected_wr_ids: &[u64],
231 timeout: u64,
232 ) -> Result<bool, anyhow::Error> {
233 let timeout = Duration::from_secs(timeout);
234 let start_time = std::time::Instant::now();
235
236 let mut remaining: std::collections::HashSet<u64> =
237 expected_wr_ids.iter().copied().collect();
238
239 while start_time.elapsed() < timeout {
240 if remaining.is_empty() {
241 return Ok(true);
242 }
243
244 let wr_ids_to_poll: Vec<u64> = remaining.iter().copied().collect();
245 match qp.poll_completion(poll_target, &wr_ids_to_poll) {
246 Ok(completions) => {
247 for (wr_id, _wc) in completions {
248 remaining.remove(&wr_id);
249 }
250 if remaining.is_empty() {
251 return Ok(true);
252 }
253 RealClock.sleep(Duration::from_millis(1)).await;
254 }
255 Err(e) => {
256 return Err(anyhow::anyhow!(
257 "RDMA polling completion failed: {} [lkey={}, rkey={}, addr=0x{:x}, size={}]",
258 e,
259 self.lkey,
260 self.rkey,
261 self.addr,
262 self.size
263 ));
264 }
265 }
266 }
267 tracing::error!(
268 "timed out while waiting on request completion for wr_ids={:?}",
269 remaining
270 );
271 Err(anyhow::anyhow!(
272 "[buffer({:?})] rdma operation did not complete in time (expected wr_ids={:?})",
273 self,
274 expected_wr_ids
275 ))
276 }
277
278 /// Drop the buffer and release remote handles.
279 ///
280 /// This method calls the owning RdmaManagerActor to release the buffer and clean up
281 /// associated memory regions. This is typically called when the buffer is no longer
282 /// needed and resources should be freed.
283 ///
284 /// # Arguments
285 /// * `client` - Mailbox used for communication
286 ///
287 /// # Returns
288 /// `Ok(())` if the operation completed successfully.
289 pub async fn drop_buffer(&self, client: &impl context::Actor) -> Result<(), anyhow::Error> {
290 tracing::debug!("[buffer] dropping buffer {:?}", self);
291 self.owner.release_buffer(client, self.clone()).await?;
292 Ok(())
293 }
294}
295
296/// Represents a domain for RDMA operations, encapsulating the necessary resources
297/// for establishing and managing RDMA connections.
298///
299/// An `RdmaDomain` manages the context, protection domain (PD), and memory region (MR)
300/// required for RDMA operations. It provides the foundation for creating queue pairs
301/// and establishing connections between RDMA devices.
302///
303/// # Fields
304///
305/// * `context`: A pointer to the RDMA device context, representing the connection to the RDMA device.
306/// * `pd`: A pointer to the protection domain, which provides isolation between different connections.
307#[derive(Clone)]
308pub struct RdmaDomain {
309 pub context: *mut rdmaxcel_sys::ibv_context,
310 pub pd: *mut rdmaxcel_sys::ibv_pd,
311}
312
313impl std::fmt::Debug for RdmaDomain {
314 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 f.debug_struct("RdmaDomain")
316 .field("context", &format!("{:p}", self.context))
317 .field("pd", &format!("{:p}", self.pd))
318 .finish()
319 }
320}
321
322// SAFETY:
323// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls.
324// RdmaDomain is `Send` because the raw pointers to ibverbs structs can be
325// accessed from any thread, and it is safe to drop `RdmaDomain` (and run the
326// ibverbs destructors) from any thread.
327unsafe impl Send for RdmaDomain {}
328
329// SAFETY:
330// This function contains code marked unsafe as it interacts with the Rdma device through rdmaxcel_sys calls.
331// RdmaDomain is `Sync` because the underlying ibverbs APIs are thread-safe.
332unsafe impl Sync for RdmaDomain {}
333
334impl Drop for RdmaDomain {
335 fn drop(&mut self) {
336 unsafe {
337 rdmaxcel_sys::ibv_dealloc_pd(self.pd);
338 }
339 }
340}
341
342impl RdmaDomain {
343 /// Creates a new RdmaDomain.
344 ///
345 /// This function initializes the RDMA device context, creates a protection domain,
346 /// and registers a memory region with appropriate access permissions.
347 ///
348 /// SAFETY:
349 /// Our memory region (MR) registration uses implicit ODP for RDMA access, which maps large virtual
350 /// address ranges without explicit pinning. This is convenient, but it broadens the memory footprint
351 /// exposed to the NIC and introduces a security liability.
352 ///
353 /// We currently assume a trusted, single-environment and are not enforcing finer-grained memory isolation
354 /// at this layer. We plan to investigate mitigations - such as memory windows or tighter registration
355 /// boundaries in future follow-ups.
356 ///
357 /// # Arguments
358 ///
359 /// * `config` - Configuration settings for the RDMA operations
360 ///
361 /// # Errors
362 ///
363 /// This function may return errors if:
364 /// * No RDMA devices are found
365 /// * The specified device cannot be found
366 /// * Device context creation fails
367 /// * Protection domain allocation fails
368 /// * Memory region registration fails
369 pub fn new(device: RdmaDevice) -> Result<Self, anyhow::Error> {
370 tracing::debug!("creating RdmaDomain for device {}", device.name());
371 // SAFETY:
372 // This code uses unsafe rdmaxcel_sys calls to interact with the RDMA device, but is safe because:
373 // - All pointers are properly initialized and checked for null before use
374 // - Memory registration follows the ibverbs API contract with proper access flags
375 // - Resources are properly cleaned up in error cases to prevent leaks
376 // - The operations follow the documented RDMA protocol for device initialization
377 unsafe {
378 // Get the device based on the provided RdmaDevice
379 let device_name = device.name();
380 let mut num_devices = 0i32;
381 let devices = rdmaxcel_sys::ibv_get_device_list(&mut num_devices as *mut _);
382
383 if devices.is_null() || num_devices == 0 {
384 return Err(anyhow::anyhow!("no RDMA devices found"));
385 }
386
387 // Find the device with the matching name
388 let mut device_ptr = std::ptr::null_mut();
389 for i in 0..num_devices {
390 let dev = *devices.offset(i as isize);
391 let dev_name =
392 CStr::from_ptr(rdmaxcel_sys::ibv_get_device_name(dev)).to_string_lossy();
393
394 if dev_name == *device_name {
395 device_ptr = dev;
396 break;
397 }
398 }
399
400 // If we didn't find the device, return an error
401 if device_ptr.is_null() {
402 rdmaxcel_sys::ibv_free_device_list(devices);
403 return Err(anyhow::anyhow!("device '{}' not found", device_name));
404 }
405 tracing::info!("using RDMA device: {}", device_name);
406
407 // Open device
408 let context = rdmaxcel_sys::ibv_open_device(device_ptr);
409 if context.is_null() {
410 rdmaxcel_sys::ibv_free_device_list(devices);
411 let os_error = Error::last_os_error();
412 return Err(anyhow::anyhow!("failed to create context: {}", os_error));
413 }
414
415 // Create protection domain
416 let pd = rdmaxcel_sys::ibv_alloc_pd(context);
417 if pd.is_null() {
418 rdmaxcel_sys::ibv_close_device(context);
419 rdmaxcel_sys::ibv_free_device_list(devices);
420 let os_error = Error::last_os_error();
421 return Err(anyhow::anyhow!(
422 "failed to create protection domain (PD): {}",
423 os_error
424 ));
425 }
426
427 // Avoids memory leaks
428 rdmaxcel_sys::ibv_free_device_list(devices);
429
430 let domain = RdmaDomain { context, pd };
431
432 Ok(domain)
433 }
434 }
435}
436/// Enum to specify which completion queue to poll
437#[derive(Debug, Clone, Copy, PartialEq)]
438pub enum PollTarget {
439 Send,
440 Recv,
441}
442
443/// Represents an RDMA Queue Pair (QP) that enables communication between two endpoints.
444///
445/// An `RdmaQueuePair` encapsulates the send and receive queues, completion queue,
446/// and other resources needed for RDMA communication. It provides methods for
447/// establishing connections and performing RDMA operations like read and write.
448///
449/// # Fields
450///
451/// * `send_cq` - Send Completion Queue pointer for tracking send operation completions
452/// * `recv_cq` - Receive Completion Queue pointer for tracking receive operation completions
453/// * `qp` - Queue Pair pointer that manages send and receive operations
454/// * `dv_qp` - Pointer to the mlx5 device-specific queue pair structure
455/// * `dv_send_cq` - Pointer to the mlx5 device-specific send completion queue structure
456/// * `dv_recv_cq` - Pointer to the mlx5 device-specific receive completion queue structure
457/// * `context` - RDMA device context pointer
458/// * `config` - Configuration settings for the queue pair
459///
460/// # Connection Lifecycle
461///
462/// 1. Create with `new()` from an `RdmaDomain`
463/// 2. Get connection info with `get_qp_info()`
464/// 3. Exchange connection info with remote peer (application must handle this)
465/// 4. Connect to remote endpoint with `connect()`
466/// 5. Perform RDMA operations with `put()` or `get()`
467/// 6. Poll for completions with `poll_send_completion(wr_id)` or `poll_recv_completion(wr_id)`
468///
469/// # Notes
470/// - The `qp` field stores a pointer to `rdmaxcel_qp_t` (not `ibv_qp`)
471/// - `rdmaxcel_qp_t` contains atomic counters and completion caches internally
472/// - This makes RdmaQueuePair trivially Clone and Serialize
473/// - Multiple clones share the same underlying rdmaxcel_qp_t via the pointer
474#[derive(Debug, Serialize, Deserialize, Named, Clone)]
475pub struct RdmaQueuePair {
476 pub send_cq: usize, // *mut rdmaxcel_sys::ibv_cq,
477 pub recv_cq: usize, // *mut rdmaxcel_sys::ibv_cq,
478 pub qp: usize, // *mut rdmaxcel_sys::rdmaxcel_qp_t
479 pub dv_qp: usize, // *mut rdmaxcel_sys::mlx5dv_qp,
480 pub dv_send_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
481 pub dv_recv_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
482 context: usize, // *mut rdmaxcel_sys::ibv_context,
483 config: IbverbsConfig,
484}
485
486impl RdmaQueuePair {
487 /// Applies hardware initialization delay if this is the first operation since RTS.
488 ///
489 /// This ensures the hardware has sufficient time to settle after reaching
490 /// Ready-to-Send state before the first actual operation.
491 fn apply_first_op_delay(&self, wr_id: u64) {
492 unsafe {
493 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
494 if wr_id == 0 {
495 let rts_timestamp = rdmaxcel_sys::rdmaxcel_qp_load_rts_timestamp(qp);
496 assert!(
497 rts_timestamp != u64::MAX,
498 "First operation attempted before queue pair reached RTS state! Call connect() first."
499 );
500 let current_nanos = RealClock
501 .system_time_now()
502 .duration_since(std::time::UNIX_EPOCH)
503 .unwrap()
504 .as_nanos() as u64;
505 let elapsed_nanos = current_nanos - rts_timestamp;
506 let elapsed = Duration::from_nanos(elapsed_nanos);
507 let init_delay = Duration::from_millis(self.config.hw_init_delay_ms);
508 if elapsed < init_delay {
509 let remaining_delay = init_delay - elapsed;
510 sleep(remaining_delay);
511 }
512 }
513 }
514 }
515
516 /// Creates a new RdmaQueuePair from a given RdmaDomain.
517 ///
518 /// This function initializes a new Queue Pair (QP) and associated Completion Queue (CQ)
519 /// using the resources from the provided RdmaDomain. The QP is created in the RESET state
520 /// and must be transitioned to other states via the `connect()` method before use.
521 ///
522 /// # Arguments
523 ///
524 /// * `domain` - Reference to an RdmaDomain that provides the context, protection domain,
525 /// and memory region for this queue pair
526 ///
527 /// # Returns
528 ///
529 /// * `Result<Self>` - A new RdmaQueuePair instance or an error if creation fails
530 ///
531 /// # Errors
532 ///
533 /// This function may return errors if:
534 /// * Completion queue (CQ) creation fails
535 /// * Queue pair (QP) creation fails
536 pub fn new(
537 context: *mut rdmaxcel_sys::ibv_context,
538 pd: *mut rdmaxcel_sys::ibv_pd,
539 config: IbverbsConfig,
540 ) -> Result<Self, anyhow::Error> {
541 tracing::debug!("creating an RdmaQueuePair from config {}", config);
542 unsafe {
543 // Resolve Auto to a concrete QP type based on device capabilities
544 let resolved_qp_type = resolve_qp_type(config.qp_type);
545 let qp = rdmaxcel_sys::rdmaxcel_qp_create(
546 context,
547 pd,
548 config.cq_entries,
549 config.max_send_wr.try_into().unwrap(),
550 config.max_recv_wr.try_into().unwrap(),
551 config.max_send_sge.try_into().unwrap(),
552 config.max_recv_sge.try_into().unwrap(),
553 resolved_qp_type,
554 );
555
556 if qp.is_null() {
557 let os_error = Error::last_os_error();
558 return Err(anyhow::anyhow!(
559 "failed to create queue pair (QP): {}",
560 os_error
561 ));
562 }
563
564 let send_cq = (*(*qp).ibv_qp).send_cq;
565 let recv_cq = (*(*qp).ibv_qp).recv_cq;
566
567 // mlx5dv provider APIs
568 let dv_qp = rdmaxcel_sys::create_mlx5dv_qp((*qp).ibv_qp);
569 let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq((*qp).ibv_qp);
570 let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq((*qp).ibv_qp);
571
572 if dv_qp.is_null() || dv_send_cq.is_null() || dv_recv_cq.is_null() {
573 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
574 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
575 rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
576 return Err(anyhow::anyhow!(
577 "failed to init mlx5dv_qp or completion queues"
578 ));
579 }
580
581 // GPU Direct RDMA specific registrations
582 if config.use_gpu_direct {
583 let ret = rdmaxcel_sys::register_cuda_memory(dv_qp, dv_recv_cq, dv_send_cq);
584 if ret != 0 {
585 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
586 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
587 rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
588 return Err(anyhow::anyhow!(
589 "failed to register GPU Direct RDMA memory: {:?}",
590 ret
591 ));
592 }
593 }
594 Ok(RdmaQueuePair {
595 send_cq: send_cq as usize,
596 recv_cq: recv_cq as usize,
597 qp: qp as usize,
598 dv_qp: qp as usize,
599 dv_send_cq: dv_send_cq as usize,
600 dv_recv_cq: dv_recv_cq as usize,
601 context: context as usize,
602 config,
603 })
604 }
605 }
606
607 /// Returns the information required for a remote peer to connect to this queue pair.
608 ///
609 /// This method retrieves the local queue pair attributes and port information needed by
610 /// a remote peer to establish an RDMA connection. The returned `RdmaQpInfo` contains
611 /// the queue pair number, LID, GID, and other necessary connection parameters.
612 ///
613 /// # Returns
614 ///
615 /// * `Result<RdmaQpInfo>` - Connection information for the remote peer or an error
616 ///
617 /// # Errors
618 ///
619 /// This function may return errors if:
620 /// * Port attribute query fails
621 /// * GID query fails
622 pub fn get_qp_info(&mut self) -> Result<RdmaQpInfo, anyhow::Error> {
623 // SAFETY:
624 // This code uses unsafe rdmaxcel_sys calls to query RDMA device information, but is safe because:
625 // - All pointers are properly initialized before use
626 // - Port and GID queries follow the documented ibverbs API contract
627 // - Error handling properly checks return codes from ibverbs functions
628 // - The memory address provided is only stored, not dereferenced in this function
629 unsafe {
630 let context = self.context as *mut rdmaxcel_sys::ibv_context;
631 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
632 let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
633 let errno = rdmaxcel_sys::ibv_query_port(
634 context,
635 self.config.port_num,
636 &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
637 );
638 if errno != 0 {
639 let os_error = Error::last_os_error();
640 return Err(anyhow::anyhow!(
641 "Failed to query port attributes: {}",
642 os_error
643 ));
644 }
645
646 let mut gid = Gid::default();
647 let ret = rdmaxcel_sys::ibv_query_gid(
648 context,
649 self.config.port_num,
650 i32::from(self.config.gid_index),
651 gid.as_mut(),
652 );
653 if ret != 0 {
654 return Err(anyhow::anyhow!("Failed to query GID"));
655 }
656
657 Ok(RdmaQpInfo {
658 qp_num: (*(*qp).ibv_qp).qp_num,
659 lid: port_attr.lid,
660 gid: Some(gid),
661 psn: self.config.psn,
662 })
663 }
664 }
665
666 pub fn state(&mut self) -> Result<u32, anyhow::Error> {
667 // SAFETY: This block interacts with the RDMA device through rdmaxcel_sys calls.
668 unsafe {
669 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
670 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
671 ..Default::default()
672 };
673 let mut qp_init_attr = rdmaxcel_sys::ibv_qp_init_attr {
674 ..Default::default()
675 };
676 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE;
677 let errno = rdmaxcel_sys::ibv_query_qp(
678 (*qp).ibv_qp,
679 &mut qp_attr,
680 mask.0 as i32,
681 &mut qp_init_attr,
682 );
683 if errno != 0 {
684 let os_error = Error::last_os_error();
685 return Err(anyhow::anyhow!("failed to query QP state: {}", os_error));
686 }
687 Ok(qp_attr.qp_state)
688 }
689 }
690 /// Connect to a remote Rdma connection point.
691 ///
692 /// This performs the necessary QP state transitions (INIT->RTR->RTS) to establish a connection.
693 ///
694 /// # Arguments
695 ///
696 /// * `connection_info` - The remote connection info to connect to
697 pub fn connect(&mut self, connection_info: &RdmaQpInfo) -> Result<(), anyhow::Error> {
698 // SAFETY:
699 // This unsafe block is necessary because we're interacting with the RDMA device through rdmaxcel_sys calls.
700 // The operations are safe because:
701 // 1. We're following the documented ibverbs API contract
702 // 2. All pointers used are properly initialized and owned by this struct
703 // 3. The QP state transitions (INIT->RTR->RTS) follow the required RDMA connection protocol
704 // 4. Memory access is properly bounded by the registered memory regions
705 unsafe {
706 // Transition to INIT
707 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
708
709 let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
710 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
711 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
712 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
713
714 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
715 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_INIT,
716 qp_access_flags: qp_access_flags.0,
717 pkey_index: self.config.pkey_index,
718 port_num: self.config.port_num,
719 ..Default::default()
720 };
721
722 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
723 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX
724 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT
725 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS;
726
727 let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
728 if errno != 0 {
729 let os_error = Error::last_os_error();
730 return Err(anyhow::anyhow!(
731 "failed to transition QP to INIT: {}",
732 os_error
733 ));
734 }
735
736 // Transition to RTR (Ready to Receive)
737 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
738 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTR,
739 path_mtu: self.config.path_mtu,
740 dest_qp_num: connection_info.qp_num,
741 rq_psn: connection_info.psn,
742 max_dest_rd_atomic: self.config.max_dest_rd_atomic,
743 min_rnr_timer: self.config.min_rnr_timer,
744 ah_attr: rdmaxcel_sys::ibv_ah_attr {
745 dlid: connection_info.lid,
746 sl: 0,
747 src_path_bits: 0,
748 port_num: self.config.port_num,
749 grh: Default::default(),
750 ..Default::default()
751 },
752 ..Default::default()
753 };
754
755 // If the remote connection info contains a Gid, the routing will be global.
756 // Otherwise, it will be local, i.e. using LID.
757 if let Some(gid) = connection_info.gid {
758 qp_attr.ah_attr.is_global = 1;
759 qp_attr.ah_attr.grh.dgid = gid.into();
760 qp_attr.ah_attr.grh.hop_limit = 0xff;
761 qp_attr.ah_attr.grh.sgid_index = self.config.gid_index;
762 } else {
763 // Use LID-based routing, e.g. for Infiniband/RoCEv1
764 qp_attr.ah_attr.is_global = 0;
765 }
766
767 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
768 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_AV
769 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PATH_MTU
770 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_DEST_QPN
771 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RQ_PSN
772 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC
773 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER;
774
775 let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
776 if errno != 0 {
777 let os_error = Error::last_os_error();
778 return Err(anyhow::anyhow!(
779 "failed to transition QP to RTR: {}",
780 os_error
781 ));
782 }
783
784 // Transition to RTS (Ready to Send)
785 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
786 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS,
787 sq_psn: self.config.psn,
788 max_rd_atomic: self.config.max_rd_atomic,
789 retry_cnt: self.config.retry_cnt,
790 rnr_retry: self.config.rnr_retry,
791 timeout: self.config.qp_timeout,
792 ..Default::default()
793 };
794
795 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
796 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_TIMEOUT
797 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RETRY_CNT
798 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_SQ_PSN
799 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY
800 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC;
801
802 let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
803 if errno != 0 {
804 let os_error = Error::last_os_error();
805 return Err(anyhow::anyhow!(
806 "failed to transition QP to RTS: {}",
807 os_error
808 ));
809 }
810 tracing::debug!(
811 "connection sequence has successfully completed (qp: {:?})",
812 qp
813 );
814
815 // Record RTS time now that the queue pair is ready to send
816 let rts_timestamp_nanos = RealClock
817 .system_time_now()
818 .duration_since(std::time::UNIX_EPOCH)
819 .unwrap()
820 .as_nanos() as u64;
821 rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos);
822
823 Ok(())
824 }
825 }
826
827 pub fn recv(&mut self, lhandle: RdmaBuffer, rhandle: RdmaBuffer) -> Result<u64, anyhow::Error> {
828 unsafe {
829 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
830 let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(qp);
831 self.post_op(
832 0,
833 lhandle.lkey,
834 0,
835 idx,
836 true,
837 RdmaOperation::Recv,
838 0,
839 rhandle.rkey,
840 )
841 .unwrap();
842 rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(qp);
843 Ok(idx)
844 }
845 }
846
847 pub fn put_with_recv(
848 &mut self,
849 lhandle: RdmaBuffer,
850 rhandle: RdmaBuffer,
851 ) -> Result<Vec<u64>, anyhow::Error> {
852 unsafe {
853 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
854 let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(qp);
855 self.post_op(
856 lhandle.addr,
857 lhandle.lkey,
858 lhandle.size,
859 idx,
860 true,
861 RdmaOperation::WriteWithImm,
862 rhandle.addr,
863 rhandle.rkey,
864 )
865 .unwrap();
866 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(qp);
867 Ok(vec![idx])
868 }
869 }
870
871 pub fn put(
872 &mut self,
873 lhandle: RdmaBuffer,
874 rhandle: RdmaBuffer,
875 ) -> Result<Vec<u64>, anyhow::Error> {
876 let total_size = lhandle.size;
877 if rhandle.size < total_size {
878 return Err(anyhow::anyhow!(
879 "Remote buffer size ({}) is smaller than local buffer size ({})",
880 rhandle.size,
881 total_size
882 ));
883 }
884
885 let mut remaining = total_size;
886 let mut offset = 0;
887 let mut wr_ids = Vec::new();
888 while remaining > 0 {
889 let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
890 let idx = unsafe {
891 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
892 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
893 )
894 };
895 wr_ids.push(idx);
896 self.post_op(
897 lhandle.addr + offset,
898 lhandle.lkey,
899 chunk_size,
900 idx,
901 true,
902 RdmaOperation::Write,
903 rhandle.addr + offset,
904 rhandle.rkey,
905 )?;
906 unsafe {
907 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
908 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
909 );
910 }
911
912 remaining -= chunk_size;
913 offset += chunk_size;
914 }
915
916 Ok(wr_ids)
917 }
918
919 /// Get a doorbell for the queue pair.
920 ///
921 /// This method returns a doorbell that can be used to trigger the execution of
922 /// previously enqueued operations.
923 ///
924 /// # Returns
925 ///
926 /// * `Result<DoorBell, anyhow::Error>` - A doorbell for the queue pair
927 pub fn ring_doorbell(&mut self) -> Result<(), anyhow::Error> {
928 unsafe {
929 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
930 let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
931 let base_ptr = (*dv_qp).sq.buf as *mut u8;
932 let wqe_cnt = (*dv_qp).sq.wqe_cnt;
933 let stride = (*dv_qp).sq.stride;
934 let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(qp);
935 let mut send_db_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp);
936 if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) {
937 return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
938 }
939 self.apply_first_op_delay(send_db_idx);
940 while send_db_idx < send_wqe_idx {
941 let offset = (send_db_idx % wqe_cnt as u64) * stride as u64;
942 let src_ptr = (base_ptr as *mut u8).wrapping_add(offset as usize);
943 rdmaxcel_sys::db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
944 send_db_idx += 1;
945 rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx(qp, send_db_idx);
946 }
947 Ok(())
948 }
949 }
950
951 /// Enqueues a put operation without ringing the doorbell.
952 ///
953 /// This method prepares a put operation but does not execute it.
954 /// Use `get_doorbell().ring()` to execute the operation.
955 ///
956 /// # Arguments
957 ///
958 /// * `lhandle` - Local buffer handle
959 /// * `rhandle` - Remote buffer handle
960 ///
961 /// # Returns
962 ///
963 /// * `Result<Vec<u64>, anyhow::Error>` - The work request IDs or error
964 pub fn enqueue_put(
965 &mut self,
966 lhandle: RdmaBuffer,
967 rhandle: RdmaBuffer,
968 ) -> Result<Vec<u64>, anyhow::Error> {
969 let idx = unsafe {
970 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
971 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
972 )
973 };
974
975 self.send_wqe(
976 lhandle.addr,
977 lhandle.lkey,
978 lhandle.size,
979 idx,
980 true,
981 RdmaOperation::Write,
982 rhandle.addr,
983 rhandle.rkey,
984 )?;
985 Ok(vec![idx])
986 }
987
988 /// Enqueues a put with receive operation without ringing the doorbell.
989 ///
990 /// This method prepares a put with receive operation but does not execute it.
991 /// Use `get_doorbell().ring()` to execute the operation.
992 ///
993 /// # Arguments
994 ///
995 /// * `lhandle` - Local buffer handle
996 /// * `rhandle` - Remote buffer handle
997 ///
998 /// # Returns
999 ///
1000 /// * `Result<Vec<u64>, anyhow::Error>` - The work request IDs or error
1001 pub fn enqueue_put_with_recv(
1002 &mut self,
1003 lhandle: RdmaBuffer,
1004 rhandle: RdmaBuffer,
1005 ) -> Result<Vec<u64>, anyhow::Error> {
1006 let idx = unsafe {
1007 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
1008 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1009 )
1010 };
1011
1012 self.send_wqe(
1013 lhandle.addr,
1014 lhandle.lkey,
1015 lhandle.size,
1016 idx,
1017 true,
1018 RdmaOperation::WriteWithImm,
1019 rhandle.addr,
1020 rhandle.rkey,
1021 )?;
1022 Ok(vec![idx])
1023 }
1024
1025 /// Enqueues a get operation without ringing the doorbell.
1026 ///
1027 /// This method prepares a get operation but does not execute it.
1028 /// Use `get_doorbell().ring()` to execute the operation.
1029 ///
1030 /// # Arguments
1031 ///
1032 /// * `lhandle` - Local buffer handle
1033 /// * `rhandle` - Remote buffer handle
1034 ///
1035 /// # Returns
1036 ///
1037 /// * `Result<Vec<u64>, anyhow::Error>` - The work request IDs or error
1038 pub fn enqueue_get(
1039 &mut self,
1040 lhandle: RdmaBuffer,
1041 rhandle: RdmaBuffer,
1042 ) -> Result<Vec<u64>, anyhow::Error> {
1043 let idx = unsafe {
1044 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
1045 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1046 )
1047 };
1048
1049 self.send_wqe(
1050 lhandle.addr,
1051 lhandle.lkey,
1052 lhandle.size,
1053 idx,
1054 true,
1055 RdmaOperation::Read,
1056 rhandle.addr,
1057 rhandle.rkey,
1058 )?;
1059 Ok(vec![idx])
1060 }
1061
1062 pub fn get(
1063 &mut self,
1064 lhandle: RdmaBuffer,
1065 rhandle: RdmaBuffer,
1066 ) -> Result<Vec<u64>, anyhow::Error> {
1067 let total_size = lhandle.size;
1068 if rhandle.size < total_size {
1069 return Err(anyhow::anyhow!(
1070 "Remote buffer size ({}) is smaller than local buffer size ({})",
1071 rhandle.size,
1072 total_size
1073 ));
1074 }
1075
1076 let mut remaining = total_size;
1077 let mut offset = 0;
1078 let mut wr_ids = Vec::new();
1079
1080 while remaining > 0 {
1081 let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
1082 let idx = unsafe {
1083 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
1084 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1085 )
1086 };
1087 wr_ids.push(idx);
1088
1089 self.post_op(
1090 lhandle.addr + offset,
1091 lhandle.lkey,
1092 chunk_size,
1093 idx,
1094 true,
1095 RdmaOperation::Read,
1096 rhandle.addr + offset,
1097 rhandle.rkey,
1098 )?;
1099 unsafe {
1100 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
1101 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
1102 );
1103 }
1104
1105 remaining -= chunk_size;
1106 offset += chunk_size;
1107 }
1108
1109 Ok(wr_ids)
1110 }
1111
1112 /// Posts a request to the queue pair.
1113 ///
1114 /// # Arguments
1115 ///
1116 /// * `local_addr` - The local address containing data to send
1117 /// * `length` - Length of the data to send
1118 /// * `wr_id` - Work request ID for completion identification
1119 /// * `signaled` - Whether to generate a completion event
1120 /// * `op_type` - Optional operation type
1121 /// * `raddr` - the remote address, representing the memory location on the remote peer
1122 /// * `rkey` - the remote key, representing the key required to access the remote memory region
1123 fn post_op(
1124 &mut self,
1125 laddr: usize,
1126 lkey: u32,
1127 length: usize,
1128 wr_id: u64,
1129 signaled: bool,
1130 op_type: RdmaOperation,
1131 raddr: usize,
1132 rkey: u32,
1133 ) -> Result<(), anyhow::Error> {
1134 // SAFETY:
1135 // This code uses unsafe rdmaxcel_sys calls to post work requests to the RDMA device, but is safe because:
1136 // - All pointers (send_sge, send_wr) are properly initialized on the stack before use
1137 // - The memory address in `local_addr` is not dereferenced, only passed to the device
1138 // - The remote connection info is verified to exist before accessing
1139 // - The ibverbs post_send operation follows the documented API contract
1140 // - Error codes from the device are properly checked and propagated
1141 unsafe {
1142 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
1143 let context = self.context as *mut rdmaxcel_sys::ibv_context;
1144 let ops = &mut (*context).ops;
1145 let errno;
1146 if op_type == RdmaOperation::Recv {
1147 let mut sge = rdmaxcel_sys::ibv_sge {
1148 addr: laddr as u64,
1149 length: length as u32,
1150 lkey,
1151 };
1152 let mut wr = rdmaxcel_sys::ibv_recv_wr {
1153 wr_id: wr_id.into(),
1154 sg_list: &mut sge as *mut _,
1155 num_sge: 1,
1156 ..Default::default()
1157 };
1158 let mut bad_wr: *mut rdmaxcel_sys::ibv_recv_wr = std::ptr::null_mut();
1159 errno =
1160 ops.post_recv.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
1161 } else if op_type == RdmaOperation::Write
1162 || op_type == RdmaOperation::Read
1163 || op_type == RdmaOperation::WriteWithImm
1164 {
1165 // Apply hardware initialization delay if this is the first operation
1166 self.apply_first_op_delay(wr_id);
1167 let send_flags = if signaled {
1168 rdmaxcel_sys::ibv_send_flags::IBV_SEND_SIGNALED.0
1169 } else {
1170 0
1171 };
1172 let mut sge = rdmaxcel_sys::ibv_sge {
1173 addr: laddr as u64,
1174 length: length as u32,
1175 lkey,
1176 };
1177 let mut wr = rdmaxcel_sys::ibv_send_wr {
1178 wr_id: wr_id.into(),
1179 next: std::ptr::null_mut(),
1180 sg_list: &mut sge as *mut _,
1181 num_sge: 1,
1182 opcode: op_type.into(),
1183 send_flags,
1184 wr: Default::default(),
1185 qp_type: Default::default(),
1186 __bindgen_anon_1: Default::default(),
1187 __bindgen_anon_2: Default::default(),
1188 };
1189
1190 wr.wr.rdma.remote_addr = raddr as u64;
1191 wr.wr.rdma.rkey = rkey;
1192 let mut bad_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut();
1193
1194 errno =
1195 ops.post_send.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
1196 } else {
1197 panic!("Not Implemented");
1198 }
1199
1200 if errno != 0 {
1201 let os_error = Error::last_os_error();
1202 return Err(anyhow::anyhow!("Failed to post send request: {}", os_error));
1203 }
1204 tracing::debug!(
1205 "completed sending {:?} request (lkey: {}, addr: 0x{:x}, length {}) to (raddr 0x{:x}, rkey {})",
1206 op_type,
1207 lkey,
1208 laddr,
1209 length,
1210 raddr,
1211 rkey,
1212 );
1213
1214 Ok(())
1215 }
1216 }
1217
1218 fn send_wqe(
1219 &mut self,
1220 laddr: usize,
1221 lkey: u32,
1222 length: usize,
1223 wr_id: u64,
1224 signaled: bool,
1225 op_type: RdmaOperation,
1226 raddr: usize,
1227 rkey: u32,
1228 ) -> Result<DoorBell, anyhow::Error> {
1229 unsafe {
1230 let op_type_val = match op_type {
1231 RdmaOperation::Write => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE,
1232 RdmaOperation::WriteWithImm => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE_IMM,
1233 RdmaOperation::Read => rdmaxcel_sys::MLX5_OPCODE_RDMA_READ,
1234 RdmaOperation::Recv => 0,
1235 };
1236
1237 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
1238 let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
1239 let _dv_cq = if op_type == RdmaOperation::Recv {
1240 self.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq
1241 } else {
1242 self.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq
1243 };
1244
1245 // Create the WQE parameters struct
1246
1247 let buf = if op_type == RdmaOperation::Recv {
1248 (*dv_qp).rq.buf as *mut u8
1249 } else {
1250 (*dv_qp).sq.buf as *mut u8
1251 };
1252
1253 let params = rdmaxcel_sys::wqe_params_t {
1254 laddr,
1255 lkey,
1256 length,
1257 wr_id: wr_id.into(),
1258 signaled,
1259 op_type: op_type_val,
1260 raddr,
1261 rkey,
1262 qp_num: (*(*qp).ibv_qp).qp_num,
1263 buf,
1264 dbrec: (*dv_qp).dbrec,
1265 wqe_cnt: (*dv_qp).sq.wqe_cnt,
1266 };
1267
1268 // Call the C function to post the WQE
1269 if op_type == RdmaOperation::Recv {
1270 rdmaxcel_sys::recv_wqe(params);
1271 std::ptr::write_volatile((*dv_qp).dbrec, 1_u32.to_be());
1272 } else {
1273 rdmaxcel_sys::send_wqe(params);
1274 };
1275
1276 // Create and return a DoorBell struct
1277 Ok(DoorBell {
1278 dst_ptr: (*dv_qp).bf.reg as usize,
1279 src_ptr: (*dv_qp).sq.buf as usize,
1280 size: 8,
1281 })
1282 }
1283 }
1284
1285 /// Poll for work completions by wr_ids.
1286 ///
1287 /// # Arguments
1288 ///
1289 /// * `target` - Which completion queue to poll (Send, Receive)
1290 /// * `expected_wr_ids` - Slice of work request IDs to wait for
1291 ///
1292 /// # Returns
1293 ///
1294 /// * `Ok(Vec<(u64, IbvWc)>)` - Vector of (wr_id, completion) pairs that were found
1295 /// * `Err(e)` - An error occurred
1296 pub fn poll_completion(
1297 &mut self,
1298 target: PollTarget,
1299 expected_wr_ids: &[u64],
1300 ) -> Result<Vec<(u64, IbvWc)>, anyhow::Error> {
1301 if expected_wr_ids.is_empty() {
1302 return Ok(Vec::new());
1303 }
1304
1305 unsafe {
1306 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
1307 let qp_num = (*(*qp).ibv_qp).qp_num;
1308
1309 let (cq, cache, cq_type) = match target {
1310 PollTarget::Send => (
1311 self.send_cq as *mut rdmaxcel_sys::ibv_cq,
1312 rdmaxcel_sys::rdmaxcel_qp_get_send_cache(qp),
1313 "send",
1314 ),
1315 PollTarget::Recv => (
1316 self.recv_cq as *mut rdmaxcel_sys::ibv_cq,
1317 rdmaxcel_sys::rdmaxcel_qp_get_recv_cache(qp),
1318 "recv",
1319 ),
1320 };
1321
1322 let mut results = Vec::new();
1323
1324 // Single-shot poll: check each wr_id once and return what we find
1325 for &expected_wr_id in expected_wr_ids {
1326 let mut poll_ctx = rdmaxcel_sys::poll_context_t {
1327 expected_wr_id,
1328 expected_qp_num: qp_num,
1329 cache,
1330 cq,
1331 };
1332
1333 let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
1334 let ret = rdmaxcel_sys::poll_cq_with_cache(&mut poll_ctx, &mut wc);
1335
1336 match ret {
1337 1 => {
1338 // Found completion
1339 if !wc.is_valid() {
1340 if let Some((status, vendor_err)) = wc.error() {
1341 return Err(anyhow::anyhow!(
1342 "{} completion failed for wr_id={}: status={:?}, vendor_err={}",
1343 cq_type,
1344 expected_wr_id,
1345 status,
1346 vendor_err,
1347 ));
1348 }
1349 }
1350 results.push((expected_wr_id, IbvWc::from(wc)));
1351 }
1352 0 => {
1353 // Not found yet - this is fine for single-shot poll
1354 }
1355 -17 => {
1356 // RDMAXCEL_COMPLETION_FAILED: Completion found but failed - wc contains the error details
1357 let error_msg =
1358 std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1359 .to_str()
1360 .unwrap_or("Unknown error");
1361 if let Some((status, vendor_err)) = wc.error() {
1362 return Err(anyhow::anyhow!(
1363 "Failed to poll {} CQ for wr_id={}: {} [status={:?}, vendor_err={}, qp_num={}, byte_len={}]",
1364 cq_type,
1365 expected_wr_id,
1366 error_msg,
1367 status,
1368 vendor_err,
1369 wc.qp_num,
1370 wc.len(),
1371 ));
1372 } else {
1373 return Err(anyhow::anyhow!(
1374 "Failed to poll {} CQ for wr_id={}: {} [qp_num={}, byte_len={}]",
1375 cq_type,
1376 expected_wr_id,
1377 error_msg,
1378 wc.qp_num,
1379 wc.len(),
1380 ));
1381 }
1382 }
1383 _ => {
1384 // Other errors
1385 let error_msg =
1386 std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1387 .to_str()
1388 .unwrap_or("Unknown error");
1389 return Err(anyhow::anyhow!(
1390 "Failed to poll {} CQ for wr_id={}: {}",
1391 cq_type,
1392 expected_wr_id,
1393 error_msg
1394 ));
1395 }
1396 }
1397 }
1398
1399 Ok(results)
1400 }
1401 }
1402}
1403
1404/// Utility to validate execution context.
1405///
1406/// Remote Execution environments do not always have access to the nvidia_peermem module
1407/// and/or set the PeerMappingOverride parameter due to security. This function can be
1408/// used to validate that the execution context when running operations that need this
1409/// functionality (ie. cudaHostRegisterIoMemory).
1410///
1411/// # Returns
1412///
1413/// * `Ok(())` if the execution context is valid
1414/// * `Err(anyhow::Error)` if the execution context is invalid
1415pub async fn validate_execution_context() -> Result<(), anyhow::Error> {
1416 // Check for nvidia peermem
1417 match fs::read_to_string("/proc/modules") {
1418 Ok(contents) => {
1419 if !contents.contains("nvidia_peermem") {
1420 return Err(anyhow::anyhow!(
1421 "nvidia_peermem module not found in /proc/modules"
1422 ));
1423 }
1424 }
1425 Err(e) => {
1426 return Err(anyhow::anyhow!(e));
1427 }
1428 }
1429
1430 // Test file access to nvidia params
1431 match fs::read_to_string("/proc/driver/nvidia/params") {
1432 Ok(contents) => {
1433 if !contents.contains("PeerMappingOverride=1") {
1434 return Err(anyhow::anyhow!(
1435 "PeerMappingOverride=1 not found in /proc/driver/nvidia/params"
1436 ));
1437 }
1438 }
1439 Err(e) => {
1440 return Err(anyhow::anyhow!(e));
1441 }
1442 }
1443 Ok(())
1444}
1445
1446/// Get all segments that have been registered with MRs
1447///
1448/// # Returns
1449/// * `Vec<SegmentInfo>` - Vector containing all registered segment information
1450pub fn get_registered_cuda_segments() -> Vec<rdmaxcel_sys::rdma_segment_info_t> {
1451 unsafe {
1452 let segment_count = rdmaxcel_sys::rdma_get_active_segment_count();
1453 if segment_count <= 0 {
1454 return Vec::new();
1455 }
1456
1457 let mut segments = vec![
1458 std::mem::MaybeUninit::<rdmaxcel_sys::rdma_segment_info_t>::zeroed()
1459 .assume_init();
1460 segment_count as usize
1461 ];
1462 let actual_count =
1463 rdmaxcel_sys::rdma_get_all_segment_info(segments.as_mut_ptr(), segment_count);
1464
1465 if actual_count > 0 {
1466 segments.truncate(actual_count as usize);
1467 segments
1468 } else {
1469 Vec::new()
1470 }
1471 }
1472}
1473
1474/// Check if PyTorch CUDA caching allocator has expandable segments enabled.
1475///
1476/// This function calls the C++ implementation that directly accesses the
1477/// PyTorch C10 CUDA allocator configuration to check if expandable segments
1478/// are enabled, which is required for RDMA operations with CUDA tensors.
1479///
1480/// # Returns
1481///
1482/// `true` if both CUDA caching allocator is enabled AND expandable segments are enabled,
1483/// `false` otherwise.
1484pub fn pt_cuda_allocator_compatibility() -> bool {
1485 // SAFETY: We are calling a C++ function from rdmaxcel that accesses PyTorch C10 APIs.
1486 unsafe { rdmaxcel_sys::pt_cuda_allocator_compatibility() }
1487}
1488
1489#[cfg(test)]
1490mod tests {
1491 use super::*;
1492
1493 #[test]
1494 fn test_create_connection() {
1495 // Skip test if RDMA devices are not available
1496 if crate::ibverbs_primitives::get_all_devices().is_empty() {
1497 println!("Skipping test: RDMA devices not available");
1498 return;
1499 }
1500
1501 let config = IbverbsConfig {
1502 use_gpu_direct: false,
1503 ..Default::default()
1504 };
1505 let domain = RdmaDomain::new(config.device.clone());
1506 assert!(domain.is_ok());
1507
1508 let domain = domain.unwrap();
1509 let queue_pair = RdmaQueuePair::new(domain.context, domain.pd, config.clone());
1510 assert!(queue_pair.is_ok());
1511 }
1512
1513 #[test]
1514 fn test_loopback_connection() {
1515 // Skip test if RDMA devices are not available
1516 if crate::ibverbs_primitives::get_all_devices().is_empty() {
1517 println!("Skipping test: RDMA devices not available");
1518 return;
1519 }
1520
1521 let server_config = IbverbsConfig {
1522 use_gpu_direct: false,
1523 ..Default::default()
1524 };
1525 let client_config = IbverbsConfig {
1526 use_gpu_direct: false,
1527 ..Default::default()
1528 };
1529
1530 let server_domain = RdmaDomain::new(server_config.device.clone()).unwrap();
1531 let client_domain = RdmaDomain::new(client_config.device.clone()).unwrap();
1532
1533 let mut server_qp = RdmaQueuePair::new(
1534 server_domain.context,
1535 server_domain.pd,
1536 server_config.clone(),
1537 )
1538 .unwrap();
1539 let mut client_qp = RdmaQueuePair::new(
1540 client_domain.context,
1541 client_domain.pd,
1542 client_config.clone(),
1543 )
1544 .unwrap();
1545
1546 let server_connection_info = server_qp.get_qp_info().unwrap();
1547 let client_connection_info = client_qp.get_qp_info().unwrap();
1548
1549 assert!(server_qp.connect(&client_connection_info).is_ok());
1550 assert!(client_qp.connect(&server_connection_info).is_ok());
1551 }
1552}