monarch_rdma/
rdma_manager_actor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # RDMA Manager Actor
10//!
11//! Manages RDMA connections and operations using `hyperactor` for asynchronous messaging.
12//!
13//! ## Architecture
14//!
15//! `RdmaManagerActor` is a per-host entity that:
16//! - Manages connections to multiple remote RdmaManagerActors (i.e. across the hosts in a Monarch cluster)
17//! - Handles memory registration, connection setup, and data transfer
18//! - Manages all RdmaBuffers in its associated host
19//!
20//! ## Core Operations
21//!
22//! - Connection establishment with partner actors
23//! - RDMA operations (put/write, get/read)
24//! - Completion polling
25//! - Memory region management
26//!
27//! ## Usage
28//!
29//! See test examples: `test_rdma_write_loopback` and `test_rdma_read_loopback`.
30use std::collections::HashMap;
31
32use async_trait::async_trait;
33use hyperactor::Actor;
34use hyperactor::ActorId;
35use hyperactor::ActorRef;
36use hyperactor::Context;
37use hyperactor::HandleClient;
38use hyperactor::Handler;
39use hyperactor::Instance;
40use hyperactor::Named;
41use hyperactor::OncePortRef;
42use hyperactor::RefClient;
43use hyperactor::supervision::ActorSupervisionEvent;
44use serde::Deserialize;
45use serde::Serialize;
46
47use crate::ibverbs_primitives::IbverbsConfig;
48use crate::ibverbs_primitives::RdmaMemoryRegionView;
49use crate::ibverbs_primitives::RdmaQpInfo;
50use crate::ibverbs_primitives::ibverbs_supported;
51use crate::rdma_components::RdmaBuffer;
52use crate::rdma_components::RdmaDomain;
53use crate::rdma_components::RdmaQueuePair;
54use crate::rdma_components::get_registered_cuda_segments;
55use crate::validate_execution_context;
56
57/// Represents the state of a queue pair in the manager, either available or checked out.
58#[derive(Debug, Clone)]
59pub enum QueuePairState {
60    Available(RdmaQueuePair),
61    CheckedOut,
62}
63
64/// Helper function to get detailed error messages from RDMAXCEL error codes
65pub fn get_rdmaxcel_error_message(error_code: i32) -> String {
66    unsafe {
67        let c_str = rdmaxcel_sys::rdmaxcel_error_string(error_code);
68        std::ffi::CStr::from_ptr(c_str)
69            .to_string_lossy()
70            .into_owned()
71    }
72}
73
74/// Represents a reference to a remote RDMA buffer that can be accessed via RDMA operations.
75/// This struct encapsulates all the information needed to identify and access a memory region
76/// on a remote host using RDMA.
77#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
78pub enum RdmaManagerMessage {
79    RequestBuffer {
80        addr: usize,
81        size: usize,
82        #[reply]
83        /// `reply` - Reply channel to return the RDMA buffer handle
84        reply: OncePortRef<RdmaBuffer>,
85    },
86    ReleaseBuffer {
87        buffer: RdmaBuffer,
88    },
89    RequestQueuePair {
90        remote: ActorRef<RdmaManagerActor>,
91        #[reply]
92        /// `reply` - Reply channel to return the queue pair for communication
93        reply: OncePortRef<RdmaQueuePair>,
94    },
95    Connect {
96        /// `other` - The ActorId of the actor to connect to
97        other: ActorRef<RdmaManagerActor>,
98        /// `endpoint` - Connection information needed to establish the RDMA connection
99        endpoint: RdmaQpInfo,
100    },
101    InitializeQP {
102        remote: ActorRef<RdmaManagerActor>,
103        #[reply]
104        /// `reply` - Reply channel to return the queue pair for communication
105        reply: OncePortRef<bool>,
106    },
107    ConnectionInfo {
108        /// `other` - The ActorId to get connection info for
109        other: ActorRef<RdmaManagerActor>,
110        #[reply]
111        /// `reply` - Reply channel to return the connection info
112        reply: OncePortRef<RdmaQpInfo>,
113    },
114    ReleaseQueuePair {
115        /// `other` - The ActorId to release queue pair for  
116        other: ActorRef<RdmaManagerActor>,
117        /// `qp` - The queue pair to return (ownership transferred back)
118        qp: RdmaQueuePair,
119    },
120}
121
122#[derive(Debug)]
123#[hyperactor::export(
124    spawn = true,
125    handlers = [
126        RdmaManagerMessage,
127    ],
128)]
129pub struct RdmaManagerActor {
130    // Map between ActorIds and their corresponding RdmaQueuePair
131    qp_map: HashMap<ActorId, QueuePairState>,
132
133    // MR configuration QP for self that cannot be loaned out
134    loopback_qp: Option<RdmaQueuePair>,
135
136    // The RDMA domain associated with this actor.
137    //
138    // The domain is responsible for managing the RDMA resources and configurations
139    // specific to this actor. It encapsulates the context and protection domain
140    // necessary for RDMA operations, ensuring that all RDMA activities are
141    // performed within a consistent and isolated environment.
142    //
143    // This domain is initialized during the creation of the `RdmaManagerActor`
144    // and is used throughout the actor's lifecycle to manage RDMA connections
145    // and operations.
146    domain: RdmaDomain,
147    config: IbverbsConfig,
148
149    // Flag indicating PyTorch CUDA allocator compatibility
150    // True if both C10 CUDA allocator is enabled AND expandable segments are enabled
151    pt_cuda_alloc: bool,
152
153    // Map of unique RdmaMemoryRegionView to ibv_mr*.  In case of cuda w/ pytorch its -1
154    // since its managed independently.  Only used for registration/deregistration purposes
155    mr_map: HashMap<usize, usize>,
156    // Id for next mrv created
157    mrv_id: usize,
158}
159
160impl RdmaManagerActor {
161    fn find_cuda_segment_for_address(
162        &mut self,
163        addr: usize,
164        size: usize,
165    ) -> Option<RdmaMemoryRegionView> {
166        let registered_segments = get_registered_cuda_segments();
167        for segment in registered_segments {
168            let start_addr = segment.phys_address;
169            let end_addr = start_addr + segment.phys_size;
170            if start_addr <= addr && addr + size <= end_addr {
171                let offset = addr - start_addr;
172                let rdma_addr = segment.mr_addr + offset;
173
174                let mrv = RdmaMemoryRegionView {
175                    id: self.mrv_id,
176                    virtual_addr: addr,
177                    rdma_addr,
178                    size,
179                    lkey: segment.lkey,
180                    rkey: segment.rkey,
181                };
182                self.mrv_id += 1;
183                return Some(mrv);
184            }
185        }
186        None
187    }
188
189    fn register_mr(
190        &mut self,
191        addr: usize,
192        size: usize,
193    ) -> Result<RdmaMemoryRegionView, anyhow::Error> {
194        unsafe {
195            let mut mem_type: i32 = 0;
196            let ptr = addr as cuda_sys::CUdeviceptr;
197            let err = cuda_sys::cuPointerGetAttribute(
198                &mut mem_type as *mut _ as *mut std::ffi::c_void,
199                cuda_sys::CUpointer_attribute_enum::CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
200                ptr,
201            );
202            let is_cuda = err == cuda_sys::CUresult::CUDA_SUCCESS;
203
204            let access = rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_LOCAL_WRITE
205                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_WRITE
206                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_READ
207                | rdmaxcel_sys::ibv_access_flags::IBV_ACCESS_REMOTE_ATOMIC;
208
209            let mut mr: *mut rdmaxcel_sys::ibv_mr = std::ptr::null_mut();
210            let mrv;
211
212            if is_cuda && self.pt_cuda_alloc {
213                // Get registered segments and check if our memory range is covered
214                let mut maybe_mrv = self.find_cuda_segment_for_address(addr, size);
215                // not found, lets re-sync with caching allocator  and retry
216                if maybe_mrv.is_none() {
217                    let qp = self.loopback_qp.as_mut().unwrap();
218                    let err = rdmaxcel_sys::register_segments(
219                        self.domain.pd,
220                        qp.qp as *mut rdmaxcel_sys::ibv_qp,
221                    );
222                    if err != 0 {
223                        let error_msg = get_rdmaxcel_error_message(err);
224                        return Err(anyhow::anyhow!(
225                            "RdmaXcel register_sements failed (addr: 0x{:x}, size: {}): {}",
226                            addr,
227                            size,
228                            error_msg
229                        ));
230                    }
231
232                    maybe_mrv = self.find_cuda_segment_for_address(addr, size);
233                }
234                // if still not found, throw exception
235                if maybe_mrv.is_none() {
236                    return Err(anyhow::anyhow!(
237                        "MR registration failed for cuda (addr: 0x{:x}, size: {}), unable to find segment in CudaCachingAllocator",
238                        addr,
239                        size
240                    ));
241                }
242                mrv = maybe_mrv.unwrap();
243            } else if is_cuda {
244                let mut fd: i32 = -1;
245                cuda_sys::cuMemGetHandleForAddressRange(
246                    &mut fd as *mut i32 as *mut std::ffi::c_void,
247                    addr as cuda_sys::CUdeviceptr,
248                    size,
249                    cuda_sys::CUmemRangeHandleType::CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD,
250                    0,
251                );
252                mr = rdmaxcel_sys::ibv_reg_dmabuf_mr(
253                    self.domain.pd,
254                    0,
255                    size,
256                    0,
257                    fd,
258                    access.0 as i32,
259                );
260                if mr.is_null() {
261                    return Err(anyhow::anyhow!("Failed to register dmabuf MR"));
262                }
263                mrv = RdmaMemoryRegionView {
264                    id: self.mrv_id,
265                    virtual_addr: addr,
266                    rdma_addr: (*mr).addr as usize,
267                    size,
268                    lkey: (*mr).lkey,
269                    rkey: (*mr).rkey,
270                };
271                self.mrv_id += 1;
272            } else {
273                // CPU memory path
274                mr = rdmaxcel_sys::ibv_reg_mr(
275                    self.domain.pd,
276                    addr as *mut std::ffi::c_void,
277                    size,
278                    access.0 as i32,
279                );
280
281                if mr.is_null() {
282                    return Err(anyhow::anyhow!("failed to register standard MR"));
283                }
284
285                mrv = RdmaMemoryRegionView {
286                    id: self.mrv_id,
287                    virtual_addr: addr,
288                    rdma_addr: (*mr).addr as usize,
289                    size,
290                    lkey: (*mr).lkey,
291                    rkey: (*mr).rkey,
292                };
293                self.mrv_id += 1;
294            }
295            self.mr_map.insert(mrv.id, mr as usize);
296            Ok(mrv)
297        }
298    }
299
300    fn deregister_mr(&mut self, id: usize) -> Result<(), anyhow::Error> {
301        if let Some(mr_ptr) = self.mr_map.remove(&id) {
302            if mr_ptr != 0 {
303                unsafe {
304                    rdmaxcel_sys::ibv_dereg_mr(mr_ptr as *mut rdmaxcel_sys::ibv_mr);
305                }
306            }
307        }
308        Ok(())
309    }
310}
311
312#[async_trait]
313impl Actor for RdmaManagerActor {
314    type Params = Option<IbverbsConfig>;
315
316    async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
317        if !ibverbs_supported() {
318            return Err(anyhow::anyhow!(
319                "Cannot create RdmaManagerActor because RDMA is not supported on this machine"
320            ));
321        }
322
323        // Use provided config or default if none provided
324        let mut config = params.unwrap_or_default();
325        tracing::debug!("rdma is enabled, using device {}", config.device);
326
327        let pt_cuda_alloc = crate::rdma_components::pt_cuda_allocator_compatibility();
328
329        // check config and hardware support align
330        if config.use_gpu_direct {
331            match validate_execution_context().await {
332                Ok(_) => {
333                    tracing::info!("GPU Direct RDMA execution context validated successfully");
334                }
335                Err(e) => {
336                    tracing::warn!(
337                        "GPU Direct RDMA execution context validation failed: {}. Downgrading to standard ibverbs mode.",
338                        e
339                    );
340                    config.use_gpu_direct = false;
341                }
342            }
343        }
344
345        let domain = RdmaDomain::new(config.device.clone())
346            .map_err(|e| anyhow::anyhow!("rdmaManagerActor could not create domain: {}", e))?;
347
348        Ok(Self {
349            qp_map: HashMap::new(),
350            loopback_qp: None,
351            domain,
352            config,
353            pt_cuda_alloc,
354            mr_map: HashMap::new(),
355            mrv_id: 0,
356        })
357    }
358
359    async fn init(&mut self, _this: &Instance<Self>) -> Result<(), anyhow::Error> {
360        // Create a loopback queue pair for self-communication
361        let mut qp = RdmaQueuePair::new(self.domain.context, self.domain.pd, self.config.clone())
362            .map_err(|e| anyhow::anyhow!("could not create RdmaQueuePair: {}", e))?;
363
364        // Get connection info for loopback
365        let endpoint = qp
366            .get_qp_info()
367            .map_err(|e| anyhow::anyhow!("could not get QP info: {}", e))?;
368
369        // Connect to itself
370        qp.connect(&endpoint)
371            .map_err(|e| anyhow::anyhow!("could not connect to RDMA endpoint: {}", e))?;
372
373        self.loopback_qp = Some(qp);
374        tracing::debug!("successfully created special loopback connection");
375
376        Ok(())
377    }
378
379    async fn handle_supervision_event(
380        &mut self,
381        _cx: &Instance<Self>,
382        _event: &ActorSupervisionEvent,
383    ) -> Result<bool, anyhow::Error> {
384        tracing::error!("rdmaManagerActor supervision event: {:?}", _event);
385        tracing::error!("rdmaManagerActor error occurred, stop the worker process, exit code: 1");
386        std::process::exit(1);
387    }
388}
389
390#[async_trait]
391#[hyperactor::forward(RdmaManagerMessage)]
392impl RdmaManagerMessageHandler for RdmaManagerActor {
393    /// Requests a buffer to be registered with the RDMA domain.
394    ///
395    /// This function registers a memory region with the RDMA domain and returns an `RdmaBuffer`
396    /// that encapsulates the necessary information for RDMA operations.
397    ///
398    /// # Arguments
399    ///
400    /// * `this` - The context of the actor requesting the buffer.
401    /// * `addr` - The starting address of the memory region to be registered.
402    /// * `size` - The size of the memory region to be registered.
403    ///
404    /// # Returns
405    ///
406    /// * `Result<RdmaBuffer, anyhow::Error>` - On success, returns an `RdmaBuffer` containing
407    ///   the registered memory region's details. On failure, returns an error.
408    async fn request_buffer(
409        &mut self,
410        cx: &Context<Self>,
411        addr: usize,
412        size: usize,
413    ) -> Result<RdmaBuffer, anyhow::Error> {
414        let mrv = self.register_mr(addr, size)?;
415
416        Ok(RdmaBuffer {
417            owner: cx.bind().clone(),
418            mr_id: mrv.id,
419            addr: mrv.rdma_addr,
420            size: mrv.size,
421            rkey: mrv.rkey,
422            lkey: mrv.lkey,
423        })
424    }
425
426    /// Deregisters a buffer from the RDMA domain.
427    ///
428    /// This function removes the specified `RdmaBuffer` from the RDMA domain,
429    /// effectively releasing the resources associated with it.
430    ///
431    /// # Arguments
432    ///
433    /// * `_this` - The context of the actor releasing the buffer.
434    /// * `buffer` - The `RdmaBuffer` to be deregistered.
435    ///
436    /// # Returns
437    ///
438    /// * `Result<(), anyhow::Error>` - On success, returns `Ok(())`. On failure, returns an error.
439    async fn release_buffer(
440        &mut self,
441        _cx: &Context<Self>,
442        buffer: RdmaBuffer,
443    ) -> Result<(), anyhow::Error> {
444        self.deregister_mr(buffer.mr_id)
445            .map_err(|e| anyhow::anyhow!("could not deregister buffer: {}", e))?;
446        Ok(())
447    }
448
449    /// Requests a queue pair for communication with a remote RDMA manager actor.
450    ///
451    /// Basic logic: if queue pair exists in map, return it; if None, create connection first.
452    ///
453    /// # Arguments
454    ///
455    /// * `cx` - The context of the actor requesting the queue pair.
456    /// * `remote` - The ActorRef of the remote RDMA manager actor to communicate with.
457    ///
458    /// # Returns
459    ///
460    /// * `Result<RdmaQueuePair, anyhow::Error>` - On success, returns the queue pair for communication.
461    ///   On failure, returns an error.
462    async fn request_queue_pair(
463        &mut self,
464        cx: &Context<Self>,
465        remote: ActorRef<RdmaManagerActor>,
466    ) -> Result<RdmaQueuePair, anyhow::Error> {
467        let remote_id = remote.actor_id().clone();
468
469        // Check if queue pair exists in map.
470        // IMPOTRANT we clone QP here, but its all simple metadata
471        // and subsequent owner will update it and return it.
472        match self.qp_map.get(&remote_id).cloned() {
473            Some(QueuePairState::Available(qp)) => {
474                // Queue pair exists and is available - return it
475                self.qp_map.insert(remote_id, QueuePairState::CheckedOut);
476                Ok(qp)
477            }
478            Some(QueuePairState::CheckedOut) => {
479                // Queue pair exists but is already checked out
480                Err(anyhow::anyhow!(
481                    "Queue pair for actor {} is already checked out",
482                    remote_id
483                ))
484            }
485            None => {
486                // Queue pair doesn't exist - need to create connection
487                let is_loopback = remote_id == cx.bind::<RdmaManagerActor>().actor_id().clone();
488
489                if is_loopback {
490                    // Loopback connection setup
491                    self.initialize_qp(cx, remote.clone()).await?;
492                    let endpoint = self.connection_info(cx, remote.clone()).await?;
493                    self.connect(cx, remote.clone(), endpoint).await?;
494                } else {
495                    // Remote connection setup
496                    self.initialize_qp(cx, remote.clone()).await?;
497                    remote.initialize_qp(cx, cx.bind().clone()).await?;
498                    let remote_endpoint = remote.connection_info(cx, cx.bind().clone()).await?;
499                    self.connect(cx, remote.clone(), remote_endpoint).await?;
500                    let local_endpoint = self.connection_info(cx, remote.clone()).await?;
501                    remote
502                        .connect(cx, cx.bind().clone(), local_endpoint)
503                        .await?;
504                }
505
506                // Now that connection is established, get the queue pair
507                match self.qp_map.get(&remote_id).cloned() {
508                    Some(QueuePairState::Available(qp)) => {
509                        self.qp_map.insert(remote_id, QueuePairState::CheckedOut);
510                        Ok(qp)
511                    }
512                    _ => Err(anyhow::anyhow!(
513                        "Failed to create connection for actor {}",
514                        remote_id
515                    )),
516                }
517            }
518        }
519    }
520
521    /// Convenience utility to create a new RdmaQueuePair.
522    ///
523    /// This function initializes a new RDMA connection with another actor if one doesn't already exist.
524    /// It creates a new RdmaQueuePair associated with the specified actor ID and adds it to the
525    /// connection map.
526    ///
527    /// # Arguments
528    ///
529    /// * `other` - The ActorRef of the remote actor to connect with
530    async fn initialize_qp(
531        &mut self,
532        _cx: &Context<Self>,
533        other: ActorRef<RdmaManagerActor>,
534    ) -> Result<bool, anyhow::Error> {
535        let key = other.actor_id().clone();
536
537        if let std::collections::hash_map::Entry::Vacant(e) = self.qp_map.entry(key) {
538            let qp = RdmaQueuePair::new(self.domain.context, self.domain.pd, self.config.clone())
539                .map_err(|e| anyhow::anyhow!("could not create RdmaQueuePair: {}", e))?;
540            e.insert(QueuePairState::Available(qp));
541            tracing::debug!("successfully created a connection with {:?}", other);
542        }
543        Ok(true)
544    }
545
546    /// Establishes a connection with another actor
547    ///
548    /// # Arguments
549    /// * `other` - The ActorRef of the actor to connect to
550    /// * `endpoint` - Connection information needed to establish the RDMA connection
551    async fn connect(
552        &mut self,
553        _cx: &Context<Self>,
554        other: ActorRef<RdmaManagerActor>,
555        endpoint: RdmaQpInfo,
556    ) -> Result<(), anyhow::Error> {
557        tracing::debug!("connecting with {:?}", other);
558        let other_id = other.actor_id().clone();
559
560        match self.qp_map.get_mut(&other_id) {
561            Some(QueuePairState::Available(qp)) => {
562                qp.connect(&endpoint)
563                    .map_err(|e| anyhow::anyhow!("could not connect to RDMA endpoint: {}", e))?;
564                Ok(())
565            }
566            Some(QueuePairState::CheckedOut) => Err(anyhow::anyhow!(
567                "Cannot connect: queue pair for actor {} is checked out",
568                other_id
569            )),
570            None => Err(anyhow::anyhow!(
571                "On connect, no connection found for actor {}",
572                other_id
573            )),
574        }
575    }
576
577    /// Gets connection information for establishing an RDMA connection
578    ///
579    /// # Arguments
580    /// * `other` - The ActorRef to get connection info for
581    ///
582    /// # Returns
583    /// * `RdmaQpInfo` - Connection information needed for the RDMA connection
584    async fn connection_info(
585        &mut self,
586        _cx: &Context<Self>,
587        other: ActorRef<RdmaManagerActor>,
588    ) -> Result<RdmaQpInfo, anyhow::Error> {
589        tracing::debug!("getting connection info with {:?}", other);
590        let other_id = other.actor_id().clone();
591
592        match self.qp_map.get_mut(&other_id) {
593            Some(QueuePairState::Available(qp)) => {
594                let connection_info = qp.get_qp_info()?;
595                Ok(connection_info)
596            }
597            Some(QueuePairState::CheckedOut) => Err(anyhow::anyhow!(
598                "Cannot get connection info: queue pair for actor {} is checked out",
599                other_id
600            )),
601            None => Err(anyhow::anyhow!(
602                "No connection found for actor {}",
603                other_id
604            )),
605        }
606    }
607
608    /// Releases a queue pair back to the HashMap
609    ///
610    /// This method returns a queue pair to the HashMap after the caller has finished
611    /// using it. This completes the request/release cycle, similar to RdmaBuffer.
612    ///
613    /// # Arguments
614    /// * `other` - The ActorRef to release queue pair for
615    /// * `qp` - The queue pair to return (ownership transferred back)
616    async fn release_queue_pair(
617        &mut self,
618        _cx: &Context<Self>,
619        other: ActorRef<RdmaManagerActor>,
620        qp: RdmaQueuePair,
621    ) -> Result<(), anyhow::Error> {
622        let remote_id = other.actor_id().clone();
623
624        // Check if the queue pair is in the expected CheckedOut state
625        match self.qp_map.get(&remote_id) {
626            Some(QueuePairState::CheckedOut) => {
627                // Restore the queue pair to Available state
628                self.qp_map
629                    .insert(remote_id.clone(), QueuePairState::Available(qp));
630                tracing::debug!("Released queue pair for actor {:?}", remote_id);
631                Ok(())
632            }
633            Some(QueuePairState::Available(_)) => Err(anyhow::anyhow!(
634                "Cannot release queue pair for actor {}: queue pair is not checked out",
635                remote_id
636            )),
637            None => Err(anyhow::anyhow!(
638                "Cannot release queue pair for actor {}: no queue pair found",
639                remote_id
640            )),
641        }
642    }
643}