Skip to main content

monarch_rdma/backend/ibverbs/
queue_pair.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! ibverbs queue pair, doorbell, and completion polling.
10//!
11//! An [`IbvQueuePair`] encapsulates the send and receive queues, completion
12//! queues, and other resources needed for RDMA communication. It provides
13//! methods for establishing connections and performing RDMA operations.
14
15/// Maximum size for a single RDMA operation in bytes (1 GiB).
16const MAX_RDMA_MSG_SIZE: usize = 1024 * 1024 * 1024;
17
18use std::io::Error;
19use std::result::Result;
20use std::time::Duration;
21
22use async_trait::async_trait;
23use hyperactor::Actor;
24use hyperactor::ActorHandle;
25use hyperactor::ActorId;
26use hyperactor::ActorRef;
27use hyperactor::Context;
28use hyperactor::Endpoint as _;
29use hyperactor::Handler;
30use hyperactor::Instance;
31use hyperactor::PortRef;
32use hyperactor::actor::Binds;
33use hyperactor::actor::Referable;
34use hyperactor::actor::RemoteHandles;
35use hyperactor::mailbox::MessageEnvelope;
36use hyperactor::mailbox::Undeliverable;
37use serde::Deserialize;
38use serde::Serialize;
39use typeuri::Named;
40
41use super::IbvBuffer;
42use super::manager_actor::EnsureQueuePair;
43use super::manager_actor::QpInitializerDone;
44use super::manager_actor::QpInitializerFailed;
45use super::primitives::Gid;
46use super::primitives::IbvConfig;
47use super::primitives::IbvOperation;
48use super::primitives::IbvQpInfo;
49use super::primitives::IbvWc;
50use super::primitives::resolve_qp_type;
51
52/// A structured error from [`IbvQueuePair::poll_completion`].
53///
54/// Carries the `ibv_wc_status` and vendor error code (when available) so
55/// callers can match on specific completion statuses without string parsing.
56#[derive(Debug)]
57pub struct PollCompletionError {
58    pub status: Option<rdmaxcel_sys::ibv_wc_status::Type>,
59    pub vendor_err: Option<u32>,
60    message: String,
61}
62
63impl std::fmt::Display for PollCompletionError {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        f.write_str(&self.message)
66    }
67}
68
69impl std::error::Error for PollCompletionError {}
70
71impl PollCompletionError {
72    /// Returns `true` when the completion status is `IBV_WC_WR_FLUSH_ERR`,
73    /// which typically indicates a secondary failure after the QP entered
74    /// error state due to a different work request's failure.
75    pub fn is_wr_flush_err(&self) -> bool {
76        self.status == Some(rdmaxcel_sys::ibv_wc_status::IBV_WC_WR_FLUSH_ERR)
77    }
78}
79
80/// A doorbell trigger for batched RDMA operations.
81///
82/// Rings the hardware doorbell to execute previously enqueued work requests.
83#[derive(Debug, Named, Clone, Serialize, Deserialize)]
84pub struct DoorBell {
85    pub src_ptr: usize,
86    pub dst_ptr: usize,
87    pub size: usize,
88}
89wirevalue::register_type!(DoorBell);
90
91/// Specifies which completion queue to poll.
92#[derive(Debug, Clone, Copy, PartialEq)]
93pub enum PollTarget {
94    Send,
95    Recv,
96}
97
98/// An RDMA Queue Pair (QP) for communication between two endpoints.
99///
100/// Encapsulates the send/receive queues, completion queues, and mlx5dv
101/// device-specific structures needed for RDMA communication.
102///
103/// # Connection Lifecycle
104///
105/// 1. Create with `new()` from context and protection domain pointers
106/// 2. Get connection info with `get_qp_info()`
107/// 3. Exchange connection info with remote peer
108/// 4. Connect to remote endpoint with `connect()`
109/// 5. Perform RDMA operations with `put()` or `get()`
110/// 6. Poll for completions with `poll_completion()`
111///
112/// # Notes
113/// - The `qp` field stores a pointer to `rdmaxcel_qp_t` (not `ibv_qp`)
114/// - `rdmaxcel_qp_t` contains atomic counters and completion caches internally
115/// - This makes IbvQueuePair trivially Clone and Serialize
116/// - Multiple clones share the same underlying rdmaxcel_qp_t via the pointer
117#[derive(Debug, Serialize, Deserialize, Named, Clone)]
118pub struct IbvQueuePair {
119    pub send_cq: usize,    // *mut rdmaxcel_sys::ibv_cq,
120    pub recv_cq: usize,    // *mut rdmaxcel_sys::ibv_cq,
121    pub qp: usize,         // *mut rdmaxcel_sys::rdmaxcel_qp_t
122    pub dv_qp: usize,      // *mut rdmaxcel_sys::mlx5dv_qp,
123    pub dv_send_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
124    pub dv_recv_cq: usize, // *mut rdmaxcel_sys::mlx5dv_cq,
125    context: usize,        // *mut rdmaxcel_sys::ibv_context,
126    config: IbvConfig,
127    is_efa: bool,
128}
129wirevalue::register_type!(IbvQueuePair);
130
131impl IbvQueuePair {
132    fn is_efa(&self) -> bool {
133        self.is_efa
134    }
135
136    /// Applies hardware initialization delay if this is the first operation since RTS.
137    fn apply_first_op_delay(&self, wr_id: u64) {
138        unsafe {
139            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
140            if wr_id == 0 {
141                let rts_timestamp = rdmaxcel_sys::rdmaxcel_qp_load_rts_timestamp(qp);
142                assert!(
143                    rts_timestamp != u64::MAX,
144                    "First operation attempted before queue pair reached RTS state! Call connect() first."
145                );
146                let current_nanos = std::time::SystemTime::now()
147                    .duration_since(std::time::UNIX_EPOCH)
148                    .unwrap()
149                    .as_nanos() as u64;
150                let elapsed_nanos = current_nanos - rts_timestamp;
151                let elapsed = Duration::from_nanos(elapsed_nanos);
152                let init_delay = Duration::from_millis(self.config.hw_init_delay_ms);
153                if elapsed < init_delay {
154                    let remaining_delay = init_delay - elapsed;
155                    // Sync context within unsafe block; tokio::time::sleep is async
156                    // and converting would require propagating async through the
157                    // entire post_op / ring_doorbell call chain.
158                    std::thread::sleep(remaining_delay);
159                }
160            }
161        }
162    }
163
164    /// Creates a new IbvQueuePair.
165    ///
166    /// Initializes a new Queue Pair (QP) and associated Completion Queues (CQ)
167    /// using the provided context and protection domain. The QP is created in
168    /// the RESET state and must be transitioned via `connect()` before use.
169    ///
170    /// # Errors
171    ///
172    /// Returns errors if CQ or QP creation fails.
173    pub fn new(
174        context: *mut rdmaxcel_sys::ibv_context,
175        pd: *mut rdmaxcel_sys::ibv_pd,
176        config: IbvConfig,
177    ) -> Result<Self, anyhow::Error> {
178        tracing::debug!("creating an IbvQueuePair from config {}", config);
179        unsafe {
180            // Resolve Auto to a concrete QP type based on device capabilities
181            let resolved_qp_type = resolve_qp_type(config.qp_type);
182            let is_efa = resolved_qp_type == rdmaxcel_sys::RDMA_QP_TYPE_EFA;
183            let qp = rdmaxcel_sys::rdmaxcel_qp_create(
184                context,
185                pd,
186                config.cq_entries,
187                config.max_send_wr.try_into().unwrap(),
188                config.max_recv_wr.try_into().unwrap(),
189                config.max_send_sge.try_into().unwrap(),
190                config.max_recv_sge.try_into().unwrap(),
191                resolved_qp_type,
192            );
193
194            if qp.is_null() {
195                let os_error = Error::last_os_error();
196                return Err(anyhow::anyhow!(
197                    "failed to create queue pair (QP): {}",
198                    os_error
199                ));
200            }
201
202            let send_cq = (*(*qp).ibv_qp).send_cq;
203            let recv_cq = (*(*qp).ibv_qp).recv_cq;
204
205            // EFA uses standard ibverbs (not mlx5dv), so skip dv setup
206            if is_efa {
207                return Ok(IbvQueuePair {
208                    send_cq: send_cq as usize,
209                    recv_cq: recv_cq as usize,
210                    qp: qp as usize,
211                    dv_qp: 0,
212                    dv_send_cq: 0,
213                    dv_recv_cq: 0,
214                    context: context as usize,
215                    config,
216                    is_efa: true,
217                });
218            }
219
220            let dv_qp = rdmaxcel_sys::create_mlx5dv_qp((*qp).ibv_qp);
221            let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq((*qp).ibv_qp);
222            let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq((*qp).ibv_qp);
223
224            if dv_qp.is_null() || dv_send_cq.is_null() || dv_recv_cq.is_null() {
225                rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
226                rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
227                rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
228                return Err(anyhow::anyhow!(
229                    "failed to init mlx5dv_qp or completion queues"
230                ));
231            }
232
233            if config.use_gpu_direct {
234                let ret = rdmaxcel_sys::register_cuda_memory(dv_qp, dv_recv_cq, dv_send_cq);
235                if ret != 0 {
236                    rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
237                    rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
238                    rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
239                    return Err(anyhow::anyhow!(
240                        "failed to register GPU Direct RDMA memory: {:?}",
241                        ret
242                    ));
243                }
244            }
245            Ok(IbvQueuePair {
246                send_cq: send_cq as usize,
247                recv_cq: recv_cq as usize,
248                qp: qp as usize,
249                dv_qp: dv_qp as usize,
250                dv_send_cq: dv_send_cq as usize,
251                dv_recv_cq: dv_recv_cq as usize,
252                context: context as usize,
253                config,
254                is_efa: false,
255            })
256        }
257    }
258
259    /// Returns the connection info needed by a remote peer to connect to this QP.
260    pub fn get_qp_info(&mut self) -> Result<IbvQpInfo, anyhow::Error> {
261        unsafe {
262            let context = self.context as *mut rdmaxcel_sys::ibv_context;
263            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
264            let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
265            let errno = rdmaxcel_sys::ibv_query_port(
266                context,
267                self.config.port_num,
268                &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
269            );
270            if errno != 0 {
271                let os_error = Error::last_os_error();
272                return Err(anyhow::anyhow!(
273                    "Failed to query port attributes: {}",
274                    os_error
275                ));
276            }
277
278            let mut gid = Gid::default();
279            let ret = rdmaxcel_sys::ibv_query_gid(
280                context,
281                self.config.port_num,
282                i32::from(self.config.gid_index),
283                gid.as_mut(),
284            );
285            if ret != 0 {
286                return Err(anyhow::anyhow!("Failed to query GID"));
287            }
288
289            Ok(IbvQpInfo {
290                qp_num: (*(*qp).ibv_qp).qp_num,
291                lid: port_attr.lid,
292                gid: Some(gid),
293                psn: self.config.psn,
294            })
295        }
296    }
297
298    /// Returns the current state of the QP.
299    pub fn state(&mut self) -> Result<u32, anyhow::Error> {
300        unsafe {
301            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
302            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
303                ..Default::default()
304            };
305            let mut qp_init_attr = rdmaxcel_sys::ibv_qp_init_attr {
306                ..Default::default()
307            };
308            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE;
309            let errno = rdmaxcel_sys::ibv_query_qp(
310                (*qp).ibv_qp,
311                &mut qp_attr,
312                mask.0 as i32,
313                &mut qp_init_attr,
314            );
315            if errno != 0 {
316                let os_error = Error::last_os_error();
317                return Err(anyhow::anyhow!("failed to query QP state: {}", os_error));
318            }
319            Ok(qp_attr.qp_state)
320        }
321    }
322
323    /// Transitions the QP through INIT -> RTR -> RTS to establish a connection.
324    ///
325    /// # Arguments
326    ///
327    /// * `connection_info` - The remote connection info to connect to
328    pub fn connect(&mut self, connection_info: &IbvQpInfo) -> Result<(), anyhow::Error> {
329        // EFA: use unified C function for QP state transitions
330        if self.is_efa() {
331            return self.efa_connect(connection_info);
332        }
333
334        unsafe {
335            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
336
337            let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
338                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
339                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
340                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
341
342            // Transition to INIT
343            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
344                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_INIT,
345                qp_access_flags: qp_access_flags.0,
346                pkey_index: self.config.pkey_index,
347                port_num: self.config.port_num,
348                ..Default::default()
349            };
350
351            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
352                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX
353                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT
354                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS;
355
356            let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
357            if errno != 0 {
358                let os_error = Error::last_os_error();
359                return Err(anyhow::anyhow!(
360                    "failed to transition QP to INIT: {}",
361                    os_error
362                ));
363            }
364
365            // Transition to RTR (Ready to Receive)
366            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
367                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTR,
368                path_mtu: self.config.path_mtu,
369                dest_qp_num: connection_info.qp_num,
370                rq_psn: connection_info.psn,
371                max_dest_rd_atomic: self.config.max_dest_rd_atomic,
372                min_rnr_timer: self.config.min_rnr_timer,
373                ah_attr: rdmaxcel_sys::ibv_ah_attr {
374                    dlid: connection_info.lid,
375                    sl: 0,
376                    src_path_bits: 0,
377                    port_num: self.config.port_num,
378                    grh: Default::default(),
379                    ..Default::default()
380                },
381                ..Default::default()
382            };
383
384            if let Some(gid) = connection_info.gid {
385                qp_attr.ah_attr.is_global = 1;
386                qp_attr.ah_attr.grh.dgid = rdmaxcel_sys::ibv_gid::from(gid);
387                qp_attr.ah_attr.grh.hop_limit = 0xff;
388                qp_attr.ah_attr.grh.sgid_index = self.config.gid_index;
389            } else {
390                qp_attr.ah_attr.is_global = 0;
391            }
392
393            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
394                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_AV
395                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PATH_MTU
396                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_DEST_QPN
397                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RQ_PSN
398                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC
399                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER;
400
401            let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
402            if errno != 0 {
403                let os_error = Error::last_os_error();
404                return Err(anyhow::anyhow!(
405                    "failed to transition QP to RTR: {}",
406                    os_error
407                ));
408            }
409
410            // Transition to RTS (Ready to Send)
411            let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
412                qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS,
413                sq_psn: self.config.psn,
414                max_rd_atomic: self.config.max_rd_atomic,
415                retry_cnt: self.config.retry_cnt,
416                rnr_retry: self.config.rnr_retry,
417                timeout: self.config.qp_timeout,
418                ..Default::default()
419            };
420
421            let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
422                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_TIMEOUT
423                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RETRY_CNT
424                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_SQ_PSN
425                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY
426                | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC;
427
428            let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
429            if errno != 0 {
430                let os_error = Error::last_os_error();
431                return Err(anyhow::anyhow!(
432                    "failed to transition QP to RTS: {}",
433                    os_error
434                ));
435            }
436            tracing::debug!(
437                "connection sequence has successfully completed (qp: {:?})",
438                qp
439            );
440
441            let rts_timestamp_nanos = std::time::SystemTime::now()
442                .duration_since(std::time::UNIX_EPOCH)
443                .unwrap()
444                .as_nanos() as u64;
445            rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos);
446
447            Ok(())
448        }
449    }
450
451    /// Connects via the EFA-specific C function for QP state transitions.
452    fn efa_connect(&mut self, connection_info: &IbvQpInfo) -> Result<(), anyhow::Error> {
453        let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
454
455        let gid_ptr = connection_info.gid.as_ref().map_or(std::ptr::null(), |g| {
456            let ibv_gid: &rdmaxcel_sys::ibv_gid = g.as_ref();
457            unsafe { ibv_gid.raw.as_ptr() }
458        });
459
460        unsafe {
461            let ret = rdmaxcel_sys::rdmaxcel_efa_connect(
462                qp,
463                self.config.port_num,
464                self.config.pkey_index,
465                0x4242, // qkey
466                self.config.psn,
467                self.config.gid_index,
468                gid_ptr,
469                connection_info.qp_num,
470            );
471            if ret != 0 {
472                let msg = std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
473                    .to_str()
474                    .unwrap_or("unknown");
475                return Err(anyhow::anyhow!("EFA connect failed: {}", msg));
476            }
477        }
478
479        // Store RTS timestamp for first-op delay
480        let rts_timestamp_nanos = std::time::SystemTime::now()
481            .duration_since(std::time::UNIX_EPOCH)
482            .unwrap()
483            .as_nanos() as u64;
484        unsafe {
485            rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos);
486        }
487
488        Ok(())
489    }
490
491    pub fn recv(&mut self, lhandle: IbvBuffer, rhandle: IbvBuffer) -> Result<u64, anyhow::Error> {
492        unsafe {
493            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
494            let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(qp);
495            self.post_op(
496                0,
497                lhandle.lkey,
498                0,
499                idx,
500                true,
501                IbvOperation::Recv,
502                0,
503                rhandle.rkey,
504            )
505            .unwrap();
506            rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(qp);
507            Ok(idx)
508        }
509    }
510
511    pub fn put_with_recv(
512        &mut self,
513        lhandle: IbvBuffer,
514        rhandle: IbvBuffer,
515    ) -> Result<Vec<u64>, anyhow::Error> {
516        unsafe {
517            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
518            let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(qp);
519            self.post_op(
520                lhandle.addr,
521                lhandle.lkey,
522                lhandle.size,
523                idx,
524                true,
525                IbvOperation::WriteWithImm,
526                rhandle.addr,
527                rhandle.rkey,
528            )
529            .unwrap();
530            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(qp);
531            Ok(vec![idx])
532        }
533    }
534
535    pub fn put(
536        &mut self,
537        lhandle: IbvBuffer,
538        rhandle: IbvBuffer,
539    ) -> Result<Vec<u64>, anyhow::Error> {
540        let total_size = lhandle.size;
541        if rhandle.size < total_size {
542            return Err(anyhow::anyhow!(
543                "Remote buffer size ({}) is smaller than local buffer size ({})",
544                rhandle.size,
545                total_size
546            ));
547        }
548
549        let mut remaining = total_size;
550        let mut offset = 0;
551        let mut wr_ids = Vec::new();
552        while remaining > 0 {
553            let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
554            let idx = unsafe {
555                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
556                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
557                )
558            };
559            wr_ids.push(idx);
560            self.post_op(
561                lhandle.addr + offset,
562                lhandle.lkey,
563                chunk_size,
564                idx,
565                true,
566                IbvOperation::Write,
567                rhandle.addr + offset,
568                rhandle.rkey,
569            )?;
570            unsafe {
571                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
572                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
573                );
574            }
575
576            remaining -= chunk_size;
577            offset += chunk_size;
578        }
579
580        Ok(wr_ids)
581    }
582
583    /// Rings the doorbell to execute all enqueued operations.
584    pub fn ring_doorbell(&mut self) -> Result<(), anyhow::Error> {
585        // EFA uses standard ibverbs (not mlx5dv), so skip doorbell ringing
586        if self.is_efa() {
587            return Ok(());
588        }
589
590        unsafe {
591            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
592            let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
593            let base_ptr = (*dv_qp).sq.buf as *mut u8;
594            let wqe_cnt = (*dv_qp).sq.wqe_cnt;
595            let stride = (*dv_qp).sq.stride;
596            let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(qp);
597            let mut send_db_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp);
598            if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) {
599                return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
600            }
601            self.apply_first_op_delay(send_db_idx);
602            while send_db_idx < send_wqe_idx {
603                let offset = (send_db_idx % wqe_cnt as u64) * stride as u64;
604                let src_ptr = base_ptr.wrapping_add(offset as usize);
605                rdmaxcel_sys::db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
606                send_db_idx += 1;
607                rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx(qp, send_db_idx);
608            }
609            Ok(())
610        }
611    }
612
613    /// Enqueues a put operation without ringing the doorbell.
614    pub fn enqueue_put(
615        &mut self,
616        lhandle: IbvBuffer,
617        rhandle: IbvBuffer,
618    ) -> Result<Vec<u64>, anyhow::Error> {
619        let idx = unsafe {
620            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
621                self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
622            )
623        };
624
625        self.send_wqe(
626            lhandle.addr,
627            lhandle.lkey,
628            lhandle.size,
629            idx,
630            true,
631            IbvOperation::Write,
632            rhandle.addr,
633            rhandle.rkey,
634        )?;
635        Ok(vec![idx])
636    }
637
638    /// Enqueues a put-with-receive operation without ringing the doorbell.
639    pub fn enqueue_put_with_recv(
640        &mut self,
641        lhandle: IbvBuffer,
642        rhandle: IbvBuffer,
643    ) -> Result<Vec<u64>, anyhow::Error> {
644        let idx = unsafe {
645            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
646                self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
647            )
648        };
649
650        self.send_wqe(
651            lhandle.addr,
652            lhandle.lkey,
653            lhandle.size,
654            idx,
655            true,
656            IbvOperation::WriteWithImm,
657            rhandle.addr,
658            rhandle.rkey,
659        )?;
660        Ok(vec![idx])
661    }
662
663    /// Enqueues a get operation without ringing the doorbell.
664    pub fn enqueue_get(
665        &mut self,
666        lhandle: IbvBuffer,
667        rhandle: IbvBuffer,
668    ) -> Result<Vec<u64>, anyhow::Error> {
669        let idx = unsafe {
670            rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
671                self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
672            )
673        };
674
675        self.send_wqe(
676            lhandle.addr,
677            lhandle.lkey,
678            lhandle.size,
679            idx,
680            true,
681            IbvOperation::Read,
682            rhandle.addr,
683            rhandle.rkey,
684        )?;
685        Ok(vec![idx])
686    }
687
688    pub fn get(
689        &mut self,
690        lhandle: IbvBuffer,
691        rhandle: IbvBuffer,
692    ) -> Result<Vec<u64>, anyhow::Error> {
693        let total_size = lhandle.size;
694        if rhandle.size < total_size {
695            return Err(anyhow::anyhow!(
696                "Remote buffer size ({}) is smaller than local buffer size ({})",
697                rhandle.size,
698                total_size
699            ));
700        }
701
702        let mut remaining = total_size;
703        let mut offset = 0;
704        let mut wr_ids = Vec::new();
705
706        while remaining > 0 {
707            let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
708            let idx = unsafe {
709                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
710                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
711                )
712            };
713            wr_ids.push(idx);
714
715            self.post_op(
716                lhandle.addr + offset,
717                lhandle.lkey,
718                chunk_size,
719                idx,
720                true,
721                IbvOperation::Read,
722                rhandle.addr + offset,
723                rhandle.rkey,
724            )?;
725            unsafe {
726                rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
727                    self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
728                );
729            }
730
731            remaining -= chunk_size;
732            offset += chunk_size;
733        }
734
735        Ok(wr_ids)
736    }
737
738    /// Posts a request to the queue pair.
739    fn post_op(
740        &mut self,
741        laddr: usize,
742        lkey: u32,
743        length: usize,
744        wr_id: u64,
745        signaled: bool,
746        op_type: IbvOperation,
747        raddr: usize,
748        rkey: u32,
749    ) -> Result<(), anyhow::Error> {
750        // EFA: use unified C function
751        if self.is_efa() {
752            return self.post_op_efa(laddr, lkey, length, wr_id, signaled, op_type, raddr, rkey);
753        }
754
755        unsafe {
756            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
757            let context = self.context as *mut rdmaxcel_sys::ibv_context;
758            let ops = &mut (*context).ops;
759            let errno;
760            if op_type == IbvOperation::Recv {
761                let mut sge = rdmaxcel_sys::ibv_sge {
762                    addr: laddr as u64,
763                    length: length as u32,
764                    lkey,
765                };
766                let mut wr = rdmaxcel_sys::ibv_recv_wr {
767                    wr_id,
768                    sg_list: &mut sge as *mut _,
769                    num_sge: 1,
770                    ..Default::default()
771                };
772                let mut bad_wr: *mut rdmaxcel_sys::ibv_recv_wr = std::ptr::null_mut();
773                errno =
774                    ops.post_recv.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
775            } else if op_type == IbvOperation::Write
776                || op_type == IbvOperation::Read
777                || op_type == IbvOperation::WriteWithImm
778            {
779                self.apply_first_op_delay(wr_id);
780                let send_flags = if signaled {
781                    rdmaxcel_sys::ibv_send_flags::IBV_SEND_SIGNALED.0
782                } else {
783                    0
784                };
785                let mut sge = rdmaxcel_sys::ibv_sge {
786                    addr: laddr as u64,
787                    length: length as u32,
788                    lkey,
789                };
790                let mut wr = rdmaxcel_sys::ibv_send_wr {
791                    wr_id,
792                    next: std::ptr::null_mut(),
793                    sg_list: &mut sge as *mut _,
794                    num_sge: 1,
795                    opcode: op_type.into(),
796                    send_flags,
797                    wr: Default::default(),
798                    qp_type: Default::default(),
799                    __bindgen_anon_1: Default::default(),
800                    __bindgen_anon_2: Default::default(),
801                };
802
803                wr.wr.rdma.remote_addr = raddr as u64;
804                wr.wr.rdma.rkey = rkey;
805                let mut bad_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut();
806
807                errno =
808                    ops.post_send.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
809            } else {
810                panic!("Not Implemented");
811            }
812
813            if errno != 0 {
814                let os_error = Error::last_os_error();
815                return Err(anyhow::anyhow!("Failed to post send request: {}", os_error));
816            }
817            tracing::debug!(
818                "completed sending {:?} request (lkey: {}, addr: 0x{:x}, length {}) to (raddr 0x{:x}, rkey {})",
819                op_type,
820                lkey,
821                laddr,
822                length,
823                raddr,
824                rkey,
825            );
826
827            Ok(())
828        }
829    }
830
831    /// Posts an RDMA operation via the EFA-specific C function.
832    fn post_op_efa(
833        &mut self,
834        laddr: usize,
835        lkey: u32,
836        length: usize,
837        wr_id: u64,
838        signaled: bool,
839        op_type: IbvOperation,
840        raddr: usize,
841        rkey: u32,
842    ) -> Result<(), anyhow::Error> {
843        let c_op = match op_type {
844            IbvOperation::Write => 0,
845            IbvOperation::Read => 1,
846            IbvOperation::Recv => 2,
847            IbvOperation::WriteWithImm => 3,
848        };
849
850        unsafe {
851            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
852            let ret = rdmaxcel_sys::rdmaxcel_qp_post_op(
853                qp,
854                laddr as *mut std::ffi::c_void,
855                lkey,
856                length,
857                raddr as *mut std::ffi::c_void,
858                rkey,
859                wr_id,
860                signaled as i32,
861                c_op,
862            );
863            if ret != 0 {
864                let msg = std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
865                    .to_str()
866                    .unwrap_or("unknown");
867                return Err(anyhow::anyhow!("EFA post_op failed: {}", msg));
868            }
869        }
870        Ok(())
871    }
872
873    fn send_wqe(
874        &mut self,
875        laddr: usize,
876        lkey: u32,
877        length: usize,
878        wr_id: u64,
879        signaled: bool,
880        op_type: IbvOperation,
881        raddr: usize,
882        rkey: u32,
883    ) -> Result<DoorBell, anyhow::Error> {
884        // Non-mlx5 devices use the unified C post_op path
885        if self.is_efa() {
886            self.post_op(laddr, lkey, length, wr_id, signaled, op_type, raddr, rkey)?;
887            return Ok(DoorBell {
888                dst_ptr: 0,
889                src_ptr: 0,
890                size: 0,
891            });
892        }
893
894        unsafe {
895            let op_type_val = match op_type {
896                IbvOperation::Write => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE,
897                IbvOperation::WriteWithImm => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE_IMM,
898                IbvOperation::Read => rdmaxcel_sys::MLX5_OPCODE_RDMA_READ,
899                IbvOperation::Recv => 0,
900            };
901
902            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
903            let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
904            let _dv_cq = if op_type == IbvOperation::Recv {
905                self.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq
906            } else {
907                self.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq
908            };
909
910            let buf = if op_type == IbvOperation::Recv {
911                (*dv_qp).rq.buf as *mut u8
912            } else {
913                (*dv_qp).sq.buf as *mut u8
914            };
915
916            let params = rdmaxcel_sys::wqe_params_t {
917                laddr,
918                lkey,
919                length,
920                wr_id,
921                signaled,
922                op_type: op_type_val,
923                raddr,
924                rkey,
925                qp_num: (*(*qp).ibv_qp).qp_num,
926                buf,
927                dbrec: (*dv_qp).dbrec,
928                wqe_cnt: (*dv_qp).sq.wqe_cnt,
929            };
930
931            if op_type == IbvOperation::Recv {
932                rdmaxcel_sys::recv_wqe(params);
933                std::ptr::write_volatile((*dv_qp).dbrec, 1_u32.to_be());
934            } else {
935                rdmaxcel_sys::send_wqe(params);
936            };
937
938            Ok(DoorBell {
939                dst_ptr: (*dv_qp).bf.reg as usize,
940                src_ptr: (*dv_qp).sq.buf as usize,
941                size: 8,
942            })
943        }
944    }
945
946    /// Polls for work completions by wr_ids.
947    ///
948    /// # Arguments
949    ///
950    /// * `target` - Which completion queue to poll (Send, Receive)
951    /// * `expected_wr_ids` - Slice of work request IDs to wait for
952    ///
953    /// # Returns
954    ///
955    /// * `Ok(Vec<(u64, IbvWc)>)` - Vector of (wr_id, completion) pairs found
956    /// * `Err(e)` - An error occurred
957    pub fn poll_completion(
958        &mut self,
959        target: PollTarget,
960        expected_wr_ids: &[u64],
961    ) -> Result<Vec<(u64, IbvWc)>, PollCompletionError> {
962        if expected_wr_ids.is_empty() {
963            return Ok(Vec::new());
964        }
965
966        unsafe {
967            let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
968            let qp_num = (*(*qp).ibv_qp).qp_num;
969
970            let (cq, cache, cq_type) = match target {
971                PollTarget::Send => (
972                    self.send_cq as *mut rdmaxcel_sys::ibv_cq,
973                    rdmaxcel_sys::rdmaxcel_qp_get_send_cache(qp),
974                    "send",
975                ),
976                PollTarget::Recv => (
977                    self.recv_cq as *mut rdmaxcel_sys::ibv_cq,
978                    rdmaxcel_sys::rdmaxcel_qp_get_recv_cache(qp),
979                    "recv",
980                ),
981            };
982
983            let mut results = Vec::new();
984
985            for &expected_wr_id in expected_wr_ids {
986                let mut poll_ctx = rdmaxcel_sys::poll_context_t {
987                    expected_wr_id,
988                    expected_qp_num: qp_num,
989                    cache,
990                    cq,
991                };
992
993                let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
994                let ret = rdmaxcel_sys::poll_cq_with_cache(&mut poll_ctx, &mut wc);
995
996                match ret {
997                    1 => {
998                        if !wc.is_valid() {
999                            if let Some((status, vendor_err)) = wc.error() {
1000                                return Err(PollCompletionError {
1001                                    status: Some(status),
1002                                    vendor_err: Some(vendor_err),
1003                                    message: format!(
1004                                        "{} completion failed for wr_id={}: status={:?}, vendor_err={}",
1005                                        cq_type, expected_wr_id, status, vendor_err,
1006                                    ),
1007                                });
1008                            }
1009                        }
1010                        results.push((expected_wr_id, IbvWc::from(wc)));
1011                    }
1012                    0 => {
1013                        // Not found yet
1014                    }
1015                    -17 => {
1016                        let error_msg =
1017                            std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1018                                .to_str()
1019                                .unwrap_or("Unknown error");
1020                        if let Some((status, vendor_err)) = wc.error() {
1021                            return Err(PollCompletionError {
1022                                status: Some(status),
1023                                vendor_err: Some(vendor_err),
1024                                message: format!(
1025                                    "Failed to poll {} CQ for wr_id={}: {} [status={:?}, vendor_err={}, qp_num={}, byte_len={}]",
1026                                    cq_type,
1027                                    expected_wr_id,
1028                                    error_msg,
1029                                    status,
1030                                    vendor_err,
1031                                    wc.qp_num,
1032                                    wc.len(),
1033                                ),
1034                            });
1035                        } else {
1036                            return Err(PollCompletionError {
1037                                status: None,
1038                                vendor_err: None,
1039                                message: format!(
1040                                    "Failed to poll {} CQ for wr_id={}: {} [qp_num={}, byte_len={}]",
1041                                    cq_type,
1042                                    expected_wr_id,
1043                                    error_msg,
1044                                    wc.qp_num,
1045                                    wc.len(),
1046                                ),
1047                            });
1048                        }
1049                    }
1050                    _ => {
1051                        let error_msg =
1052                            std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1053                                .to_str()
1054                                .unwrap_or("Unknown error");
1055                        return Err(PollCompletionError {
1056                            status: None,
1057                            vendor_err: None,
1058                            message: format!(
1059                                "Failed to poll {} CQ for wr_id={}: {}",
1060                                cq_type, expected_wr_id, error_msg,
1061                            ),
1062                        });
1063                    }
1064                }
1065            }
1066
1067            Ok(results)
1068        }
1069    }
1070}
1071
1072// =====================================================================
1073// QueuePairInitializer
1074// =====================================================================
1075//
1076// Drives one local `IbvQueuePair` through `INIT → RTR → RTS` off the
1077// owning `IbvManagerActor`'s mailbox. Each peer spawns its own
1078// initializer; the two converge by exchanging `NotifyRts` directly
1079// after one round-trip through the peer's manager (`EnsureQueuePair`
1080// → `PeerInfo`). A side declares itself "Ready" as soon as it
1081// observes the peer's `NotifyRts`; it does not wait for the peer to
1082// observe its own.
1083//
1084// Progress is tracked by two flags — `our_rts_sent` (we received
1085// `PeerInfo`, connected, and sent our `NotifyRts`) and
1086// `peer_rts_received` (we observed the peer's `NotifyRts`). When
1087// both are true the handshake hands the qp to the manager via
1088// [`QpInitializerDone`]. The `terminal` flag short-circuits any
1089// further handler work after success or failure. The qp is held in
1090// a `QpGuard` so any failure path (or aborted message delivery)
1091// destroys it.
1092
1093/// Identifies a per-peer queue pair held by one
1094/// [`super::manager_actor::IbvManagerActor`]. The same conceptual
1095/// QP is referenced by two distinct keys, one from each side: each
1096/// manager stores the local view (its own device, the peer's actor
1097/// id, the peer's device).
1098#[derive(Clone, Hash, Eq, PartialEq, Debug, Serialize, Deserialize, Named)]
1099pub(super) struct QpKey {
1100    pub(super) self_device: String,
1101    pub(super) other_id: ActorId,
1102    pub(super) other_device: String,
1103}
1104
1105/// Cross-proc reply payload for [`EnsureQueuePair`]: peer's endpoint
1106/// plus a `PortRef` to the peer initializer's `NotifyRts` port, or
1107/// an error string from the peer side.
1108#[derive(Debug, Serialize, Deserialize, Named)]
1109pub(super) struct PeerInfo(pub(super) Result<(IbvQpInfo, PortRef<NotifyRts>), String>);
1110wirevalue::register_type!(PeerInfo);
1111
1112/// Cross-proc fire-and-forget. Sent from one initializer to the peer
1113/// initializer once we hit RTS. A queue pair can begin sending to
1114/// its peer as soon as it receives this message.
1115#[derive(Debug, Serialize, Deserialize, Named)]
1116pub(super) struct NotifyRts;
1117wirevalue::register_type!(NotifyRts);
1118
1119/// Local-only self-message fired by the timeout task. Triggers
1120/// the initializer to abort the handshake.
1121#[derive(Debug)]
1122struct InitializationFailed;
1123
1124/// RAII wrapper that destroys the wrapped queue pair on drop.
1125/// Use `into_inner` to extract the qp without destroying.
1126#[derive(Debug)]
1127pub(super) struct QpGuard {
1128    qp: Option<IbvQueuePair>,
1129}
1130
1131impl QpGuard {
1132    pub(super) fn new(qp: IbvQueuePair) -> Self {
1133        Self { qp: Some(qp) }
1134    }
1135
1136    /// Consume the guard and return the qp; suppresses Drop's destroy.
1137    pub(super) fn into_inner(mut self) -> IbvQueuePair {
1138        self.qp.take().expect("QpGuard already drained")
1139    }
1140
1141    /// Delegates to [`IbvQueuePair::connect`].
1142    pub(super) fn connect(&mut self, info: &IbvQpInfo) -> Result<(), anyhow::Error> {
1143        self.qp
1144            .as_mut()
1145            .expect("QpGuard already drained")
1146            .connect(info)
1147    }
1148
1149    /// Delegates to [`IbvQueuePair::get_qp_info`].
1150    pub(super) fn get_qp_info(&mut self) -> Result<IbvQpInfo, anyhow::Error> {
1151        self.qp
1152            .as_mut()
1153            .expect("QpGuard already drained")
1154            .get_qp_info()
1155    }
1156}
1157
1158impl Drop for QpGuard {
1159    fn drop(&mut self) {
1160        if let Some(qp) = self.qp.take() {
1161            // SAFETY: `QpGuard` owns the `IbvQueuePair` and exposes
1162            // no API that hands out a reference to it, so safe code
1163            // cannot have cloned the underlying `rdmaxcel_qp_t`
1164            // pointer out from under us. The only way to extract a
1165            // live clone is `into_inner`, which consumes `self` and
1166            // skips this `Drop`; reaching here means `into_inner`
1167            // was never called.
1168            unsafe { destroy_qp(&qp) };
1169        }
1170    }
1171}
1172
1173/// Bundle of trait bounds for an actor type that can play the role
1174/// of [`QueuePairInitializer`]'s owner/peer manager.
1175pub(super) trait QpOwner:
1176    Actor
1177    + Referable
1178    + Binds<Self>
1179    + RemoteHandles<EnsureQueuePair<Self>>
1180    + Handler<QpInitializerDone>
1181    + Handler<QpInitializerFailed>
1182{
1183}
1184
1185impl<T> QpOwner for T where
1186    T: Actor
1187        + Referable
1188        + Binds<T>
1189        + RemoteHandles<EnsureQueuePair<T>>
1190        + Handler<QpInitializerDone>
1191        + Handler<QpInitializerFailed>
1192{
1193}
1194
1195/// Per-peer queue-pair handshake actor. See module docs.
1196///
1197/// Generic over the manager actor type `A` so tests can swap in a
1198/// mock.
1199#[derive(Debug)]
1200#[hyperactor::export(handlers = [PeerInfo, NotifyRts])]
1201pub(super) struct QueuePairInitializer<A: QpOwner> {
1202    owner: ActorHandle<A>,
1203    other: ActorRef<A>,
1204    qp_key: QpKey,
1205    /// Held until the handshake succeeds (handed to the manager
1206    /// via [`QpInitializerDone`]) or fails (dropped here, which
1207    /// destroys the qp via `QpGuard::drop`).
1208    qp: Option<QpGuard>,
1209    /// Per-side handshake budget pulled from
1210    /// `RDMA_QP_INIT_TIMEOUT` at construction.
1211    timeout: Duration,
1212    /// Set in `Handler<PeerInfo>` after we connect the qp and send
1213    /// our `NotifyRts` to the peer.
1214    our_rts_sent: bool,
1215    /// Set in `Handler<NotifyRts>` when the peer's `NotifyRts`
1216    /// arrives.
1217    peer_rts_received: bool,
1218    /// Set by `done`/`fail` once a terminal report has been
1219    /// dispatched to the owner. All further handler work
1220    /// short-circuits.
1221    terminal: bool,
1222    /// Currently-armed timeout. `arm_timeout` aborts any prior one.
1223    timeout_handle: Option<tokio::task::JoinHandle<()>>,
1224}
1225
1226impl<A> QueuePairInitializer<A>
1227where
1228    A: QpOwner,
1229{
1230    pub(super) fn new(
1231        owner: ActorHandle<A>,
1232        other: ActorRef<A>,
1233        qp_key: QpKey,
1234        qp: QpGuard,
1235    ) -> Self {
1236        let timeout = hyperactor_config::global::get(crate::config::RDMA_QP_INIT_TIMEOUT);
1237        Self {
1238            owner,
1239            other,
1240            qp_key,
1241            qp: Some(qp),
1242            timeout,
1243            our_rts_sent: false,
1244            peer_rts_received: false,
1245            terminal: false,
1246            timeout_handle: None,
1247        }
1248    }
1249
1250    /// Arm a fresh `InitializationFailed` timer, aborting any prior one.
1251    fn arm_timeout(&mut self, this: &Instance<Self>) {
1252        if let Some(h) = self.timeout_handle.take() {
1253            h.abort();
1254        }
1255        let self_handle: ActorHandle<Self> = this.handle();
1256        let timeout = self.timeout;
1257        let task = tokio::spawn(async move {
1258            tokio::time::sleep(timeout).await;
1259            self_handle.post(Instance::<Self>::self_client(), InitializationFailed);
1260        });
1261        self.timeout_handle = Some(task);
1262    }
1263
1264    /// Transition to the terminal failed state, drop the qp guard
1265    /// (destroying any qp held), and report failure to the owning
1266    /// manager.
1267    fn fail(&mut self, this: &Instance<Self>, error: String) -> Result<(), anyhow::Error> {
1268        if let Some(h) = self.timeout_handle.take() {
1269            h.abort();
1270        }
1271        self.qp = None;
1272        self.terminal = true;
1273        self.owner.post(
1274            this,
1275            QpInitializerFailed {
1276                qp_key: self.qp_key.clone(),
1277                error,
1278            },
1279        );
1280        Ok(())
1281    }
1282
1283    /// Transition to the terminal success state and hand the qp
1284    /// guard to the owning manager via [`QpInitializerDone`].
1285    fn done(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
1286        if let Some(h) = self.timeout_handle.take() {
1287            h.abort();
1288        }
1289        let qp = self.qp.take().expect("qp present in done()");
1290        self.terminal = true;
1291        self.owner.post(
1292            this,
1293            QpInitializerDone {
1294                qp_key: self.qp_key.clone(),
1295                qp,
1296            },
1297        );
1298        Ok(())
1299    }
1300
1301    /// Connect our qp to the peer endpoint, then notify the peer
1302    /// that we've reached RTS. Returns the failure string for
1303    /// [`Self::fail`] on error.
1304    fn connect_and_notify(
1305        &mut self,
1306        cx: &Context<Self>,
1307        info: Result<(IbvQpInfo, PortRef<NotifyRts>), String>,
1308    ) -> Result<(), String> {
1309        let (peer_endpoint, peer_notify_rts) = info?;
1310        self.qp
1311            .as_mut()
1312            .expect("qp present pre-terminal")
1313            .connect(&peer_endpoint)
1314            .map_err(|e| format!("QpGuard::connect failed: {e}"))?;
1315        peer_notify_rts.post(cx, NotifyRts);
1316        Ok(())
1317    }
1318}
1319
1320#[async_trait]
1321impl<A> Actor for QueuePairInitializer<A>
1322where
1323    A: QpOwner,
1324{
1325    async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
1326        // Send the QueuePairInitializer's PeerInfo actor port so that the reply
1327        // is routed back to this actor's handler automatically.
1328        let reply = this.bind::<Self>().port();
1329        let sender = self.owner.bind();
1330        let sender_device = self.qp_key.self_device.clone();
1331        let receiver_device = self.qp_key.other_device.clone();
1332        self.other.post(
1333            this,
1334            EnsureQueuePair {
1335                sender,
1336                sender_device,
1337                receiver_device,
1338                reply,
1339            },
1340        );
1341
1342        self.arm_timeout(this);
1343        Ok(())
1344    }
1345
1346    async fn cleanup(
1347        &mut self,
1348        _this: &Instance<Self>,
1349        _err: Option<&hyperactor::actor::ActorError>,
1350    ) -> Result<(), anyhow::Error> {
1351        if let Some(h) = self.timeout_handle.take() {
1352            h.abort();
1353        }
1354        Ok(())
1355    }
1356
1357    async fn handle_undeliverable_message(
1358        &mut self,
1359        this: &Instance<Self>,
1360        undeliverable: Undeliverable<MessageEnvelope>,
1361    ) -> Result<(), anyhow::Error> {
1362        let error = match undeliverable {
1363            Undeliverable::Message(envelope) => envelope.error_msg().unwrap_or_default(),
1364            Undeliverable::Lost(lost) => lost.error,
1365        };
1366        if self.terminal {
1367            tracing::warn!(
1368                "undeliverable message after handshake terminated: {}",
1369                error
1370            );
1371            return Ok(());
1372        }
1373        self.fail(this, error)
1374    }
1375}
1376
1377impl<A> Drop for QueuePairInitializer<A>
1378where
1379    A: QpOwner,
1380{
1381    fn drop(&mut self) {
1382        if let Some(h) = self.timeout_handle.take() {
1383            h.abort();
1384        }
1385    }
1386}
1387
1388/// Destroy the underlying `rdmaxcel_qp_t`.
1389///
1390/// # Safety
1391///
1392/// `IbvQueuePair` derives [`Clone`] but the wrapped `rdmaxcel_qp_t`
1393/// pointer is shared by all clones; this call frees that pointer. The
1394/// caller must guarantee no remaining clones of `qp` are in use (no
1395/// other code is reading from or posting to `qp.qp`, and no future
1396/// code will), since accessing a freed `rdmaxcel_qp_t` is undefined
1397/// behavior.
1398pub(super) unsafe fn destroy_qp(qp: &IbvQueuePair) {
1399    // SAFETY: The caller has guaranteed no other live clone of `qp`
1400    // observes `qp.qp` (see this function's `# Safety` section). This
1401    // is truly unsafe -- the current implementation does not properly
1402    // track outstanding clones. An imminent change will fix this, but
1403    // for now it isn't a regression.
1404    unsafe {
1405        if qp.qp != 0 {
1406            let rdmaxcel_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
1407            rdmaxcel_sys::rdmaxcel_qp_destroy(rdmaxcel_qp);
1408        }
1409    }
1410}
1411
1412#[async_trait]
1413impl<A> Handler<PeerInfo> for QueuePairInitializer<A>
1414where
1415    A: QpOwner,
1416{
1417    async fn handle(&mut self, cx: &Context<Self>, msg: PeerInfo) -> Result<(), anyhow::Error> {
1418        if self.terminal {
1419            tracing::warn!("PeerInfo received after queue pair already terminal");
1420            return Ok(());
1421        }
1422        debug_assert!(!self.our_rts_sent, "duplicate PeerInfo");
1423        if let Err(e) = self.connect_and_notify(cx, msg.0) {
1424            return self.fail(cx, e);
1425        }
1426        self.our_rts_sent = true;
1427        if self.peer_rts_received {
1428            return self.done(cx);
1429        }
1430        // Rearm the timeout for the remaining wait on the peer's
1431        // `NotifyRts` so a hang past this point still surfaces as a
1432        // failure.
1433        self.arm_timeout(cx);
1434        Ok(())
1435    }
1436}
1437
1438#[async_trait]
1439impl<A> Handler<NotifyRts> for QueuePairInitializer<A>
1440where
1441    A: QpOwner,
1442{
1443    async fn handle(&mut self, cx: &Context<Self>, _msg: NotifyRts) -> Result<(), anyhow::Error> {
1444        if self.terminal {
1445            tracing::warn!("NotifyRts received after queue pair already terminal");
1446            return Ok(());
1447        }
1448        debug_assert!(!self.peer_rts_received, "duplicate NotifyRts");
1449        self.peer_rts_received = true;
1450        if self.our_rts_sent {
1451            return self.done(cx);
1452        }
1453        // Rearm the timeout for the remaining wait on our own
1454        // `PeerInfo` reply.
1455        self.arm_timeout(cx);
1456        Ok(())
1457    }
1458}
1459
1460#[async_trait]
1461impl<A> Handler<InitializationFailed> for QueuePairInitializer<A>
1462where
1463    A: QpOwner,
1464{
1465    async fn handle(
1466        &mut self,
1467        cx: &Context<Self>,
1468        _msg: InitializationFailed,
1469    ) -> Result<(), anyhow::Error> {
1470        if self.terminal {
1471            return Ok(());
1472        }
1473        self.fail(cx, "QP initialization timed out".into())
1474    }
1475}
1476
1477#[cfg(test)]
1478mod tests {
1479    use std::sync::Arc;
1480    use std::sync::Mutex;
1481    use std::time::Duration;
1482    use std::time::Instant;
1483
1484    use anyhow::Result;
1485    use async_trait::async_trait;
1486    use hyperactor::Context;
1487    use hyperactor::Handler;
1488    use hyperactor::PortRef;
1489    use hyperactor::mailbox::DeliveryError;
1490    use hyperactor::mailbox::MessageEnvelope;
1491    use hyperactor::mailbox::Undeliverable;
1492    use hyperactor::port::Port;
1493    use hyperactor::proc::Proc;
1494    use hyperactor_config::Flattrs;
1495
1496    use super::*;
1497    use crate::backend::ibverbs::domain::IbvDomain;
1498    use crate::backend::ibverbs::manager_actor::EnsureQueuePair;
1499    use crate::backend::ibverbs::primitives::IbvConfig;
1500    use crate::backend::ibverbs::primitives::get_all_devices;
1501
1502    #[test]
1503    fn test_create_connection() {
1504        if get_all_devices().is_empty() {
1505            println!("Skipping test: RDMA devices not available");
1506            return;
1507        }
1508
1509        let config = IbvConfig {
1510            use_gpu_direct: false,
1511            ..Default::default()
1512        };
1513        let domain = IbvDomain::new(config.device.clone());
1514        assert!(domain.is_ok());
1515
1516        let domain = domain.unwrap();
1517        let queue_pair = IbvQueuePair::new(domain.context, domain.pd, config.clone());
1518        assert!(queue_pair.is_ok());
1519    }
1520
1521    #[test]
1522    fn test_loopback_connection() {
1523        if get_all_devices().is_empty() {
1524            println!("Skipping test: RDMA devices not available");
1525            return;
1526        }
1527
1528        let server_config = IbvConfig {
1529            use_gpu_direct: false,
1530            ..Default::default()
1531        };
1532        let client_config = IbvConfig {
1533            use_gpu_direct: false,
1534            ..Default::default()
1535        };
1536
1537        let server_domain = IbvDomain::new(server_config.device.clone()).unwrap();
1538        let client_domain = IbvDomain::new(client_config.device.clone()).unwrap();
1539
1540        let mut server_qp = IbvQueuePair::new(
1541            server_domain.context,
1542            server_domain.pd,
1543            server_config.clone(),
1544        )
1545        .unwrap();
1546        let mut client_qp = IbvQueuePair::new(
1547            client_domain.context,
1548            client_domain.pd,
1549            client_config.clone(),
1550        )
1551        .unwrap();
1552
1553        let server_connection_info = server_qp.get_qp_info().unwrap();
1554        let client_connection_info = client_qp.get_qp_info().unwrap();
1555
1556        assert!(server_qp.connect(&client_connection_info).is_ok());
1557        assert!(client_qp.connect(&server_connection_info).is_ok());
1558    }
1559
1560    /// Outcomes recorded by [`MockManager`] for assertions.
1561    #[derive(Default, Debug)]
1562    struct MockState {
1563        done: Vec<QpKey>,
1564        failed: Vec<(QpKey, String)>,
1565        /// Number of `NotifyRts` messages the mock received from the
1566        /// initializer (i.e., how many times the initializer reached
1567        /// the "we've hit RTS" point and sent us a notification).
1568        notify_rts: usize,
1569    }
1570
1571    /// Scripted reply for the next `EnsureQueuePair` the mock sees.
1572    /// After a single reply the mock disarms back to `DropReply`.
1573    #[derive(Debug)]
1574    enum MockResponse {
1575        /// Reply with `PeerInfo(Ok((info, mock_notify_rts_port)))`. The
1576        /// caller must drive the initializer's `NotifyRts` port from
1577        /// the test to reach `Succeeded`.
1578        Success(IbvQpInfo),
1579        /// Like `Success`, but the `PortRef<NotifyRts>` handed back is
1580        /// attested to an unreachable address in the mock's own proc
1581        /// so the initializer's `NotifyRts` send bounces back as
1582        /// undeliverable.
1583        SuccessWithBogusNotifyRts(IbvQpInfo),
1584        Error(String),
1585        DropReply,
1586    }
1587
1588    /// Zero-initialized [`IbvQueuePair`]. `qp == 0` so `QpGuard::Drop`
1589    /// is a no-op; tests using this must not exercise [`IbvQueuePair::connect`]
1590    /// (it would deref a null pointer).
1591    fn fake_qp() -> IbvQueuePair {
1592        IbvQueuePair {
1593            send_cq: 0,
1594            recv_cq: 0,
1595            qp: 0,
1596            dv_qp: 0,
1597            dv_send_cq: 0,
1598            dv_recv_cq: 0,
1599            context: 0,
1600            config: IbvConfig::default(),
1601            is_efa: false,
1602        }
1603    }
1604
1605    /// A real (loopback) `IbvQueuePair` and its `IbvQpInfo`. Returns
1606    /// `None` when no RDMA device is present.
1607    fn loopback_qp() -> Option<(QpGuard, IbvQpInfo)> {
1608        if get_all_devices().is_empty() {
1609            return None;
1610        }
1611        let config = IbvConfig::default();
1612        let domain = IbvDomain::new(config.device.clone()).ok()?;
1613        let mut qp = QpGuard::new(IbvQueuePair::new(domain.context, domain.pd, config).ok()?);
1614        let info = qp.get_qp_info().ok()?;
1615        Some((qp, info))
1616    }
1617
1618    #[derive(Debug)]
1619    #[hyperactor::export(handlers = [EnsureQueuePair<MockManager>, NotifyRts])]
1620    struct MockManager {
1621        state: Arc<Mutex<MockState>>,
1622        response: MockResponse,
1623    }
1624
1625    #[async_trait]
1626    impl Actor for MockManager {}
1627
1628    #[async_trait]
1629    impl Handler<EnsureQueuePair<MockManager>> for MockManager {
1630        async fn handle(
1631            &mut self,
1632            cx: &Context<Self>,
1633            msg: EnsureQueuePair<MockManager>,
1634        ) -> Result<()> {
1635            let response = std::mem::replace(&mut self.response, MockResponse::DropReply);
1636            match response {
1637                MockResponse::Success(info) => {
1638                    let notify_rts = cx.bind::<MockManager>().port::<NotifyRts>();
1639                    msg.reply.post(cx, PeerInfo(Ok((info, notify_rts))));
1640                }
1641                MockResponse::SuccessWithBogusNotifyRts(info) => {
1642                    let bogus = hyperactor::context::Mailbox::mailbox(cx)
1643                        .actor_addr()
1644                        .proc_addr()
1645                        .actor_addr("bogus")
1646                        .port_addr(Port::from(0u64));
1647                    let notify_rts = PortRef::<NotifyRts>::attest(bogus);
1648                    msg.reply.post(cx, PeerInfo(Ok((info, notify_rts))));
1649                }
1650                MockResponse::Error(e) => {
1651                    msg.reply.post(cx, PeerInfo(Err(e)));
1652                }
1653                MockResponse::DropReply => {}
1654            }
1655            Ok(())
1656        }
1657    }
1658
1659    #[async_trait]
1660    impl Handler<NotifyRts> for MockManager {
1661        async fn handle(&mut self, _cx: &Context<Self>, _msg: NotifyRts) -> Result<()> {
1662            self.state.lock().unwrap().notify_rts += 1;
1663            Ok(())
1664        }
1665    }
1666
1667    #[async_trait]
1668    impl Handler<QpInitializerDone> for MockManager {
1669        async fn handle(&mut self, _cx: &Context<Self>, msg: QpInitializerDone) -> Result<()> {
1670            let _ = msg.qp.into_inner();
1671            self.state.lock().unwrap().done.push(msg.qp_key);
1672            Ok(())
1673        }
1674    }
1675
1676    #[async_trait]
1677    impl Handler<QpInitializerFailed> for MockManager {
1678        async fn handle(&mut self, _cx: &Context<Self>, msg: QpInitializerFailed) -> Result<()> {
1679            self.state
1680                .lock()
1681                .unwrap()
1682                .failed
1683                .push((msg.qp_key, msg.error));
1684            Ok(())
1685        }
1686    }
1687
1688    struct Harness {
1689        proc: Proc,
1690        init_handle: ActorHandle<QueuePairInitializer<MockManager>>,
1691        state: Arc<Mutex<MockState>>,
1692        qp_key: QpKey,
1693    }
1694
1695    impl Harness {
1696        fn build(qp: QpGuard, response: MockResponse) -> Result<Self> {
1697            let proc = Proc::anonymous();
1698            let state = Arc::new(Mutex::new(MockState::default()));
1699            let mock = MockManager {
1700                state: state.clone(),
1701                response,
1702            };
1703            let mock_handle = proc.spawn("mock", mock)?;
1704            let mock_ref = mock_handle.bind::<MockManager>();
1705            let qp_key = QpKey {
1706                self_device: "mock0".into(),
1707                other_id: mock_ref.actor_addr().id().clone(),
1708                other_device: "mock0".into(),
1709            };
1710            let initializer = QueuePairInitializer::new(mock_handle, mock_ref, qp_key.clone(), qp);
1711            let init_handle = proc.spawn("initializer", initializer)?;
1712            // Bind well-known ports so PeerInfo/NotifyRts can route.
1713            let _ = init_handle.bind::<QueuePairInitializer<MockManager>>();
1714            Ok(Harness {
1715                proc,
1716                init_handle,
1717                state,
1718                qp_key,
1719            })
1720        }
1721
1722        async fn await_done(&self) -> QpKey {
1723            let deadline = Instant::now() + Duration::from_secs(5);
1724            loop {
1725                if let Some(key) = self.state.lock().unwrap().done.first().cloned() {
1726                    return key;
1727                }
1728                if Instant::now() >= deadline {
1729                    panic!(
1730                        "QpInitializerDone not delivered within 5s; state={:?}",
1731                        self.state.lock().unwrap()
1732                    );
1733                }
1734                tokio::time::sleep(Duration::from_millis(10)).await;
1735            }
1736        }
1737
1738        async fn await_failed(&self) -> (QpKey, String) {
1739            let deadline = Instant::now() + Duration::from_secs(5);
1740            loop {
1741                if let Some(entry) = self.state.lock().unwrap().failed.first().cloned() {
1742                    return entry;
1743                }
1744                if Instant::now() >= deadline {
1745                    panic!(
1746                        "QpInitializerFailed was not delivered within 5s; state={:?}",
1747                        self.state.lock().unwrap()
1748                    );
1749                }
1750                tokio::time::sleep(Duration::from_millis(10)).await;
1751            }
1752        }
1753    }
1754
1755    #[tokio::test]
1756    async fn test_peer_info_error_transitions_to_failed() {
1757        let harness = Harness::build(
1758            QpGuard::new(fake_qp()),
1759            MockResponse::Error("peer rejected".into()),
1760        )
1761        .unwrap();
1762        let (key, error) = harness.await_failed().await;
1763        assert_eq!(key, harness.qp_key);
1764        assert_eq!(error, "peer rejected");
1765        // No spurious done callbacks.
1766        assert!(harness.state.lock().unwrap().done.is_empty());
1767    }
1768
1769    #[tokio::test]
1770    async fn test_initial_timeout_transitions_to_failed() {
1771        // Drop the configured per-handshake budget to 200ms so the
1772        // test doesn't sit on the default 30s.
1773        let lock = hyperactor_config::global::lock();
1774        let _guard = lock.override_key(
1775            crate::config::RDMA_QP_INIT_TIMEOUT,
1776            Duration::from_millis(200),
1777        );
1778
1779        let harness = Harness::build(QpGuard::new(fake_qp()), MockResponse::DropReply).unwrap();
1780        let (key, error) = harness.await_failed().await;
1781        assert_eq!(key, harness.qp_key);
1782        assert!(
1783            error.contains("timed out"),
1784            "expected timeout error, got {error}"
1785        );
1786    }
1787
1788    /// Real loopback handshake. The mock replies `Success`, the
1789    /// initializer connects the qp to itself and sends `NotifyRts` to
1790    /// the mock; the test then delivers `NotifyRts` directly to the
1791    /// initializer's well-known port to drive it to success.
1792    #[tokio::test]
1793    async fn test_loopback_handshake_succeeds() -> Result<()> {
1794        let Some((qp, info)) = loopback_qp() else {
1795            panic!("Skipping test: RDMA devices not available");
1796        };
1797        let harness = Harness::build(qp, MockResponse::Success(info))?;
1798
1799        let (peer, _) = harness.proc.client("peer")?;
1800        harness.init_handle.post(&peer, NotifyRts);
1801
1802        let key = harness.await_done().await;
1803        assert_eq!(key, harness.qp_key);
1804        let state = harness.state.lock().unwrap();
1805        assert!(state.failed.is_empty());
1806        assert_eq!(
1807            state.notify_rts, 1,
1808            "initializer must send exactly one NotifyRts to the peer after qp.connect"
1809        );
1810        Ok(())
1811    }
1812
1813    /// Real loopback `qp.connect` succeeds and the initializer
1814    /// flips `our_rts_sent`, but the test never delivers `NotifyRts`
1815    /// back to the initializer's port. The rearmed timer fires and
1816    /// the handshake is reported as failed.
1817    #[tokio::test]
1818    async fn test_notify_rts_timeout_after_peer_info() -> Result<()> {
1819        let Some((qp, info)) = loopback_qp() else {
1820            panic!("Skipping test: RDMA devices not available");
1821        };
1822        let lock = hyperactor_config::global::lock();
1823        let _guard = lock.override_key(
1824            crate::config::RDMA_QP_INIT_TIMEOUT,
1825            Duration::from_millis(200),
1826        );
1827
1828        let harness = Harness::build(qp, MockResponse::Success(info))?;
1829        let (key, error) = harness.await_failed().await;
1830        assert_eq!(key, harness.qp_key);
1831        assert!(
1832            error.contains("timed out"),
1833            "expected timeout error, got {error}"
1834        );
1835        // Receiving exactly one NotifyRts confirms the initializer
1836        // ran `qp.connect` + sent NotifyRts to the peer and was
1837        // waiting on the peer's `NotifyRts` when the rearmed timer
1838        // fired.
1839        assert_eq!(harness.state.lock().unwrap().notify_rts, 1);
1840        Ok(())
1841    }
1842
1843    fn fake_undeliverable(proc: &Proc, error: &str) -> Undeliverable<MessageEnvelope> {
1844        let mut envelope = MessageEnvelope::serialize(
1845            proc.proc_addr().actor_addr("test-sender"),
1846            proc.proc_addr()
1847                .actor_addr("test-dest")
1848                .port_addr(Port::from(0u64)),
1849            &0u64,
1850            Flattrs::default(),
1851        )
1852        .unwrap();
1853        envelope.set_error(DeliveryError::Mailbox(error.into()));
1854        Undeliverable::Message(envelope)
1855    }
1856
1857    /// In an awaiting state, an undeliverable message returned to the
1858    /// initializer trips `handle_undeliverable_message` into `fail()`,
1859    /// which reports `QpInitializerFailed` to the owner with the
1860    /// envelope's error message.
1861    #[tokio::test]
1862    async fn test_undeliverable_in_awaiting_transitions_to_failed() {
1863        let harness = Harness::build(QpGuard::new(fake_qp()), MockResponse::DropReply).unwrap();
1864        let undeliverable = fake_undeliverable(&harness.proc, "simulated bounce");
1865        let (peer, _) = harness.proc.client("peer").unwrap();
1866        harness.init_handle.post(&peer, undeliverable);
1867        let (key, error) = harness.await_failed().await;
1868        assert_eq!(key, harness.qp_key);
1869        assert!(
1870            error.contains("simulated bounce"),
1871            "expected delivery error, got {error}"
1872        );
1873    }
1874
1875    /// `PeerInfo` carries a `PortRef<NotifyRts>` attested to a bogus
1876    /// address; the initializer's send bounces back as undeliverable
1877    /// after `our_rts_sent` is set, and `handle_undeliverable_message`
1878    /// trips `fail()`.
1879    #[tokio::test]
1880    async fn test_notify_rts_undeliverable_transitions_to_failed() -> Result<()> {
1881        let Some((qp, info)) = loopback_qp() else {
1882            panic!("Skipping test: RDMA devices not available");
1883        };
1884        let harness = Harness::build(qp, MockResponse::SuccessWithBogusNotifyRts(info))?;
1885        let (key, error) = harness.await_failed().await;
1886        assert_eq!(key, harness.qp_key);
1887        assert!(
1888            error.contains("address not routable"),
1889            "expected delivery error, got {error:?}"
1890        );
1891        Ok(())
1892    }
1893
1894    /// Once the initializer is terminal, a late undeliverable is
1895    /// just warn-logged and must not produce a second
1896    /// `QpInitializerFailed` callback.
1897    #[tokio::test]
1898    async fn test_undeliverable_after_terminated_does_not_re_fail() {
1899        let harness = Harness::build(
1900            QpGuard::new(fake_qp()),
1901            MockResponse::Error("first fail".into()),
1902        )
1903        .unwrap();
1904        let _ = harness.await_failed().await;
1905
1906        let undeliverable = fake_undeliverable(&harness.proc, "late bounce");
1907        let (peer, _) = harness.proc.client("peer").unwrap();
1908        harness.init_handle.post(&peer, undeliverable);
1909        tokio::time::sleep(Duration::from_millis(50)).await;
1910        assert_eq!(harness.state.lock().unwrap().failed.len(), 1);
1911    }
1912}