1const MAX_RDMA_MSG_SIZE: usize = 1024 * 1024 * 1024;
17
18use std::io::Error;
19use std::result::Result;
20use std::time::Duration;
21
22use async_trait::async_trait;
23use hyperactor::Actor;
24use hyperactor::ActorHandle;
25use hyperactor::ActorId;
26use hyperactor::ActorRef;
27use hyperactor::Context;
28use hyperactor::Endpoint as _;
29use hyperactor::Handler;
30use hyperactor::Instance;
31use hyperactor::PortRef;
32use hyperactor::actor::Binds;
33use hyperactor::actor::Referable;
34use hyperactor::actor::RemoteHandles;
35use hyperactor::mailbox::MessageEnvelope;
36use hyperactor::mailbox::Undeliverable;
37use serde::Deserialize;
38use serde::Serialize;
39use typeuri::Named;
40
41use super::IbvBuffer;
42use super::manager_actor::EnsureQueuePair;
43use super::manager_actor::QpInitializerDone;
44use super::manager_actor::QpInitializerFailed;
45use super::primitives::Gid;
46use super::primitives::IbvConfig;
47use super::primitives::IbvOperation;
48use super::primitives::IbvQpInfo;
49use super::primitives::IbvWc;
50use super::primitives::resolve_qp_type;
51
52#[derive(Debug)]
57pub struct PollCompletionError {
58 pub status: Option<rdmaxcel_sys::ibv_wc_status::Type>,
59 pub vendor_err: Option<u32>,
60 message: String,
61}
62
63impl std::fmt::Display for PollCompletionError {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 f.write_str(&self.message)
66 }
67}
68
69impl std::error::Error for PollCompletionError {}
70
71impl PollCompletionError {
72 pub fn is_wr_flush_err(&self) -> bool {
76 self.status == Some(rdmaxcel_sys::ibv_wc_status::IBV_WC_WR_FLUSH_ERR)
77 }
78}
79
80#[derive(Debug, Named, Clone, Serialize, Deserialize)]
84pub struct DoorBell {
85 pub src_ptr: usize,
86 pub dst_ptr: usize,
87 pub size: usize,
88}
89wirevalue::register_type!(DoorBell);
90
91#[derive(Debug, Clone, Copy, PartialEq)]
93pub enum PollTarget {
94 Send,
95 Recv,
96}
97
98#[derive(Debug, Serialize, Deserialize, Named, Clone)]
118pub struct IbvQueuePair {
119 pub send_cq: usize, pub recv_cq: usize, pub qp: usize, pub dv_qp: usize, pub dv_send_cq: usize, pub dv_recv_cq: usize, context: usize, config: IbvConfig,
127 is_efa: bool,
128}
129wirevalue::register_type!(IbvQueuePair);
130
131impl IbvQueuePair {
132 fn is_efa(&self) -> bool {
133 self.is_efa
134 }
135
136 fn apply_first_op_delay(&self, wr_id: u64) {
138 unsafe {
139 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
140 if wr_id == 0 {
141 let rts_timestamp = rdmaxcel_sys::rdmaxcel_qp_load_rts_timestamp(qp);
142 assert!(
143 rts_timestamp != u64::MAX,
144 "First operation attempted before queue pair reached RTS state! Call connect() first."
145 );
146 let current_nanos = std::time::SystemTime::now()
147 .duration_since(std::time::UNIX_EPOCH)
148 .unwrap()
149 .as_nanos() as u64;
150 let elapsed_nanos = current_nanos - rts_timestamp;
151 let elapsed = Duration::from_nanos(elapsed_nanos);
152 let init_delay = Duration::from_millis(self.config.hw_init_delay_ms);
153 if elapsed < init_delay {
154 let remaining_delay = init_delay - elapsed;
155 std::thread::sleep(remaining_delay);
159 }
160 }
161 }
162 }
163
164 pub fn new(
174 context: *mut rdmaxcel_sys::ibv_context,
175 pd: *mut rdmaxcel_sys::ibv_pd,
176 config: IbvConfig,
177 ) -> Result<Self, anyhow::Error> {
178 tracing::debug!("creating an IbvQueuePair from config {}", config);
179 unsafe {
180 let resolved_qp_type = resolve_qp_type(config.qp_type);
182 let is_efa = resolved_qp_type == rdmaxcel_sys::RDMA_QP_TYPE_EFA;
183 let qp = rdmaxcel_sys::rdmaxcel_qp_create(
184 context,
185 pd,
186 config.cq_entries,
187 config.max_send_wr.try_into().unwrap(),
188 config.max_recv_wr.try_into().unwrap(),
189 config.max_send_sge.try_into().unwrap(),
190 config.max_recv_sge.try_into().unwrap(),
191 resolved_qp_type,
192 );
193
194 if qp.is_null() {
195 let os_error = Error::last_os_error();
196 return Err(anyhow::anyhow!(
197 "failed to create queue pair (QP): {}",
198 os_error
199 ));
200 }
201
202 let send_cq = (*(*qp).ibv_qp).send_cq;
203 let recv_cq = (*(*qp).ibv_qp).recv_cq;
204
205 if is_efa {
207 return Ok(IbvQueuePair {
208 send_cq: send_cq as usize,
209 recv_cq: recv_cq as usize,
210 qp: qp as usize,
211 dv_qp: 0,
212 dv_send_cq: 0,
213 dv_recv_cq: 0,
214 context: context as usize,
215 config,
216 is_efa: true,
217 });
218 }
219
220 let dv_qp = rdmaxcel_sys::create_mlx5dv_qp((*qp).ibv_qp);
221 let dv_send_cq = rdmaxcel_sys::create_mlx5dv_send_cq((*qp).ibv_qp);
222 let dv_recv_cq = rdmaxcel_sys::create_mlx5dv_recv_cq((*qp).ibv_qp);
223
224 if dv_qp.is_null() || dv_send_cq.is_null() || dv_recv_cq.is_null() {
225 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
226 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
227 rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
228 return Err(anyhow::anyhow!(
229 "failed to init mlx5dv_qp or completion queues"
230 ));
231 }
232
233 if config.use_gpu_direct {
234 let ret = rdmaxcel_sys::register_cuda_memory(dv_qp, dv_recv_cq, dv_send_cq);
235 if ret != 0 {
236 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).recv_cq);
237 rdmaxcel_sys::ibv_destroy_cq((*(*qp).ibv_qp).send_cq);
238 rdmaxcel_sys::ibv_destroy_qp((*qp).ibv_qp);
239 return Err(anyhow::anyhow!(
240 "failed to register GPU Direct RDMA memory: {:?}",
241 ret
242 ));
243 }
244 }
245 Ok(IbvQueuePair {
246 send_cq: send_cq as usize,
247 recv_cq: recv_cq as usize,
248 qp: qp as usize,
249 dv_qp: dv_qp as usize,
250 dv_send_cq: dv_send_cq as usize,
251 dv_recv_cq: dv_recv_cq as usize,
252 context: context as usize,
253 config,
254 is_efa: false,
255 })
256 }
257 }
258
259 pub fn get_qp_info(&mut self) -> Result<IbvQpInfo, anyhow::Error> {
261 unsafe {
262 let context = self.context as *mut rdmaxcel_sys::ibv_context;
263 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
264 let mut port_attr = rdmaxcel_sys::ibv_port_attr::default();
265 let errno = rdmaxcel_sys::ibv_query_port(
266 context,
267 self.config.port_num,
268 &mut port_attr as *mut rdmaxcel_sys::ibv_port_attr as *mut _,
269 );
270 if errno != 0 {
271 let os_error = Error::last_os_error();
272 return Err(anyhow::anyhow!(
273 "Failed to query port attributes: {}",
274 os_error
275 ));
276 }
277
278 let mut gid = Gid::default();
279 let ret = rdmaxcel_sys::ibv_query_gid(
280 context,
281 self.config.port_num,
282 i32::from(self.config.gid_index),
283 gid.as_mut(),
284 );
285 if ret != 0 {
286 return Err(anyhow::anyhow!("Failed to query GID"));
287 }
288
289 Ok(IbvQpInfo {
290 qp_num: (*(*qp).ibv_qp).qp_num,
291 lid: port_attr.lid,
292 gid: Some(gid),
293 psn: self.config.psn,
294 })
295 }
296 }
297
298 pub fn state(&mut self) -> Result<u32, anyhow::Error> {
300 unsafe {
301 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
302 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
303 ..Default::default()
304 };
305 let mut qp_init_attr = rdmaxcel_sys::ibv_qp_init_attr {
306 ..Default::default()
307 };
308 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE;
309 let errno = rdmaxcel_sys::ibv_query_qp(
310 (*qp).ibv_qp,
311 &mut qp_attr,
312 mask.0 as i32,
313 &mut qp_init_attr,
314 );
315 if errno != 0 {
316 let os_error = Error::last_os_error();
317 return Err(anyhow::anyhow!("failed to query QP state: {}", os_error));
318 }
319 Ok(qp_attr.qp_state)
320 }
321 }
322
323 pub fn connect(&mut self, connection_info: &IbvQpInfo) -> Result<(), anyhow::Error> {
329 if self.is_efa() {
331 return self.efa_connect(connection_info);
332 }
333
334 unsafe {
335 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
336
337 let qp_access_flags = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
338 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
339 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
340 | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
341
342 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
344 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_INIT,
345 qp_access_flags: qp_access_flags.0,
346 pkey_index: self.config.pkey_index,
347 port_num: self.config.port_num,
348 ..Default::default()
349 };
350
351 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
352 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PKEY_INDEX
353 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PORT
354 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_ACCESS_FLAGS;
355
356 let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
357 if errno != 0 {
358 let os_error = Error::last_os_error();
359 return Err(anyhow::anyhow!(
360 "failed to transition QP to INIT: {}",
361 os_error
362 ));
363 }
364
365 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
367 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTR,
368 path_mtu: self.config.path_mtu,
369 dest_qp_num: connection_info.qp_num,
370 rq_psn: connection_info.psn,
371 max_dest_rd_atomic: self.config.max_dest_rd_atomic,
372 min_rnr_timer: self.config.min_rnr_timer,
373 ah_attr: rdmaxcel_sys::ibv_ah_attr {
374 dlid: connection_info.lid,
375 sl: 0,
376 src_path_bits: 0,
377 port_num: self.config.port_num,
378 grh: Default::default(),
379 ..Default::default()
380 },
381 ..Default::default()
382 };
383
384 if let Some(gid) = connection_info.gid {
385 qp_attr.ah_attr.is_global = 1;
386 qp_attr.ah_attr.grh.dgid = rdmaxcel_sys::ibv_gid::from(gid);
387 qp_attr.ah_attr.grh.hop_limit = 0xff;
388 qp_attr.ah_attr.grh.sgid_index = self.config.gid_index;
389 } else {
390 qp_attr.ah_attr.is_global = 0;
391 }
392
393 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
394 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_AV
395 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_PATH_MTU
396 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_DEST_QPN
397 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RQ_PSN
398 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_DEST_RD_ATOMIC
399 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MIN_RNR_TIMER;
400
401 let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
402 if errno != 0 {
403 let os_error = Error::last_os_error();
404 return Err(anyhow::anyhow!(
405 "failed to transition QP to RTR: {}",
406 os_error
407 ));
408 }
409
410 let mut qp_attr = rdmaxcel_sys::ibv_qp_attr {
412 qp_state: rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS,
413 sq_psn: self.config.psn,
414 max_rd_atomic: self.config.max_rd_atomic,
415 retry_cnt: self.config.retry_cnt,
416 rnr_retry: self.config.rnr_retry,
417 timeout: self.config.qp_timeout,
418 ..Default::default()
419 };
420
421 let mask = rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_STATE
422 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_TIMEOUT
423 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RETRY_CNT
424 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_SQ_PSN
425 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_RNR_RETRY
426 | rdmaxcel_sys::ibv_qp_attr_mask::IBV_QP_MAX_QP_RD_ATOMIC;
427
428 let errno = rdmaxcel_sys::ibv_modify_qp((*qp).ibv_qp, &mut qp_attr, mask.0 as i32);
429 if errno != 0 {
430 let os_error = Error::last_os_error();
431 return Err(anyhow::anyhow!(
432 "failed to transition QP to RTS: {}",
433 os_error
434 ));
435 }
436 tracing::debug!(
437 "connection sequence has successfully completed (qp: {:?})",
438 qp
439 );
440
441 let rts_timestamp_nanos = std::time::SystemTime::now()
442 .duration_since(std::time::UNIX_EPOCH)
443 .unwrap()
444 .as_nanos() as u64;
445 rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos);
446
447 Ok(())
448 }
449 }
450
451 fn efa_connect(&mut self, connection_info: &IbvQpInfo) -> Result<(), anyhow::Error> {
453 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
454
455 let gid_ptr = connection_info.gid.as_ref().map_or(std::ptr::null(), |g| {
456 let ibv_gid: &rdmaxcel_sys::ibv_gid = g.as_ref();
457 unsafe { ibv_gid.raw.as_ptr() }
458 });
459
460 unsafe {
461 let ret = rdmaxcel_sys::rdmaxcel_efa_connect(
462 qp,
463 self.config.port_num,
464 self.config.pkey_index,
465 0x4242, self.config.psn,
467 self.config.gid_index,
468 gid_ptr,
469 connection_info.qp_num,
470 );
471 if ret != 0 {
472 let msg = std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
473 .to_str()
474 .unwrap_or("unknown");
475 return Err(anyhow::anyhow!("EFA connect failed: {}", msg));
476 }
477 }
478
479 let rts_timestamp_nanos = std::time::SystemTime::now()
481 .duration_since(std::time::UNIX_EPOCH)
482 .unwrap()
483 .as_nanos() as u64;
484 unsafe {
485 rdmaxcel_sys::rdmaxcel_qp_store_rts_timestamp(qp, rts_timestamp_nanos);
486 }
487
488 Ok(())
489 }
490
491 pub fn recv(&mut self, lhandle: IbvBuffer, rhandle: IbvBuffer) -> Result<u64, anyhow::Error> {
492 unsafe {
493 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
494 let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_wqe_idx(qp);
495 self.post_op(
496 0,
497 lhandle.lkey,
498 0,
499 idx,
500 true,
501 IbvOperation::Recv,
502 0,
503 rhandle.rkey,
504 )
505 .unwrap();
506 rdmaxcel_sys::rdmaxcel_qp_fetch_add_recv_db_idx(qp);
507 Ok(idx)
508 }
509 }
510
511 pub fn put_with_recv(
512 &mut self,
513 lhandle: IbvBuffer,
514 rhandle: IbvBuffer,
515 ) -> Result<Vec<u64>, anyhow::Error> {
516 unsafe {
517 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
518 let idx = rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(qp);
519 self.post_op(
520 lhandle.addr,
521 lhandle.lkey,
522 lhandle.size,
523 idx,
524 true,
525 IbvOperation::WriteWithImm,
526 rhandle.addr,
527 rhandle.rkey,
528 )
529 .unwrap();
530 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(qp);
531 Ok(vec![idx])
532 }
533 }
534
535 pub fn put(
536 &mut self,
537 lhandle: IbvBuffer,
538 rhandle: IbvBuffer,
539 ) -> Result<Vec<u64>, anyhow::Error> {
540 let total_size = lhandle.size;
541 if rhandle.size < total_size {
542 return Err(anyhow::anyhow!(
543 "Remote buffer size ({}) is smaller than local buffer size ({})",
544 rhandle.size,
545 total_size
546 ));
547 }
548
549 let mut remaining = total_size;
550 let mut offset = 0;
551 let mut wr_ids = Vec::new();
552 while remaining > 0 {
553 let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
554 let idx = unsafe {
555 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
556 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
557 )
558 };
559 wr_ids.push(idx);
560 self.post_op(
561 lhandle.addr + offset,
562 lhandle.lkey,
563 chunk_size,
564 idx,
565 true,
566 IbvOperation::Write,
567 rhandle.addr + offset,
568 rhandle.rkey,
569 )?;
570 unsafe {
571 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
572 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
573 );
574 }
575
576 remaining -= chunk_size;
577 offset += chunk_size;
578 }
579
580 Ok(wr_ids)
581 }
582
583 pub fn ring_doorbell(&mut self) -> Result<(), anyhow::Error> {
585 if self.is_efa() {
587 return Ok(());
588 }
589
590 unsafe {
591 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
592 let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
593 let base_ptr = (*dv_qp).sq.buf as *mut u8;
594 let wqe_cnt = (*dv_qp).sq.wqe_cnt;
595 let stride = (*dv_qp).sq.stride;
596 let send_wqe_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_wqe_idx(qp);
597 let mut send_db_idx = rdmaxcel_sys::rdmaxcel_qp_load_send_db_idx(qp);
598 if (wqe_cnt as u64) < (send_wqe_idx - send_db_idx) {
599 return Err(anyhow::anyhow!("Overflow of WQE, possible data loss"));
600 }
601 self.apply_first_op_delay(send_db_idx);
602 while send_db_idx < send_wqe_idx {
603 let offset = (send_db_idx % wqe_cnt as u64) * stride as u64;
604 let src_ptr = base_ptr.wrapping_add(offset as usize);
605 rdmaxcel_sys::db_ring((*dv_qp).bf.reg, src_ptr as *mut std::ffi::c_void);
606 send_db_idx += 1;
607 rdmaxcel_sys::rdmaxcel_qp_store_send_db_idx(qp, send_db_idx);
608 }
609 Ok(())
610 }
611 }
612
613 pub fn enqueue_put(
615 &mut self,
616 lhandle: IbvBuffer,
617 rhandle: IbvBuffer,
618 ) -> Result<Vec<u64>, anyhow::Error> {
619 let idx = unsafe {
620 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
621 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
622 )
623 };
624
625 self.send_wqe(
626 lhandle.addr,
627 lhandle.lkey,
628 lhandle.size,
629 idx,
630 true,
631 IbvOperation::Write,
632 rhandle.addr,
633 rhandle.rkey,
634 )?;
635 Ok(vec![idx])
636 }
637
638 pub fn enqueue_put_with_recv(
640 &mut self,
641 lhandle: IbvBuffer,
642 rhandle: IbvBuffer,
643 ) -> Result<Vec<u64>, anyhow::Error> {
644 let idx = unsafe {
645 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
646 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
647 )
648 };
649
650 self.send_wqe(
651 lhandle.addr,
652 lhandle.lkey,
653 lhandle.size,
654 idx,
655 true,
656 IbvOperation::WriteWithImm,
657 rhandle.addr,
658 rhandle.rkey,
659 )?;
660 Ok(vec![idx])
661 }
662
663 pub fn enqueue_get(
665 &mut self,
666 lhandle: IbvBuffer,
667 rhandle: IbvBuffer,
668 ) -> Result<Vec<u64>, anyhow::Error> {
669 let idx = unsafe {
670 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
671 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
672 )
673 };
674
675 self.send_wqe(
676 lhandle.addr,
677 lhandle.lkey,
678 lhandle.size,
679 idx,
680 true,
681 IbvOperation::Read,
682 rhandle.addr,
683 rhandle.rkey,
684 )?;
685 Ok(vec![idx])
686 }
687
688 pub fn get(
689 &mut self,
690 lhandle: IbvBuffer,
691 rhandle: IbvBuffer,
692 ) -> Result<Vec<u64>, anyhow::Error> {
693 let total_size = lhandle.size;
694 if rhandle.size < total_size {
695 return Err(anyhow::anyhow!(
696 "Remote buffer size ({}) is smaller than local buffer size ({})",
697 rhandle.size,
698 total_size
699 ));
700 }
701
702 let mut remaining = total_size;
703 let mut offset = 0;
704 let mut wr_ids = Vec::new();
705
706 while remaining > 0 {
707 let chunk_size = std::cmp::min(remaining, MAX_RDMA_MSG_SIZE);
708 let idx = unsafe {
709 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_wqe_idx(
710 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
711 )
712 };
713 wr_ids.push(idx);
714
715 self.post_op(
716 lhandle.addr + offset,
717 lhandle.lkey,
718 chunk_size,
719 idx,
720 true,
721 IbvOperation::Read,
722 rhandle.addr + offset,
723 rhandle.rkey,
724 )?;
725 unsafe {
726 rdmaxcel_sys::rdmaxcel_qp_fetch_add_send_db_idx(
727 self.qp as *mut rdmaxcel_sys::rdmaxcel_qp,
728 );
729 }
730
731 remaining -= chunk_size;
732 offset += chunk_size;
733 }
734
735 Ok(wr_ids)
736 }
737
738 fn post_op(
740 &mut self,
741 laddr: usize,
742 lkey: u32,
743 length: usize,
744 wr_id: u64,
745 signaled: bool,
746 op_type: IbvOperation,
747 raddr: usize,
748 rkey: u32,
749 ) -> Result<(), anyhow::Error> {
750 if self.is_efa() {
752 return self.post_op_efa(laddr, lkey, length, wr_id, signaled, op_type, raddr, rkey);
753 }
754
755 unsafe {
756 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
757 let context = self.context as *mut rdmaxcel_sys::ibv_context;
758 let ops = &mut (*context).ops;
759 let errno;
760 if op_type == IbvOperation::Recv {
761 let mut sge = rdmaxcel_sys::ibv_sge {
762 addr: laddr as u64,
763 length: length as u32,
764 lkey,
765 };
766 let mut wr = rdmaxcel_sys::ibv_recv_wr {
767 wr_id,
768 sg_list: &mut sge as *mut _,
769 num_sge: 1,
770 ..Default::default()
771 };
772 let mut bad_wr: *mut rdmaxcel_sys::ibv_recv_wr = std::ptr::null_mut();
773 errno =
774 ops.post_recv.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
775 } else if op_type == IbvOperation::Write
776 || op_type == IbvOperation::Read
777 || op_type == IbvOperation::WriteWithImm
778 {
779 self.apply_first_op_delay(wr_id);
780 let send_flags = if signaled {
781 rdmaxcel_sys::ibv_send_flags::IBV_SEND_SIGNALED.0
782 } else {
783 0
784 };
785 let mut sge = rdmaxcel_sys::ibv_sge {
786 addr: laddr as u64,
787 length: length as u32,
788 lkey,
789 };
790 let mut wr = rdmaxcel_sys::ibv_send_wr {
791 wr_id,
792 next: std::ptr::null_mut(),
793 sg_list: &mut sge as *mut _,
794 num_sge: 1,
795 opcode: op_type.into(),
796 send_flags,
797 wr: Default::default(),
798 qp_type: Default::default(),
799 __bindgen_anon_1: Default::default(),
800 __bindgen_anon_2: Default::default(),
801 };
802
803 wr.wr.rdma.remote_addr = raddr as u64;
804 wr.wr.rdma.rkey = rkey;
805 let mut bad_wr: *mut rdmaxcel_sys::ibv_send_wr = std::ptr::null_mut();
806
807 errno =
808 ops.post_send.as_mut().unwrap()((*qp).ibv_qp, &mut wr as *mut _, &mut bad_wr);
809 } else {
810 panic!("Not Implemented");
811 }
812
813 if errno != 0 {
814 let os_error = Error::last_os_error();
815 return Err(anyhow::anyhow!("Failed to post send request: {}", os_error));
816 }
817 tracing::debug!(
818 "completed sending {:?} request (lkey: {}, addr: 0x{:x}, length {}) to (raddr 0x{:x}, rkey {})",
819 op_type,
820 lkey,
821 laddr,
822 length,
823 raddr,
824 rkey,
825 );
826
827 Ok(())
828 }
829 }
830
831 fn post_op_efa(
833 &mut self,
834 laddr: usize,
835 lkey: u32,
836 length: usize,
837 wr_id: u64,
838 signaled: bool,
839 op_type: IbvOperation,
840 raddr: usize,
841 rkey: u32,
842 ) -> Result<(), anyhow::Error> {
843 let c_op = match op_type {
844 IbvOperation::Write => 0,
845 IbvOperation::Read => 1,
846 IbvOperation::Recv => 2,
847 IbvOperation::WriteWithImm => 3,
848 };
849
850 unsafe {
851 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
852 let ret = rdmaxcel_sys::rdmaxcel_qp_post_op(
853 qp,
854 laddr as *mut std::ffi::c_void,
855 lkey,
856 length,
857 raddr as *mut std::ffi::c_void,
858 rkey,
859 wr_id,
860 signaled as i32,
861 c_op,
862 );
863 if ret != 0 {
864 let msg = std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
865 .to_str()
866 .unwrap_or("unknown");
867 return Err(anyhow::anyhow!("EFA post_op failed: {}", msg));
868 }
869 }
870 Ok(())
871 }
872
873 fn send_wqe(
874 &mut self,
875 laddr: usize,
876 lkey: u32,
877 length: usize,
878 wr_id: u64,
879 signaled: bool,
880 op_type: IbvOperation,
881 raddr: usize,
882 rkey: u32,
883 ) -> Result<DoorBell, anyhow::Error> {
884 if self.is_efa() {
886 self.post_op(laddr, lkey, length, wr_id, signaled, op_type, raddr, rkey)?;
887 return Ok(DoorBell {
888 dst_ptr: 0,
889 src_ptr: 0,
890 size: 0,
891 });
892 }
893
894 unsafe {
895 let op_type_val = match op_type {
896 IbvOperation::Write => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE,
897 IbvOperation::WriteWithImm => rdmaxcel_sys::MLX5_OPCODE_RDMA_WRITE_IMM,
898 IbvOperation::Read => rdmaxcel_sys::MLX5_OPCODE_RDMA_READ,
899 IbvOperation::Recv => 0,
900 };
901
902 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
903 let dv_qp = self.dv_qp as *mut rdmaxcel_sys::mlx5dv_qp;
904 let _dv_cq = if op_type == IbvOperation::Recv {
905 self.dv_recv_cq as *mut rdmaxcel_sys::mlx5dv_cq
906 } else {
907 self.dv_send_cq as *mut rdmaxcel_sys::mlx5dv_cq
908 };
909
910 let buf = if op_type == IbvOperation::Recv {
911 (*dv_qp).rq.buf as *mut u8
912 } else {
913 (*dv_qp).sq.buf as *mut u8
914 };
915
916 let params = rdmaxcel_sys::wqe_params_t {
917 laddr,
918 lkey,
919 length,
920 wr_id,
921 signaled,
922 op_type: op_type_val,
923 raddr,
924 rkey,
925 qp_num: (*(*qp).ibv_qp).qp_num,
926 buf,
927 dbrec: (*dv_qp).dbrec,
928 wqe_cnt: (*dv_qp).sq.wqe_cnt,
929 };
930
931 if op_type == IbvOperation::Recv {
932 rdmaxcel_sys::recv_wqe(params);
933 std::ptr::write_volatile((*dv_qp).dbrec, 1_u32.to_be());
934 } else {
935 rdmaxcel_sys::send_wqe(params);
936 };
937
938 Ok(DoorBell {
939 dst_ptr: (*dv_qp).bf.reg as usize,
940 src_ptr: (*dv_qp).sq.buf as usize,
941 size: 8,
942 })
943 }
944 }
945
946 pub fn poll_completion(
958 &mut self,
959 target: PollTarget,
960 expected_wr_ids: &[u64],
961 ) -> Result<Vec<(u64, IbvWc)>, PollCompletionError> {
962 if expected_wr_ids.is_empty() {
963 return Ok(Vec::new());
964 }
965
966 unsafe {
967 let qp = self.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
968 let qp_num = (*(*qp).ibv_qp).qp_num;
969
970 let (cq, cache, cq_type) = match target {
971 PollTarget::Send => (
972 self.send_cq as *mut rdmaxcel_sys::ibv_cq,
973 rdmaxcel_sys::rdmaxcel_qp_get_send_cache(qp),
974 "send",
975 ),
976 PollTarget::Recv => (
977 self.recv_cq as *mut rdmaxcel_sys::ibv_cq,
978 rdmaxcel_sys::rdmaxcel_qp_get_recv_cache(qp),
979 "recv",
980 ),
981 };
982
983 let mut results = Vec::new();
984
985 for &expected_wr_id in expected_wr_ids {
986 let mut poll_ctx = rdmaxcel_sys::poll_context_t {
987 expected_wr_id,
988 expected_qp_num: qp_num,
989 cache,
990 cq,
991 };
992
993 let mut wc = std::mem::MaybeUninit::<rdmaxcel_sys::ibv_wc>::zeroed().assume_init();
994 let ret = rdmaxcel_sys::poll_cq_with_cache(&mut poll_ctx, &mut wc);
995
996 match ret {
997 1 => {
998 if !wc.is_valid() {
999 if let Some((status, vendor_err)) = wc.error() {
1000 return Err(PollCompletionError {
1001 status: Some(status),
1002 vendor_err: Some(vendor_err),
1003 message: format!(
1004 "{} completion failed for wr_id={}: status={:?}, vendor_err={}",
1005 cq_type, expected_wr_id, status, vendor_err,
1006 ),
1007 });
1008 }
1009 }
1010 results.push((expected_wr_id, IbvWc::from(wc)));
1011 }
1012 0 => {
1013 }
1015 -17 => {
1016 let error_msg =
1017 std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1018 .to_str()
1019 .unwrap_or("Unknown error");
1020 if let Some((status, vendor_err)) = wc.error() {
1021 return Err(PollCompletionError {
1022 status: Some(status),
1023 vendor_err: Some(vendor_err),
1024 message: format!(
1025 "Failed to poll {} CQ for wr_id={}: {} [status={:?}, vendor_err={}, qp_num={}, byte_len={}]",
1026 cq_type,
1027 expected_wr_id,
1028 error_msg,
1029 status,
1030 vendor_err,
1031 wc.qp_num,
1032 wc.len(),
1033 ),
1034 });
1035 } else {
1036 return Err(PollCompletionError {
1037 status: None,
1038 vendor_err: None,
1039 message: format!(
1040 "Failed to poll {} CQ for wr_id={}: {} [qp_num={}, byte_len={}]",
1041 cq_type,
1042 expected_wr_id,
1043 error_msg,
1044 wc.qp_num,
1045 wc.len(),
1046 ),
1047 });
1048 }
1049 }
1050 _ => {
1051 let error_msg =
1052 std::ffi::CStr::from_ptr(rdmaxcel_sys::rdmaxcel_error_string(ret))
1053 .to_str()
1054 .unwrap_or("Unknown error");
1055 return Err(PollCompletionError {
1056 status: None,
1057 vendor_err: None,
1058 message: format!(
1059 "Failed to poll {} CQ for wr_id={}: {}",
1060 cq_type, expected_wr_id, error_msg,
1061 ),
1062 });
1063 }
1064 }
1065 }
1066
1067 Ok(results)
1068 }
1069 }
1070}
1071
1072#[derive(Clone, Hash, Eq, PartialEq, Debug, Serialize, Deserialize, Named)]
1099pub(super) struct QpKey {
1100 pub(super) self_device: String,
1101 pub(super) other_id: ActorId,
1102 pub(super) other_device: String,
1103}
1104
1105#[derive(Debug, Serialize, Deserialize, Named)]
1109pub(super) struct PeerInfo(pub(super) Result<(IbvQpInfo, PortRef<NotifyRts>), String>);
1110wirevalue::register_type!(PeerInfo);
1111
1112#[derive(Debug, Serialize, Deserialize, Named)]
1116pub(super) struct NotifyRts;
1117wirevalue::register_type!(NotifyRts);
1118
1119#[derive(Debug)]
1122struct InitializationFailed;
1123
1124#[derive(Debug)]
1127pub(super) struct QpGuard {
1128 qp: Option<IbvQueuePair>,
1129}
1130
1131impl QpGuard {
1132 pub(super) fn new(qp: IbvQueuePair) -> Self {
1133 Self { qp: Some(qp) }
1134 }
1135
1136 pub(super) fn into_inner(mut self) -> IbvQueuePair {
1138 self.qp.take().expect("QpGuard already drained")
1139 }
1140
1141 pub(super) fn connect(&mut self, info: &IbvQpInfo) -> Result<(), anyhow::Error> {
1143 self.qp
1144 .as_mut()
1145 .expect("QpGuard already drained")
1146 .connect(info)
1147 }
1148
1149 pub(super) fn get_qp_info(&mut self) -> Result<IbvQpInfo, anyhow::Error> {
1151 self.qp
1152 .as_mut()
1153 .expect("QpGuard already drained")
1154 .get_qp_info()
1155 }
1156}
1157
1158impl Drop for QpGuard {
1159 fn drop(&mut self) {
1160 if let Some(qp) = self.qp.take() {
1161 unsafe { destroy_qp(&qp) };
1169 }
1170 }
1171}
1172
1173pub(super) trait QpOwner:
1176 Actor
1177 + Referable
1178 + Binds<Self>
1179 + RemoteHandles<EnsureQueuePair<Self>>
1180 + Handler<QpInitializerDone>
1181 + Handler<QpInitializerFailed>
1182{
1183}
1184
1185impl<T> QpOwner for T where
1186 T: Actor
1187 + Referable
1188 + Binds<T>
1189 + RemoteHandles<EnsureQueuePair<T>>
1190 + Handler<QpInitializerDone>
1191 + Handler<QpInitializerFailed>
1192{
1193}
1194
1195#[derive(Debug)]
1200#[hyperactor::export(handlers = [PeerInfo, NotifyRts])]
1201pub(super) struct QueuePairInitializer<A: QpOwner> {
1202 owner: ActorHandle<A>,
1203 other: ActorRef<A>,
1204 qp_key: QpKey,
1205 qp: Option<QpGuard>,
1209 timeout: Duration,
1212 our_rts_sent: bool,
1215 peer_rts_received: bool,
1218 terminal: bool,
1222 timeout_handle: Option<tokio::task::JoinHandle<()>>,
1224}
1225
1226impl<A> QueuePairInitializer<A>
1227where
1228 A: QpOwner,
1229{
1230 pub(super) fn new(
1231 owner: ActorHandle<A>,
1232 other: ActorRef<A>,
1233 qp_key: QpKey,
1234 qp: QpGuard,
1235 ) -> Self {
1236 let timeout = hyperactor_config::global::get(crate::config::RDMA_QP_INIT_TIMEOUT);
1237 Self {
1238 owner,
1239 other,
1240 qp_key,
1241 qp: Some(qp),
1242 timeout,
1243 our_rts_sent: false,
1244 peer_rts_received: false,
1245 terminal: false,
1246 timeout_handle: None,
1247 }
1248 }
1249
1250 fn arm_timeout(&mut self, this: &Instance<Self>) {
1252 if let Some(h) = self.timeout_handle.take() {
1253 h.abort();
1254 }
1255 let self_handle: ActorHandle<Self> = this.handle();
1256 let timeout = self.timeout;
1257 let task = tokio::spawn(async move {
1258 tokio::time::sleep(timeout).await;
1259 self_handle.post(Instance::<Self>::self_client(), InitializationFailed);
1260 });
1261 self.timeout_handle = Some(task);
1262 }
1263
1264 fn fail(&mut self, this: &Instance<Self>, error: String) -> Result<(), anyhow::Error> {
1268 if let Some(h) = self.timeout_handle.take() {
1269 h.abort();
1270 }
1271 self.qp = None;
1272 self.terminal = true;
1273 self.owner.post(
1274 this,
1275 QpInitializerFailed {
1276 qp_key: self.qp_key.clone(),
1277 error,
1278 },
1279 );
1280 Ok(())
1281 }
1282
1283 fn done(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
1286 if let Some(h) = self.timeout_handle.take() {
1287 h.abort();
1288 }
1289 let qp = self.qp.take().expect("qp present in done()");
1290 self.terminal = true;
1291 self.owner.post(
1292 this,
1293 QpInitializerDone {
1294 qp_key: self.qp_key.clone(),
1295 qp,
1296 },
1297 );
1298 Ok(())
1299 }
1300
1301 fn connect_and_notify(
1305 &mut self,
1306 cx: &Context<Self>,
1307 info: Result<(IbvQpInfo, PortRef<NotifyRts>), String>,
1308 ) -> Result<(), String> {
1309 let (peer_endpoint, peer_notify_rts) = info?;
1310 self.qp
1311 .as_mut()
1312 .expect("qp present pre-terminal")
1313 .connect(&peer_endpoint)
1314 .map_err(|e| format!("QpGuard::connect failed: {e}"))?;
1315 peer_notify_rts.post(cx, NotifyRts);
1316 Ok(())
1317 }
1318}
1319
1320#[async_trait]
1321impl<A> Actor for QueuePairInitializer<A>
1322where
1323 A: QpOwner,
1324{
1325 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
1326 let reply = this.bind::<Self>().port();
1329 let sender = self.owner.bind();
1330 let sender_device = self.qp_key.self_device.clone();
1331 let receiver_device = self.qp_key.other_device.clone();
1332 self.other.post(
1333 this,
1334 EnsureQueuePair {
1335 sender,
1336 sender_device,
1337 receiver_device,
1338 reply,
1339 },
1340 );
1341
1342 self.arm_timeout(this);
1343 Ok(())
1344 }
1345
1346 async fn cleanup(
1347 &mut self,
1348 _this: &Instance<Self>,
1349 _err: Option<&hyperactor::actor::ActorError>,
1350 ) -> Result<(), anyhow::Error> {
1351 if let Some(h) = self.timeout_handle.take() {
1352 h.abort();
1353 }
1354 Ok(())
1355 }
1356
1357 async fn handle_undeliverable_message(
1358 &mut self,
1359 this: &Instance<Self>,
1360 undeliverable: Undeliverable<MessageEnvelope>,
1361 ) -> Result<(), anyhow::Error> {
1362 let error = match undeliverable {
1363 Undeliverable::Message(envelope) => envelope.error_msg().unwrap_or_default(),
1364 Undeliverable::Lost(lost) => lost.error,
1365 };
1366 if self.terminal {
1367 tracing::warn!(
1368 "undeliverable message after handshake terminated: {}",
1369 error
1370 );
1371 return Ok(());
1372 }
1373 self.fail(this, error)
1374 }
1375}
1376
1377impl<A> Drop for QueuePairInitializer<A>
1378where
1379 A: QpOwner,
1380{
1381 fn drop(&mut self) {
1382 if let Some(h) = self.timeout_handle.take() {
1383 h.abort();
1384 }
1385 }
1386}
1387
1388pub(super) unsafe fn destroy_qp(qp: &IbvQueuePair) {
1399 unsafe {
1405 if qp.qp != 0 {
1406 let rdmaxcel_qp = qp.qp as *mut rdmaxcel_sys::rdmaxcel_qp;
1407 rdmaxcel_sys::rdmaxcel_qp_destroy(rdmaxcel_qp);
1408 }
1409 }
1410}
1411
1412#[async_trait]
1413impl<A> Handler<PeerInfo> for QueuePairInitializer<A>
1414where
1415 A: QpOwner,
1416{
1417 async fn handle(&mut self, cx: &Context<Self>, msg: PeerInfo) -> Result<(), anyhow::Error> {
1418 if self.terminal {
1419 tracing::warn!("PeerInfo received after queue pair already terminal");
1420 return Ok(());
1421 }
1422 debug_assert!(!self.our_rts_sent, "duplicate PeerInfo");
1423 if let Err(e) = self.connect_and_notify(cx, msg.0) {
1424 return self.fail(cx, e);
1425 }
1426 self.our_rts_sent = true;
1427 if self.peer_rts_received {
1428 return self.done(cx);
1429 }
1430 self.arm_timeout(cx);
1434 Ok(())
1435 }
1436}
1437
1438#[async_trait]
1439impl<A> Handler<NotifyRts> for QueuePairInitializer<A>
1440where
1441 A: QpOwner,
1442{
1443 async fn handle(&mut self, cx: &Context<Self>, _msg: NotifyRts) -> Result<(), anyhow::Error> {
1444 if self.terminal {
1445 tracing::warn!("NotifyRts received after queue pair already terminal");
1446 return Ok(());
1447 }
1448 debug_assert!(!self.peer_rts_received, "duplicate NotifyRts");
1449 self.peer_rts_received = true;
1450 if self.our_rts_sent {
1451 return self.done(cx);
1452 }
1453 self.arm_timeout(cx);
1456 Ok(())
1457 }
1458}
1459
1460#[async_trait]
1461impl<A> Handler<InitializationFailed> for QueuePairInitializer<A>
1462where
1463 A: QpOwner,
1464{
1465 async fn handle(
1466 &mut self,
1467 cx: &Context<Self>,
1468 _msg: InitializationFailed,
1469 ) -> Result<(), anyhow::Error> {
1470 if self.terminal {
1471 return Ok(());
1472 }
1473 self.fail(cx, "QP initialization timed out".into())
1474 }
1475}
1476
1477#[cfg(test)]
1478mod tests {
1479 use std::sync::Arc;
1480 use std::sync::Mutex;
1481 use std::time::Duration;
1482 use std::time::Instant;
1483
1484 use anyhow::Result;
1485 use async_trait::async_trait;
1486 use hyperactor::Context;
1487 use hyperactor::Handler;
1488 use hyperactor::PortRef;
1489 use hyperactor::mailbox::DeliveryError;
1490 use hyperactor::mailbox::MessageEnvelope;
1491 use hyperactor::mailbox::Undeliverable;
1492 use hyperactor::port::Port;
1493 use hyperactor::proc::Proc;
1494 use hyperactor_config::Flattrs;
1495
1496 use super::*;
1497 use crate::backend::ibverbs::domain::IbvDomain;
1498 use crate::backend::ibverbs::manager_actor::EnsureQueuePair;
1499 use crate::backend::ibverbs::primitives::IbvConfig;
1500 use crate::backend::ibverbs::primitives::get_all_devices;
1501
1502 #[test]
1503 fn test_create_connection() {
1504 if get_all_devices().is_empty() {
1505 println!("Skipping test: RDMA devices not available");
1506 return;
1507 }
1508
1509 let config = IbvConfig {
1510 use_gpu_direct: false,
1511 ..Default::default()
1512 };
1513 let domain = IbvDomain::new(config.device.clone());
1514 assert!(domain.is_ok());
1515
1516 let domain = domain.unwrap();
1517 let queue_pair = IbvQueuePair::new(domain.context, domain.pd, config.clone());
1518 assert!(queue_pair.is_ok());
1519 }
1520
1521 #[test]
1522 fn test_loopback_connection() {
1523 if get_all_devices().is_empty() {
1524 println!("Skipping test: RDMA devices not available");
1525 return;
1526 }
1527
1528 let server_config = IbvConfig {
1529 use_gpu_direct: false,
1530 ..Default::default()
1531 };
1532 let client_config = IbvConfig {
1533 use_gpu_direct: false,
1534 ..Default::default()
1535 };
1536
1537 let server_domain = IbvDomain::new(server_config.device.clone()).unwrap();
1538 let client_domain = IbvDomain::new(client_config.device.clone()).unwrap();
1539
1540 let mut server_qp = IbvQueuePair::new(
1541 server_domain.context,
1542 server_domain.pd,
1543 server_config.clone(),
1544 )
1545 .unwrap();
1546 let mut client_qp = IbvQueuePair::new(
1547 client_domain.context,
1548 client_domain.pd,
1549 client_config.clone(),
1550 )
1551 .unwrap();
1552
1553 let server_connection_info = server_qp.get_qp_info().unwrap();
1554 let client_connection_info = client_qp.get_qp_info().unwrap();
1555
1556 assert!(server_qp.connect(&client_connection_info).is_ok());
1557 assert!(client_qp.connect(&server_connection_info).is_ok());
1558 }
1559
1560 #[derive(Default, Debug)]
1562 struct MockState {
1563 done: Vec<QpKey>,
1564 failed: Vec<(QpKey, String)>,
1565 notify_rts: usize,
1569 }
1570
1571 #[derive(Debug)]
1574 enum MockResponse {
1575 Success(IbvQpInfo),
1579 SuccessWithBogusNotifyRts(IbvQpInfo),
1584 Error(String),
1585 DropReply,
1586 }
1587
1588 fn fake_qp() -> IbvQueuePair {
1592 IbvQueuePair {
1593 send_cq: 0,
1594 recv_cq: 0,
1595 qp: 0,
1596 dv_qp: 0,
1597 dv_send_cq: 0,
1598 dv_recv_cq: 0,
1599 context: 0,
1600 config: IbvConfig::default(),
1601 is_efa: false,
1602 }
1603 }
1604
1605 fn loopback_qp() -> Option<(QpGuard, IbvQpInfo)> {
1608 if get_all_devices().is_empty() {
1609 return None;
1610 }
1611 let config = IbvConfig::default();
1612 let domain = IbvDomain::new(config.device.clone()).ok()?;
1613 let mut qp = QpGuard::new(IbvQueuePair::new(domain.context, domain.pd, config).ok()?);
1614 let info = qp.get_qp_info().ok()?;
1615 Some((qp, info))
1616 }
1617
1618 #[derive(Debug)]
1619 #[hyperactor::export(handlers = [EnsureQueuePair<MockManager>, NotifyRts])]
1620 struct MockManager {
1621 state: Arc<Mutex<MockState>>,
1622 response: MockResponse,
1623 }
1624
1625 #[async_trait]
1626 impl Actor for MockManager {}
1627
1628 #[async_trait]
1629 impl Handler<EnsureQueuePair<MockManager>> for MockManager {
1630 async fn handle(
1631 &mut self,
1632 cx: &Context<Self>,
1633 msg: EnsureQueuePair<MockManager>,
1634 ) -> Result<()> {
1635 let response = std::mem::replace(&mut self.response, MockResponse::DropReply);
1636 match response {
1637 MockResponse::Success(info) => {
1638 let notify_rts = cx.bind::<MockManager>().port::<NotifyRts>();
1639 msg.reply.post(cx, PeerInfo(Ok((info, notify_rts))));
1640 }
1641 MockResponse::SuccessWithBogusNotifyRts(info) => {
1642 let bogus = hyperactor::context::Mailbox::mailbox(cx)
1643 .actor_addr()
1644 .proc_addr()
1645 .actor_addr("bogus")
1646 .port_addr(Port::from(0u64));
1647 let notify_rts = PortRef::<NotifyRts>::attest(bogus);
1648 msg.reply.post(cx, PeerInfo(Ok((info, notify_rts))));
1649 }
1650 MockResponse::Error(e) => {
1651 msg.reply.post(cx, PeerInfo(Err(e)));
1652 }
1653 MockResponse::DropReply => {}
1654 }
1655 Ok(())
1656 }
1657 }
1658
1659 #[async_trait]
1660 impl Handler<NotifyRts> for MockManager {
1661 async fn handle(&mut self, _cx: &Context<Self>, _msg: NotifyRts) -> Result<()> {
1662 self.state.lock().unwrap().notify_rts += 1;
1663 Ok(())
1664 }
1665 }
1666
1667 #[async_trait]
1668 impl Handler<QpInitializerDone> for MockManager {
1669 async fn handle(&mut self, _cx: &Context<Self>, msg: QpInitializerDone) -> Result<()> {
1670 let _ = msg.qp.into_inner();
1671 self.state.lock().unwrap().done.push(msg.qp_key);
1672 Ok(())
1673 }
1674 }
1675
1676 #[async_trait]
1677 impl Handler<QpInitializerFailed> for MockManager {
1678 async fn handle(&mut self, _cx: &Context<Self>, msg: QpInitializerFailed) -> Result<()> {
1679 self.state
1680 .lock()
1681 .unwrap()
1682 .failed
1683 .push((msg.qp_key, msg.error));
1684 Ok(())
1685 }
1686 }
1687
1688 struct Harness {
1689 proc: Proc,
1690 init_handle: ActorHandle<QueuePairInitializer<MockManager>>,
1691 state: Arc<Mutex<MockState>>,
1692 qp_key: QpKey,
1693 }
1694
1695 impl Harness {
1696 fn build(qp: QpGuard, response: MockResponse) -> Result<Self> {
1697 let proc = Proc::anonymous();
1698 let state = Arc::new(Mutex::new(MockState::default()));
1699 let mock = MockManager {
1700 state: state.clone(),
1701 response,
1702 };
1703 let mock_handle = proc.spawn("mock", mock)?;
1704 let mock_ref = mock_handle.bind::<MockManager>();
1705 let qp_key = QpKey {
1706 self_device: "mock0".into(),
1707 other_id: mock_ref.actor_addr().id().clone(),
1708 other_device: "mock0".into(),
1709 };
1710 let initializer = QueuePairInitializer::new(mock_handle, mock_ref, qp_key.clone(), qp);
1711 let init_handle = proc.spawn("initializer", initializer)?;
1712 let _ = init_handle.bind::<QueuePairInitializer<MockManager>>();
1714 Ok(Harness {
1715 proc,
1716 init_handle,
1717 state,
1718 qp_key,
1719 })
1720 }
1721
1722 async fn await_done(&self) -> QpKey {
1723 let deadline = Instant::now() + Duration::from_secs(5);
1724 loop {
1725 if let Some(key) = self.state.lock().unwrap().done.first().cloned() {
1726 return key;
1727 }
1728 if Instant::now() >= deadline {
1729 panic!(
1730 "QpInitializerDone not delivered within 5s; state={:?}",
1731 self.state.lock().unwrap()
1732 );
1733 }
1734 tokio::time::sleep(Duration::from_millis(10)).await;
1735 }
1736 }
1737
1738 async fn await_failed(&self) -> (QpKey, String) {
1739 let deadline = Instant::now() + Duration::from_secs(5);
1740 loop {
1741 if let Some(entry) = self.state.lock().unwrap().failed.first().cloned() {
1742 return entry;
1743 }
1744 if Instant::now() >= deadline {
1745 panic!(
1746 "QpInitializerFailed was not delivered within 5s; state={:?}",
1747 self.state.lock().unwrap()
1748 );
1749 }
1750 tokio::time::sleep(Duration::from_millis(10)).await;
1751 }
1752 }
1753 }
1754
1755 #[tokio::test]
1756 async fn test_peer_info_error_transitions_to_failed() {
1757 let harness = Harness::build(
1758 QpGuard::new(fake_qp()),
1759 MockResponse::Error("peer rejected".into()),
1760 )
1761 .unwrap();
1762 let (key, error) = harness.await_failed().await;
1763 assert_eq!(key, harness.qp_key);
1764 assert_eq!(error, "peer rejected");
1765 assert!(harness.state.lock().unwrap().done.is_empty());
1767 }
1768
1769 #[tokio::test]
1770 async fn test_initial_timeout_transitions_to_failed() {
1771 let lock = hyperactor_config::global::lock();
1774 let _guard = lock.override_key(
1775 crate::config::RDMA_QP_INIT_TIMEOUT,
1776 Duration::from_millis(200),
1777 );
1778
1779 let harness = Harness::build(QpGuard::new(fake_qp()), MockResponse::DropReply).unwrap();
1780 let (key, error) = harness.await_failed().await;
1781 assert_eq!(key, harness.qp_key);
1782 assert!(
1783 error.contains("timed out"),
1784 "expected timeout error, got {error}"
1785 );
1786 }
1787
1788 #[tokio::test]
1793 async fn test_loopback_handshake_succeeds() -> Result<()> {
1794 let Some((qp, info)) = loopback_qp() else {
1795 panic!("Skipping test: RDMA devices not available");
1796 };
1797 let harness = Harness::build(qp, MockResponse::Success(info))?;
1798
1799 let (peer, _) = harness.proc.client("peer")?;
1800 harness.init_handle.post(&peer, NotifyRts);
1801
1802 let key = harness.await_done().await;
1803 assert_eq!(key, harness.qp_key);
1804 let state = harness.state.lock().unwrap();
1805 assert!(state.failed.is_empty());
1806 assert_eq!(
1807 state.notify_rts, 1,
1808 "initializer must send exactly one NotifyRts to the peer after qp.connect"
1809 );
1810 Ok(())
1811 }
1812
1813 #[tokio::test]
1818 async fn test_notify_rts_timeout_after_peer_info() -> Result<()> {
1819 let Some((qp, info)) = loopback_qp() else {
1820 panic!("Skipping test: RDMA devices not available");
1821 };
1822 let lock = hyperactor_config::global::lock();
1823 let _guard = lock.override_key(
1824 crate::config::RDMA_QP_INIT_TIMEOUT,
1825 Duration::from_millis(200),
1826 );
1827
1828 let harness = Harness::build(qp, MockResponse::Success(info))?;
1829 let (key, error) = harness.await_failed().await;
1830 assert_eq!(key, harness.qp_key);
1831 assert!(
1832 error.contains("timed out"),
1833 "expected timeout error, got {error}"
1834 );
1835 assert_eq!(harness.state.lock().unwrap().notify_rts, 1);
1840 Ok(())
1841 }
1842
1843 fn fake_undeliverable(proc: &Proc, error: &str) -> Undeliverable<MessageEnvelope> {
1844 let mut envelope = MessageEnvelope::serialize(
1845 proc.proc_addr().actor_addr("test-sender"),
1846 proc.proc_addr()
1847 .actor_addr("test-dest")
1848 .port_addr(Port::from(0u64)),
1849 &0u64,
1850 Flattrs::default(),
1851 )
1852 .unwrap();
1853 envelope.set_error(DeliveryError::Mailbox(error.into()));
1854 Undeliverable::Message(envelope)
1855 }
1856
1857 #[tokio::test]
1862 async fn test_undeliverable_in_awaiting_transitions_to_failed() {
1863 let harness = Harness::build(QpGuard::new(fake_qp()), MockResponse::DropReply).unwrap();
1864 let undeliverable = fake_undeliverable(&harness.proc, "simulated bounce");
1865 let (peer, _) = harness.proc.client("peer").unwrap();
1866 harness.init_handle.post(&peer, undeliverable);
1867 let (key, error) = harness.await_failed().await;
1868 assert_eq!(key, harness.qp_key);
1869 assert!(
1870 error.contains("simulated bounce"),
1871 "expected delivery error, got {error}"
1872 );
1873 }
1874
1875 #[tokio::test]
1880 async fn test_notify_rts_undeliverable_transitions_to_failed() -> Result<()> {
1881 let Some((qp, info)) = loopback_qp() else {
1882 panic!("Skipping test: RDMA devices not available");
1883 };
1884 let harness = Harness::build(qp, MockResponse::SuccessWithBogusNotifyRts(info))?;
1885 let (key, error) = harness.await_failed().await;
1886 assert_eq!(key, harness.qp_key);
1887 assert!(
1888 error.contains("address not routable"),
1889 "expected delivery error, got {error:?}"
1890 );
1891 Ok(())
1892 }
1893
1894 #[tokio::test]
1898 async fn test_undeliverable_after_terminated_does_not_re_fail() {
1899 let harness = Harness::build(
1900 QpGuard::new(fake_qp()),
1901 MockResponse::Error("first fail".into()),
1902 )
1903 .unwrap();
1904 let _ = harness.await_failed().await;
1905
1906 let undeliverable = fake_undeliverable(&harness.proc, "late bounce");
1907 let (peer, _) = harness.proc.client("peer").unwrap();
1908 harness.init_handle.post(&peer, undeliverable);
1909 tokio::time::sleep(Duration::from_millis(50)).await;
1910 assert_eq!(harness.state.lock().unwrap().failed.len(), 1);
1911 }
1912}